In [35]:
# import sys
import argparse
import numpy as np
import os
import cv2
import torch
import matplotlib.pyplot as plt
# from skimage.metrics import structural_similarity as ssim
# from skimage.measure import ransac
# from skimage.transform import FundamentalMatrixTransform, AffineTransform

from datetime import datetime

import torch
# from torchvision import transforms
# import torch.nn.functional as F
# from torch.utils import data
# from torchsummary import summary
# from pytorch_model_summary import summary

torch.manual_seed(9793047918980052389)
print('Seed:', torch.seed())

from utils.utils0 import *
from utils.utils1 import *
from utils.utils1 import ModelParams, print_summary
from utils import test

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')

# Stub to warn about opencv version.
if int(cv2.__version__[0]) < 3: # pragma: no cover
  print('Warning: OpenCV 3 is not installed')

image_size = 256

Seed: 13896392224620828981
Device: cuda


In [36]:

# from utils.SuperPoint import SuperPointFrontend
# from utils.utils1 import transform_points_DVF
def test(model_name, model_, model_params, timestamp):
    # model_name: name of the model
    # model: model to be tested
    # model_params: model parameters
    # timestamp: timestamp of the model
    print('Test function input:', model_name, model_, model_params, timestamp)

    test_dataset = datagen(model_params.dataset, False, model_params.sup, batch_size=1) # batch_size need to be 1
    print(model_params.batch_size)

    # if model is a string, load the model
    # if model is a loaded model, use the model
    if isinstance(model_, str):
        model = model_loader(model_name, model_params)
        buffer = io.BytesIO()
        torch.save(model.state_dict(), buffer)
        buffer.seek(0)
        model.load_state_dict(torch.load(model_))
        print(f'Loaded model from {model_}')
    elif isinstance(model_, nn.Module):
        print(f'Using model {model_name}')
        model = model_

    # Set model to training mode
    model.eval()

    # Create output directory
    output_dir = f"output/{model_name}_{model_params.get_model_code()}_{timestamp}_test"
    os.makedirs(output_dir, exist_ok=True)

    # Validate model
    # validation_loss = 0.0

    metrics = []
    # create a csv file to store the metrics
    csv_file = f"{output_dir}/metrics.csv"

    with torch.no_grad():
        testbar = tqdm(test_dataset, desc=f'Testing:')
        for i, data in enumerate(testbar, 0):
            # Get images and affine parameters
            source_image, target_image, affine_params_true, points1, points2, points1_2_true = data

            source_image = source_image.requires_grad_(True).to(device)
            target_image = target_image.requires_grad_(True).to(device)
            # add gradient to the matches
            points1 = points1.requires_grad_(True).to(device)
            points2 = points2.requires_grad_(True).to(device)

            # Forward + backward + optimize
            outputs = model(source_image, target_image, points1)
            # for i in range(len(outputs)):
            #     print(i, outputs[i].shape)
            transformed_source_affine = outputs[0]
            affine_params_predicted = outputs[1]
            points1_2_predicted = outputs[2]

            # try:
            #     points1_2_predicted = points1_2_predicted.reshape(
            #     points1_2_predicted.shape[2], points1_2_predicted.shape[1])
            # except:
            #     pass

            if i < 100:
                plot_ = True
            else:
                plot_ = False

            # print(points1_2_predicted.shape, points2.shape, points1.shape)
            # for loop to plot each image, use the actual batch size from output
            batch = 0

            try: 
                # points1_2_predicted[batch] = points1_2_predicted[batch].reshape(
                #     points1_2_predicted[batch].shape[1], points1_2_predicted[batch].shape[0])
                results = DL_affine_plot(f"test", output_dir,
                    f"{i}", f"{model_params.batch_size}", 
                    source_image[batch, 0, :, :].cpu().numpy(), 
                    target_image[batch, 0, :, :].cpu().numpy(), 
                    transformed_source_affine[batch, 0, :, :].cpu().numpy(),
                    points1[batch].cpu().detach().numpy().T, 
                    points2[batch].cpu().detach().numpy().T, 
                    points1_2_predicted[batch].cpu().detach().numpy().T, None, None, 
                    affine_params_true[batch], affine_params_predicted[batch], 
                    heatmap1=None, heatmap2=None, plot=plot_)

                # calculate metrics
                # matches1_transformed = results[0]
                mse_before = results[1]
                mse12 = results[2]
                tre_before = results[3]
                tre12 = results[4]
                mse12_image_before = results[5]
                mse12_image = results[6]
                ssim12_image_before = results[7]
                ssim12_image = results[8]

                # append metrics to metrics list
                metrics.append([i, mse_before, mse12, tre_before, tre12, \
                                mse12_image_before, mse12_image, ssim12_image_before, ssim12_image, np.max(points1_2_predicted[batch].shape)])
            except:
                # print(f"Error at {i*model_params.batch_size+batch}")
                pass

    metrics_ = []
    with open(csv_file, 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["index", "mse_before", "mse12", "tre_before", "tre12", "mse12_image_before", "mse12_image", "ssim12_image_before", "ssim12_image", "num_points"])
        for i in range(len(metrics)):
            writer.writerow(metrics[i])
        # write the average and std of the metrics
        metrics_ = metrics.copy()
        metrics = np.array(metrics)
        nan_mask = np.isnan(metrics).any(axis=1)
        metrics = metrics[~nan_mask]

        print(metrics[40])

        avg = ["average", np.mean(metrics[:, 1]), np.mean(metrics[:, 2]), np.mean(metrics[:, 3]), np.mean(metrics[:, 4]), 
            np.mean(metrics[:, 5]), np.mean(metrics[:, 6]), np.mean(metrics[:, 7]), np.mean(metrics[:, 8])]
        std = ["std", np.std(metrics[:, 1]), np.std(metrics[:, 2]), np.std(metrics[:, 3]), np.std(metrics[:, 4]), 
            np.std(metrics[:, 5]), np.std(metrics[:, 6]), np.std(metrics[:, 7]), np.std(metrics[:, 8])]
        writer.writerow(avg)
        writer.writerow(std)

    print(f"The test results are saved in {csv_file}")

    # metrics_ = []
    # with open(csv_file, 'w', newline='') as file:

    #     writer = csv.writer(file)
    #     writer.writerow(["index", "mse_before", "mse12", "tre_before", "tre12", "mse12_image_before", "mse12_image", 
    #                      "ssim12_image_before", "ssim12_image", "num_points", "votes"])
    #     for i in range(len(metrics)):
    #         writer.writerow(metrics[i])

    #     # print(metrics[40:45])
        
    #     # drop the last column of the array 'metrics'
    #     metrics = [metrics[i][1:-1] for i in range(len(metrics))]
    #     metrics_ = metrics.copy()
    #     metrics = np.array(metrics)

    #     # print(metrics[40:45])

    #     # metrics = metrics[:, :8]
    #     nan_mask = np.isnan(metrics).any(axis=1)
    #     metrics = metrics[~nan_mask]

    #     # print(metrics[40:45])

    #     # avg = ["average", np.mean(metrics[:, 1]), np.mean(metrics[:, 2]), np.mean(metrics[:, 3]), np.mean(metrics[:, 4]), 
    #     #     np.mean(metrics[:, 5]), np.mean(metrics[:, 6]), np.mean(metrics[:, 7]), np.mean(metrics[:, 8])]
    #     avg = ["average", np.mean(metrics[:, 1:9], axis=0)]
    #     std = ["std", np.std(metrics[:, 1:9], axis=0)]
    #     writer.writerow(avg)
    #     writer.writerow(std)

    # print(f"The test results are saved in {csv_file}")

    return csv_file, metrics_


