In [None]:
from monai.utils import set_determinism, first
from monai.transforms import (
    EnsureChannelFirstD,
    Compose,
    LoadImageD,
    RandRotateD,
    RandZoomD,
    ScaleIntensityRanged,
)
import monai
from monai.data import DataLoader, Dataset, CacheDataset
from monai.config import print_config, USE_COMPILED
from monai.networks.nets import GlobalNet, LocalNet, RegUNet
from monai.networks.blocks import Warp
from monai.apps import MedNISTDataset
import torch.nn.functional as F

from glob import glob
import cv2
import torchmetrics

from torch.autograd import Variable

from scipy.spatial.distance import directed_hausdorff
import pandas as pd

import torch.nn as nn

import numpy as np
import torch
from torch.nn import MSELoss
import matplotlib.pyplot as plt
import os
import tempfile
from monai.losses import *
from monai.metrics import *
from piqa import SSIM


print_config()
set_determinism(42)

: 

In [None]:
dataDir = 'CAMUS_EStoED_A4C'

root_dir = 'data/'+dataDir+'/'
print(root_dir)

# trainBatch = 8
testBatch = 2

img_size = 512

previousWeight = 512

preTrained = 0

EP = 100

num_workers = 0

# ExpName = "AC_DLIR"

# fileNames = ExpName + '_' + dataDir + '_' +str(img_size) + '_'

: 

In [None]:
print('How many GPUs = ' + str(torch.cuda.device_count()))

#checking for device

device=torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(device)

if not torch.cuda.is_available():
  raise Exception("GPU not availalbe. CPU training will be too slow.")

print("device name", torch.cuda.get_device_name(0))

: 

In [None]:
class EchoDataset(Dataset):
    def __init__(self, images_path):

        self.images_path = images_path
        self.n_samples = len(images_path)

    def __getitem__(self, index):
        """ Reading image """
        image = cv2.imread(self.images_path[index], cv2.IMREAD_GRAYSCALE)
        image = cv2.resize(image,(img_size, img_size))
        image = image/(image.max()) ## (512, 512, 3)
#         print(image.max())
        image = np.expand_dims(image, axis=0)
        image = image.astype(np.float32)
        self.images_path[index]
        return image

    def __len__(self):
        return self.n_samples
    
class EchoDatasetMask(Dataset):
    def __init__(self, images_path):

        self.images_path = images_path
        self.n_samples = len(images_path)

    def __getitem__(self, index):
        """ Reading image """
        image = cv2.imread(self.images_path[index], cv2.IMREAD_GRAYSCALE)
        image = cv2.resize(image,(img_size, img_size), interpolation=cv2.INTER_NEAREST)
#         print(image.max())
#         image = image/(image.max()) ## (512, 512, 3)
        image = np.expand_dims(image, axis=0)
        image = image.astype(np.float32)
        self.images_path[index]
        return image

    def __len__(self):
        return self.n_samples

: 

In [None]:
def get_batches(train_dir,
                batch_size,
                num_workers,
                pin_memory):
    
    train_data = EchoDataset(images_path=train_dir)

    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              num_workers=num_workers,
                              pin_memory=pin_memory,
                              shuffle=False)

    return train_loader

def get_batches_mask(train_dir,
                batch_size,
                num_workers,
                pin_memory):
    
    train_data = EchoDatasetMask(images_path=train_dir)

    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              num_workers=num_workers,
                              pin_memory=pin_memory,
                              shuffle=False)

    return train_loader


print(f'Val Sample numbers (fixed_img) = {len(sorted(glob(root_dir+"val/fixed_img/*.png")))}')
print(f'Val Sample numbers (fixed_msk) = {len(sorted(glob(root_dir+"val/fixed_msk/*.png")))}')
print(f'Val Sample numbers (moving_img) = {len(sorted(glob(root_dir+"val/moving_img/*.png")))}')
print(f'Val Sample numbers (moving_msk) = {len(sorted(glob(root_dir+"val/moving_msk/*.png")))}')
print()


fixed_val_img = get_batches(train_dir = sorted(glob(root_dir+"val/fixed_img/*")),
                                        batch_size = testBatch,
                                        num_workers = num_workers,
                                        pin_memory = True)

fixed_val_msk = get_batches_mask(train_dir = sorted(glob(root_dir+"val/fixed_msk/*")),
                                        batch_size = testBatch,
                                        num_workers = num_workers,
                                        pin_memory = True)


moving_val_img = get_batches(train_dir = sorted(glob(root_dir+"val/moving_img/*")),
                                        batch_size = testBatch,
                                        num_workers = num_workers,
                                        pin_memory = True)

