# Import

In [11]:
import numpy as np
import os
import cv2
import torch
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim

# Suppress the specific warning
import csv
from datetime import datetime
from tqdm import tqdm

import torch
from torchvision import transforms
from torch import nn, optim
import torch.nn.functional as F
from torch.utils import data

from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

from utils.utils0 import *
from utils.utils1 import *
from utils.utils1 import ModelParams, DL_affine_plot, loss_extra

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')

# Stub to warn about opencv version.
if int(cv2.__version__[0]) < 3: # pragma: no cover
  print('Warning: OpenCV 3 is not installed')

image_size = 256


Device: cuda


## Cases, model parameters
- Supervised DL w/ groundtruth affine transformation parameters (MSE params, MSE, NCC images)
    - Synthetic eye
    - Synthetic shape
- Unsupervised DL (MSE, NCC images)
    - Actual eye data
    - Synthetic eye
    - Synthetic shape
- Data
    - only images
    - only heatmaps
    - images & heatmaps
- Loss function
    - MSE affine parameters
    - MSE, NCC images

    

In [12]:
model_params = ModelParams(sup=1, dataset=1, image=1, heatmaps=0, 
                           loss_image=0, num_epochs=20, learning_rate=1e-3)
model_params.print_explanation()

Model name:  dataset1_sup1_image1_heatmaps0_loss_image0
Model code:  11100_0.001_0_20_1
Model params:  {'dataset': 1, 'sup': 1, 'image': 1, 'heatmaps': 0, 'loss_image_case': 0, 'loss_image': MSELoss(), 'loss_affine': <utils.utils1.loss_affine object at 0x7fe53458ac10>, 'learning_rate': 0.001, 'decay_rate': 0.96, 'start_epoch': 0, 'num_epochs': 20, 'batch_size': 1, 'model_name': 'dataset1_sup1_image1_heatmaps0_loss_image0'}

Model name:  dataset1_sup1_image1_heatmaps0_loss_image0
Model code:  11100_0.001_0_20_1
Dataset used:  Synthetic eye
Supervised or unsupervised model:  Supervised
Image type:  Image used
Heatmaps used:  Heatmaps not used
Loss function case:  0
Loss function for image:  MSELoss()
Loss function for affine:  <utils.utils1.loss_affine object at 0x7fe53458ac10>
Learning rate:  0.001
Decay rate:  0.96
Start epoch:  0
Number of epochs:  20
Batch size:  1




# Models
## SuperPoint

## ImgReg Network

In [13]:
from networks import affine_network_simple as an
import torch.nn.functional as F
import numpy as np
from utils.SuperPoint import SuperPointFrontend
from utils.utils0 import *
from utils.utils1 import *
from utils.utils1 import transform_points_DVF

