In [None]:
%load_ext autoreload
%autoreload 2
CUDA_LAUNCH_BLOCKING=1

In [None]:
from __future__ import division, print_function
from typing import Dict, SupportsRound, Tuple, Any
from os import PathLike
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torch,gc
from torch.autograd import grad
from torch.autograd import Variable
import torch.fft ############### Pytorch >= 1.8.0
import torch.nn.functional as F
import SimpleITK as sitk
import os, glob
import json
import subprocess
import sys
from PIL import Image
import torch.nn.functional as nnf
from torch.optim.lr_scheduler import CosineAnnealingLR,CosineAnnealingWarmRestarts,StepLR
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from torch.utils.data import TensorDataset, DataLoader
import torch
from torchvision import datasets, transforms
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim
import random
import yaml
from tqdm import tqdm, trange
from numpy import zeros, newaxis
from easydict import EasyDict as edict
import matplotlib.pyplot as plt
import json
import pickle
import cv2
import lagomorph as lm

import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

if torch.cuda.is_available():
    dev = "cuda"
else:
    dev = "cpu"

In [None]:
# Load the NumPy array from the pickle file
with open('./datasets/mnist_10p_train_x.pkl', 'rb') as file:
    final_X_train = pickle.load(file)

with open('./datasets/mnist_10p_train_y.pkl', 'rb') as file:
    final_y_train = pickle.load(file)

with open('./datasets/mnist_10p_test_x.pkl', 'rb') as file:
    final_X_test = pickle.load(file)

with open('./datasets/mnist_10p_test_y.pkl', 'rb') as file:
    final_y_test = pickle.load(file)

In [None]:
################Parameter Loading#######################
def read_yaml(path):
    try:
        with open(path, 'r') as f:
            file = edict(yaml.load(f, Loader=yaml.FullLoader))
        return file
    except:
        print('NO FILE READ!')
        return None
    
para = read_yaml('./parameters.yml')

xDim = para.data.x 
yDim = para.data.y
zDim = para.data.z

In [None]:
#################Network optimization########################

'''Check initilization'''
from losses import MSE, Grad
from network_epdiff import Unet
from torch.utils.data import Dataset

net = Unet(
    in_shape = (xDim, yDim),
    infeats = 2,
    nb_features = [[16, 32, 32], [32, 32, 32, 16, 16]],
    nb_levels = None,
    max_pool = 2,
    feat_mult = 1,
    nb_conv_per_level = 1,
    half_res = False,
    skip_connection = True    
)
net = net.to(dev)

class TDataset(Dataset):
    def __init__(self, data, labels=None):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data = self.data[index]
        label = self.labels[index] if self.labels is not None else None
        return data, label


train_set = TDataset(final_X_train, final_y_train)
trainloader = torch.utils.data.DataLoader(train_set, batch_size = 1, shuffle=True, num_workers=1)

test_set = TDataset(final_X_test, final_y_test)
testloader = torch.utils.data.DataLoader(test_set, batch_size = para.solver.batch_size, shuffle=True, num_workers=1)

running_loss = 0 
running_loss_val = 0
template_loss = 0
printfreq = 1
sigma = 0.02
repara_trick = 0.0
loss_array = torch.FloatTensor(para.solver.epochs,1).fill_(0)
loss_array_val = torch.FloatTensor(para.solver.epochs,1).fill_(0)


if(para.model.loss == 'L2'):
    criterion = nn.MSELoss()
elif (para.model.loss == 'L1'):
    criterion = nn.L1Loss()
if(para.model.optimizer == 'Adam'):
    optimizer = optim.Adam(net.parameters(), lr= para.solver.lr)
elif (para.model.optimizer == 'SGD'):
    optimizer = optim.SGD(net.parameters(), lr= para.solver.lr, momentum=0.9)
if (para.model.scheduler == 'CosAn'):
    scheduler = CosineAnnealingLR(optimizer, T_max=len(trainloader), eta_min=0)

optimizer_template = optim.Adam(net.parameters(), lr= para.solver.lr)
scheduler_template = CosineAnnealingLR(optimizer_template, T_max=len(trainloader), eta_min=0)

