# This notebook contains the code necessary to evaluate and visualise the predictions of a single model. 
# **Instructions**: Run all cells and enter the number corresponding to your model of choice when prompted.

In [1]:
from tqdm import tqdm
import os
import time
from datetime import datetime
from random import randint

import numpy as np
from scipy import stats
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVR
from sklearn.model_selection import KFold

import nibabel as nib

import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.animation as anim
import matplotlib.patches as mpatches
import matplotlib.gridspec as gridspec

import seaborn as sns
from skimage.transform import resize
from skimage.util import montage

from IPython.display import Image as show_gif
from IPython.display import clear_output

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

import re
import warnings
from time import sleep

warnings.simplefilter("ignore")


In [2]:
from utils.Meter import dice_coef_metric_per_classes, jaccard_coef_metric_per_classes

from utils.BratsDataset import BratsDataset

from utils.Meter import BCEDiceLoss

from utils.viz_eval_utils import get_dataloaders, compute_scores_per_classes, count_parameters

from models.UNet3d import UNet3d
from models.UNet3dSingleConv import UNet3dSingleConv
from models.UNet3dDropout import UNet3dDropout
from models.SwinUNETR import SwinUNETR
from models.UNet3d_SELU import UNet3d_SELU
from models.UNet3d_atten import UNet3d_atten
from models.ONet3d import ONet3d
from models.ONet3d_v2 import ONet3d_v2
from models.ONet3d_v3 import ONet3d_v3
from models.ONet3d_v3_DoubleConv import ONet3d_v3_DoubleConv
from models.UNet3d_GELU import UNet3d_GELU
from models.ONet3d_v3_GELU import ONet3d_v3_GELU
from models.SphereNet3d import SphereNet3d

In [3]:
# def collect_images():
#     return list(zip([os.path.join("tr_pediatric_modalities/t1c", image) for image in sorted(os.listdir("tr_pediatric_modalities/t1c"))], [os.path.join("tr_pediatric_modalities/t1n", image) for image in sorted(os.listdir("tr_pediatric_modalities/t1n"))], [os.path.join("tr_pediatric_modalities/t2w", image) for image in sorted(os.listdir("tr_pediatric_modalities/t2w"))], [os.path.join("tr_pediatric_modalities/t2f", image) for image in sorted(os.listdir("tr_pediatric_modalities/t2f"))]))

In [4]:
def collect_images():
    return list(zip([os.path.join("val_pediatric_modalities/t1c", image) for image in sorted(os.listdir("val_pediatric_modalities/t1c"))], [os.path.join("val_pediatric_modalities/t1n", image) for image in sorted(os.listdir("val_pediatric_modalities/t1n"))], [os.path.join("val_pediatric_modalities/t2w", image) for image in sorted(os.listdir("val_pediatric_modalities/t2w"))], [os.path.join("val_pediatric_modalities/t2f", image) for image in sorted(os.listdir("val_pediatric_modalities/t2f"))]))

In [5]:
import pandas as pd
import numpy as np
import nibabel as nib
from torch.utils.data import Dataset
import albumentations as A
from albumentations import Compose
import os

