# Import

In [1]:
import numpy as np
import os
import cv2
import torch
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim

# Suppress the specific warning
import csv
from datetime import datetime
from tqdm import tqdm

import torch
from torchvision import transforms
from torch import nn, optim
import torch.nn.functional as F
from torch.utils import data

from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

from utils.utils0 import *
from utils.utils1 import *
from utils.utils1 import ModelParams, DL_affine_plot, loss_extra

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')

# Stub to warn about opencv version.
if int(cv2.__version__[0]) < 3: # pragma: no cover
  print('Warning: OpenCV 3 is not installed')

image_size = 256


Device: cuda


## Cases, model parameters
- Supervised DL w/ groundtruth affine transformation parameters (MSE params, MSE, NCC images)
    - Synthetic eye
    - Synthetic shape
- Unsupervised DL (MSE, NCC images)
    - Actual eye data
    - Synthetic eye
    - Synthetic shape
- Data
    - only images
    - only heatmaps
    - images & heatmaps
- Loss function
    - MSE affine parameters
    - MSE, NCC images

    

In [2]:
model_params = ModelParams(sup=0, dataset=1, image=1, heatmaps=0, 
                           loss_image=0, num_epochs=5, learning_rate=1e-3)
model_params.print_explanation()

Model name:  dataset1_sup0_image1_heatmaps0_loss_image0
Model code:  10100_0.001_0_5_1
Model params:  {'dataset': 1, 'sup': 0, 'image': 1, 'heatmaps': 0, 'loss_image_case': 0, 'loss_image': MSELoss(), 'loss_affine': None, 'learning_rate': 0.001, 'decay_rate': 0.96, 'start_epoch': 0, 'num_epochs': 5, 'batch_size': 1, 'model_name': 'dataset1_sup0_image1_heatmaps0_loss_image0'}

Model name:  dataset1_sup0_image1_heatmaps0_loss_image0
Model code:  10100_0.001_0_5_1
Dataset used:  Synthetic eye
Supervised or unsupervised model:  Unsupervised
Image type:  Image used
Heatmaps used:  Heatmaps not used
Loss function case:  0
Loss function for image:  MSELoss()
Loss function for affine:  None
Learning rate:  0.001
Decay rate:  0.96
Start epoch:  0
Number of epochs:  5
Batch size:  1




# Models
## SuperPoint

## ImgReg Network

In [3]:
from utils.SPaffineNet3 import SP_AffineNet3

## SP ImgReg model

# Load data

In [4]:
from utils.datagen import datagen

In [5]:
# test datagen for all datasets and training and testing
# for dataset in range(4): # don't forget to change this back to 2
#     for is_train in [True, False]:
#         for sup in [False, True]:
#             print(f'dataset: {dataset}, is_train: {is_train}, sup: {sup}')
#             dataloader = datagen(dataset, is_train, sup)
            
#             if sup==1 and dataset==0:
#                 print('skipping')
#                 pass
#             else:
#                 try:
#                     print('index, source_img.shape,       target_img.shape')
#                     for i, (source_img, target_img) in enumerate(dataloader):
#                         print(i, source_img.shape, target_img.shape)
#                         if i == 2:
#                             break
#                 except ValueError:
#                     print('index, source_img.shape,       target_img.shape,            affine_params.shape')
#                     for i, batch in enumerate(dataloader):
#                         print(i, batch[0].shape, batch[1].shape, batch[2].shape)
#                         if i == 5:
#                             break
#             print('\n')

# Training
## Dataset initialization

In [6]:
train_dataset = datagen(model_params.dataset, True, model_params.sup)
test_dataset = datagen(model_params.dataset, False, model_params.sup)

# Get sample batch
print('Train set: ', [x.shape for x in next(iter(train_dataset))])
print('Test set: ', [x.shape for x in next(iter(test_dataset))])