moving_val_msk = get_batches_mask(train_dir = sorted(glob(root_dir+"val/moving_msk/*")),
                                        batch_size = testBatch,
                                        num_workers = num_workers,
                                        pin_memory = True)


print("Val IMG FIXED:", fixed_val_img)
print("Val MSK FIXED:", fixed_val_msk)
print("Val IMG Moving:", moving_val_img)
print("Val MSK Moving:", moving_val_msk)

: 

In [None]:
dataloaders = {
    'fixed_val_img': fixed_val_img,
    'fixed_val_msk': fixed_val_msk,
    'moving_val_img': moving_val_img,
    'moving_val_msk': moving_val_msk
    }

: 

In [None]:
fixed_val_img_ = first(dataloaders["fixed_val_img"])[0][0]
fixed_val_msk_ = first(dataloaders["fixed_val_msk"])[0][0]
moving_val_img_ = first(dataloaders["moving_val_img"])[0][0]
moving_val_msk_ = first(dataloaders["moving_val_msk"])[0][0]


print(f"fixed_val_img_ shape: {fixed_val_img_.shape}")
print(f"fixed_val_msk_ shape: {fixed_val_msk_.shape}")
print(f"moving_val_img_ shape: {moving_val_img_.shape}")
print(f"moving_val_msk_ shape: {moving_val_msk_.shape}")


print(f"fixed_val_img_ Range: {fixed_val_img_.max()} {fixed_val_img_.min()}")
print(f"fixed_val_msk_ range: {fixed_val_msk_.max()} {fixed_val_msk_.min()} {np.unique(fixed_val_msk_)}")
print(f"moving_val_img_ Range: {moving_val_img_.max()} {moving_val_img_.min()}")
print(f"moving_val_msk_ range: {moving_val_msk_.max()} {moving_val_msk_.min()} {np.unique(moving_val_msk_)}")


plt.figure("check", (10, 5))

plt.subplot(2, 4, 5)
plt.title("fixed_val_img_")
plt.imshow(fixed_val_img_, cmap="gray")
plt.axis('off')

plt.subplot(2, 4, 6)
plt.title("fixed_val_msk_")
plt.imshow(fixed_val_msk_, cmap="gray")
plt.axis('off')

plt.subplot(2, 4, 7)
plt.title("moving_val_img_")
plt.imshow(moving_val_img_, cmap="gray")
plt.axis('off')

plt.subplot(2, 4, 8)
plt.title("moving_val_msk_")
plt.imshow(moving_val_msk_, cmap="gray")
plt.axis('off')

plt.show()

: 

In [None]:
mod = RegUNet(
    spatial_dims=2,
    in_channels=2,
    num_channel_initial=32,
    depth = 5,
    extract_levels=[5],
    out_activation=None,
    out_channels=2,
    out_kernel_initializer="zeros",
    concat_skip=False)


import torch, torchinfo
model = mod.to(device)

model.load_state_dict(torch.load('_MS_Adv_AC_DLIR_CAMUS_EStoED_A4C_512_.pth', map_location= device))



import os
path = "Result"

# Check whether the specified path exists or not
isExist = os.path.exists(path)
if not isExist:
   # Create a new directory because it does not exist
   os.makedirs(path)
   print("The new directory is created!")


Paths = path + '/'

warp_layer = Warp().to(device)

: 

In [None]:
def make_one_hot(labels, C=2):
    '''
    Converts an integer label torch.autograd.Variable to a one-hot Variable.
    
    Parameters
    ----------
    labels : torch.autograd.Variable of torch.cuda.LongTensor
        N x 1 x H x W, where N is batch size. 
        Each value is an integer representing correct classification.
    C : integer. 
        number of classes in labels.
    
    Returns
    -------
    target : torch.autograd.Variable of torch.cuda.FloatTensor
        N x C x H x W, where C is class number. One-hot encoded.
    '''
    labels = labels.long()
    one_hot = torch.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).zero_().to(device)
    target = one_hot.scatter_(1, labels.data, 1)
    
    target = Variable(target)
        
    return target

: 

In [None]:
model.eval()

metric = monai.metrics.MSEMetric()

before_MSEMetric = []
after_MSEMetric = []

before_compute_meandice = []
after_compute_meandice = []

before_compute_hausdorff_distance = []
after_compute_hausdorff_distance = []
                                                                         

