In [2]:
import glob
import os
import random
import monai
import numpy as np
from sklearn.model_selection import train_test_split
import torch
from tqdm.notebook import tqdm
#import wandb
from torchmetrics import F1Score

In [163]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_path = r'/home/jovyan/deeplearning'
test_dict_list = np.load(os.path.join(data_path, 'file_list_test.npy'),allow_pickle=True)
dia_dict_list = []
sys_dict_list = []

for i in range(len(test_dict_list)):
    if '_f1_' in test_dict_list[i]['img']:
        dia_dict_list.append(test_dict_list[i])
    else:
        sys_dict_list.append(test_dict_list[i])

class LoadData(monai.transforms.Transform):
    """
    This custom Monai transform loads the data from the segmentation dataset.
    Defining a custom transform is simple; just overwrite the __init__ function and __call__ function.
    """
    def __init__(self, keys=None):
        pass

    def __call__(self, sample):
        image = np.load(sample['img'])
        mask = np.load(sample['mask'])

        return {'img': image, 'mask': mask, 'img_meta_dict': {'affine': np.eye(2)}, 
                'mask_meta_dict': {'affine': np.eye(2)}}

def get_dataloaders(case):
    
    if case == 'LV':
        train_transform = monai.transforms.Compose(
        [
            LoadData(),
            monai.transforms.ThresholdIntensityd(keys=['mask'],threshold=2.5, above=True, cval=0),
            monai.transforms.AddChanneld(keys=['img', 'mask']), # add 1 dimension because we have only 1 channel
            monai.transforms.ScaleIntensityd(keys=['img', 'mask'],minv=0, maxv=1) # scale intensity to values between 0-1
        ]
        )

    elif case == 'MYO':
        train_transform = monai.transforms.Compose(
        [
            LoadData(),
            monai.transforms.ThresholdIntensityd(keys=['mask'],threshold=2.5, above=False, cval=0),
            monai.transforms.ThresholdIntensityd(keys=['mask'],threshold=1.5, above=True, cval=0),
            monai.transforms.AddChanneld(keys=['img', 'mask']), # add 1 dimension because we have only 1 channel
            monai.transforms.ScaleIntensityd(keys=['img', 'mask'],minv=0, maxv=1) # scale intensity to values between 0-1
        ]
        )
    
    elif case == 'RV':
        train_transform = monai.transforms.Compose(
        [
            LoadData(),
            monai.transforms.ThresholdIntensityd(keys=['mask'],threshold=1.5, above=False, cval=0),
            monai.transforms.AddChanneld(keys=['img', 'mask']), # add 1 dimension because we have only 1 channel
            monai.transforms.ScaleIntensityd(keys=['img', 'mask'],minv=0, maxv=1) # scale intensity to values between 0-1
        ]
        )
    
    else: print("Error: select a dataloader")
    
    test_dataset = monai.data.CacheDataset(test_dict_list, transform=train_transform)
    test_loader = monai.data.DataLoader(test_dataset, batch_size=10)
    test_loader.name = 'ALL'

    dia_dataset = monai.data.CacheDataset(dia_dict_list, transform=train_transform)
    dia_loader = monai.data.DataLoader(dia_dataset, batch_size=10)
    dia_loader.name = 'ED'

    sys_dataset = monai.data.CacheDataset(sys_dict_list, transform=train_transform)
    sys_loader = monai.data.DataLoader(sys_dataset, batch_size=10)
    sys_loader.name = 'ES'
    
    return [test_loader, dia_loader, sys_loader]


In [159]:
model = monai.networks.nets.UNet(
    dimensions=2,
    in_channels=1,
    out_channels=1,
    channels = (48, 52, 104, 208, 416),
    strides=(1, 2, 2, 2),
    num_res_units=2,
).to(device)