class SP_DHR_Net(nn.Module):
    def __init__(self, model_params):
        super(SP_DHR_Net, self).__init__()
        self.superpoint = SuperPointFrontend('utils/superpoint_v1.pth', nms_dist=4,
                          conf_thresh=0.015, nn_thresh=0.7, cuda=True)
        self.affineNet = an.load_network(device)
        self.nn_thresh = 0.7
        self.model_params = model_params
        print("\nRunning new version (not run SP on source image)")

    def forward(self, source_image, target_image):
        # source_image = source_image.to(device)
        # target_image = target_image.to(device)

        # print('source_image: ', source_image.shape)
        # print('target_image: ', target_image.shape)
        points1, desc1, heatmap1 = self.superpoint(source_image[0, 0, :, :].cpu().numpy())
        points2, desc2, heatmap2 = self.superpoint(target_image[0, 0, :, :].cpu().numpy())

        if self.model_params.heatmaps == 0:
            affine_params = self.affineNet(source_image, target_image)
        elif self.model_params.heatmaps == 1:
            print("This part is not yet implemented.")
            # affine_params = self.affineNet(source_image, target_image, heatmap1, heatmap2)

        # transform the source image using the affine parameters
        # using F.affine_grid and F.grid_sample
        transformed_source_affine = tensor_affine_transform(source_image, affine_params)
        points1_2, desc1_2, heatmap1_2 = self.superpoint(transformed_source_affine[0, 0, :, :].detach().cpu().numpy())

        # match the points between the two images
        tracker = PointTracker(5, nn_thresh=0.7)
        try:
            matches = tracker.nn_match_two_way(desc1, desc2, nn_thresh=self.nn_thresh)
        except:
            # print('No matches found')
            # TODO: find a better way to do this
            try:
                while matches.shape[1] < 3 and self.nn_thresh > 0.1:
                    self.nn_thresh = self.nn_thresh - 0.1
                    matches = tracker.nn_match_two_way(desc1, desc2, nn_thresh=self.nn_thresh)
            except:
                return transformed_source_affine, affine_params, [], [], [], [], [], [], []

        # take the elements from points1 and points2 using the matches as indices
        matches1 = np.array(points1[:2, matches[0, :].astype(int)])
        matches2 = np.array(points2[:2, matches[1, :].astype(int)])
        # matches1_2 = np.array(points1_2[:2, matches[0, :].astype(int)])
        # print('matches1', matches1)
        # print('matches2', matches2)
        # print('matches1_2', matches1_2)

        # try:
        #     matches1_2 = points1_2[:2, matches[0, :].astype(int)]
        # except:
        # print(affine_params.cpu().detach().shape, transformed_source_affine.shape)
        matches1_2 = transform_points_DVF(torch.tensor(matches1), 
                        affine_params.cpu().detach(), transformed_source_affine)

        # transform the points using the affine parameters
        # matches1_transformed = transform_points(matches1.T[None, :, :], affine_params.cpu().detach())
        return transformed_source_affine, affine_params, matches1, matches2, matches1_2, \
            desc1_2, desc2, heatmap1_2, heatmap2

## SP ImgReg model

# Load data

In [14]:
from utils.datagen import datagen

In [15]:
# test datagen for all datasets and training and testing
# for dataset in range(4): # don't forget to change this back to 2
#     for is_train in [True, False]:
#         for sup in [False, True]:
#             print(f'dataset: {dataset}, is_train: {is_train}, sup: {sup}')
#             dataloader = datagen(dataset, is_train, sup)
            
#             if sup==1 and dataset==0:
#                 print('skipping')
#                 pass
#             else:
#                 try:
#                     print('index, source_img.shape,       target_img.shape')
#                     for i, (source_img, target_img) in enumerate(dataloader):
#                         print(i, source_img.shape, target_img.shape)
#                         if i == 2:
#                             break
#                 except ValueError:
#                     print('index, source_img.shape,       target_img.shape,            affine_params.shape')
#                     for i, batch in enumerate(dataloader):
#                         print(i, batch[0].shape, batch[1].shape, batch[2].shape)
#                         if i == 5:
#                             break
#             print('\n')

# Training
## Dataset initialization

In [16]:
train_dataset = datagen(model_params.dataset, True, model_params.sup)
test_dataset = datagen(model_params.dataset, False, model_params.sup)

# Get sample batch
print('Train set: ', [x.shape for x in next(iter(train_dataset))])
print('Test set: ', [x.shape for x in next(iter(test_dataset))])

Train set:  [torch.Size([1, 1, 256, 256]), torch.Size([1, 1, 256, 256]), torch.Size([1, 2, 3])]
Test set:  [torch.Size([1, 1, 256, 256]), torch.Size([1, 1, 256, 256]), torch.Size([1, 2, 3])]


## Model initialize

In [17]:
# print case
print(model_params)
model_params.print_explanation()

dataset1_sup1_image1_heatmaps0_loss_image0