In [37]:

# Access the values of the command-line arguments
dataset = 0
sup = 1
image = 1
heatmaps = 0
loss_image = 0
num_epochs = 1
learning_rate = 1e-3
decay_rate = 0.96
model = 'DHR'
model_path = None
plot = 2

# model_path = 'trained_models/' + args.model_path
model_path = 'DHR_31100_0.001_0_10_100_20240508-120807.pth'
# add 'trained_models/' in front of each element of model_path
model_path = 'trained_models/' + model_path

model_params = ModelParams(dataset=dataset, sup=sup, image=image, loss_image=loss_image, 
                            num_epochs=num_epochs, learning_rate=learning_rate, 
                            decay_rate=decay_rate, 
                            plot=plot)
model_params.print_explanation()

timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

Model name:  dataset0_sup0_image1_points0_loss_image0
Model code:  00100_0.001_0_1_1
Model params:  {'dataset': 0, 'sup': 0, 'image': 1, 'points': 0, 'loss_image_case': 0, 'loss_image': MSELoss(), 'loss_affine': <utils.utils1.loss_affine object at 0x7fba10b37bb0>, 'learning_rate': 0.001, 'decay_rate': 0.96, 'start_epoch': 0, 'num_epochs': 1, 'batch_size': 1, 'model_name': 'dataset0_sup0_image1_points0_loss_image0'}

