In [None]:
from models.adaptation_model import AdaptationModel
from models.networks import define_G, define_D, make_segmentation_network, make_registration_network
from data import dataset
import torch
import matplotlib.pyplot as plt
import monai
import os
import numpy as np
import sys
from tqdm import tqdm

In [None]:
target_seg_net = make_segmentation_network()
target_images, target_segs = dataset.get_target_dataset(split='test', segmented=True)

In [None]:
path = "./results/domain-adaptation/latest_net_S_target.pth"
target_seg_net.load_state_dict(torch.load(path))
target_seg_net = target_seg_net.eval()
target_seg_net = target_seg_net.cuda()

In [None]:
dice_score =  monai.losses.DiceLoss(to_onehot_y=True, softmax=True, include_background=False, reduction='none')

dice_scores_femur = []
dice_scores_tibia = []
dice_scores_fibula = []

for i in range(len(target_images)):
    image = target_images[[i]]
    seg = target_segs[[i]]
    
    with torch.no_grad():
        pred_seg = target_seg_net(image.cuda())

    score = 1-dice_score(pred_seg, seg.cuda()).cpu()
    dice_scores_femur.append(score[0,0,0,0])
    dice_scores_tibia.append(score[0,1,0,0])
    dice_scores_fibula.append(score[0,2,0,0])
    
_femur_dice = sum(dice_scores_femur)/len(dice_scores_femur)
_tibia_dice = sum(dice_scores_tibia)/len(dice_scores_tibia)
_fibula_dice = sum(dice_scores_fibula)/len(dice_scores_fibula)

print(f'Dice Score: {np.mean([_femur_dice, _tibia_dice, _fibula_dice])}')
print(f'Dice Score - Femur: {_femur_dice}')
print(f'Dice Score - Tibia: {_tibia_dice}')
print(f'Dice Score - Fibula: {_fibula_dice}')

In [None]:
index = torch.randint(0, len(target_images), (1,)).item()

image = target_images[[index]]
seg = target_segs[[index]]

with torch.no_grad():
    pred_seg = target_seg_net(image.cuda())
pred_seg = torch.argmax(torch.softmax(pred_seg, dim=1), dim=1, keepdim=True)

plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.title('Input')
plt.imshow(image[0, 0].cpu().numpy())
plt.axis('off')
plt.subplot(1, 3, 2)
plt.title('Ground Truth')
plt.imshow(seg[0, 0].cpu().numpy())
plt.axis('off')
plt.subplot(1, 3, 3)
plt.title('Prediction')
plt.imshow(pred_seg[0, 0].cpu().numpy())
plt.axis('off')
plt.show()

In [None]:
source_seg_net = make_segmentation_network()
source_images, source_segs = dataset.get_source_dataset(split='test', segmented=True)

In [None]:
path = "./results/segmentation-source/seg_net.pt"
source_seg_net.load_state_dict(torch.load(path))
source_seg_net = source_seg_net.eval()
source_seg_net = source_seg_net.cuda()

In [None]:
reg_net = make_registration_network(source_images[:1, :1].size(), include_last_step=False, segmentation=True)
reg_net.regis_net.load_state_dict(torch.load('./results/registration-segmentation/Step_1_final.trch'))
reg_net = reg_net.cuda()
reg_net = reg_net.eval()

In [None]:
def dice_coeff(a, b):
    smooth = 0.0001

    iflat = a.view(-1)
    tflat = b.view(-1)
    intersection = (iflat * tflat).sum()
    
    return ((2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))

def dice_coef_multilabel(y_true, y_pred, numLabels):
    dices = []
    for index in range(1, numLabels):
        dices.append(dice_coeff(1.*(y_true==index), 1.*(y_pred==index)).cpu())
    return dices

In [None]:
dice_femur_source_target, dice_tibia_source_target, dice_fibula_source_target = [], [], []
dice_femur_target_source, dice_tibia_target_source, dice_fibula_target_source = [], [], []