Train set:  [torch.Size([1, 1, 256, 256]), torch.Size([1, 1, 256, 256])]
Test set:  [torch.Size([1, 1, 256, 256]), torch.Size([1, 1, 256, 256])]


## Model initialize

In [7]:
# print case
print(model_params)
model_params.print_explanation()

dataset1_sup0_image1_heatmaps0_loss_image0

Model name:  dataset1_sup0_image1_heatmaps0_loss_image0
Model code:  10100_0.001_0_5_1
Dataset used:  Synthetic eye
Supervised or unsupervised model:  Unsupervised
Image type:  Image used
Heatmaps used:  Heatmaps not used
Loss function case:  0
Loss function for image:  MSELoss()
Loss function for affine:  None
Learning rate:  0.001
Decay rate:  0.96
Start epoch:  0
Number of epochs:  5
Batch size:  1




In [8]:
model = SP_AffineNet3(model_params).to(device)
print(model)

parameters = model.parameters()
optimizer = optim.Adam(parameters, model_params.learning_rate, weight_decay=0.0001)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: model_params.decay_rate ** epoch)
# model_path = 'trained_models/10102_0.001_0_20_1_20230930-091532.pth'

# if a model is loaded, the training will continue from the epoch it was saved at
try:
    model.load_state_dict(torch.load(model_path))
    model_params.start_epoch = int(model_path.split('/')[-1].split('_')[3])
    print(f'Loaded model from {model_path}\nstarting at epoch {model_params.start_epoch}')
    if model_params.start_epoch >= model_params.num_epochs:
            model_params.num_epochs += model_params.start_epoch
except:
    model_params.start_epoch = 0
    print('No model loaded, starting from scratch')


Running new version (not run SP on source image)
-------------------------------------------------------------------------------
           Layer (type)            Input Shape         Param #     Tr. Param #
               Conv2d-1       [1, 1, 256, 256]             640             640
            LeakyReLU-2      [1, 64, 256, 256]               0               0
            GroupNorm-3      [1, 64, 256, 256]             128             128
               Conv2d-4      [1, 64, 256, 256]          16,448          16,448
               Conv2d-5      [1, 64, 128, 128]          73,856          73,856
            GroupNorm-6     [1, 128, 128, 128]             256             256
               Conv2d-7     [1, 128, 128, 128]          65,664          65,664
               Conv2d-8       [1, 128, 64, 64]         295,168         295,168
            GroupNorm-9       [1, 256, 64, 64]             512             512
              Conv2d-10       [1, 256, 64, 64]         262,400         262,400
 

## Training function