Model name:  dataset1_sup1_image1_heatmaps0_loss_image0
Model code:  11100_0.001_0_20_1
Dataset used:  Synthetic eye
Supervised or unsupervised model:  Supervised
Image type:  Image used
Heatmaps used:  Heatmaps not used
Loss function case:  0
Loss function for image:  MSELoss()
Loss function for affine:  <utils.utils1.loss_affine object at 0x7fe53458ac10>
Learning rate:  0.001
Decay rate:  0.96
Start epoch:  0
Number of epochs:  20
Batch size:  1




In [18]:
model = SP_DHR_Net(model_params)
print(model)

parameters = model.parameters()
optimizer = optim.Adam(parameters, model_params.learning_rate, weight_decay=0.0001)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: model_params.decay_rate ** epoch)
# model_path = 'trained_models/10102_0.001_0_20_1_20230930-091532.pth'

# if a model is loaded, the training will continue from the epoch it was saved at
try:
    model.load_state_dict(torch.load(model_path))
    model_params.start_epoch = int(model_path.split('/')[-1].split('_')[3])
    print(f'Loaded model from {model_path}\nstarting at epoch {model_params.start_epoch}')
    if model_params.start_epoch >= model_params.num_epochs:
            model_params.num_epochs += model_params.start_epoch
except:
    model_params.start_epoch = 0
    print('No model loaded, starting from scratch')