Model name:  dataset0_sup0_image1_points0_loss_image0
Model code:  00100_0.001_0_1_1
Dataset used:  Actual eye
Supervised or unsupervised model:  Unsupervised
Loss image type:  Image used
Points used:  Points not used
Loss function case:  0
Loss function for image:  MSELoss()
Loss function for affine:  <utils.utils1.loss_affine object at 0x7fba10b37bb0>
Learning rate:  0.001
Decay rate:  0.96
Start epoch:  0
Number of epochs:  1
Batch size:  1




In [38]:
model_path

'trained_models/DHR_31100_0.001_0_10_100_20240508-120807.pth'

In [39]:
print(f"\nTesting the trained model: {model} +++++++++++++++++++++++")

csv_file, metrics = test(model, model_path, model_params, timestamp)
print("Test model finished +++++++++++++++++++++++++++++")


Testing the trained model: DHR +++++++++++++++++++++++
Test function input: DHR trained_models/DHR_31100_0.001_0_10_100_20240508-120807.pth dataset0_sup0_image1_points0_loss_image0 20240512-102902
Test eye dataset
Number of testing data:  100
1


Using DHR difference
Loaded model from trained_models/DHR_31100_0.001_0_10_100_20240508-120807.pth


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
Testing:: 100%|██████████| 100/100 [01:49<00:00,  1.09s/it]

[4.00000000e+01 4.08337500e+03 4.14786084e+03 9.03580093e+01
 9.10563507e+01 3.98883931e-02 4.06001955e-02 4.10601044e-01
 4.21419824e-01 8.00000000e+00]
The test results are saved in output/DHR_00100_0.001_0_1_1_20240512-102902_test/metrics.csv
Test model finished +++++++++++++++++++++++++++++





In [40]:
metrics