class OwnDataset(Dataset):
    def __init__(self, do_resizing = False):
        self.images = collect_images()
        self.do_resizing = do_resizing

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        images = []
        paths = self.images[idx]
        for i, path in enumerate(paths):
            img = self.load_img(path)
            
            if self.do_resizing:
                img = self.resize(img)
            
            img = self.normalize(img)
            images.append(img)
        
        img = np.stack(images)
        
        img = np.moveaxis(img, (0, 1, 2, 3), (0, 3, 2, 1))
        
        return torch.from_numpy(img[None, ...])

    def load_img(self, file_path):
        data = nib.load(file_path)
        data = np.asarray(data.dataobj)
        return data

    def normalize(self, data: np.ndarray):
        data_min = np.min(data)
        return (data - data_min) / (np.max(data) - data_min)
    
    def resize(self, data: np.ndarray):
        # data = resize(data, (224, 224, 128), preserve_range=True)
        data = self.crop_3d_array(data, (224, 224, 128))
        return data

    def crop_3d_array(self, arr, crop_shape):
        """
        Crop a 3D array to the specified shape.
        
        Parameters
        ----------
        arr : numpy.ndarray
            The 3D input array to be cropped.
        crop_shape : tuple
            The shape of the cropped array. Must be a 3-element tuple (depth, height, width).
            
        Returns
        -------
        numpy.ndarray
            The cropped array.
        """

        assert len(crop_shape) == 3, "crop_shape must be a 3-element tuple"
        assert crop_shape[0] <= arr.shape[0], "depth of crop_shape must be <= depth of arr"
        assert crop_shape[1] <= arr.shape[1], "height of crop_shape must be <= height of arr"
        assert crop_shape[2] <= arr.shape[2], "width of crop_shape must be <= width of arr"

        depth_diff = arr.shape[0] - crop_shape[0]
        height_diff = arr.shape[1] - crop_shape[1]
        width_diff = arr.shape[2] - crop_shape[2]

        if depth_diff % 2 == 0:
            depth_crop_start = depth_diff // 2
            depth_crop_end = arr.shape[0] - (depth_diff // 2)
        else:
            depth_crop_start = depth_diff // 2
            depth_crop_end = arr.shape[0] - (depth_diff // 2) - 1

        if height_diff % 2 == 0:
            height_crop_start = height_diff // 2
            height_crop_end = arr.shape[1] - (height_diff // 2)
        else:
            height_crop_start = height_diff // 2
            height_crop_end = arr.shape[1] - (height_diff // 2) - 1

        if width_diff % 2 == 0:
            width_crop_start = width_diff // 2
            width_crop_end = arr.shape[2] - (width_diff // 2)
        else:
            width_crop_start = width_diff // 2
            width_crop_end = arr.shape[2] - (width_diff // 2) - 1

        cropped_arr = arr[depth_crop_start:depth_crop_end,
                          height_crop_start:height_crop_end, width_crop_start:width_crop_end]

        return cropped_arr

In [6]:
# def get_prediction_from_logits(pred):
    
#     # Create an empty mask with the same size as the prediction
#     mask = np.zeros_like(pred[0], dtype=np.uint8)
    
#     # Print the shape of pred
#     print(pred.shape)

#     # Count the number of True values for each class
#     print("Number of True values in class 0:", np.count_nonzero(pred[0]))
#     print("Number of True values in class 1:", np.count_nonzero(pred[1]))
#     print("Number of True values in class 2:", np.count_nonzero(pred[2]))

#     mask[(pred[0] == True) | (pred[1] == True) | (pred[2] == True)] = 1
#     print("class 1 mask")
#     print(np.count_nonzero(mask))

#     # Tumor Core (TC) is where either the 1st or 3rd prediction channels are 1 (or both)
#     mask[(pred[0] == True) | (pred[2] == True)] = 2
#     print(np.count_nonzero(mask))
#     # Enhancing Tumor (ET) is where the 3rd prediction channel is 1
#     mask[pred[2] == True] = 3
    
#     print("Unique values in the resulting mask:", np.unique(mask))

    
#     return mask


In [247]:
from scipy.ndimage import median_filter, binary_fill_holes, binary_dilation
from skimage.measure import label, regionprops
import cc3d
from scipy.ndimage.morphology import binary_dilation, binary_fill_holes, binary_closing
# def postprocess_prediction(pred):
#     # Step 1: Noise Removal
#     pred = median_filter(pred, size=3)
    
#     # Step 2: Dilation
#     pred = binary_dilation(pred, structure=np.ones((3, 3, 3)))
#     pred = cc3d.connected_components(pred)
    
#     # Step 3: Fill in the small holes
#     pred = binary_fill_holes(pred)
    
#     # Step 4: Separate Merged Lesion and threshold based on voxel count
#     pred_labels = label(pred)
#     num_labels = len(np.unique(pred_labels)) - 1  # Subtract 1 for background label
#     for region in regionprops(pred_labels):
#         if region.area < 50:
#             for coords in region.coords:
#                 pred_labels[coords[0], coords[1], coords[2]] = 0
                
#     return pred_labels



# def postprocess_prediction(pred, dilation_size=3, min_size=50):
#     # Dilation
#     pred = binary_dilation(pred, structure=np.ones((dilation_size, dilation_size, dilation_size)))

#     # Connected components
#     pred = cc3d.connected_components(pred)

#     # Size filtering
#     unique, counts = np.unique(pred, return_counts=True)
#     for u, c in zip(unique, counts):
#         if c < min_size:
#             pred[pred == u] = 0

#     # Closing
#     # pred = binary_closing(pred, structure=np.ones((dilation_size, dilation_size, dilation_size)))

#     # Hole filling
#     # pred = binary_fill_holes(pred)

#     return pred



In [248]:
def get_prediction_from_logits(pred):
    # Create an empty mask with the same size as the prediction
    mask = np.zeros_like(pred[0], dtype=np.uint8)
    # print(pred.shape)
    # print(np.count_nonzero(pred))
    # print(type(pred))
    
#     for i in range(pred.shape[0]):
#         pred[i] = postprocess_prediction(pred[i])
    
    
    # Enforce class hierarchy: if a voxel is ET, it's also TC and WT
    # pred[1] = np.logical_or(pred[1], pred[2]) 
    # pred[0] = np.logical_or(pred[0], pred[1]) 

    # Assign the value 2 to the voxels that are in pred[0] but not in pred[1] and pred[2]
    mask[(pred[0] == True) & (pred[1] == False) & (pred[2] == False)] = 2

    # Assign the value 1 to the voxels that are in pred[0] and pred[1] but not in pred[2]
    mask[(pred[0] == True) & (pred[1] == True) & (pred[2] == False)] = 1

    # Assign the value 3 to the voxels that are in pred[2]
    mask[(pred[2] == True)] = 3
    
    
#     print("class 1 mask")
#     print(np.count_nonzero(mask))

#     # Tumor Core (TC) is where either the 1st or 3rd prediction channels are 1 (or both)
    
#     print(np.count_nonzero(mask))
#     # Enhancing Tumor (ET) is where the 3rd prediction channel is 1
    
    print("Unique values in the resulting mask:", np.unique(mask))
    print("count of 1's in pred[0]", np.count_nonzero(pred[0]))
    print("count of 1's in pred[1]", np.count_nonzero(pred[1]))
    print("count of 1's in pred[2]", np.count_nonzero(pred[2]))
    return mask


In [249]:
# import numpy as np
# from scipy.ndimage import binary_dilation, binary_fill_holes
# from skimage.measure import label, regionprops

# import numpy as np
# from scipy.ndimage import binary_dilation, binary_erosion
# from skimage.measure import label, regionprops
# from skimage.morphology import closing, reconstruction

# def post_process_ensemble_prediction(pred, dilation_size=3, min_voxel_volume=50, max_hole_volume=10):
#     # Get unique labels in the prediction mask
#     unique_labels = np.unique(pred)

#     # Initiate an array to hold the post-processed prediction
#     pred_processed = np.zeros_like(pred)

#     for lbl in unique_labels:
#         if lbl == 0:
#             continue  # Skip background

#         # Create a mask for the current label
#         label_mask = (pred == lbl)

#         # Apply size filtering based on voxel volume
#         labeled_mask = label(label_mask)
#         regions = regionprops(labeled_mask)
#         for region in regions:
#             if region.area < min_voxel_volume:  # remove regions smaller than the min_voxel_volume
#                 for coordinates in region.coords:
#                     labeled_mask[coordinates[0], coordinates[1], coordinates[2]] = 0

#         # Apply dilation on the filtered mask
#         # dilated_mask = binary_dilation(labeled_mask, structure=np.ones((dilation_size, dilation_size, dilation_size)))

#         # If current label is 1, perform hole filling
#         if lbl == 1:
#             # Perform closing to fill small holes
#             closed_mask = closing(dilated_mask, selem=np.ones((dilation_size, dilation_size, dilation_size)))

#             # Perform morphological reconstruction to fill all the small holes
#             seed = np.copy(closed_mask)
#             seed[1:-1, 1:-1, 1:-1] = closed_mask.max()
#             reconstructed_mask = reconstruction(seed, closed_mask, method='erosion')

#             # The resulting 'reconstructed_mask' is a float image, due to the reconstruction process. Let's convert it to binary
#             final_mask = (reconstructed_mask > 0.5).astype(int)

#             dilated_mask = final_mask

#         # Add the processed mask for the current label to the final prediction
#         pred_processed[dilated_mask > 0] = lbl

#     return pred_processed








In [308]:
import numpy as np
from scipy.ndimage import binary_fill_holes
from skimage.measure import label, regionprops
from skimage.morphology import closing, opening, reconstruction, dilation

def post_process_ensemble_prediction(pred, closing_size=5, opening_size=3, min_voxel_volume=50, dilation_size=3):
    # Get unique labels in the prediction mask
    unique_labels = np.unique(pred)

    # Initiate an array to hold the post-processed prediction
    pred_processed = np.zeros_like(pred)

    for lbl in unique_labels:
        if lbl == 0:
            continue  # Skip background

        # Create a mask for the current label
        label_mask = (pred == lbl)

        # Apply size filtering based on voxel volume
        labeled_mask = label(label_mask)
        regions = regionprops(labeled_mask)
        for region in regions:
            if region.area < min_voxel_volume:  # remove regions smaller than the min_voxel_volume
                for coordinates in region.coords:
                    labeled_mask[coordinates[0], coordinates[1], coordinates[2]] = 0
                    # If the region is small, apply dilation

        # If current label is 1, perform hole filling
        if lbl == 1:
            # Perform closing to fill small holes
            labeled_mask = closing(labeled_mask, selem=np.ones((closing_size, closing_size, closing_size)))

            # Perform morphological reconstruction to fill all the small holes
        seed = np.copy(labeled_mask)
        seed[1:-1, 1:-1, 1:-1] = labeled_mask.max()
        reconstructed_mask = reconstruction(seed, labeled_mask, method='erosion')

        # The resulting 'reconstructed_mask' is a float image, due to the reconstruction process. Let's convert it to binary
        final_mask = (reconstructed_mask > 0.5).astype(int)

        # Replace labeled_mask with the final mask (which has filled holes) for label 1
        labeled_mask = final_mask

        # Add the processed mask for the current label to the final prediction
        pred_processed[labeled_mask > 0] = lbl

    return pred_processed


In [309]:
def generate_prediction(model1, model2, model3, model4, model5, model_ET, model_TC, model_WT):
    dataset = OwnDataset()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    models = [ model1, model2, model3,model4,model5]

    with torch.no_grad():
        for i in range(len(dataset)):
            data = dataset[i]
            path = dataset.images[i][0]
            data = data.to(device)

            # List to hold all logits for each model
            logits_list = []

            for model in models:
                logits = model(data)
                logits = logits.detach().cpu().numpy()
                logits = logits.squeeze(0)
                logits = np.moveaxis(logits,  (0, 3, 2, 1),(0, 1, 2, 3))
                logits_list.append(logits)

                           # Get averaged logits for each channel
            average_logits = np.zeros_like(logits_list[0])
            # for channel in range(logits_list[0].shape[0]):
            #     summed_logits = np.sum(np.array([logit[channel] for logit in logits_list]), axis=0)
            #     average_logits[channel] = summed_logits / len(models)  # Calculate the average
            # print(average_logits)

            logits_ET = model_ET(data).detach().cpu().numpy()
            logits_ET = logits_ET.squeeze(0)
            logits_ET = np.moveaxis(logits_ET,  (0, 3, 2, 1),(0, 1, 2, 3))

            logits_TC = model_TC(data).detach().cpu().numpy()
            logits_TC = logits_TC.squeeze(0)
            logits_TC = np.moveaxis(logits_TC,  (0, 3, 2, 1),(0, 1, 2, 3))

            logits_WT = model_WT(data).detach().cpu().numpy()
            logits_WT = logits_WT.squeeze(0)
            logits_WT = np.moveaxis(logits_WT,  (0, 3, 2, 1),(0, 1, 2, 3))

            # Threshold the individual model predictions
            logits_ET = (logits_ET > 0.53).astype(int)
            logits_TC = (logits_TC > 0.53).astype(int)
            logits_WT = (logits_WT > 0.53).astype(int)

            # majority_vote = np.zeros_like(average_logits[0])

            for channel in range(logits_list[0].shape[0]):
                summed_logits = np.sum(np.array([logit[channel] for logit in logits_list]), axis=0)
                average_logits[channel] = np.where(summed_logits > 4, 1, 0)
                

                # Threshold the averaged logits
            # average_logits[0] = (average_logits[0] > 0.53).astype(int)
            # average_logits[1] = (average_logits[1] > 0.53).astype(int)
            # average_logits[2] = (average_logits[2] > 0.53).astype(int)


            # print(np.unique(logits_ET),np.unique(logits_WT[0]))

            # Apply condition for ET and WT
            # average_logits[0] = np.where((average_logits[0] == 1) & (logits_WT == 1), 1, 0)
            # average_logits[1] = np.where((average_logits[1] == 1) & (logits_TC == 1), 1, 0)
            # average_logits[2] = np.where((average_logits[2]==1) | (logits_ET == 1), 1, 0)
                
                
                
            predictions = get_prediction_from_logits(average_logits)
            # print(predictions.shape)
            
            

            # predictions = postprocess_prediction(average_logits)
            # print(np.unique(predictions))

            # Apply post-processing techniques
            predictions = post_process_ensemble_prediction(predictions)
            # predictions = postprocess_prediction(predictions)

            
            
            brain_3D_nib = nib.load(path)

            pred_mask_3D_nib = nib.Nifti1Image(predictions, affine = brain_3D_nib.affine, header = brain_3D_nib.header)
            nib.save(pred_mask_3D_nib, os.path.join("outputs_folder/combine_8_models_d1", os.path.basename(path)))


In [310]:
# def generate_prediction(model):
#     dataset = OwnDataset()
#     device = 'cuda' if torch.cuda.is_available() else 'cpu'

#     with torch.no_grad():
#         for i in range(len(dataset)):
#             data = dataset[i]
#             path = dataset.images[i][0]
#             data = data.to(device)
#             print(data.shape)
#             logits = model(data)
#             print(logits.shape)
#             logits = logits.detach().cpu().numpy() > 0.50
#             print("logits shape")
                  
#             # print(logits.shape)
            
#             logits = logits.squeeze(0)
#             logits = np.moveaxis(logits,  (0, 3, 2, 1),(0, 1, 2, 3))
            
#             print(logits.shape)
#             # predictions = np.moveaxis(predictions, (0, 1, 2), (2, 1, 0))
#             predictions = get_prediction_from_logits(logits)
#             print(predictions.shape)
            
#             print(np.unique(predictions))
            
#             brain_3D_nib = nib.load(path)

#             pred_mask_3D_nib = nib.Nifti1Image(predictions, affine = brain_3D_nib.affine, header = brain_3D_nib.header)
#             nib.save(pred_mask_3D_nib, os.path.join("outputs_folder/output_ONet_post", os.path.basename(path)))
            
#             pred_mask_3D_nib = nib.Nifti1Image(predictions[1], affine = brain_3D_nib.affine, header = brain_3D_nib.header)
#             nib.save(pred_mask_3D_nib, "2new__" + os.path.basename(path))
            
#             pred_mask_3D_nib = nib.Nifti1Image(predictions[2], affine = brain_3D_nib.affine, header = brain_3D_nib.header)
#             nib.save(pred_mask_3D_nib, "3new__" + os.path.basename(path))

In [311]:
# def generate_prediction(model1, model2):
#     dataset = OwnDataset()
#     device = 'cuda' if torch.cuda.is_available() else 'cpu'
#     models = [model1, model2]

#     with torch.no_grad():
#         for i in range(len(dataset)):
#             data = dataset[i]
#             path = dataset.images[i][0]
#             data = data.to(device)

#             # List to hold all logits for each model
#             logits_list = []

#             for model in models:
#                 logits = model(data)
#                 logits = logits.detach().cpu().numpy() > 0.53
#                 logits = logits.squeeze(0)
#                 logits = np.moveaxis(logits,  (0, 3, 2, 1),(0, 1, 2, 3))
#                 logits_list.append(logits)

#             # Perform majority voting for each channel
#             majority_vote = np.zeros_like(logits_list[0])

#             for channel in range(logits_list[0].shape[0]):
#                 summed_logits = np.sum(np.array([logit[channel] for logit in logits_list]), axis=0)
#                 majority_vote[channel] = np.where(summed_logits > len(models) / 2, 1, 0)

#             predictions = get_prediction_from_logits(majority_vote)

#             print(np.unique(predictions))

#             brain_3D_nib = nib.load(path)

#             pred_mask_3D_nib = nib.Nifti1Image(predictions, affine = brain_3D_nib.affine, header = brain_3D_nib.header)
#             nib.save(pred_mask_3D_nib, os.path.join("outputs_folder/combine_2_final_tr_Unet_Onet_del", os.path.basename(path)))


## Create a dictionary which maps the model names/directories in the Logs folder to their respective model initialisations.

In [312]:
modelDict = {
    "3DOnet_DoubleConv_Kernel1": ONet3d_v3_DoubleConv(in_channels=4, n_classes=3, n_channels=24).to('cuda'),
    "3DOnet_SingleConv_Kernel1": ONet3d_v3(in_channels=4, n_classes=3, n_channels=24).to('cuda'),
    "3DOnet_SingleConv_Kernel1_32_Channels": ONet3d_v3(in_channels=4, n_classes=3, n_channels=32).to('cuda'),
    "3DOnet_SingleConv_Kernel1_GELU": ONet3d_v3_GELU(in_channels=4, n_classes=3, n_channels=24).to('cuda'),
    "3DOnet_SingleConv_Kernel1_GELU_AdamW": ONet3d_v3_GELU(in_channels=4, n_classes=3, n_channels=24).to('cuda'),
    "3DOnet_SingleConv_Kernel3": ONet3d_v2(in_channels=4, n_classes=3, n_channels=24).to('cuda'),
    "3DOnet_SingleConv_Kernel5": ONet3d(in_channels=4, n_classes=3, n_channels=24).to('cuda'),
    "3DUnet": UNet3d(in_channels=4, n_classes=3, n_channels=24).to('cuda'),
    "3DUnet_32_Channels": UNet3d(in_channels=4, n_classes=3, n_channels=32).to('cuda'),
    "3DUnet_Atten": UNet3d_atten(in_channels=4, n_classes=3, n_channels=24).to('cuda'),
    "3DUnet_Dropout": UNet3dDropout(in_channels=4, n_classes=3, n_channels=24).to('cuda'),
    "3DUnet_GELU": UNet3d_GELU(in_channels=4, n_classes=3, n_channels=24).to('cuda'),
    "3DUnet_SELU": UNet3d_SELU(in_channels=4, n_classes=3, n_channels=24).to('cuda'),
    "3DUnet_SingleConv": UNet3dSingleConv(in_channels=4, n_classes=3, n_channels=24).to('cuda'),
    "SphereNet3D": SphereNet3d(in_channels=4, n_classes=3, n_channels=16).to('cuda'),
    "SwinUNETR": SwinUNETR(in_channels=4, out_channels=3, img_size=(128, 224, 224), depths=(1, 1, 1, 1), num_heads=(2, 4, 8, 16)).to('cuda'),
    "SwinUNETR_AdamW": SwinUNETR(in_channels=4, out_channels=3, img_size=(128, 224, 224), depths=(1, 1, 1, 1), num_heads=(2, 4, 8, 16)).to('cuda'),
    "SwinUNETR_DoubleLayerDepth": SwinUNETR(in_channels=4, out_channels=3, img_size=(128, 224, 224), depths=(2, 2, 2, 2), num_heads=(2, 4, 8, 16)).to('cuda'),
}


In [313]:
df1 = pd.read_csv("data.csv")
df2 = pd.read_csv("test_data.csv")
combined_df = pd.concat([df1, df2], axis=0)

# Write the combined DataFrame to a new CSV file
combined_df.to_csv('combined_file.csv', index=False)

## Define function which takes in a numerical user input and outputs a string representation of the chosen model.

In [314]:
def chooseModel():
    availableActions = {str(i+1): k for (i, k)
                        in zip(range(len(modelDict)), modelDict.keys())}
    nl = '\n'
    # Takes in a dictionary with key/value pair corresponding with control/action
    availableActionsList = [(key, val)
                            for key, val in availableActions.items()]
    print(f"Use number keys to choose one of the models below: \n")
    print(
        f"Available Models: {nl.join(f'[{tup[0]}: {tup[1]}]' for tup in availableActionsList)}")
    sleep(1)
    while True:
        userInput = input("Enter your action: ")
        if userInput not in availableActions:
            print(
                f"{userInput} is an invalid action. Please try again.")
        else:
            break
    return availableActions[userInput]


# chooseModel()


#  Run this cell to Evaluate and visualise predictions for a single model. Model is selected by inputting a number when prompted.

In [315]:
def get_dataloaders(
    dataset: torch.utils.data.Dataset,
    path_to_csv: str,
    # phase: str,
    val_fold: int = 1,  # Choose which fold to be the validation fold
    test_fold: int = 0,
    batch_size: int = 1,
    num_workers: int = 4,
    do_resizing: bool = True,
):
    assert (val_fold != test_fold)

    df = pd.read_csv("combined_file.csv")

    '''Returns: dataloader for the model training'''
    # Data in folds other than 0 are used for training
    train_df = df.loc[~df['fold'].isin(
        [val_fold, test_fold])].reset_index(drop=True)
    # Data in fold 0 is used for validation
    val_df = df.loc[df['fold'] == val_fold].reset_index(drop=True)
    test_df = df.loc[df['fold'] == test_fold].reset_index(drop=True)

    # dataset = dataset(df, phase)
    train_dataset = dataset(train_df, "train", do_resizing=do_resizing)
    val_dataset = dataset(val_df, "val", do_resizing=do_resizing)
    test_dataset = dataset(test_df, "test", do_resizing=do_resizing)
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        shuffle=True,
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        shuffle=True,
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        shuffle=True,
    )
    return train_dataloader, val_dataloader, test_dataloader


In [316]:
from torch.utils.data import Dataset
import albumentations as A
from albumentations import Compose
import os


class BratsDataset(Dataset):
    def __init__(self, df: pd.read_csv("combined_file.csv"), phase: str = "test", do_resizing: bool = False):
        # Dataframe containing patient, path and fold mapping information
        self.df = df
        # "train" "valid" or "test". Determines whether to apply preprocessing
        self.phase = phase
        self.augmentations = self.get_augmentations(phase)
        self.data_types = ['-t1c.nii.gz', '-t1n.nii.gz', '-t2w.nii.gz', '-t2f.nii.gz']
        self.do_resizing = do_resizing

    def __len__(self):
        return self.df.shape[0]

    # Makes class accessible by square-bracket notations; determines behaviour upon square-bracket access
    def __getitem__(self, idx):
        id_ = self.df.loc[idx, 'Brats20ID']
        root_path = self.df.loc[self.df['Brats20ID'] == id_]['path'].values[0]
        images = []
        for data_type in self.data_types:
            img_path = os.path.join(root_path, 'BraTS-PED-' + id_.split('_')[-1].zfill(5) + '-000' + data_type)
            img = self.load_img(img_path)
            if self.do_resizing:
                img = self.resize(img)
            img = self.normalize(img)
            images.append(img)
        img = np.stack(images)
        img = np.moveaxis(img, (0, 1, 2, 3), (0, 3, 2, 1))
        if self.phase == 'test':
            return {
                "Id": id_,
                "image": img
            }
    # TO-DO: Implement possible augmentations here? Lower priority for now

    def get_augmentations(self, phase):
        list_transforms = []
        list_trfms = Compose(list_transforms,is_check_shapes=False )
        return list_trfms

    def load_img(self, file_path):
        data = nib.load(file_path)
        data = np.asarray(data.dataobj)
        return data

    def normalize(self, data: np.ndarray):
        data_min = np.min(data)
        return (data - data_min) / (np.max(data) - data_min)

    def resize(self, data: np.ndarray):
        # data = resize(data, (224, 224, 128), preserve_range=True)
        data = self.crop_3d_array(data, (240, 240, 155))
        return data

    def preprocess_mask_labels(self, mask: np.ndarray):
        # In the BraTS challenge, the segmentation performance is evaluated on three partially overlapping sub-regions of tumors,
        # namely, whole tumor (WT), tumor core (TC), and enhancing tumor (ET).
        # The WT is the union of ED, NCR/NET, and ET, while the TC includes NCR/NET and ET.

        mask_WT = mask.copy()
        mask_WT[mask_WT == 1] = 1
        mask_WT[mask_WT == 2] = 1
        mask_WT[mask_WT == 3] = 1

        mask_TC = mask.copy()
        mask_TC[mask_TC == 1] = 1
        mask_TC[mask_TC == 2] = 0
        mask_TC[mask_TC == 3] = 1

        mask_ET = mask.copy()
        mask_ET[mask_ET == 1] = 0
        mask_ET[mask_ET == 2] = 0
        mask_ET[mask_ET == 3] = 1

        mask = np.stack([mask_WT, mask_TC, mask_ET])
        mask = np.moveaxis(mask, (0, 1, 2, 3), (0, 3, 2, 1))

        return mask

    def crop_3d_array(self, arr, crop_shape):
        """
        Crop a 3D array to the specified shape.
        
        Parameters
        ----------
        arr : numpy.ndarray
            The 3D input array to be cropped.
        crop_shape : tuple
            The shape of the cropped array. Must be a 3-element tuple (depth, height, width).
            
        Returns
        -------
        numpy.ndarray
            The cropped array.
        """

        assert len(crop_shape) == 3, "crop_shape must be a 3-element tuple"
        assert crop_shape[0] <= arr.shape[0], "depth of crop_shape must be <= depth of arr"
        assert crop_shape[1] <= arr.shape[1], "height of crop_shape must be <= height of arr"
        assert crop_shape[2] <= arr.shape[2], "width of crop_shape must be <= width of arr"

        depth_diff = arr.shape[0] - crop_shape[0]
        height_diff = arr.shape[1] - crop_shape[1]
        width_diff = arr.shape[2] - crop_shape[2]

        if depth_diff % 2 == 0:
            depth_crop_start = depth_diff // 2
            depth_crop_end = arr.shape[0] - (depth_diff // 2)
        else:
            depth_crop_start = depth_diff // 2
            depth_crop_end = arr.shape[0] - (depth_diff // 2) - 1

        if height_diff % 2 == 0:
            height_crop_start = height_diff // 2
            height_crop_end = arr.shape[1] - (height_diff // 2)
        else:
            height_crop_start = height_diff // 2
            height_crop_end = arr.shape[1] - (height_diff // 2) - 1

        if width_diff % 2 == 0:
            width_crop_start = width_diff // 2
            width_crop_end = arr.shape[2] - (width_diff // 2)
        else:
            width_crop_start = width_diff // 2
            width_crop_end = arr.shape[2] - (width_diff // 2) - 1

        cropped_arr = arr[depth_crop_start:depth_crop_end,
                          height_crop_start:height_crop_end, width_crop_start:width_crop_end]

        return cropped_arr

In [317]:
def evalSingle():
    # model_name = chooseModel()
    # print(f"'{model_name}' selected for evaluation and visualisation")

    model_name = "3DUnet_GELU"
    # model = ONet3d_v3_GELU(in_channels=4, n_classes=3, n_channels=32).to('cuda')
    # model2 = SphereNet3d(in_channels=4, n_classes=3, n_channels=24).to('cuda')
    
    
    
    
    
    
    # model1 = ONet3d_v3_GELU(in_channels=4, n_classes=3, n_channels=64).to('cuda')
    # model2 = ONet3d_v3_GELU(in_channels=4, n_classes=3, n_channels=32).to('cuda')
    # model3 = UNet3d_GELU(in_channels=4, n_classes=3, n_channels=32).to('cuda')
    # model4 = UNet3d_GELU(in_channels=4, n_classes=3, n_channels=64).to('cuda')
    
    model_1 = ONet3d_v3_GELU(in_channels=4, n_classes=3, n_channels=32).to('cuda')
    model_2 = UNet3d_GELU(in_channels=4, n_classes=3, n_channels=32).to('cuda')
    model_3 = ONet3d_v2(in_channels=4, n_classes=3, n_channels=32).to('cuda')
    model_4 = ONet3d_v3_DoubleConv(in_channels=4, n_classes=3, n_channels=32).to('cuda')
    model_5 = ONet3d(in_channels=4, n_classes=3, n_channels=32).to('cuda')
    
    
    
    model_ET = ONet3d_v3_GELU(in_channels=4, n_classes=1, n_channels=32).to('cuda')
    model_WT = ONet3d_v3_GELU(in_channels=4, n_classes=1, n_channels=32).to('cuda')    
    model_TC = ONet3d_v3_GELU(in_channels=4, n_classes=1, n_channels=32).to('cuda')    
    
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    checkpoint_path = None
    
    
    # checkpoint_path = "Logs/3DUnet_your_modifications/your_last_epoch_model_20230726-122634.pth"
    # checkpoint_path_2 = "Logs/3DUnet_your_modifications/your_best_model_20230727-032820_96_Sphere_24c.pth  
    

    checkpoint_path_1 = "Logs/ONet_32_31_07/your_best_model_20230801-144818_199_Onet_d31_c32.pth"
    checkpoint_path_2  = "Logs/UNet_32_31_07/your_best_model_20230801-114142_199_UNet_d31_c32.pth"
    checkpoint_path_3 = "Logs/ONet3d_v2_32_31_07/your_best_model_20230801-040449_199_Onet_v2_d31_c32.pth"
    checkpoint_path_4 = "Logs/ONet_32_DoubleConv_31_07/your_best_model_20230801-045347_199_Onet_32_DoubleConv_d31_c32.pth"
    checkpoint_path_5 = "Logs/only_ONet_32_31_07/your_best_model_20230801-071616_199_only_Onet_d31_c32.pth"
    
    checkpoint_path_ET = "Logs/ONet_32_31_07_ET/your_best_model_20230801-024903_199_Onet_d31_c32_ET.pth"
    checkpoint_path_TC = "Logs/ONet_32_31_07_TC/your_best_model_20230801-150513_155_Onet_d31_c32_TC.pth"
    checkpoint_path_WT = "Logs/3DUnet_your_modifications/your_last_epoch_model_20230728-020546.pth"
       
      # break
        
# work/pi_gschlaug_umass_edu/Shashi_files/3D_exp_folder/3D_Brain_Tumor_Seg_V2-master/Logs/3DUnet_your_modifications/your_best_model_20230719-221836.pth
    # checkpoint_path = "Logs/3DUnet_your_modifications/your_best_model_20230719-155138.pth"

    try:
        # model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
        # model.eval()
        model_1.load_state_dict(torch.load(checkpoint_path_1, map_location='cpu'))
        model_1.eval()
        model_2.load_state_dict(torch.load(checkpoint_path_2, map_location='cpu'))
        model_2.eval()
        model_3.load_state_dict(torch.load(checkpoint_path_3, map_location='cpu'))
        model_3.eval()
        model_4.load_state_dict(torch.load(checkpoint_path_4, map_location='cpu'))
        model_4.eval()
        model_5.load_state_dict(torch.load(checkpoint_path_5, map_location='cpu'))
        model_5.eval()
        model_ET.load_state_dict(torch.load(checkpoint_path_ET, map_location='cpu'))
        model_ET.eval()
        model_TC.load_state_dict(torch.load(checkpoint_path_TC, map_location='cpu'))
        model_TC.eval()
        model_WT.load_state_dict(torch.load(checkpoint_path_WT, map_location='cpu'))
        model_WT.eval()
        
        
        # print(f"{model_name} loaded with chkpt: {checkpoint_path}. parameters: {count_parameters(model)}")
    except Exception as e:
        print(
            f"Error loading {model_name} with chkpt: {checkpoint_path}. parameters: {count_parameters(model)}")
        print(e)
        return

    generate_prediction(model_1, model_2, model_3, model_4, model_5, model_ET, model_TC, model_WT)

In [318]:
evalSingle()

Unique values in the resulting mask: [0 1 2 3]
count of 1's in pred[0] 182720
count of 1's in pred[1] 105687
count of 1's in pred[2] 42949
Unique values in the resulting mask: [0 1 2 3]
count of 1's in pred[0] 11199
count of 1's in pred[1] 8194
count of 1's in pred[2] 59
Unique values in the resulting mask: [0 1 2 3]
count of 1's in pred[0] 79980
count of 1's in pred[1] 78168
count of 1's in pred[2] 22893
Unique values in the resulting mask: [0 1 2 3]
count of 1's in pred[0] 50967
count of 1's in pred[1] 49705
count of 1's in pred[2] 496
Unique values in the resulting mask: [0 1 2 3]
count of 1's in pred[0] 220381
count of 1's in pred[1] 134849
count of 1's in pred[2] 3411
Unique values in the resulting mask: [0 1 2 3]
count of 1's in pred[0] 48886
count of 1's in pred[1] 48673
count of 1's in pred[2] 993
Unique values in the resulting mask: [0 1 2]
count of 1's in pred[0] 23868
count of 1's in pred[1] 23743
count of 1's in pred[2] 0
Unique values in the resulting mask: [0 1 2 3]
count

### Display accuracy for all test samples.

In [121]:
import os
import shutil

# Specify the directory path
dir_path = "tr_pediatric_modalities/t2w/.ipynb_checkpoints"

# Check if the directory exists before trying to remove it
if os.path.exists(dir_path):
    # Remove the directory and all its contents
    shutil.rmtree(dir_path)
    print(f"Directory {dir_path} has been removed successfully")
else:
    print(f"No such directory: {dir_path}")


Directory tr_pediatric_modalities/t2w/.ipynb_checkpoints has been removed successfully


In [None]:
val_metics_df


Unnamed: 0,Ids,WT dice,WT jaccard,TC dice,TC jaccard,ET dice,ET jaccard
0,BraTS20_Training_003,0.880806,0.787001,0.850592,0.740025,0.858111,0.751483
1,BraTS20_Training_004,0.83025,0.709766,0.860727,0.755505,0.531713,0.362131


In [None]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import nibabel as nib

# Assuming that the new data is a DataFrame stored in a CSV file
new_data_path = "test_data.csv"
new_data = pd.read_csv(new_data_path)

# Initialize dataset with new data
new_dataset = BratsDataset(df=new_data, phase="test", do_resizing=False)

# Initialize a DataLoader with the new dataset
# Note: set batch_size according to your memory capacity
new_data_loader = DataLoader(new_dataset, batch_size=1, shuffle=False)

# Define the classes
classes = ["WT", "TC", "ET"]
model = ONet3d_v3_GELU(in_channels=4, n_classes=3, n_channels=32).to('cuda')

# Initialize the model
checkpoint_path = "Logs/3DUnet_your_modifications/your_best_model_20230719-155138.pth"

model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) # replace with your actual model

# Load the model weights
# model_path = "path_to_your_saved_model.pth"
# model.load_state_dict(torch.load(model_path))
# model = model.to(trainer.device)  # Ensure the model is in the same device as your data
model.eval()  # Set the model to evaluation mode

# Now you can use the model to make predictions on new data
with torch.no_grad():
    for i, data in enumerate(new_data_loader):
        imgs = data['image'].to(trainer.device)
        logits = model(imgs)

        # Convert logits to segmented masks
        pred_masks = np.argmax(logits.cpu().numpy(), axis=1)  # change axis accordingly

        # Save segmented mask as .nii file
        # Assuming you have a corresponding affine for each image
        affine = data['affine']
        pred_img = nib.Nifti1Image(pred_masks, affine)
        nib.save(pred_img, f'prediction_{i}.nii.gz')


FileNotFoundError: No such file or no access: 'ASNR-MICCAI-BraTS2023-PED-Challenge-ValidationData/BraTS-PED-00030-000/BraTS-PED-00030-000-seg.nii.gz'

In [None]:
from torch.utils.data import Dataset
import albumentations as A
from albumentations import Compose
import os


class BratsDataset(Dataset):
    def __init__(self, df: pd.read_csv("data.csv"), phase: str = "test", do_resizing: bool = False):
        # Dataframe containing patient, path and fold mapping information
        self.df = df
        # "train" "valid" or "test". Determines whether to apply preprocessing
        self.phase = phase
        self.augmentations = self.get_augmentations(phase)
        self.data_types = ['-t1c.nii.gz', '-t1n.nii.gz', '-t2w.nii.gz', '-t2f.nii.gz']
        self.do_resizing = do_resizing

    def __len__(self):
        return self.df.shape[0]

    # Makes class accessible by square-bracket notations; determines behaviour upon square-bracket access
    def __getitem__(self, idx):
        id_ = self.df.loc[idx, 'Brats20ID']
        root_path = self.df.loc[self.df['Brats20ID'] == id_]['path'].values[0]
        images = []
        for data_type in self.data_types:
            img_path = os.path.join(root_path, 'BraTS-PED-' + id_.split('_')[-1].zfill(5) + '-000' + data_type)
            img = self.load_img(img_path)
            if self.do_resizing:
                img = self.resize(img)
            img = self.normalize(img)
            images.append(img)
        img = np.stack(images)
        img = np.moveaxis(img, (0, 1, 2, 3), (0, 3, 2, 1))
        if self.phase == 'test':
            return {
                "Id": id_,
                "image": img
            }
    # TO-DO: Implement possible augmentations here? Lower priority for now

    def get_augmentations(self, phase):
        list_transforms = []
        list_trfms = Compose(list_transforms,is_check_shapes=False )
        return list_trfms

    def load_img(self, file_path):
        data = nib.load(file_path)
        data = np.asarray(data.dataobj)
        return data

    def normalize(self, data: np.ndarray):
        data_min = np.min(data)
        return (data - data_min) / (np.max(data) - data_min)

    def resize(self, data: np.ndarray):
        # data = resize(data, (224, 224, 128), preserve_range=True)
        data = self.crop_3d_array(data, (240, 240, 155))
        return data

    def preprocess_mask_labels(self, mask: np.ndarray):
        # In the BraTS challenge, the segmentation performance is evaluated on three partially overlapping sub-regions of tumors,
        # namely, whole tumor (WT), tumor core (TC), and enhancing tumor (ET).
        # The WT is the union of ED, NCR/NET, and ET, while the TC includes NCR/NET and ET.

        mask_WT = mask.copy()
        mask_WT[mask_WT == 1] = 1
        mask_WT[mask_WT == 2] = 1
        mask_WT[mask_WT == 3] = 1

        mask_TC = mask.copy()
        mask_TC[mask_TC == 1] = 1
        mask_TC[mask_TC == 2] = 0
        mask_TC[mask_TC == 3] = 1

        mask_ET = mask.copy()
        mask_ET[mask_ET == 1] = 0
        mask_ET[mask_ET == 2] = 0
        mask_ET[mask_ET == 3] = 1

        mask = np.stack([mask_WT, mask_TC, mask_ET])
        mask = np.moveaxis(mask, (0, 1, 2, 3), (0, 3, 2, 1))

        return mask

    def crop_3d_array(self, arr, crop_shape):
        """
        Crop a 3D array to the specified shape.
        
        Parameters
        ----------
        arr : numpy.ndarray
            The 3D input array to be cropped.
        crop_shape : tuple
            The shape of the cropped array. Must be a 3-element tuple (depth, height, width).
            
        Returns
        -------
        numpy.ndarray
            The cropped array.
        """

        assert len(crop_shape) == 3, "crop_shape must be a 3-element tuple"
        assert crop_shape[0] <= arr.shape[0], "depth of crop_shape must be <= depth of arr"
        assert crop_shape[1] <= arr.shape[1], "height of crop_shape must be <= height of arr"
        assert crop_shape[2] <= arr.shape[2], "width of crop_shape must be <= width of arr"

        depth_diff = arr.shape[0] - crop_shape[0]
        height_diff = arr.shape[1] - crop_shape[1]
        width_diff = arr.shape[2] - crop_shape[2]

        if depth_diff % 2 == 0:
            depth_crop_start = depth_diff // 2
            depth_crop_end = arr.shape[0] - (depth_diff // 2)
        else:
            depth_crop_start = depth_diff // 2
            depth_crop_end = arr.shape[0] - (depth_diff // 2) - 1

        if height_diff % 2 == 0:
            height_crop_start = height_diff // 2
            height_crop_end = arr.shape[1] - (height_diff // 2)
        else:
            height_crop_start = height_diff // 2
            height_crop_end = arr.shape[1] - (height_diff // 2) - 1

        if width_diff % 2 == 0:
            width_crop_start = width_diff // 2
            width_crop_end = arr.shape[2] - (width_diff // 2)
        else:
            width_crop_start = width_diff // 2
            width_crop_end = arr.shape[2] - (width_diff // 2) - 1

        cropped_arr = arr[depth_crop_start:depth_crop_end,
                          height_crop_start:height_crop_end, width_crop_start:width_crop_end]

        return cropped_arr

In [None]:
import nibabel as nib
import os

def save_prediction(prediction, name, path='predictions/'):
    os.makedirs(path, exist_ok=True)
    prediction_nii = nib.Nifti1Image(prediction, np.eye(4))
    nib.save(prediction_nii, os.path.join(path, f'{name}.nii.gz'))


In [None]:
def evalSingle(target="BraTS20_Training_004", treshold=0.5):
    model_name = "3DOnet_SingleConv_Kernel1_GELU"
    model = ONet3d_v3_GELU(in_channels=4, n_classes=3, n_channels=32).to('cuda')
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    checkpoint_path = "Logs/3DUnet_your_modifications/your_best_model_20230719-155138.pth"
    model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
    model.eval()
    _, _, test_dataloader = get_dataloaders(
      dataset=BratsDataset, path_to_csv="./data.csv", val_fold=0, test_fold=1, batch_size=1, do_resizing=True)
    for i, data in enumerate(test_dataloader):
        name, imgs = data['Id'][0], data['image']
        with torch.no_grad():
            imgs = imgs.to(device)
            logits = model(imgs)
            probs = torch.sigmoid(logits)
            predictions = (probs >= treshold).float()
            predictions = predictions.cpu()
            for j, pred in enumerate(predictions):
                save_prediction(pred.numpy(), name[j])
    print("Predictions saved as .nii.gz files.")


In [None]:
evalSingle()

hello
Predictions saved as .nii.gz files.
