# Import

In [44]:
import numpy as np
import os
import cv2
import torch
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim

# Suppress the specific warning
import csv
from datetime import datetime
from tqdm import tqdm

import torch
from torchvision import transforms
from torch import nn, optim
import torch.nn.functional as F
from torch.utils import data

from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

from utils.utils0 import *
from utils.utils1 import *
from utils.utils1 import ModelParams, DL_affine_plot, loss_extra

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')

# Stub to warn about opencv version.
if int(cv2.__version__[0]) < 3: # pragma: no cover
  print('Warning: OpenCV 3 is not installed')

image_size = 256


Device: cuda


## Cases, model parameters
- Supervised DL w/ groundtruth affine transformation parameters (MSE params, MSE, NCC images)
    - Synthetic eye
    - Synthetic shape
- Unsupervised DL (MSE, NCC images)
    - Actual eye data
    - Synthetic eye
    - Synthetic shape
- Data
    - only images
    - only heatmaps
    - images & heatmaps
- Loss function
    - MSE affine parameters
    - MSE, NCC images

    

In [45]:
model_params = ModelParams(sup=1, dataset=2, image=1, heatmaps=0, 
                           loss_image=1, num_epochs=10, learning_rate=1e-2)
model_params.print_explanation()

Model name:  dataset2_sup1_image1_heatmaps0_loss_image1
Model code:  21101_0.01_0_10_1
Model params:  {'dataset': 2, 'sup': 1, 'image': 1, 'heatmaps': 0, 'loss_image_case': 1, 'loss_image': NCC(), 'loss_affine': <utils.utils1.loss_affine object at 0x7fc07bbe73d0>, 'learning_rate': 0.01, 'decay_rate': 0.96, 'start_epoch': 0, 'num_epochs': 10, 'batch_size': 1, 'model_name': 'dataset2_sup1_image1_heatmaps0_loss_image1'}

Model name:  dataset2_sup1_image1_heatmaps0_loss_image1
Model code:  21101_0.01_0_10_1
Dataset used:  Synthetic shape
Supervised or unsupervised model:  Supervised
Image type:  Image used
Heatmaps used:  Heatmaps not used
Loss function case:  1
Loss function for image:  NCC()
Loss function for affine:  <utils.utils1.loss_affine object at 0x7fc07bbe73d0>
Learning rate:  0.01
Decay rate:  0.96
Start epoch:  0
Number of epochs:  10
Batch size:  1




# Models
## SuperPoint

## ImgReg Network

In [46]:
from utils.SPaffineNet3 import SP_AffineNet3

## SP ImgReg model

# Load data

In [47]:
from utils.datagen import datagen

In [48]:
# test datagen for all datasets and training and testing
# for dataset in range(4): # don't forget to change this back to 2
#     for is_train in [True, False]:
#         for sup in [False, True]:
#             print(f'dataset: {dataset}, is_train: {is_train}, sup: {sup}')
#             dataloader = datagen(dataset, is_train, sup)
            
#             if sup==1 and dataset==0:
#                 print('skipping')
#                 pass
#             else:
#                 try:
#                     print('index, source_img.shape,       target_img.shape')
#                     for i, (source_img, target_img) in enumerate(dataloader):
#                         print(i, source_img.shape, target_img.shape)
#                         if i == 2:
#                             break
#                 except ValueError:
#                     print('index, source_img.shape,       target_img.shape,            affine_params.shape')
#                     for i, batch in enumerate(dataloader):
#                         print(i, batch[0].shape, batch[1].shape, batch[2].shape)
#                         if i == 5:
#                             break
#             print('\n')

# Training
## Dataset initialization

In [49]:
train_dataset = datagen(model_params.dataset, True, model_params.sup)
test_dataset = datagen(model_params.dataset, False, model_params.sup)

# Get sample batch
print('Train set: ', [x.shape for x in next(iter(train_dataset))])
print('Test set: ', [x.shape for x in next(iter(test_dataset))])