In [None]:
class ConvNet(nn.Module):
    def __init__(self, num_classes=9):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=56, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=56, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 16 * 16, 128)  # Adjusted for input size 128x128
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        # print(x.shape)
        x = x.view(-1, 64 * 16 * 16)  # Adjusted for input size 128x128
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

## Classification

In [None]:
times = 3 # if you want to use any other value, please train and save the augmentation model accordingly and perform testing here. We pre-trained the augmentation model for 3times augmentation here.

clf_model_path = './saved_models/mgaug_lddmm_clf_mnist_10p_3t.pth'
classifier = torch.load(clf_model_path)
classifier.eval()

aug_model_path = './saved_models/mgaug_lddmm_mnist_10p.pth'
augment = torch.load(aug_model_path)
augment.eval()

### without test-time augmentation ###
print('------------------------------------------')
print('Without Test Time Augmentation Performance')
print('------------------------------------------')

acc = 0
predictions = []
targets = []
with torch.no_grad():
    classifier.eval()
    augment.eval()
    for j, image_data in enumerate(testloader):
        inputs, batch_labels = image_data
        inputs = inputs.to(dev)
        b, c, w, h = inputs.shape

        src_bch = inputs[:,0,...].reshape(b,1,w,h)
        tar_bch = inputs[:,1,...].reshape(b,1,w,h)

        outputs = classifier(tar_bch)
        _, predicted = torch.max(outputs, 1)

        labels = [int(label) for label in batch_labels]
        labels_tensor = torch.tensor(labels, dtype=torch.long).to(dev)

        predictions.extend(predicted.cpu().numpy())
        targets.extend(labels_tensor.cpu().numpy())

    overall_accuracyw = accuracy_score(targets, predictions)
    print("Testing Accuracy:", overall_accuracyw)

    precisionw = precision_score(targets, predictions, average='macro')
    print("Precision:", precisionw)

    recallw = recall_score(targets, predictions, average='macro')
    print("Recall:", recallw)

    f1w = f1_score(targets, predictions, average='macro')
    print("F1-score:", f1w)

### with test-time augmentation ###
print('------------------------------------------')
print('With Test Time Augmentation Performance')
print('------------------------------------------')

acc = 0
predictions = []
targets = []
with torch.no_grad():
    classifier.eval()
    augment.eval()
    for j, image_data in enumerate(testloader):
        inputs, batch_labels = image_data
        inputs = inputs.to(dev)
        b, c, w, h = inputs.shape

        src_bch = inputs[:,0,...].reshape(b,1,w,h)
        tar_bch = inputs[:,1,...].reshape(b,1,w,h)

        outputs = classifier(tar_bch)
        _, predicted = torch.max(outputs, 1)

        labels = [int(label) for label in batch_labels]
        labels_tensor = torch.tensor(labels, dtype=torch.long).to(dev)

        correct = (predicted == labels_tensor).sum().item()
        accuracy = correct / labels_tensor.size(0) 

        predictions.extend(predicted.cpu().numpy())
        targets.extend(labels_tensor.cpu().numpy())

    for time in range(times):
        for j, image_data in enumerate(testloader):
            inputs, batch_labels = image_data
            inputs = inputs.to(dev)
            b, c, w, h = inputs.shape

            src_bch = inputs[:,0,...].reshape(b,1,w,h)
            tar_bch = inputs[:,1,...].reshape(b,1,w,h)

            x = torch.cat([src_bch, tar_bch], dim=1).to(dev)
            pred = augment(x)

            outputs = classifier(pred[0])
            _, predicted = torch.max(outputs, 1)

            labels = [int(label) for label in batch_labels]
            labels_tensor = torch.tensor(labels, dtype=torch.long).to(dev)

            correct = (predicted == labels_tensor).sum().item()
            accuracy = correct / labels_tensor.size(0) 

            predictions.extend(predicted.cpu().numpy())
            targets.extend(labels_tensor.cpu().numpy())

    overall_accuracy = accuracy_score(targets, predictions)
    print("Testing Accuracy:", overall_accuracy)

    precision = precision_score(targets, predictions, average='macro')
    print("Precision:", precision)

    recall = recall_score(targets, predictions, average='macro')
    print("Recall:", recall)

    f1 = f1_score(targets, predictions, average='macro')
    print("F1-score:", f1)