for fixed_val_img_, fixed_val_msk_, moving_val_img_, moving_val_msk_ in zip(fixed_val_img, fixed_val_msk,
                                                                                        moving_val_img, moving_val_msk): 
    
    fixed_val_img_ = fixed_val_img_.to(device)
    fixed_val_msk_ = fixed_val_msk_.to(device)

    moving_val_img_ = moving_val_img_.to(device)
    moving_val_msk_ = moving_val_msk_.to(device)


    ddf_val = model(torch.cat((moving_val_img_, fixed_val_img_), dim=1))

    pred_image_val = warp_layer(moving_val_img_, ddf_val)
    pred_mask_val = warp_layer(moving_val_msk_, ddf_val)
    
    
    before_MSEMetric.extend(metric(moving_val_img_,fixed_val_img_).detach().cpu().numpy())
    
    after_MSEMetric.extend(metric(pred_image_val,fixed_val_img_).detach().cpu().numpy())
    
#     print(metric(pred_image_val,fixed_val_img_).detach().cpu().numpy())

    before_compute_hausdorff_distance.extend((compute_hausdorff_distance(y_pred = make_one_hot(moving_val_msk_, C=3), y=make_one_hot(fixed_val_msk_, C=3), directed=True)).detach().cpu().numpy())
    after_compute_hausdorff_distance.extend((compute_hausdorff_distance(y_pred = make_one_hot(pred_mask_val, C=3), y=make_one_hot(fixed_val_msk_, C=3), directed=True)).detach().cpu().numpy()) 
    
    
    before_compute_meandice.extend((compute_meandice(y_pred = make_one_hot(moving_val_msk_, C=3), y=make_one_hot(fixed_val_msk_, C=3))).detach().cpu().numpy())
    after_compute_meandice.extend((compute_meandice(y_pred = make_one_hot(pred_mask_val, C=3), y=make_one_hot(fixed_val_msk_, C=3))).detach().cpu().numpy()) 
    
    
#     break

print(np.array(before_MSEMetric).shape)
print(np.array(after_MSEMetric).shape)

print(np.array(before_compute_meandice).shape)
print(np.array(after_compute_meandice).shape)

print()

print(f'MSEMetric W/O registration {np.mean(np.array(before_MSEMetric))}')
print(f'MSEMetric W/ registration = {np.mean(np.array(after_MSEMetric))}')

print(f'Mean dice W/O registration = {np.mean(np.array(before_compute_meandice), axis=0)}')
print(f'Mean dice W/ registration = {np.mean(np.array(after_compute_meandice), axis=0)}')

print(f'Hausdorff_distance W/O registration = {np.mean(np.array(before_compute_hausdorff_distance), axis=0)}')
print(f'Hausdorff_distance W/ registration = {np.mean(np.array(after_compute_hausdorff_distance), axis=0)}')

print()

print('----------------For Report Metric +/- Std-----------------------')
print(f'Mean dice W/O registration = {np.mean(np.mean(np.array(before_compute_meandice), axis=0))} + {np.mean(np.std(np.array(before_compute_meandice), axis=0))}')
print(f'Mean dice W/ registration = {np.mean(np.mean(np.array(after_compute_meandice), axis=0))} + {np.mean(np.std(np.array(after_compute_meandice), axis=0))}')

print(f'Hausdorff_distance W/O registration = {np.mean(np.mean(np.array(before_compute_hausdorff_distance), axis=0))} + {np.mean(np.std(np.array(before_compute_hausdorff_distance), axis=0))}')
print(f'Hausdorff_distance W/ registration = {np.mean(np.mean(np.array(after_compute_hausdorff_distance), axis=0))} + {np.mean(np.std(np.array(after_compute_hausdorff_distance), axis=0))}')

: 

In [None]:
k=1000
for fixed_val_img_, fixed_val_msk_, moving_val_img_, moving_val_msk_ in zip(fixed_val_img, fixed_val_msk,
                                                                                        moving_val_img, moving_val_msk): 
    
    fixed_val_img_ = fixed_val_img_.to(device)
    fixed_val_msk_ = fixed_val_msk_.to(device)

    moving_val_img_ = moving_val_img_.to(device)
    moving_val_msk_ = moving_val_msk_.to(device)


    ddf_val = model(torch.cat((moving_val_img_, fixed_val_img_), dim=1))

    pred_image_val = warp_layer(moving_val_img_, ddf_val)
    pred_mask_val = warp_layer(moving_val_msk_, ddf_val)
    
    
    fixed_val_img__ = fixed_val_img_.detach().cpu().numpy()[:, 0]
    fixed_val_msk__ = fixed_val_msk_.detach().cpu().numpy()[:, 0]

    moving_val_img__ = moving_val_img_.detach().cpu().numpy()[:, 0]
    moving_val_msk__ = moving_val_msk_.detach().cpu().numpy()[:, 0]

    pred_image_val__ = pred_image_val.detach().cpu().numpy()[:, 0]
    pred_mask_val__ = pred_mask_val.detach().cpu().numpy()[:, 0]
    