In [156]:
def compute_metrics(model, test_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                    
    step = 0
    HD_all = 0
    cm_all = torch.zeros(1,4).to(device)
    exclsum = 0
    totalsum = 0
    surf_pred_all = torch.empty(0)
    surf_true_all = torch.empty(0)

    for batch_data in test_loader:
        step += 1
        model.eval()
        outputs = model(batch_data['img'].float().to(device))
        outputs = torch.sigmoid(outputs)

        cm = monai.metrics.get_confusion_matrix(torch.round(outputs), batch_data['mask'].to(device))

        surf_pred = torch.zeros(len(cm))
        surf_true = torch.zeros(len(cm))
        for k in range(len(cm)):
            cm_all += cm[k,0]
            surf_pred[k] = torch.round(outputs)[k].sum().item() * 1.25*1.25
            surf_true[k] = batch_data['mask'][k].sum().item() * 1.25*1.25
    
        surf_pred_all = torch.cat((surf_pred_all,surf_pred),0)
        surf_true_all = torch.cat((surf_true_all,surf_true),0)

        HD = monai.metrics.compute_hausdorff_distance(torch.round(outputs), batch_data['mask'].to(device))
        totalsum += HD.size(0)
        HD = HD[~torch.all(HD.isinf(),dim=1)]
        HD = HD[~torch.all(HD.isnan(),dim=1)]
        HD1 = sum(HD)/HD.size(0)
        HD_all += HD1.item() * 1.25
        exclsum += HD.size(0)
    
    

    HD_all = HD_all / step
    dsc_cm = 2*cm_all[0,0] / (2*cm_all[0,0] + cm_all[0,1] + cm_all[0,3])
    cormatrix = np.corrcoef(surf_pred_all.numpy(),surf_true_all.numpy())
    MAE = np.mean(abs(surf_pred_all.numpy()-surf_true_all.numpy()))
    surfbias =  np.mean(surf_pred_all.numpy()-surf_true_all.numpy())
    surfstd = np.std(surf_pred_all.numpy()-surf_true_all.numpy())

    print(f"Metrics for '{model.name}' with data: '{test_loader.name}'")
    print("DSC =", dsc_cm.item())
    print("HD =", HD_all)
    print("surfCorrCoef =", cormatrix[0,1])
    print("surfMAE =", MAE)
    print("surfbias =", surfbias)
    print("surfstd =", surfstd)
    print("Number of images exclueded from the HD: ", totalsum-exclsum, "out of", totalsum)
    print("-----------------------------------")

In [165]:
LVnet = torch.load(os.path.join(data_path, 'trained_nets/LV_runREAL/trainedUNet_LV_runREAL.pt'))
MYOnet = torch.load(os.path.join(data_path, 'trained_nets/MYO_runREAL/trainedUNet_MYO_runREAL.pt'))
RVnet = torch.load(os.path.join(data_path, 'trained_nets/RV_runREAL/trainedUNet_RV_runREAL.pt'))

loaders = get_dataloaders('LV')
model.load_state_dict(LVnet)
model.name = 'LVnet'
for m in range(len(loaders)):
    compute_metrics(model, loaders[m])

loaders = get_dataloaders('MYO')
model.load_state_dict(MYOnet)
model.name = 'MYOnet'
for m in range(len(loaders)):
    compute_metrics(model, loaders[m])

loaders = get_dataloaders('RV')
model.load_state_dict(RVnet)
model.name = 'RVnet'
for m in range(len(loaders)):
    compute_metrics(model, loaders[m])

Loading dataset: 100%|██████████| 476/476 [00:03<00:00, 144.63it/s]
Loading dataset: 100%|██████████| 238/238 [00:01<00:00, 148.25it/s]
Loading dataset: 100%|██████████| 238/238 [00:01<00:00, 145.78it/s]


Metrics for LVnet with data ALL
DSC = 0.9491065144538879
HD = 6.666590874030921
surfCorrCoef = 0.9842838739494284
surfMAE = 40.86791
surfbias = -7.654937
surfstd = 87.63697
Number of images exclueded from the HD:  42 out of 476
-----------------------------------
Metrics for LVnet with data ED
DSC = 0.9616349935531616
HD = 4.8617445517044775
surfCorrCoef = 0.9938480908374921
surfMAE = 36.528362
surfbias = -9.440651
surfstd = 54.158188
Number of images exclueded from the HD:  10 out of 238
-----------------------------------
Metrics for LVnet with data ES
DSC = 0.926520824432373
HD = 8.916456338411763
surfCorrCoef = 0.9653609368910769
surfMAE = 45.20746
surfbias = -5.8692226
surfstd = 111.44949
Number of images exclueded from the HD:  32 out of 238
-----------------------------------


Loading dataset: 100%|██████████| 476/476 [00:02<00:00, 196.16it/s]
Loading dataset: 100%|██████████| 238/238 [00:01<00:00, 148.15it/s]
Loading dataset: 100%|██████████| 238/238 [00:01<00:00, 175.80it/s]


Metrics for MYOnet with data ALL
DSC = 0.8840194940567017
HD = 8.288278381469713
surfCorrCoef = 0.9656435466546518
surfMAE = 73.33902
surfbias = 16.94459
surfstd = 109.58672
Number of images exclueded from the HD:  23 out of 476
-----------------------------------
Metrics for MYOnet with data ED
DSC = 0.8821320533752441
HD = 7.548443751482371
surfCorrCoef = 0.971311099262586
surfMAE = 62.47374
surfbias = 25.420168
surfstd = 87.19859
Number of images exclueded from the HD:  5 out of 238
-----------------------------------
Metrics for MYOnet with data ES
DSC = 0.8856804966926575
HD = 9.196833566608468
surfCorrCoef = 0.9617865014829077
surfMAE = 84.20431
surfbias = 8.469012
surfstd = 127.55875
Number of images exclueded from the HD:  18 out of 238
-----------------------------------


Loading dataset: 100%|██████████| 476/476 [00:01<00:00, 248.51it/s]
Loading dataset: 100%|██████████| 238/238 [00:00<00:00, 270.85it/s]
Loading dataset: 100%|██████████| 238/238 [00:01<00:00, 222.70it/s]


Metrics for RVnet with data ALL
DSC = 0.889615535736084
HD = 7.903172997901467
surfCorrCoef = 0.9470826461151706
surfMAE = 85.23831
surfbias = -10.04136
surfstd = 170.64822
Number of images exclueded from the HD:  120 out of 476
-----------------------------------
Metrics for RVnet with data ED
DSC = 0.9176207780838013
HD = 7.928793527231753
surfCorrCoef = 0.9648092524702943
surfMAE = 74.46822
surfbias = -25.164127
surfstd = 149.65637
Number of images exclueded from the HD:  40 out of 238
-----------------------------------
Metrics for RVnet with data ES
DSC = 0.8371891975402832
HD = 8.248899106614699
surfCorrCoef = 0.9003344559925017
surfMAE = 96.0084
surfbias = 5.0814075
surfstd = 188.11488
Number of images exclueded from the HD:  80 out of 238
-----------------------------------