for i in range(len(target_images)):
    for j in range(len(source_images)):
        source_image, source_seg = source_images[[j]], source_segs[[j]]
        target_image, target_seg = target_images[[i]], target_segs[[i]]
        
        with torch.no_grad():
            pred_source_seg = source_seg_net(source_image.cuda())
            pred_target_seg = target_seg_net(target_image.cuda())

        pred_source_seg = torch.argmax(torch.softmax(pred_source_seg, dim=1), dim=1, keepdim=True)
        pred_target_seg = torch.argmax(torch.softmax(pred_target_seg, dim=1), dim=1, keepdim=True)
        
        with torch.no_grad():      
            reg_net(pred_source_seg.float(), pred_target_seg.float())
            
        
        warped_seg_source = reg_net.as_function(source_seg.cuda().float())(reg_net.phi_AB_vectorfield).cpu()
        warped_seg_target = reg_net.as_function(target_seg.cuda().float())(reg_net.phi_BA_vectorfield).cpu()
        
        dice_source_target = dice_coef_multilabel(warped_seg_source, target_seg, 4)
        dice_target_source = dice_coef_multilabel(warped_seg_target, source_seg, 4)

        dice_femur_source_target.append(dice_source_target[0])
        dice_tibia_source_target.append(dice_source_target[1])
        dice_fibula_source_target.append(dice_source_target[2])
        
        dice_femur_target_source.append(dice_target_source[0])
        dice_tibia_target_source.append(dice_target_source[1])
        dice_fibula_target_source.append(dice_target_source[2])
        
print(f'Dice Score - Femur Source -> Target: {np.mean(dice_femur_source_target)}')
print(f'Dice Score - Tibia Source -> Target: {np.mean(dice_tibia_source_target)}')
print(f'Dice Score - Fibula Source -> Target: {np.mean(dice_fibula_source_target)}')

print(f'Dice Score - Femur Target -> Source: {np.mean(dice_femur_target_source)}')
print(f'Dice Score - Tibia Target -> Source: {np.mean(dice_tibia_target_source)}')
print(f'Dice Score - Fibula Target -> Source: {np.mean(dice_fibula_target_source)}')

In [None]:
def show_as_grid(phi, linewidth=1):
    linewidth = 0.5
    data_size = 256
    axes = plt.gca()
    plot_phi = data_size * phi[:, ::8, ::8].detach().cpu() - 0.5
    # plot_phi = data_size * phi.detach().cpu() - 0.5
    axes.set_xlim([-0.5, phi.size()[-2] - 0.5])
    axes.set_ylim([phi.size()[-1] - 0.5, -0.5])
    plt.plot(plot_phi[1], plot_phi[0], linewidth=linewidth)
    plt.plot(
        plot_phi[1].transpose(0, 1), plot_phi[0].transpose(0, 1), linewidth=linewidth
    )

In [None]:
index_source = torch.randint(0, len(source_images), (1,)).item()
index_target = torch.randint(0, len(target_images), (1,)).item()

source_image, source_seg = source_images[[index_source]], source_segs[[index_source]]
target_image, target_seg = target_images[[index_target]], target_segs[[index_target]]

with torch.no_grad():
    pred_source_seg = source_seg_net(source_image.cuda())
    pred_target_seg = target_seg_net(target_image.cuda())
pred_source_seg = torch.argmax(torch.softmax(pred_source_seg, dim=1), dim=1, keepdim=True)
pred_target_seg = torch.argmax(torch.softmax(pred_target_seg, dim=1), dim=1, keepdim=True)

with torch.no_grad():      
    reg_net(pred_source_seg.float(), pred_target_seg.float())
    
warped_image_source = reg_net.as_function(source_image.cuda().float())(reg_net.phi_AB_vectorfield)
warped_image_target = reg_net.as_function(target_image.cuda().float())(reg_net.phi_BA_vectorfield)

plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.title('Source Image')
plt.imshow(source_image[0, 0].cpu().numpy(), cmap='gray')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.title('Warped Source Image')
plt.imshow(warped_image_source[0, 0].cpu().numpy(), cmap='gray')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.title('Target Image')
plt.imshow(target_image[0, 0].cpu().numpy(), cmap='gray')
plt.axis('off')
plt.show()

plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.title('Target Image')
plt.imshow(target_image[0, 0].cpu().numpy(), cmap='gray')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.title('Warped Target Image')
plt.imshow(warped_image_target[0, 0].cpu().numpy(), cmap='gray')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.title('Source Image')
plt.imshow(source_image[0, 0].cpu().numpy(), cmap='gray')
plt.axis('off')
plt.show()