#     print(pred_image_val__.shape)
    
    for i in range(testBatch):
        fixedImage = fixed_val_img__[i,:,:].reshape(img_size,img_size)
        cv2.imwrite(Paths+ str(k) + '_'+ str(i) + '_'+'fixedImage'+'_.png', 255*fixedImage)
        
        fixedMask = fixed_val_msk__[i,:,:].reshape(img_size,img_size)
        cv2.imwrite(Paths+ str(k) + '_'+ str(i) + '_'+'fixedMask'+'_.png', 100*fixedMask)

        
        movingImage = moving_val_img__[i,:,:].reshape(img_size,img_size)
        cv2.imwrite(Paths+ str(k) + '_'+ str(i) + '_'+'movingImage'+'_.png', 255*movingImage)
        
        movingMask = moving_val_msk__[i,:,:].reshape(img_size,img_size)
        cv2.imwrite(Paths+ str(k) + '_'+ str(i) + '_'+'movingMask'+'_.png', 100*movingMask)
        
        movedImage = pred_image_val__[i,:,:].reshape(img_size,img_size)
        cv2.imwrite(Paths+ str(k) + '_'+ str(i) + '_'+'movedImage'+'_.png', 255*movedImage)
        
        movedMask = pred_mask_val__[i,:,:].reshape(img_size,img_size)
        cv2.imwrite(Paths+ str(k) + '_'+ str(i) + '_'+'movedMask'+'_.png', 100*movedMask)
        
#         break
        
    k=k+1
        
    
    
#     break

: 

In [None]:
def colored (pred_img, true_img): 
    bitwise_and = cv2.bitwise_and(pred_img, true_img)
    
    TP = np.stack((np.zeros_like(bitwise_and), bitwise_and, np.zeros_like(bitwise_and)), axis=-1)
    
    FN = np.stack((true_img-bitwise_and,
                   true_img-bitwise_and,
                   np.zeros_like(true_img-bitwise_and)), axis=-1)
    
    FP = np.stack((pred_img-bitwise_and,
                   pred_img-bitwise_and,
                  np.zeros_like(pred_img-bitwise_and)), axis=-1)
    
    return (255*(TP+FN+FP)).astype('uint8')

: 

In [None]:
%matplotlib inline
batch_size = testBatch
plt.subplots(batch_size, 7, figsize=(12, 15))

fixed_val_img_ = fixed_val_img_.detach().cpu().numpy()[:, 0]
fixed_val_msk_ = fixed_val_msk_.detach().cpu().numpy()[:, 0]

moving_val_img_ = moving_val_img_.detach().cpu().numpy()[:, 0]
moving_val_msk_ = moving_val_msk_.detach().cpu().numpy()[:, 0]

pred_image_val = pred_image_val.detach().cpu().numpy()[:, 0]
pred_mask_val = pred_mask_val.detach().cpu().numpy()[:, 0]

for b in range(batch_size):
    # moving image
    plt.subplot(batch_size, 7, b * 7 + 1)
    plt.axis('off')
    plt.title("moving img")
    plt.imshow(moving_val_img_[b], cmap="gray")
    
    # moving label
    plt.subplot(batch_size, 7, b * 7 + 2)
    plt.axis('off')
    plt.title("moving lab")
    plt.imshow(moving_val_msk_[b], cmap="gray")
    
    
    # fixed image
    plt.subplot(batch_size, 7, b * 7 + 3)
    plt.axis('off')
    plt.title("fixed img")
    plt.imshow(fixed_val_img_[b], cmap="gray")
    
    # fixed label
    plt.subplot(batch_size, 7, b * 7 + 4)
    plt.axis('off')
    plt.title("fixed lab")
    plt.imshow(fixed_val_msk_[b], cmap="gray")
    
    
    # warped moving
    plt.subplot(batch_size, 7, b * 7 + 5)
    plt.axis('off')
    plt.title("Pred Img")
    plt.imshow(pred_image_val[b], cmap="gray")


    # warped moving
    plt.subplot(batch_size, 7, b * 7 + 6)
    plt.axis('off')
    plt.title("Before Reg")
    plt.imshow(colored(moving_val_msk_[b], fixed_val_msk_[b]))
    
    
    # warped moving
    plt.subplot(batch_size, 7, b * 7 + 7)
    plt.axis('off')
    plt.title("After Reg")
    plt.imshow(colored(pred_mask_val[b], fixed_val_msk_[b]))
    
plt.axis('off')
plt.show()

: 

: 