[[0,
  952.5,
  866.50684,
  43.490982,
  41.478127,
  0.09467417,
  0.09412369,
  0.3826547475335451,
  0.39697014613346593,
  2],
 [1,
  438.91428,
  466.97562,
  28.807297,
  29.810226,
  0.02356157,
  0.023728952,
  0.4219869167820732,
  0.4288760390206901,
  35],
 [2,
  1.0068493,
  1.2077827,
  1.2436951,
  1.3893102,
  0.0011295322,
  0.001131853,
  0.8906353716834011,
  0.8874733766969876,
  146],
 [3,
  162.27272,
  177.0862,
  17.178314,
  18.004549,
  0.012607125,
  0.0125483135,
  0.6429356175700988,
  0.6566462792198877,
  88],
 [4,
  193.71666,
  210.94843,
  18.64133,
  19.435785,
  0.027227499,
  0.026923459,
  0.3895056947570645,
  0.3972143934073377,
  30],
 [5,
  127.05556,
  132.73228,
  11.901311,
  12.197722,
  0.010623909,
  0.010420578,
  0.5504893353454324,
  0.5684098502578168,
  72],
 [6,
  140.25,
  146.17899,
  16.085976,
  16.403158,
  0.017356073,
  0.01743859,
  0.45740179260941316,
  0.4686344067943819,
  18],
 [7,
  392.42856,
  387.7521,
  27.623726,


In [41]:
np.array(metrics)

array([[0.00000000e+00, 9.52500000e+02, 8.66506836e+02, 4.34909821e+01,
        4.14781265e+01, 9.46741700e-02, 9.41236913e-02, 3.82654748e-01,
        3.96970146e-01, 2.00000000e+00],
       [1.00000000e+00, 4.38914276e+02, 4.66975616e+02, 2.88072968e+01,
        2.98102264e+01, 2.35615708e-02, 2.37289518e-02, 4.21986917e-01,
        4.28876039e-01, 3.50000000e+01],
       [2.00000000e+00, 1.00684929e+00, 1.20778275e+00, 1.24369514e+00,
        1.38931024e+00, 1.12953223e-03, 1.13185297e-03, 8.90635372e-01,
        8.87473377e-01, 1.46000000e+02],
       [3.00000000e+00, 1.62272720e+02, 1.77086197e+02, 1.71783142e+01,
        1.80045490e+01, 1.26071246e-02, 1.25483135e-02, 6.42935618e-01,
        6.56646279e-01, 8.80000000e+01],
       [4.00000000e+00, 1.93716660e+02, 2.10948425e+02, 1.86413307e+01,
        1.94357853e+01, 2.72274986e-02, 2.69234590e-02, 3.89505695e-01,
        3.97214393e-01, 3.00000000e+01],
       [5.00000000e+00, 1.27055557e+02, 1.32732285e+02, 1.19013109e+01,
   

In [42]:
metrics = np.array(metrics)
nan_mask = np.isnan(metrics).any(axis=1)
metrics = metrics[~nan_mask]
np.mean(metrics, axis=0)

array([4.95858586e+01, 5.04985432e+02, 5.04687347e+02, 2.58785015e+01,
       2.59376638e+01, 3.15154436e-02, 3.13606523e-02, 4.50214391e-01,
       4.62652154e-01, 4.45252525e+01])

In [43]:
metrics.shape

(99, 10)

In [44]:
for i in metrics:
    print(i)

[0.00000000e+00 9.52500000e+02 8.66506836e+02 4.34909821e+01
 4.14781265e+01 9.46741700e-02 9.41236913e-02 3.82654748e-01
 3.96970146e-01 2.00000000e+00]
[1.00000000e+00 4.38914276e+02 4.66975616e+02 2.88072968e+01
 2.98102264e+01 2.35615708e-02 2.37289518e-02 4.21986917e-01
 4.28876039e-01 3.50000000e+01]
[2.00000000e+00 1.00684929e+00 1.20778275e+00 1.24369514e+00
 1.38931024e+00 1.12953223e-03 1.13185297e-03 8.90635372e-01
 8.87473377e-01 1.46000000e+02]
[3.00000000e+00 1.62272720e+02 1.77086197e+02 1.71783142e+01
 1.80045490e+01 1.26071246e-02 1.25483135e-02 6.42935618e-01
 6.56646279e-01 8.80000000e+01]
[4.00000000e+00 1.93716660e+02 2.10948425e+02 1.86413307e+01
 1.94357853e+01 2.72274986e-02 2.69234590e-02 3.89505695e-01
 3.97214393e-01 3.00000000e+01]
[5.00000000e+00 1.27055557e+02 1.32732285e+02 1.19013109e+01
 1.21977224e+01 1.06239086e-02 1.04205776e-02 5.50489335e-01
 5.68409850e-01 7.20000000e+01]
[6.00000000e+00 1.40250000e+02 1.46178986e+02 1.60859756e+01
 1.64031582e+01

In [45]:
metrics[41]

array([4.20000000e+01, 4.08125000e+02, 3.60208740e+02, 2.82771091e+01,
       2.64179211e+01, 3.21540907e-02, 3.11431549e-02, 5.18360528e-01,
       5.31678263e-01, 3.60000000e+01])