Running new version (not run SP on source image)
SP_DHR_Net(
  (affineNet): Affine_Network(
    (feature_extractor): Feature_Extractor(
      (input_layer): Sequential(
        (0): Conv2d(2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      )
      (layer_1): Forward_Layer(
        (pool_layer): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(3, 3))
        )
        (layer): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(3, 3))
          (1): GroupNorm(128, 128, eps=1e-05, affine=True)
          (2): PReLU(num_parameters=1)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): GroupNorm(128, 128, eps=1e-05, affine=True)
          (5): PReLU(num_parameters=1)
        )
      )
      (layer_2): Forward_Layer(
        (layer): Sequential(
          (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): GroupNorm(128, 128,

## Training function

In [19]:
# Define training function
def train(model, model_params, timestamp):
    # Define loss function based on supervised or unsupervised learning
    criterion = model_params.loss_image
    extra = loss_extra()

    if model_params.sup:
        criterion_affine = nn.MSELoss()
        # TODO: add loss for points1_affine and points2, Euclidean distance

    # Define optimizer
    optimizer = optim.Adam(model.parameters(), lr=model_params.learning_rate)

    # Create empty list to store epoch number, train loss and validation loss
    epoch_loss_list = []
    running_loss_list = []
    
    # Create output directory
    output_dir = f"output/{model_params.get_model_code()}_{timestamp}"
    os.makedirs(output_dir, exist_ok=True)

    # Train model
    for epoch in range(model_params.start_epoch, model_params.num_epochs):
        # Set model to training mode
        model.train()

        # Zero the parameter gradients
        optimizer.zero_grad()
        
        running_loss = 0.0
        train_bar = tqdm(train_dataset, desc=f'Training Epoch {epoch+1}/{model_params.num_epochs}')
        for i, data in enumerate(train_bar):
            # Get images and affine parameters
            if model_params.sup:
                source_image, target_image, affine_params_true = data
            else:
                source_image, target_image = data
                affine_params_true = None
            source_image = source_image.to(device)
            target_image = target_image.to(device)

            # Forward + backward + optimize
            outputs = model(source_image, target_image)
            # for i in range(len(outputs)):
            #         print(i, outputs[i].shape)
            # 0 torch.Size([1, 1, 256, 256])
            # 1 torch.Size([1, 2, 3])
            # 2 (2, 4)
            # 3 (2, 4)
            # 4 (1, 4, 2)
            # 5 (256, 9)
            # 6 (256, 16)
            # 7 (256, 256)
            # 8 (256, 256)
            transformed_source_affine = outputs[0] # image
            affine_params_predicted = outputs[1] # affine parameters
            points1 = outputs[2]
            points2 = outputs[3]
            points1_affine = np.array(outputs[4])

            # print(f"affine_params_true: {affine_params_true}")
            # print(f"affine_params_predicted: {affine_params_predicted}\n")

            try:
                points1_affine = points1_affine.reshape(points1_affine.shape[2], points1_affine.shape[1])
            except:
                pass
            desc1 = outputs[5]
            desc2 = outputs[6]
            heatmap1 = outputs[7]
            heatmap2 = outputs[8]

            loss = criterion(transformed_source_affine, target_image)
            loss += extra(affine_params_predicted)
            if model_params.sup:
                loss_affine = criterion_affine(affine_params_true.view(1, 2, 3), affine_params_predicted.cpu())
                # TODO: add loss for points1_affine and points2, Euclidean distance
                # loss_points = criterion_points(points1_affine, points2)
                loss += loss_affine
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            # Plot images if i < 5
            if i < 5:
                DL_affine_plot(f"epoch{epoch+1}_train", output_dir,
                    f"{i}", "_", source_image[0, 0, :, :].detach().cpu().numpy(), 
                    target_image[0, 0, :, :].detach().cpu().numpy(), 
                    transformed_source_affine[0, 0, :, :].detach().cpu().numpy(),
                    points1, points2, points1_affine, desc1, desc2, affine_params_true=affine_params_true,
                    affine_params_predict=affine_params_predicted, heatmap1=heatmap1, heatmap2=heatmap2, plot=True)

            # Print statistics
            running_loss += loss.item()
            running_loss_list.append([epoch+((i+1)/len(train_dataset)), loss.item()])
            train_bar.set_postfix({'loss': running_loss / (i+1)})
        print(f'Training Epoch {epoch+1}/{model_params.num_epochs} loss: {running_loss / len(train_dataset)}')

        # Validate model
        validation_loss = 0.0
        model.eval()
        # with torch.no_grad():
        for i, data in enumerate(test_dataset, 0):
            # Get images and affine parameters
            if model_params.sup:
                source_image, target_image, affine_params_true = data
            else:
                source_image, target_image = data
                affine_params_true = None
            source_image = source_image.to(device)
            target_image = target_image.to(device)

            # Forward pass
            outputs = model(source_image, target_image)
            # for i in range(len(outputs)):
            #     print(i, outputs[i].shape)
            transformed_source_affine = outputs[0]
            affine_params_predicted = outputs[1]
            points1 = outputs[2]
            points2 = outputs[3]
            points1_affine = np.array(outputs[4])
            try:
                points1_affine = points1_affine.reshape(points1_affine.shape[2], points1_affine.shape[1])
            except:
                pass
            desc1 = outputs[5]
            desc2 = outputs[6]
            heatmap1 = outputs[7]
            heatmap2 = outputs[8]

            loss = criterion(transformed_source_affine, target_image)
            loss += extra(affine_params_predicted)
            if model_params.sup:
                loss_affine = criterion_affine(affine_params_true.view(1, 2, 3), affine_params_predicted.cpu())
                # TODO: add loss for points1_affine and points2, Euclidean distance
                # loss_points = criterion_points(points1_affine, points2)
                loss += loss_affine

            # Add to validation loss
            validation_loss += loss.item()

            # Plot images if i < 5
            if i < 5:
                DL_affine_plot(f"epoch{epoch+1}_valid", output_dir,
                    f"{i}", "_", source_image[0, 0, :, :].detach().cpu().numpy(), 
                    target_image[0, 0, :, :].detach().cpu().numpy(), 
                    transformed_source_affine[0, 0, :, :].detach().cpu().numpy(),
                    points1, points2, points1_affine, desc1, desc2, affine_params_true=affine_params_true,
                    affine_params_predict=affine_params_predicted, heatmap1=heatmap1, heatmap2=heatmap2, plot=True)

        # Print validation statistics
        validation_loss /= len(test_dataset)
        print(f'Validation Epoch {epoch+1}/{model_params.num_epochs} loss: {validation_loss}')

        # Append epoch number, train loss and validation loss to epoch_loss_list
        epoch_loss_list.append([epoch+1, running_loss / len(train_dataset), validation_loss])

        
        # Extract epoch number, train loss and validation loss from epoch_loss_list
        epoch = [x[0] for x in epoch_loss_list]
        train_loss = [x[1] for x in epoch_loss_list]
        val_loss = [x[2] for x in epoch_loss_list]
        step = [x[0] for x in running_loss_list]
        running_train_loss = [x[1] for x in running_loss_list]

        save_plot_name = f"{output_dir}/loss_{model_params.get_model_code()}_epoch{model_params.num_epochs}_{timestamp}.png"

        # Plot train loss and validation loss against epoch number
        plt.figure()
        plt.plot(step, running_train_loss, label='Running Train Loss', alpha=0.3)
        plt.plot(epoch, train_loss, label='Train Loss')
        plt.plot(epoch, val_loss, label='Validation Loss')
        plt.title('Train and Validation Loss')
        plt.legend()
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.yscale('log')
        plt.tight_layout()
        plt.savefig(save_plot_name)
        plt.close()
        
    print('Finished Training')

    # delete all txt files in output_dir
    for file in os.listdir(output_dir):
        if file.endswith(".txt"):
            os.remove(os.path.join(output_dir, file))

    # Return epoch_loss_list
    return epoch_loss_list


## working code

In [20]:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
loss_list = train(model, model_params, timestamp)

  torch.tensor(M).view(1, 2, 3))
Training Epoch 1/20: 100%|██████████| 461/461 [01:30<00:00,  5.07it/s, loss=2.2]  


Training Epoch 1/20 loss: 2.1975847451084993
Validation Epoch 1/20 loss: 7.263564122926205


Training Epoch 2/20: 100%|██████████| 461/461 [01:39<00:00,  4.64it/s, loss=3.89]


Training Epoch 2/20 loss: 3.888145002595062


  image = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
  image = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)


Validation Epoch 2/20 loss: 9.310233995455121


Training Epoch 3/20: 100%|██████████| 461/461 [01:34<00:00,  4.87it/s, loss=3.93]


Training Epoch 3/20 loss: 3.9253992739571926
Validation Epoch 3/20 loss: 1.4278312768411199


Training Epoch 4/20: 100%|██████████| 461/461 [01:34<00:00,  4.88it/s, loss=0.656]


Training Epoch 4/20 loss: 0.6556272684978302
Validation Epoch 4/20 loss: 0.7295820838814482


Training Epoch 5/20: 100%|██████████| 461/461 [01:30<00:00,  5.08it/s, loss=0.5]  


Training Epoch 5/20 loss: 0.4999029438562869
Validation Epoch 5/20 loss: 0.156081740325744


Training Epoch 6/20: 100%|██████████| 461/461 [01:30<00:00,  5.10it/s, loss=0.109] 


Training Epoch 6/20 loss: 0.10937255937222254
Validation Epoch 6/20 loss: 0.11381915471422563


Training Epoch 7/20: 100%|██████████| 461/461 [01:33<00:00,  4.94it/s, loss=0.135]


Training Epoch 7/20 loss: 0.1354616981205771
Validation Epoch 7/20 loss: 0.04687200412325083


Training Epoch 8/20: 100%|██████████| 461/461 [01:35<00:00,  4.84it/s, loss=0.0955]


Training Epoch 8/20 loss: 0.09547796265430215
Validation Epoch 8/20 loss: 0.2183945373110815


Training Epoch 9/20: 100%|██████████| 461/461 [01:30<00:00,  5.11it/s, loss=0.123]


Training Epoch 9/20 loss: 0.12337391245544539
Validation Epoch 9/20 loss: 0.14618844600445632


Training Epoch 10/20: 100%|██████████| 461/461 [01:43<00:00,  4.45it/s, loss=0.0734]


Training Epoch 10/20 loss: 0.07338362073812658
Validation Epoch 10/20 loss: 0.08963110998546311


Training Epoch 11/20: 100%|██████████| 461/461 [01:59<00:00,  3.85it/s, loss=0.0725]


Training Epoch 11/20 loss: 0.07253742514182197
Validation Epoch 11/20 loss: 0.06919823586940765


Training Epoch 12/20: 100%|██████████| 461/461 [01:44<00:00,  4.42it/s, loss=0.0474]


Training Epoch 12/20 loss: 0.04735311784492624
Validation Epoch 12/20 loss: 0.04652125000475197


Training Epoch 13/20: 100%|██████████| 461/461 [01:35<00:00,  4.83it/s, loss=0.0522]


Training Epoch 13/20 loss: 0.05222686697524889
Validation Epoch 13/20 loss: 0.07957459142038581


Training Epoch 14/20: 100%|██████████| 461/461 [01:38<00:00,  4.67it/s, loss=0.0558]


Training Epoch 14/20 loss: 0.05582663718279685
Validation Epoch 14/20 loss: 0.06745984827364804


Training Epoch 15/20: 100%|██████████| 461/461 [01:33<00:00,  4.92it/s, loss=0.0602]


Training Epoch 15/20 loss: 0.0602284325104926
Validation Epoch 15/20 loss: 0.1385147816612633


Training Epoch 16/20: 100%|██████████| 461/461 [01:34<00:00,  4.87it/s, loss=0.158] 


Training Epoch 16/20 loss: 0.15809513810302844
Validation Epoch 16/20 loss: 0.10265482517830822


Training Epoch 17/20: 100%|██████████| 461/461 [01:32<00:00,  4.97it/s, loss=0.0617]


Training Epoch 17/20 loss: 0.06171040538367118
Validation Epoch 17/20 loss: 0.0917477202046355


Training Epoch 18/20: 100%|██████████| 461/461 [01:35<00:00,  4.84it/s, loss=0.0642]


Training Epoch 18/20 loss: 0.06420281573568437
Validation Epoch 18/20 loss: 0.13654189276585885


Training Epoch 19/20: 100%|██████████| 461/461 [01:39<00:00,  4.63it/s, loss=0.228]


Training Epoch 19/20 loss: 0.2283831620528292
Validation Epoch 19/20 loss: 0.6673955567386172


Training Epoch 20/20: 100%|██████████| 461/461 [01:33<00:00,  4.92it/s, loss=1.25] 


Training Epoch 20/20 loss: 1.2478737919359042
Validation Epoch 20/20 loss: 1.3351452131883814
Finished Training


In [21]:
print("Training output:")
for i in range(len(loss_list)):
    print(loss_list[i])

Training output:
[1, 2.1975847451084993, 7.263564122926205]
[2, 3.888145002595062, 9.310233995455121]
[3, 3.9253992739571926, 1.4278312768411199]
[4, 0.6556272684978302, 0.7295820838814482]
[5, 0.4999029438562869, 0.156081740325744]
[6, 0.10937255937222254, 0.11381915471422563]
[7, 0.1354616981205771, 0.04687200412325083]
[8, 0.09547796265430215, 0.2183945373110815]
[9, 0.12337391245544539, 0.14618844600445632]
[10, 0.07338362073812658, 0.08963110998546311]
[11, 0.07253742514182197, 0.06919823586940765]
[12, 0.04735311784492624, 0.04652125000475197]
[13, 0.05222686697524889, 0.07957459142038581]
[14, 0.05582663718279685, 0.06745984827364804]
[15, 0.0602284325104926, 0.1385147816612633]
[16, 0.15809513810302844, 0.10265482517830822]
[17, 0.06171040538367118, 0.0917477202046355]
[18, 0.06420281573568437, 0.13654189276585885]
[19, 0.2283831620528292, 0.6673955567386172]
[20, 1.2478737919359042, 1.3351452131883814]


## Model saving

In [22]:
model_save_path = "trained_models/"
model_name_to_save = model_save_path + f"DHR_{model_params.get_model_code()}_{timestamp}.pth"
print(model_name_to_save)
torch.save(model.state_dict(), model_name_to_save)


trained_models/DHR_11100_0.001_0_20_1_20231030-141020.pth


# Test model (loading and inference)

Save results and export metrics to csv

In [23]:
# model = SPmodel = SP_AffineNet().to(device)
# print(model)

# parameters = model.parameters()
# optimizer = optim.Adam(parameters, model_params.learning_rate)
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: model_params.decay_rate ** epoch)