In [9]:
# Define training function
def train(model, model_params, timestamp):
    # Define loss function based on supervised or unsupervised learning
    criterion = model_params.loss_image
    extra = loss_extra()

    if model_params.sup:
        criterion_affine = nn.MSELoss()
        # TODO: add loss for points1_affine and points2, Euclidean distance

    # Define optimizer
    optimizer = optim.Adam(model.parameters(), lr=model_params.learning_rate)

    # Create empty list to store epoch number, train loss and validation loss
    epoch_loss_list = []
    running_loss_list = []
    
    # Create output directory
    output_dir = f"output/{model_params.get_model_code()}_{timestamp}"
    os.makedirs(output_dir, exist_ok=True)

    # Train model
    for epoch in range(model_params.start_epoch, model_params.num_epochs):
        # Set model to training mode
        model.train()

        # Zero the parameter gradients
        optimizer.zero_grad()
        
        running_loss = 0.0
        train_bar = tqdm(train_dataset, desc=f'Training Epoch {epoch+1}/{model_params.num_epochs}')
        for i, data in enumerate(train_bar):
            # Get images and affine parameters
            if model_params.sup:
                source_image, target_image, affine_params_true = data
            else:
                source_image, target_image = data
                affine_params_true = None
            source_image = source_image.to(device)
            target_image = target_image.to(device)

            # Forward + backward + optimize
            outputs = model(source_image, target_image)
            # for i in range(len(outputs)):
            #         print(i, outputs[i].shape)
            # 0 torch.Size([1, 1, 256, 256])
            # 1 torch.Size([1, 2, 3])
            # 2 (2, 4)
            # 3 (2, 4)
            # 4 (1, 4, 2)
            # 5 (256, 9)
            # 6 (256, 16)
            # 7 (256, 256)
            # 8 (256, 256)
            transformed_source_affine = outputs[0] # image
            affine_params_predicted = outputs[1] # affine parameters
            points1 = outputs[2]
            points2 = outputs[3]
            points1_affine = np.array(outputs[4])

            # print(f"affine_params_true: {affine_params_true}")
            # print(f"affine_params_predicted: {affine_params_predicted}\n")

            try:
                points1_affine = points1_affine.reshape(points1_affine.shape[2], points1_affine.shape[1])
            except:
                pass
            desc1 = outputs[5]
            desc2 = outputs[6]
            heatmap1 = outputs[7]
            heatmap2 = outputs[8]

            loss = criterion(transformed_source_affine, target_image)
            loss += extra(affine_params_predicted)
            if model_params.sup:
                loss_affine = criterion_affine(affine_params_true.view(1, 2, 3), affine_params_predicted.cpu())
                # TODO: add loss for points1_affine and points2, Euclidean distance
                # loss_points = criterion_points(points1_affine, points2)
                loss += loss_affine
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            # Plot images if i < 5
            if i % 100 == 0:
                DL_affine_plot(f"epoch{epoch+1}_train", output_dir,
                    f"{i}", "_", source_image[0, 0, :, :].detach().cpu().numpy(), 
                    target_image[0, 0, :, :].detach().cpu().numpy(), 
                    transformed_source_affine[0, 0, :, :].detach().cpu().numpy(),
                    points1, points2, points1_affine, desc1, desc2, affine_params_true=affine_params_true,
                    affine_params_predict=affine_params_predicted, heatmap1=heatmap1, heatmap2=heatmap2, plot=True)

            # Print statistics
            running_loss += loss.item()
            running_loss_list.append([epoch+((i+1)/len(train_dataset)), loss.item()])
            train_bar.set_postfix({'loss': running_loss / (i+1)})
        print(f'Training Epoch {epoch+1}/{model_params.num_epochs} loss: {running_loss / len(train_dataset)}')

        # Validate model
        validation_loss = 0.0
        model.eval()
        # with torch.no_grad():
        for i, data in enumerate(test_dataset, 0):
            # Get images and affine parameters
            if model_params.sup:
                source_image, target_image, affine_params_true = data
            else:
                source_image, target_image = data
                affine_params_true = None
            source_image = source_image.to(device)
            target_image = target_image.to(device)

            # Forward pass
            outputs = model(source_image, target_image)
            # for i in range(len(outputs)):
            #     print(i, outputs[i].shape)
            transformed_source_affine = outputs[0]
            affine_params_predicted = outputs[1]
            points1 = outputs[2]
            points2 = outputs[3]
            points1_affine = np.array(outputs[4])
            try:
                points1_affine = points1_affine.reshape(points1_affine.shape[2], points1_affine.shape[1])
            except:
                pass
            desc1 = outputs[5]
            desc2 = outputs[6]
            heatmap1 = outputs[7]
            heatmap2 = outputs[8]

            loss = criterion(transformed_source_affine, target_image)
            loss += extra(affine_params_predicted)
            if model_params.sup:
                loss_affine = criterion_affine(affine_params_true.view(1, 2, 3), affine_params_predicted.cpu())
                # TODO: add loss for points1_affine and points2, Euclidean distance
                # loss_points = criterion_points(points1_affine, points2)
                loss += loss_affine

            # Add to validation loss
            validation_loss += loss.item()

            if i % 50 == 0:
                DL_affine_plot(f"epoch{epoch+1}_valid", output_dir,
                    f"{i}", "_", source_image[0, 0, :, :].detach().cpu().numpy(), 
                    target_image[0, 0, :, :].detach().cpu().numpy(), 
                    transformed_source_affine[0, 0, :, :].detach().cpu().numpy(),
                    points1, points2, points1_affine, desc1, desc2, affine_params_true=affine_params_true,
                    affine_params_predict=affine_params_predicted, heatmap1=heatmap1, heatmap2=heatmap2, plot=True)

        # Print validation statistics
        validation_loss /= len(test_dataset)
        print(f'Validation Epoch {epoch+1}/{model_params.num_epochs} loss: {validation_loss}')

        # Append epoch number, train loss and validation loss to epoch_loss_list
        epoch_loss_list.append([epoch+1, running_loss / len(train_dataset), validation_loss])

        
        # Extract epoch number, train loss and validation loss from epoch_loss_list
        epoch = [x[0] for x in epoch_loss_list]
        train_loss = [x[1] for x in epoch_loss_list]
        val_loss = [x[2] for x in epoch_loss_list]
        step = [x[0] for x in running_loss_list]
        running_train_loss = [x[1] for x in running_loss_list]

        save_plot_name = f"{output_dir}/loss_{model_params.get_model_code()}_epoch{model_params.num_epochs}_{timestamp}.png"

        # Plot train loss and validation loss against epoch number
        plt.figure()
        plt.plot(step, running_train_loss, label='Running Train Loss', alpha=0.3)
        plt.plot(epoch, train_loss, label='Train Loss')
        plt.plot(epoch, val_loss, label='Validation Loss')
        plt.title('Train and Validation Loss')
        plt.legend()
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.yscale('log')
        plt.tight_layout()
        plt.savefig(save_plot_name)
        plt.close()
        
    print('Finished Training')

    # delete all txt files in output_dir
    for file in os.listdir(output_dir):
        if file.endswith(".txt"):
            os.remove(os.path.join(output_dir, file))

    # Return epoch_loss_list
    return epoch_loss_list