In [None]:
times = 2 # if you want to use any other value, please train and save the augmentation model accordingly and perform testing here. We pre-trained the augmentation model for 3times augmentation here.

clf_model_path = './saved_models/joint_epdiff_mgaug_mnist_clf_10p_2t.pth'
classifier = torch.load(clf_model_path)
classifier.eval()

aug_model_path = './saved_models/joint_epdiff_mgaug_mnist_10p_2t.pth'
augment = torch.load(aug_model_path)
augment.eval()

### without test-time augmentation ###
print('------------------------------------------')
print('Without Test Time Augmentation Performance')
print('------------------------------------------')

acc = 0
predictions = []
targets = []
with torch.no_grad():
    classifier.eval()
    augment.eval()
    for j, image_data in enumerate(testloader):
        inputs, batch_labels = image_data
        inputs = inputs.to(dev)
        b, c, w, h = inputs.shape

        src_bch = inputs[:,0,...].reshape(b,1,w,h)
        tar_bch = inputs[:,1,...].reshape(b,1,w,h)

        outputs = classifier(tar_bch)
        _, predicted = torch.max(outputs, 1)

        labels = [int(label) for label in batch_labels]
        labels_tensor = torch.tensor(labels, dtype=torch.long).to(dev)

        predictions.extend(predicted.cpu().numpy())
        targets.extend(labels_tensor.cpu().numpy())

    overall_accuracyw = accuracy_score(targets, predictions)
    print("Testing Accuracy:", overall_accuracyw)

    precisionw = precision_score(targets, predictions, average='macro')
    print("Precision:", precisionw)

    recallw = recall_score(targets, predictions, average='macro')
    print("Recall:", recallw)

    f1w = f1_score(targets, predictions, average='macro')
    print("F1-score:", f1w)

### with test-time augmentation ###
print('------------------------------------------')
print('With Test Time Augmentation Performance')
print('------------------------------------------')

acc = 0
predictions = []
targets = []
with torch.no_grad():
    classifier.eval()
    augment.eval()
    for j, image_data in enumerate(testloader):
        inputs, batch_labels = image_data
        inputs = inputs.to(dev)
        b, c, w, h = inputs.shape

        src_bch = inputs[:,0,...].reshape(b,1,w,h)
        tar_bch = inputs[:,1,...].reshape(b,1,w,h)

        outputs = classifier(tar_bch)
        _, predicted = torch.max(outputs, 1)

        labels = [int(label) for label in batch_labels]
        labels_tensor = torch.tensor(labels, dtype=torch.long).to(dev)

        correct = (predicted == labels_tensor).sum().item()
        accuracy = correct / labels_tensor.size(0) 

        predictions.extend(predicted.cpu().numpy())
        targets.extend(labels_tensor.cpu().numpy())

    for time in range(times):
        for j, image_data in enumerate(testloader):
            inputs, batch_labels = image_data
            inputs = inputs.to(dev)
            b, c, w, h = inputs.shape

            src_bch = inputs[:,0,...].reshape(b,1,w,h)
            tar_bch = inputs[:,1,...].reshape(b,1,w,h)

            x = torch.cat([src_bch, tar_bch], dim=1).to(dev)
            pred = augment(x)

            outputs = classifier(pred[0])
            _, predicted = torch.max(outputs, 1)

            labels = [int(label) for label in batch_labels]
            labels_tensor = torch.tensor(labels, dtype=torch.long).to(dev)

            correct = (predicted == labels_tensor).sum().item()
            accuracy = correct / labels_tensor.size(0) 

            predictions.extend(predicted.cpu().numpy())
            targets.extend(labels_tensor.cpu().numpy())

    overall_accuracy = accuracy_score(targets, predictions)
    print("Testing Accuracy:", overall_accuracy)

    precision = precision_score(targets, predictions, average='macro')
    print("Precision:", precision)

    recall = recall_score(targets, predictions, average='macro')
    print("Recall:", recall)

    f1 = f1_score(targets, predictions, average='macro')
    print("F1-score:", f1)

