In [13]:
import numpy as np
import torch
import h5py
import nibabel as nib
from pathlib import Path
import sys
import os
from scipy.spatial.distance import cdist 

In [2]:
def get_predictions(data_path, threshold=0.8):
    """
    Loads the HDF5 file and returns the data
    """

    with h5py.File(data_path, "r") as f:
        pred = f["predictions"][:]
    mask_p_np = np.squeeze(pred, axis=0)
    if mask_p_np.shape == (220, 256, 256):
        # reshape ds from z, x, y to x, y, z
        mask_p_np = np.moveaxis(mask_p_np, 0, -1)
    mask_p_np[mask_p_np > threshold] = 1
    mask_p_np[mask_p_np <= threshold] = 0

    return mask_p_np
    

In [3]:
def get_mask(data_path):
    """
    Loads the nifty file and returns the mask of a case
    """

    mask = nib.load(data_path)
    return mask.get_fdata()

In [4]:
pred_path = os.path.join(os.getcwd(), os.pardir, os.pardir , "data", "predictions")
mask_path = os.path.join(os.getcwd(), os.pardir, os.pardir , "data", "training")

In [54]:
def eval_DiceScore(pred, mask):
    """
    Computes the Dice score for a given prediction and mask
    """

    pred = pred.flatten()
    mask = mask.flatten()

    intersection = np.sum(pred * mask)
    union = np.sum(pred) + np.sum(mask)

    return 2 * intersection / union

def jaccard(pred, mask):
    """
    Computes the Jaccard score for a given prediction and mask
    """

    pred = pred.flatten()
    mask = mask.flatten()

    intersection = np.sum(pred * mask)
    union = np.sum(pred) + np.sum(mask) - intersection

    return intersection / union

def hausdorff(pred, mask):
    """
    Computes the Hausrdoff score for a given prediction and mask
    """
    pass


def average_distance(pred, mask):
    pass

def pearson_correlation(pred, mask):
    cov = pred * mask
    pred_std = np.std(pred)
    mask_std = np.std(mask)

    return cov / (pred_std * mask_std)

In [41]:
iteration = 5
case = 'A123' # 'A120' 'A121' 'A123' 'A124' 'A126' 'A127' 'A129'
threshold = 0.9

pred_file = os.path.join(pred_path, "iteration{}".format(iteration), "{}_predictions.h5".format(case))
mask_file = os.path.join(mask_path, "{}_masks.nii.gz".format(case))
pred = get_predictions(pred_file, threshold=threshold)
mask = get_mask(mask_file)
print("Dice: " + str(eval_DiceScore(pred, mask)))
print("Jaccard: " + str(jaccard(pred, mask)))

Dice: 0.303194025107262
Jaccard: 0.17868514703127927


In [55]:
def batch_eval(pred, mask, iteration, cases, threshold_batch):
    pred_path_batch = os.path.join(os.getcwd(), os.pardir, os.pardir , "data", "predictions", "iteration{}".format(iteration))
    masks_path_batch = os.path.join(os.getcwd(), os.pardir, os.pardir , "data", "training")

    dice_scores = []
    jaccard_scores = []

    for case in cases:
        pred_file = os.path.join(pred_path_batch, "{}_predictions.h5".format(case))
        mask_file = os.path.join(masks_path_batch, "{}_masks.nii.gz".format(case))
        pred = get_predictions(pred_file, threshold=threshold_batch)
        mask = get_mask(mask_file)
        dice_scores.append(eval_DiceScore(pred, mask))
        jaccard_scores.append(jaccard(pred, mask))
    
    return np.mean(np.array(dice_scores)), np.mean(np.array(jaccard_scores))

In [59]:
# Eval metrics over batch
iteration_batch = 5
threshold_batch = 0.95
cases = ['A123', 'A120', 'A121', 'A123', 'A124', 'A126', 'A127', 'A129']

for i in range(1, iteration + 1):
    dice, jaccard_score = batch_eval(pred, mask, i, cases, threshold_batch)
    print("Iteration {}: Dice: {}, Jaccard: {}".format(i, dice, jaccard_score))
    print("-----------------------------------------------------")


Iteration 1: Dice: 0.006975304696811936, Jaccard: 0.0035167022942208725
-----------------------------------------------------
Iteration 2: Dice: 0.17296064469651032, Jaccard: 0.1330205578662148
-----------------------------------------------------
Iteration 3: Dice: 0.0, Jaccard: 0.0
-----------------------------------------------------
Iteration 4: Dice: 0.4081701577679366, Jaccard: 0.296544407140541
-----------------------------------------------------
Iteration 5: Dice: 0.5386860848663539, Jaccard: 0.4300732220450385
-----------------------------------------------------
