In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import random
import csv
from PIL import Image
import shutil
from tqdm import tqdm
import torch

from utils.utils0 import tensor_affine_transform, transform_to_displacement_field
from utils.utils1 import transform_points_DVF, ModelParams
from utils.SuperPoint import SuperPointFrontend, PointTracker
from utils.datagen import datagen

nn_thresh = 0.7
superpoint = SuperPointFrontend('utils/superpoint_v1.pth', nms_dist=4,
                          conf_thresh=0.015, nn_thresh=nn_thresh, cuda=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Device: cuda


In [2]:
model_params = ModelParams(sup=1, dataset=2, image=1, heatmaps=0, 
                           loss_image=1, num_epochs=10, learning_rate=1e-4)
model_params.print_explanation()

Model name:  dataset2_sup1_image1_heatmaps0_loss_image1
Model code:  21101_0.0001_0_10_1
Model params:  {'dataset': 2, 'sup': 1, 'image': 1, 'heatmaps': 0, 'loss_image_case': 1, 'loss_image': NCC(), 'loss_affine': <utils.utils1.loss_affine object at 0x7ff2a81e6820>, 'learning_rate': 0.0001, 'decay_rate': 0.96, 'start_epoch': 0, 'num_epochs': 10, 'batch_size': 1, 'model_name': 'dataset2_sup1_image1_heatmaps0_loss_image1'}

Model name:  dataset2_sup1_image1_heatmaps0_loss_image1
Model code:  21101_0.0001_0_10_1
Dataset used:  Synthetic eye medium
Supervised or unsupervised model:  Supervised
Image type:  Image used
Heatmaps used:  Heatmaps not used
Loss function case:  1
Loss function for image:  NCC()
Loss function for affine:  <utils.utils1.loss_affine object at 0x7ff2a81e6820>
Learning rate:  0.0001
Decay rate:  0.96
Start epoch:  0
Number of epochs:  10
Batch size:  1




In [3]:
train_dataset, train_df, train_path = datagen(model_params.dataset, True, model_params.sup)
test_dataset, test_df, test_path = datagen(model_params.dataset, False, model_params.sup)

In [4]:
train_df.head()

Unnamed: 0,source,target,M00,M01,M02,M10,M11,M12,image_path
0,Dataset/synth_eye_medium_train/img_0_original.png,Dataset/synth_eye_medium_train/img_0_transform...,1.160598,0.0,-0.119816,0.0,1.12106,-0.031026,Dataset/Dataset-processed/15-12-2559/2011248/L...
1,Dataset/synth_eye_medium_train/img_1_original.png,Dataset/synth_eye_medium_train/img_1_transform...,0.836795,0.0,0.050553,0.0,1.162089,0.074007,Dataset/Dataset-processed/15-12-2559/2011248/L...
2,Dataset/synth_eye_medium_train/img_2_original.png,Dataset/synth_eye_medium_train/img_2_transform...,1.050389,0.0,0.158718,0.0,0.879272,-0.117546,Dataset/Dataset-processed/15-12-2559/2011248/R...
3,Dataset/synth_eye_medium_train/img_3_original.png,Dataset/synth_eye_medium_train/img_3_transform...,1.135652,0.0,-0.069543,0.0,0.997452,-0.048658,Dataset/Dataset-processed/15-12-2559/2011248/R...
4,Dataset/synth_eye_medium_train/img_4_original.png,Dataset/synth_eye_medium_train/img_4_transform...,0.988166,0.0,-0.14685,0.0,1.081306,0.00638,Dataset/Dataset-processed/15-12-2559/2011248/R...


In [5]:
# add one column to the dataframe to store the path to keypoints file
train_df['keypoints'] = train_df['source'].apply(lambda x: x.replace('_original.png', '_keypoints.csv'))
test_df['keypoints'] = test_df['source'].apply(lambda x: x.replace('_original.png', '_keypoints.csv'))

train_df.head()

Unnamed: 0,source,target,M00,M01,M02,M10,M11,M12,image_path,keypoints
0,Dataset/synth_eye_medium_train/img_0_original.png,Dataset/synth_eye_medium_train/img_0_transform...,1.160598,0.0,-0.119816,0.0,1.12106,-0.031026,Dataset/Dataset-processed/15-12-2559/2011248/L...,Dataset/synth_eye_medium_train/img_0_keypoints...
1,Dataset/synth_eye_medium_train/img_1_original.png,Dataset/synth_eye_medium_train/img_1_transform...,0.836795,0.0,0.050553,0.0,1.162089,0.074007,Dataset/Dataset-processed/15-12-2559/2011248/L...,Dataset/synth_eye_medium_train/img_1_keypoints...
2,Dataset/synth_eye_medium_train/img_2_original.png,Dataset/synth_eye_medium_train/img_2_transform...,1.050389,0.0,0.158718,0.0,0.879272,-0.117546,Dataset/Dataset-processed/15-12-2559/2011248/R...,Dataset/synth_eye_medium_train/img_2_keypoints...
3,Dataset/synth_eye_medium_train/img_3_original.png,Dataset/synth_eye_medium_train/img_3_transform...,1.135652,0.0,-0.069543,0.0,0.997452,-0.048658,Dataset/Dataset-processed/15-12-2559/2011248/R...,Dataset/synth_eye_medium_train/img_3_keypoints...
4,Dataset/synth_eye_medium_train/img_4_original.png,Dataset/synth_eye_medium_train/img_4_transform...,0.988166,0.0,-0.14685,0.0,1.081306,0.00638,Dataset/Dataset-processed/15-12-2559/2011248/R...,Dataset/synth_eye_medium_train/img_4_keypoints...


In [6]:
# save the dataframe to a csv file
train_df.to_csv('Dataset/synth_eye_medium_train.csv', index=False)
test_df.to_csv('Dataset/synth_eye_medium_test.csv', index=False)

In [9]:
train_bar = tqdm(train_dataset, total=len(train_dataset), desc='Train')
for i, data in enumerate(train_bar):

    # Get images and affine parameters
    if model_params.sup:
        source_image, target_image, affine_params_true = data
        affine_params_true = torch.tensor(affine_params_true)
    else:
        source_image, target_image = data
        affine_params_true = None
    source_image = source_image.to(device)
    target_image = target_image.to(device)

    points1, desc1, heatmap1 = superpoint(source_image[0, 0, :, :].cpu().numpy())
    points2, desc2, heatmap2 = superpoint(target_image[0, 0, :, :].cpu().numpy())

    tracker = PointTracker(5, nn_thresh=0.7)
    try:
        matches = tracker.nn_match_two_way(desc1, desc2, nn_thresh=nn_thresh)
    except:
        print('No matches found')
        # TODO: find a better way to do this
        pass

    matches1 = np.array(points1[:2, matches[0, :].astype(int)])
    matches2 = np.array(points2[:2, matches[1, :].astype(int)])
    matches1_2 = transform_points_DVF(torch.tensor(matches1), 
                        affine_params_true, target_image)
    
    # fig, ax = plt.subplots(1, 2)
    # ax[0].imshow(source_image[0, 0, :, :].cpu().detach().numpy())
    # ax[0].plot(matches1[0, :], matches1[1, :], 'r.')
    # ax[1].imshow(target_image[0, 0, :, :].cpu().detach().numpy())
    # ax[1].plot(matches2[0, :], matches2[1, :], 'g.')
    # ax[1].plot(matches1_2[0, :], matches1_2[1, :], 'r.')
    # plt.show()

    # create a dataframe with the matches
    df = pd.DataFrame({'x1': matches1[0, :], 'y1': matches1[1, :],
                       'x2': matches2[0, :], 'y2': matches2[1, :],
                       'x2_': matches1_2[0, :], 'y2_': matches1_2[1, :]})
    save_name = train_df['keypoints'].iloc[i]
    # print(save_name)
    df.to_csv(save_name, index=False)

  affine_params_true = torch.tensor(affine_params_true)
Train:   0%|          | 1/200 [00:00<00:28,  7.01it/s]

Train: 100%|██████████| 200/200 [00:31<00:00,  6.32it/s]


In [10]:
test_bar = tqdm(test_dataset, total=len(test_dataset), desc='Test')
for i, data in enumerate(test_bar):

    # Get images and affine parameters
    if model_params.sup:
        source_image, target_image, affine_params_true = data
        affine_params_true = torch.tensor(affine_params_true)
    else:
        source_image, target_image = data
        affine_params_true = None
    source_image = source_image.to(device)
    target_image = target_image.to(device)

    points1, desc1, heatmap1 = superpoint(source_image[0, 0, :, :].cpu().numpy())
    points2, desc2, heatmap2 = superpoint(target_image[0, 0, :, :].cpu().numpy())

    tracker = PointTracker(5, nn_thresh=0.7)
    try:
        matches = tracker.nn_match_two_way(desc1, desc2, nn_thresh=nn_thresh)
    except:
        print('No matches found')
        # TODO: find a better way to do this
        pass

    matches1 = np.array(points1[:2, matches[0, :].astype(int)])
    matches2 = np.array(points2[:2, matches[1, :].astype(int)])
    matches1_2 = transform_points_DVF(torch.tensor(matches1), 
                        affine_params_true, target_image)
    
    # fig, ax = plt.subplots(1, 2)
    # ax[0].imshow(source_image[0, 0, :, :].cpu().detach().numpy())
    # ax[0].plot(matches1[0, :], matches1[1, :], 'r.')
    # ax[1].imshow(target_image[0, 0, :, :].cpu().detach().numpy())
    # ax[1].plot(matches2[0, :], matches2[1, :], 'g.')
    # ax[1].plot(matches1_2[0, :], matches1_2[1, :], 'r.')
    # plt.show()

    # create a dataframe with the matches
    df = pd.DataFrame({'x1': matches1[0, :], 'y1': matches1[1, :],
                       'x2': matches2[0, :], 'y2': matches2[1, :],
                       'x2_': matches1_2[0, :], 'y2_': matches1_2[1, :]})
    save_name = test_df['keypoints'].iloc[i]
    df.to_csv(save_name, index=False)

  affine_params_true = torch.tensor(affine_params_true)
Test: 100%|██████████| 100/100 [00:16<00:00,  6.17it/s]


# verify that the saved keypoints are correct