In [None]:
def detJac(np_displacement_field):
    
    np_displacement_field = np_displacement_field.permute(2, 3, 1, 0).squeeze()
    sitk_displacement_field = sitk.GetImageFromArray(np_displacement_field, isVector=True)
    jacobian_det_volume = sitk.DisplacementFieldJacobianDeterminant(sitk_displacement_field)
    jacobian_det_np_arr = sitk.GetArrayViewFromImage(jacobian_det_volume)
    
    return jacobian_det_np_arr

def save_img(src, tar, deform, field, count):
    
    src = src.squeeze().detach().cpu()
    tar = tar.squeeze().detach().cpu()
    deform = deform.squeeze().detach().cpu()
    field = field.detach().cpu()
    
    np_displacement_field = field.permute(2, 3, 1, 0).squeeze()
    sitk_displacement_field = sitk.GetImageFromArray(np_displacement_field, isVector=True)
    jacobian_det_volume = sitk.DisplacementFieldJacobianDeterminant(sitk_displacement_field)
    jacobian_det_np_arr = sitk.GetArrayViewFromImage(jacobian_det_volume)
    
    src_img = sitk.GetImageFromArray(src)
    tar_img = sitk.GetImageFromArray(tar)
    def_img = sitk.GetImageFromArray(deform)
    detjac = sitk.GetImageFromArray(jacobian_det_np_arr)
    
    sitk.WriteImage(src_img, f'./save_img/src/src{count}.nii')
    sitk.WriteImage(tar_img, f'./save_img/tar/tar{count}.nii')
    sitk.WriteImage(detjac, f'./save_img/detJac/dj{count}.nii')

def plot(src, tar, deform, count, augment_no):
    
    src = src.squeeze().detach().cpu()
    tar = tar.squeeze().detach().cpu()
    deform = deform.squeeze().detach().cpu()
    
    # fig, axs = plt.subplots(1, 4, figsize=(8, 2))
    fig, axs = plt.subplots(1, 2, figsize=(4, 2), gridspec_kw={'width_ratios': [1, 1]}) # [1, 1, 1, 1.5]


    axs[0].imshow(src, cmap='gray', aspect='auto')
    axs[0].axis('tight')
    axs[0].axis('off')  
    axs[0].set_title('Img')
    
    # axs[1].imshow(tar, cmap='gray', aspect='auto')
    # axs[1].axis('tight')
    # axs[1].axis('off')  
    # axs[1].set_title('Tar')
    
    axs[1].imshow(deform, cmap='gray', aspect='auto')
    axs[1].axis('tight')
    axs[1].axis('off')  
    axs[1].set_title('Augmented')
    
    plt.show()

In [None]:
aug_model_path = './saved_models/joint_epdiff_mgaug_mnist_10p_2t.pth'
augment = torch.load(aug_model_path)
augment.eval()

count = 0
times_of_augmentation = 1

for aug_no in range(times_of_augmentation): 
    with torch.no_grad():
        aug_model_path.eval()
        print("Augment Cycle: ", aug_no)
        for j, image_data in enumerate(trainloader):
            inputs, batch_labels = image_data
            inputs = inputs.to(dev)
            b, c, w, h = inputs.shape
            optimizer.zero_grad()
            src_bch = inputs[:,0,...].reshape(b,1,w,h)
            tar_bch = inputs[:,1,...].reshape(b,1,w,h)
            x = torch.cat([src_bch, tar_bch], dim=1).to(dev)
            pred = aug_model_path(x)  

            plot(src_bch, tar_bch, pred[0], count, times_of_augmentation)
            # save_img(src_bch, tar_bch, pred[0], pred[1], count)

            count += 1
            
print("Total Augmented Images: ", times_of_augmentation * len(sampleloader)) # assuming batch size of sampleloader is 1.