## working code

In [10]:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
loss_list = train(model, model_params, timestamp)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  image = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
  image = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
Training Epoch 1/5: 100%|██████████| 461/461 [01:31<00:00,  5.02it/s, loss=4.24e+3]


Training Epoch 1/5 loss: 4238.731186757789
Validation Epoch 1/5 loss: 5073.507404852351


Training Epoch 2/5: 100%|██████████| 461/461 [01:28<00:00,  5.19it/s, loss=5.74e+3]


Training Epoch 2/5 loss: 5739.813438970935
Validation Epoch 2/5 loss: 6000.112492832569


Training Epoch 3/5: 100%|██████████| 461/461 [01:36<00:00,  4.78it/s, loss=5.73e+3]


Training Epoch 3/5 loss: 5726.798460328923
Validation Epoch 3/5 loss: 6000.112313646789


Training Epoch 4/5: 100%|██████████| 461/461 [01:29<00:00,  5.15it/s, loss=5.75e+3]


Training Epoch 4/5 loss: 5746.321364394626
Validation Epoch 4/5 loss: 5935.892009210149


Training Epoch 5/5: 100%|██████████| 461/461 [01:26<00:00,  5.35it/s, loss=5.77e+3]


Training Epoch 5/5 loss: 5772.350731252234
Validation Epoch 5/5 loss: 6000.113760571961
Finished Training


In [11]:
print("Training output:")
for i in range(len(loss_list)):
    print(loss_list[i])

Training output:
[1, 4238.731186757789, 5073.507404852351]
[2, 5739.813438970935, 6000.112492832569]
[3, 5726.798460328923, 6000.112313646789]
[4, 5746.321364394626, 5935.892009210149]
[5, 5772.350731252234, 6000.113760571961]


## Model saving