# model.load_state_dict(torch.load(model_name_to_save))

In [24]:
def test(model, model_params, timestamp):
    # Set model to training mode
    model.eval()

    # Create output directory
    output_dir = f"output/{model_params.get_model_code()}_{timestamp}_test"
    os.makedirs(output_dir, exist_ok=True)

    # Validate model
    # validation_loss = 0.0

    # create a csv file to store the metrics
    csv_file = f"{output_dir}/metrics.csv"
    with open(csv_file, 'w', newline='') as file:
        writer = csv.writer(file)
        # matches1_transformed.shape[-1], mse_before, mse12, tre_before, tre12, \
        # mse12_image, ssim12_image, 
        writer.writerow(["index", "mse_before", "mse12", "tre_before", "tre12", "mse12_image_before", "mse12_image", "ssim12_image_before", "ssim12_image"])

    with torch.no_grad():
        testbar = tqdm(test_dataset, desc=f'Testing:')
        for i, data in enumerate(testbar, 0):
            # Get images and affine parameters
            if model_params.sup:
                source_image, target_image, affine_params_true = data
            else:
                source_image, target_image = data
            source_image = source_image.to(device)
            target_image = target_image.to(device)

            # Forward pass
            outputs = model(source_image, target_image)
            # for i in range(len(outputs)):
            #     print(i, outputs[i].shape)
            transformed_source_affine = outputs[0]
            affine_params_predicted = outputs[1]
            points1 = outputs[2]
            points2 = outputs[3]
            points1_affine = np.array(outputs[4])
            try:
                points1_affine = points1_affine.reshape(points1_affine.shape[2], points1_affine.shape[1])
            except:
                pass
            desc1 = outputs[5]
            desc2 = outputs[6]
            heatmap1 = outputs[7]
            heatmap2 = outputs[8]

            if i < 50:
                plot_ = True
            else:
                plot_ = False

            results = DL_affine_plot(f"{i+1}", output_dir,
                f"{i}", "_", source_image[0, 0, :, :].cpu().numpy(), target_image[0, 0, :, :].cpu().numpy(), \
                transformed_source_affine[0, 0, :, :].cpu().numpy(), \
                points1, points2, points1_affine, desc1, desc2, affine_params_true=affine_params_true,
                affine_params_predict=affine_params_predicted, heatmap1=heatmap1, heatmap2=heatmap2, plot=plot_)

            # calculate metrics
            # matches1_transformed = results[0]
            mse_before = results[1]
            mse12 = results[2]
            tre_before = results[3]
            tre12 = results[4]
            mse12_image_before = results[5]
            mse12_image = results[6]
            ssim12_image_before = results[7]
            ssim12_image = results[8]

            # write metrics to csv file
            with open(csv_file, 'a', newline='') as file:
                writer = csv.writer(file) # TODO: might need to export true & predicted affine parameters too
                writer.writerow([i, mse_before, mse12, tre_before, tre12, mse12_image_before, mse12_image, ssim12_image_before, ssim12_image])

    # delete all txt files in output_dir
    for file in os.listdir(output_dir):
        if file.endswith(".txt"):
            os.remove(os.path.join(output_dir, file))

# timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
metrics = test(model, model_params, timestamp)

Testing::   0%|          | 0/109 [00:00<?, ?it/s]

Testing:: 100%|██████████| 109/109 [01:36<00:00,  1.12it/s]
