In [1]:
import os
import json
import numpy as np
import cv2
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import torchvision.models as models
from torchvision import transforms
import nibabel as nib
import random
from torch.utils.data import DataLoader
import time



final_project_path = '/home/jws2215/e6691-2024spring-project-jwss-jws2215' # vm
data_folder_path = os.path.join(final_project_path, 'BraTS2020')
train_folder_path = os.path.join(data_folder_path, 'train')
valid_folder_path = os.path.join(data_folder_path, 'valid')


if torch.cuda.is_available():
    device = torch.device('cuda')
    print("GPU name: ", torch.cuda.get_device_name(0))
    allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)  # Convert bytes to gigabytes
    cached_memory = torch.cuda.memory_reserved() / (1024 ** 3)  # Convert bytes to gigabytes
    print(f"Allocated Memory: {allocated_memory:.2f} GB")
    print(f"Cached Memory: {cached_memory:.2f} GB")
    total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # Convert bytes to gigabytes
    print(f"Total GPU Memory: {total_memory:.2f} GB")
else:
    print("CUDA is not available. Cannot print memory usage.")
    device = torch.device('cpu')



GPU name:  Tesla T4
Allocated Memory: 0.00 GB
Cached Memory: 0.00 GB
Total GPU Memory: 14.58 GB


In [2]:
def create_data_dictionary(folder_path):
    data_dict = {}
    subfolders = [f for f in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, f))]
    
    for idx, subfolder in enumerate(subfolders):
        abs_path = os.path.join(folder_path, subfolder)
        data_dict[idx] = {'absolute_path': abs_path, 'folder_name': subfolder}
    
    return data_dict

data_path_dictionary_valid = create_data_dictionary(valid_folder_path)



In [3]:
## all data is stored in (240,240,155) 

class ImageDataset(Dataset):
    def __init__(self, data_path_dictionary):
        self.data_path_dictionary = data_path_dictionary

    def __len__(self):
        # print(len(self.annotations["images"]))
        return len(self.data_path_dictionary)

    def __getitem__(self, idx):
        
        folder_name = self.data_path_dictionary[idx]["folder_name"]
        folder_path = self.data_path_dictionary[idx]["absolute_path"]
        
        seg_path = os.path.join(folder_path, folder_name + '_seg.nii')
        t1_path = os.path.join(folder_path, folder_name + '_t1.nii')
        t1ce_path = os.path.join(folder_path, folder_name + '_t1ce.nii')
        t2_path = os.path.join(folder_path, folder_name + '_t2.nii')
        flair_path = os.path.join(folder_path, folder_name + '_flair.nii')
        
        # Load .nii files as nparrays
        seg_img = nib.load(seg_path).get_fdata()
        
        t1_img = nib.load(t1_path).get_fdata() #combine these ones
        t1ce_img = nib.load(t1ce_path).get_fdata()#combine these ones
        t2_img = nib.load(t2_path).get_fdata()#combine these ones
        flair_img = nib.load(flair_path).get_fdata()#combine these ones
        
        # Combine the MRI scans into a single 4-channel image
        combined_mri = np.stack([t1_img, t1ce_img, t2_img, flair_img], axis=0)  
        
        # Convert combined_mri and seg_img to torch tensors
        combined_mri = torch.tensor(combined_mri, dtype=torch.int32)
        seg_img = torch.tensor(seg_img, dtype=torch.int32)

        
        #convert to binary problem:
        seg_img[seg_img != 0] = 1
        
        return combined_mri, seg_img

val_dataset = ImageDataset(data_path_dictionary_valid)
batch_size = 1 # remember each item in a batch actually contains 155 subbatches
val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [4]:
from seg_models import UNet, UNetPaper
    
    
# Define the number of classes
num_output_classes = 1 
num_input_channels = 4

# Custom Models
model = UNet(num_input_channels, num_output_classes)

# Paper Models
# model = UNetPaper(num_input_channels, num_output_classes)

# Load the saved weights
saved_models_path = os.path.join(final_project_path, 'saved_models')


####### Change the name here #########
saved_model_path = os.path.join(saved_models_path, 'train_unet1_bce_lr1e-5_20e')

# Add the file name
saved_weights_path = 'best_unet.pth'
saved_weights_path = os.path.join(saved_model_path, saved_weights_path)
saved_weights = torch.load(saved_weights_path)

# Load the weights into the model
model.load_state_dict(saved_weights)
model = model.to(device)
print("model", model)