Train set:  [torch.Size([1, 1, 256, 256]), torch.Size([1, 1, 256, 256]), torch.Size([1, 2, 3])]
Test set:  [torch.Size([1, 1, 256, 256]), torch.Size([1, 1, 256, 256]), torch.Size([1, 2, 3])]


## Model initialize

In [50]:
# print case
print(model_params)
model_params.print_explanation()

dataset2_sup1_image1_heatmaps0_loss_image1

Model name:  dataset2_sup1_image1_heatmaps0_loss_image1
Model code:  21101_0.01_0_10_1
Dataset used:  Synthetic shape
Supervised or unsupervised model:  Supervised
Image type:  Image used
Heatmaps used:  Heatmaps not used
Loss function case:  1
Loss function for image:  NCC()
Loss function for affine:  <utils.utils1.loss_affine object at 0x7fc07bbe73d0>
Learning rate:  0.01
Decay rate:  0.96
Start epoch:  0
Number of epochs:  10
Batch size:  1




In [51]:
model = SP_AffineNet3(model_params).to(device)
print(model)

parameters = model.parameters()
optimizer = optim.Adam(parameters, model_params.learning_rate, weight_decay=0.0001)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: model_params.decay_rate ** epoch)
# model_path = 'trained_models/10102_0.001_0_20_1_20230930-091532.pth'

# if a model is loaded, the training will continue from the epoch it was saved at
try:
    model.load_state_dict(torch.load(model_path))
    model_params.start_epoch = int(model_path.split('/')[-1].split('_')[3])
    print(f'Loaded model from {model_path}\nstarting at epoch {model_params.start_epoch}')
    if model_params.start_epoch >= model_params.num_epochs:
            model_params.num_epochs += model_params.start_epoch
except:
    model_params.start_epoch = 0
    print('No model loaded, starting from scratch')