In [12]:
model_save_path = "trained_models/"
model_name_to_save = model_save_path + f"model3_{model_params.get_model_code()}_{timestamp}.pth"
print(model_name_to_save)
torch.save(model.state_dict(), model_name_to_save)


trained_models/model3_10100_0.001_0_5_1_20231103-094248.pth


# Test model (loading and inference)

Save results and export metrics to csv

In [13]:
# model = SPmodel = SP_AffineNet().to(device)
# print(model)

# parameters = model.parameters()
# optimizer = optim.Adam(parameters, model_params.learning_rate)
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: model_params.decay_rate ** epoch)

# model.load_state_dict(torch.load(model_name_to_save))

In [14]:
def test(model, model_params, timestamp):
    # Set model to training mode
    model.eval()

    # Create output directory
    output_dir = f"output/{model_params.get_model_code()}_{timestamp}_test"
    os.makedirs(output_dir, exist_ok=True)

    # Validate model
    # validation_loss = 0.0

    # create a csv file to store the metrics
    csv_file = f"{output_dir}/metrics.csv"
    with open(csv_file, 'w', newline='') as file:
        writer = csv.writer(file)
        # matches1_transformed.shape[-1], mse_before, mse12, tre_before, tre12, \
        # mse12_image, ssim12_image, 
        writer.writerow(["index", "mse_before", "mse12", "tre_before", "tre12", "mse12_image_before", "mse12_image", "ssim12_image_before", "ssim12_image"])

    with torch.no_grad():
        testbar = tqdm(test_dataset, desc=f'Testing:')
        for i, data in enumerate(testbar, 0):
            # Get images and affine parameters
            if model_params.sup:
                source_image, target_image, affine_params_true = data
            else:
                source_image, target_image = data
            source_image = source_image.to(device)
            target_image = target_image.to(device)

            # Forward pass
            outputs = model(source_image, target_image)
            # for i in range(len(outputs)):
            #     print(i, outputs[i].shape)
            transformed_source_affine = outputs[0]
            affine_params_predicted = outputs[1]
            points1 = outputs[2]
            points2 = outputs[3]
            points1_affine = np.array(outputs[4])
            try:
                points1_affine = points1_affine.reshape(points1_affine.shape[2], points1_affine.shape[1])
            except:
                pass
            desc1 = outputs[5]
            desc2 = outputs[6]
            heatmap1 = outputs[7]
            heatmap2 = outputs[8]

            if i < 50:
                plot_ = True
            else:
                plot_ = False

            results = DL_affine_plot(f"{i+1}", output_dir,
                f"{i}", "_", source_image[0, 0, :, :].cpu().numpy(), target_image[0, 0, :, :].cpu().numpy(), \
                transformed_source_affine[0, 0, :, :].cpu().numpy(), \
                points1, points2, points1_affine, desc1, desc2, affine_params_true=affine_params_true,
                affine_params_predict=affine_params_predicted, heatmap1=heatmap1, heatmap2=heatmap2, plot=plot_)

            # calculate metrics
            # matches1_transformed = results[0]
            mse_before = results[1]
            mse12 = results[2]
            tre_before = results[3]
            tre12 = results[4]
            mse12_image_before = results[5]
            mse12_image = results[6]
            ssim12_image_before = results[7]
            ssim12_image = results[8]

            # write metrics to csv file
            with open(csv_file, 'a', newline='') as file:
                writer = csv.writer(file) # TODO: might need to export true & predicted affine parameters too
                writer.writerow([i, mse_before, mse12, tre_before, tre12, mse12_image_before, mse12_image, ssim12_image_before, ssim12_image])

    # delete all txt files in output_dir
    for file in os.listdir(output_dir):
        if file.endswith(".txt"):
            os.remove(os.path.join(output_dir, file))

# timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
metrics = test(model, model_params, timestamp)

Testing::   0%|          | 0/109 [00:00<?, ?it/s]


UnboundLocalError: local variable 'affine_params_true' referenced before assignment