model UNet(
  (encoder_conv1): Sequential(
    (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (encoder_conv2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (encoder_conv3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (encoder_conv4): Sequential(
    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (bottleneck)

In [5]:
# Define functions to calculate Dice coefficient and IoU

def calculate_dice_coefficient(prediction, target):
    # Calculate intersection and union
    intersection = torch.sum(prediction * target)
    union = torch.sum(prediction) + torch.sum(target)
    
    # Calculate Dice coefficient
    dice_coefficient = (2. * intersection) / (union + 1e-6)  # Add a small value to avoid division by zero
    
    return dice_coefficient

def calculate_iou(prediction, target):
    # Calculate intersection and union
    intersection = torch.sum(prediction * target)
    union = torch.sum(prediction) + torch.sum(target) - intersection
    
    # Calculate IoU
    iou = intersection / (union + 1e-6)  # Add a small value to avoid division by zero
    
    return iou

In [11]:
from sklearn.metrics import accuracy_score, recall_score, precision_score
import warnings
warnings.filterwarnings('ignore')  # "error", "ignore", "always", "default", "module" or "once"


model.eval()
dice_coefficient = 0.0
iou_score = 0.0
accuracy = 0.0
sensitivity = 0.0
specificity = 0.0
total_batches = 0

threshold = 0.5

with torch.no_grad():
    for combined_mris, seg_imgs in val_loader:
        print("Batch Num:", total_batches)
        for depth_idx in range(combined_mris.size(4)):
            current_slice = combined_mris[:, :, :, :, depth_idx].to(device)
            current_seg = seg_imgs[:, :, :, depth_idx].long()
            current_seg_one_hot = current_seg.unsqueeze(1).float().to(device)

            # Predict
            outputs = model(current_slice.float())
            binary_prediction = (outputs > threshold).squeeze().cpu().detach().to(torch.int)
            # print("binary_prediction", binary_prediction.shape)

            # Calculate Dice and IoU scores for the current depth slice
            current_seg_np = current_seg.cpu().detach().to(torch.int)
            # print("current_seg_np", current_seg_np.shape)
            unique_values = torch.unique(binary_prediction)
            # print("binary_prediction", unique_values)
            unique_values = torch.unique(current_seg_np)
            # print("current_seg_np", unique_values)
            dice_coefficient += calculate_dice_coefficient(binary_prediction, current_seg_np)
            iou_score += calculate_iou(binary_prediction, current_seg_np)
            
            # Assuming binary_prediction and current_seg_np are both binary tensors
            binary_prediction_np = binary_prediction.cpu().numpy().flatten()
            current_seg_np_flat = current_seg_np.cpu().numpy().flatten()

            # Calculate Accuracy
            accuracy += accuracy_score(current_seg_np_flat, binary_prediction_np)

            # Calculate Sensitivity (Recall)
            sensitivity += recall_score(current_seg_np_flat, binary_prediction_np, average='binary', pos_label=1, zero_division='warn')

            # Calculate Specificity Specificity = TN / (TN + FP)
            true_negatives = ((current_seg_np_flat == 0) & (binary_prediction_np == 0)).sum()
            false_positives = ((current_seg_np_flat == 0) & (binary_prediction_np == 1)).sum()
            # Check if the denominator is zero
            if (true_negatives + false_positives) == 0:
                specificity += 0.0  # Assign a default value when the denominator is zero
            else:
                specificity += true_negatives / (true_negatives + false_positives)
            
            total_batches += 1

# Calculate average scores for the epoch
print("total_batches", total_batches)
avg_dice_coefficient = dice_coefficient / total_batches
avg_iou_score = iou_score / total_batches
avg_accuracy = accuracy / total_batches
avg_sensitivity = sensitivity / total_batches
avg_specificity = specificity / total_batches

print("Average Dice Coefficient:", avg_dice_coefficient)
print("Average IoU Score:", avg_iou_score)
print("Average Accuracy:", avg_accuracy)
print("Average Sensitivity (Recall):", avg_sensitivity)
print("Average Specificity:", avg_specificity)

Batch Num: 0
Batch Num: 155
Batch Num: 310
Batch Num: 465
Batch Num: 620
Batch Num: 775
Batch Num: 930
Batch Num: 1085
Batch Num: 1240
Batch Num: 1395
Batch Num: 1550
Batch Num: 1705
Batch Num: 1860
Batch Num: 2015
Batch Num: 2170
Batch Num: 2325
Batch Num: 2480
Batch Num: 2635
Batch Num: 2790
Batch Num: 2945
total_batches 3100
Average Dice Coefficient: tensor(0.2889)
Average IoU Score: tensor(0.2524)
Average Accuracy: 0.9975052419354828
Average Sensitivity (Recall): 0.29352666664076826
Average Specificity: 0.9985677304300982