Running new version (not run SP on source image)
SP_AffineNet3(
  (affineNet): AffineNet3(
    (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv1s): Conv2d(64, 64, kernel_size=(2, 2), stride=(2, 2))
    (conv2s): Conv2d(128, 128, kernel_size=(2, 2), stride=(2, 2))
    (conv3s): Conv2d(256, 256, kernel_size=(2, 2), stride=(2, 2))
    (fc1): Linear(in_features=512, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=64, bias=True)
    (fc3): Linear(in_features=64, out_features=6, bias=True)
    (dropout): Dropout(p=0.7, inplace=False)
    (aPooling): AdaptiveAvgPool2d(output_size=(1, 1))
    (ReLU): LeakyReLU(negative_slope=0.01)
    (Act1): GroupNorm(32, 64, eps=1e-05, affine=True)
    (Act2): GroupNorm(64, 128, eps=1e-05, affine=True)
    (Act3): GroupNorm(128, 256, ep

## Training function

In [52]:
# Define training function
def train(model, model_params, timestamp):
    # Define loss function based on supervised or unsupervised learning
    criterion = model_params.loss_image
    extra = loss_extra()

    if model_params.sup:
        criterion_affine = nn.MSELoss()
        # TODO: add loss for points1_affine and points2, Euclidean distance

    # Define optimizer
    optimizer = optim.Adam(model.parameters(), lr=model_params.learning_rate)

    # Create empty list to store epoch number, train loss and validation loss
    epoch_loss_list = []
    running_loss_list = []
    
    # Create output directory
    output_dir = f"output/{model_params.get_model_code()}_{timestamp}"
    os.makedirs(output_dir, exist_ok=True)

    # Train model
    for epoch in range(model_params.start_epoch, model_params.num_epochs):
        # Set model to training mode
        model.train()

        # Zero the parameter gradients
        optimizer.zero_grad()
        
        running_loss = 0.0
        train_bar = tqdm(train_dataset, desc=f'Training Epoch {epoch+1}/{model_params.num_epochs}')
        for i, data in enumerate(train_bar):
            # Get images and affine parameters
            if model_params.sup:
                source_image, target_image, affine_params_true = data
            else:
                source_image, target_image = data
                affine_params_true = None
            source_image = source_image.to(device)
            target_image = target_image.to(device)

            # Forward + backward + optimize
            outputs = model(source_image, target_image)
            # for i in range(len(outputs)):
            #         print(i, outputs[i].shape)
            # 0 torch.Size([1, 1, 256, 256])
            # 1 torch.Size([1, 2, 3])
            # 2 (2, 4)
            # 3 (2, 4)
            # 4 (1, 4, 2)
            # 5 (256, 9)
            # 6 (256, 16)
            # 7 (256, 256)
            # 8 (256, 256)
            transformed_source_affine = outputs[0] # image
            affine_params_predicted = outputs[1] # affine parameters
            points1 = outputs[2]
            points2 = outputs[3]
            points1_affine = np.array(outputs[4])

            # print(f"affine_params_true: {affine_params_true}")
            # print(f"affine_params_predicted: {affine_params_predicted}\n")

            try:
                points1_affine = points1_affine.reshape(points1_affine.shape[2], points1_affine.shape[1])
            except:
                pass
            desc1 = outputs[5]
            desc2 = outputs[6]
            heatmap1 = outputs[7]
            heatmap2 = outputs[8]

            loss = criterion(transformed_source_affine, target_image)
            loss += extra(affine_params_predicted)
            if model_params.sup:
                loss_affine = criterion_affine(affine_params_true.view(1, 2, 3), affine_params_predicted.cpu())
                # TODO: add loss for points1_affine and points2, Euclidean distance
                # loss_points = criterion_points(points1_affine, points2)
                loss += loss_affine
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            # Plot images if i < 5
            if i < 5:
                DL_affine_plot(f"epoch{epoch+1}_train", output_dir,
                    f"{i}", "_", source_image[0, 0, :, :].detach().cpu().numpy(), 
                    target_image[0, 0, :, :].detach().cpu().numpy(), 
                    transformed_source_affine[0, 0, :, :].detach().cpu().numpy(),
                    points1, points2, points1_affine, desc1, desc2, affine_params_true=affine_params_true,
                    affine_params_predict=affine_params_predicted, heatmap1=heatmap1, heatmap2=heatmap2, plot=True)

            # Print statistics
            running_loss += loss.item()
            running_loss_list.append([epoch+((i+1)/len(train_dataset)), loss.item()])
            train_bar.set_postfix({'loss': running_loss / (i+1)})
        print(f'Training Epoch {epoch+1}/{model_params.num_epochs} loss: {running_loss / len(train_dataset)}')

        # Validate model
        validation_loss = 0.0
        model.eval()
        # with torch.no_grad():
        for i, data in enumerate(test_dataset, 0):
            # Get images and affine parameters
            if model_params.sup:
                source_image, target_image, affine_params_true = data
            else:
                source_image, target_image = data
                affine_params_true = None
            source_image = source_image.to(device)
            target_image = target_image.to(device)

            # Forward pass
            outputs = model(source_image, target_image)
            # for i in range(len(outputs)):
            #     print(i, outputs[i].shape)
            transformed_source_affine = outputs[0]
            affine_params_predicted = outputs[1]
            points1 = outputs[2]
            points2 = outputs[3]
            points1_affine = np.array(outputs[4])
            try:
                points1_affine = points1_affine.reshape(points1_affine.shape[2], points1_affine.shape[1])
            except:
                pass
            desc1 = outputs[5]
            desc2 = outputs[6]
            heatmap1 = outputs[7]
            heatmap2 = outputs[8]

            loss = criterion(transformed_source_affine, target_image)
            loss += extra(affine_params_predicted)
            if model_params.sup:
                loss_affine = criterion_affine(affine_params_true.view(1, 2, 3), affine_params_predicted.cpu())
                # TODO: add loss for points1_affine and points2, Euclidean distance
                # loss_points = criterion_points(points1_affine, points2)
                loss += loss_affine

            # Add to validation loss
            validation_loss += loss.item()

            # Plot images if i < 5
            if i < 5:
                DL_affine_plot(f"epoch{epoch+1}_valid", output_dir,
                    f"{i}", "_", source_image[0, 0, :, :].detach().cpu().numpy(), 
                    target_image[0, 0, :, :].detach().cpu().numpy(), 
                    transformed_source_affine[0, 0, :, :].detach().cpu().numpy(),
                    points1, points2, points1_affine, desc1, desc2, affine_params_true=affine_params_true,
                    affine_params_predict=affine_params_predicted, heatmap1=heatmap1, heatmap2=heatmap2, plot=True)

        # Print validation statistics
        validation_loss /= len(test_dataset)
        print(f'Validation Epoch {epoch+1}/{model_params.num_epochs} loss: {validation_loss}')

        # Append epoch number, train loss and validation loss to epoch_loss_list
        epoch_loss_list.append([epoch+1, running_loss / len(train_dataset), validation_loss])

        
        # Extract epoch number, train loss and validation loss from epoch_loss_list
        epoch = [x[0] for x in epoch_loss_list]
        train_loss = [x[1] for x in epoch_loss_list]
        val_loss = [x[2] for x in epoch_loss_list]
        step = [x[0] for x in running_loss_list]
        running_train_loss = [x[1] for x in running_loss_list]

        save_plot_name = f"{output_dir}/loss_{model_params.get_model_code()}_epoch{model_params.num_epochs}_{timestamp}.png"

        # Plot train loss and validation loss against epoch number
        plt.figure()
        plt.plot(step, running_train_loss, label='Running Train Loss', alpha=0.3)
        plt.plot(epoch, train_loss, label='Train Loss')
        plt.plot(epoch, val_loss, label='Validation Loss')
        plt.title('Train and Validation Loss')
        plt.legend()
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.yscale('log')
        plt.tight_layout()
        plt.savefig(save_plot_name)
        plt.close()
        
    print('Finished Training')

    # delete all txt files in output_dir
    for file in os.listdir(output_dir):
        if file.endswith(".txt"):
            os.remove(os.path.join(output_dir, file))

    # Return epoch_loss_list
    return epoch_loss_list


## working code

In [53]:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
loss_list = train(model, model_params, timestamp)

  torch.tensor(M).view(1, 2, 3))
  image = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
  image = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
Training Epoch 1/10: 100%|██████████| 360/360 [00:27<00:00, 13.03it/s, loss=1.16e+7]


Training Epoch 1/10 loss: 11598290.223731168
Validation Epoch 1/10 loss: 239009.353515625


Training Epoch 2/10: 100%|██████████| 360/360 [00:27<00:00, 12.92it/s, loss=2.22e+9]


Training Epoch 2/10 loss: 2224853848.2828126
Validation Epoch 2/10 loss: 32503087.55


Training Epoch 3/10: 100%|██████████| 360/360 [00:27<00:00, 13.24it/s, loss=1.94e+10]


Training Epoch 3/10 loss: 19364786733.877777
Validation Epoch 3/10 loss: 195047966.4


Training Epoch 4/10: 100%|██████████| 360/360 [00:28<00:00, 12.80it/s, loss=1.02e+10]


Training Epoch 4/10 loss: 10173419981.022223
Validation Epoch 4/10 loss: 135035695.2


Training Epoch 5/10:  35%|███▌      | 126/360 [00:13<00:25,  9.31it/s, loss=6.46e+9]


KeyboardInterrupt: 

In [None]:
print("Training output:")
for i in range(len(loss_list)):
    print(loss_list[i])

Training output:
[1, 0.24539990491337246, 0.22000780366361142]
[2, 0.24089633435424831, 0.21959682665765284]
[3, 0.24925384295897351, 0.21985937096178532]
[4, 0.26845063306391237, 0.21944834664463997]
[5, 0.2517348497692082, 0.21950501017272472]
[6, 0.24793884151925644, 0.21988297253847122]
[7, 0.25682318293386036, 0.21949527338147162]
[8, 0.24465472961051596, 0.21963326036930084]
[9, 0.247489680411915, 0.21973580308258533]
[10, 0.23891730540328557, 0.21962998770177364]


## Model saving

In [None]:
model_save_path = "trained_models/"
model_name_to_save = model_save_path + f"model3_{model_params.get_model_code()}_{timestamp}.pth"
print(model_name_to_save)
torch.save(model.state_dict(), model_name_to_save)


trained_models/model3_21101_0.01_0_10_1_20231018-225701.pth


# Test model (loading and inference)

Save results and export metrics to csv

In [None]:
# model = SPmodel = SP_AffineNet().to(device)
# print(model)

# parameters = model.parameters()
# optimizer = optim.Adam(parameters, model_params.learning_rate)
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: model_params.decay_rate ** epoch)

# model.load_state_dict(torch.load(model_name_to_save))

In [None]:
def test(model, model_params, timestamp):
    # Set model to training mode
    model.eval()

    # Create output directory
    output_dir = f"output/{model_params.get_model_code()}_{timestamp}_test"
    os.makedirs(output_dir, exist_ok=True)

    # Validate model
    # validation_loss = 0.0

    # create a csv file to store the metrics
    csv_file = f"{output_dir}/metrics.csv"
    with open(csv_file, 'w', newline='') as file:
        writer = csv.writer(file)
        # matches1_transformed.shape[-1], mse_before, mse12, tre_before, tre12, \
        # mse12_image, ssim12_image, 
        writer.writerow(["index", "mse_before", "mse12", "tre_before", "tre12", "mse12_image_before", "mse12_image", "ssim12_image_before", "ssim12_image"])

    with torch.no_grad():
        testbar = tqdm(test_dataset, desc=f'Testing:')
        for i, data in enumerate(testbar, 0):
            # Get images and affine parameters
            if model_params.sup:
                source_image, target_image, affine_params_true = data
            else:
                source_image, target_image = data
            source_image = source_image.to(device)
            target_image = target_image.to(device)

            # Forward pass
            outputs = model(source_image, target_image)
            # for i in range(len(outputs)):
            #     print(i, outputs[i].shape)
            transformed_source_affine = outputs[0]
            affine_params_predicted = outputs[1]
            points1 = outputs[2]
            points2 = outputs[3]
            points1_affine = np.array(outputs[4])
            try:
                points1_affine = points1_affine.reshape(points1_affine.shape[2], points1_affine.shape[1])
            except:
                pass
            desc1 = outputs[5]
            desc2 = outputs[6]
            heatmap1 = outputs[7]
            heatmap2 = outputs[8]

            if i < 50:
                plot_ = True
            else:
                plot_ = False

            results = DL_affine_plot(f"{i+1}", output_dir,
                f"{i}", "_", source_image[0, 0, :, :].cpu().numpy(), target_image[0, 0, :, :].cpu().numpy(), \
                transformed_source_affine[0, 0, :, :].cpu().numpy(), \
                points1, points2, points1_affine, desc1, desc2, affine_params_true=affine_params_true,
                affine_params_predict=affine_params_predicted, heatmap1=heatmap1, heatmap2=heatmap2, plot=plot_)

            # calculate metrics
            # matches1_transformed = results[0]
            mse_before = results[1]
            mse12 = results[2]
            tre_before = results[3]
            tre12 = results[4]
            mse12_image_before = results[5]
            mse12_image = results[6]
            ssim12_image_before = results[7]
            ssim12_image = results[8]

            # write metrics to csv file
            with open(csv_file, 'a', newline='') as file:
                writer = csv.writer(file) # TODO: might need to export true & predicted affine parameters too
                writer.writerow([i, mse_before, mse12, tre_before, tre12, mse12_image_before, mse12_image, ssim12_image_before, ssim12_image])

    # delete all txt files in output_dir
    for file in os.listdir(output_dir):
        if file.endswith(".txt"):
            os.remove(os.path.join(output_dir, file))

# timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
metrics = test(model, model_params, timestamp)

Testing:: 100%|██████████| 40/40 [00:53<00:00,  1.35s/it]
