# Model Evaluation
The goal here is to load in a model and evaluate its performance on a random valid 3d MRI scan. The goal is to visualize it.

In [9]:
import os
import json
import numpy as np
import cv2
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import torchvision.models as models
from torchvision import transforms
import nibabel as nib
import random
from torch.utils.data import DataLoader
import time



final_project_path = '/home/jws2215/e6691-2024spring-project-jwss-jws2215' # vm
data_folder_path = os.path.join(final_project_path, 'BraTS2020')
train_folder_path = os.path.join(data_folder_path, 'train')
valid_folder_path = os.path.join(data_folder_path, 'valid')


if torch.cuda.is_available():
    device = torch.device('cuda')
    print("GPU name: ", torch.cuda.get_device_name(0))
    allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)  # Convert bytes to gigabytes
    cached_memory = torch.cuda.memory_reserved() / (1024 ** 3)  # Convert bytes to gigabytes
    print(f"Allocated Memory: {allocated_memory:.2f} GB")
    print(f"Cached Memory: {cached_memory:.2f} GB")
    total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # Convert bytes to gigabytes
    print(f"Total GPU Memory: {total_memory:.2f} GB")
else:
    print("CUDA is not available. Cannot print memory usage.")
    device = torch.device('cpu')



GPU name:  Tesla T4
Allocated Memory: 0.23 GB
Cached Memory: 0.24 GB
Total GPU Memory: 14.58 GB


In [10]:
def create_data_dictionary(folder_path):
    data_dict = {}
    subfolders = [f for f in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, f))]
    
    for idx, subfolder in enumerate(subfolders):
        abs_path = os.path.join(folder_path, subfolder)
        data_dict[idx] = {'absolute_path': abs_path, 'folder_name': subfolder}
    
    return data_dict

data_path_dictionary_train = create_data_dictionary(train_folder_path)
data_path_dictionary_valid = create_data_dictionary(valid_folder_path)



In [11]:
## all data is stored in (240,240,155) 

class ImageDataset(Dataset):
    def __init__(self, data_path_dictionary):
        self.data_path_dictionary = data_path_dictionary

    def __len__(self):
        # print(len(self.annotations["images"]))
        return len(self.data_path_dictionary)

    def __getitem__(self, idx):
        
        folder_name = self.data_path_dictionary[idx]["folder_name"]
        folder_path = self.data_path_dictionary[idx]["absolute_path"]
        
        seg_path = os.path.join(folder_path, folder_name + '_seg.nii')
        t1_path = os.path.join(folder_path, folder_name + '_t1.nii')
        t1ce_path = os.path.join(folder_path, folder_name + '_t1ce.nii')
        t2_path = os.path.join(folder_path, folder_name + '_t2.nii')
        flair_path = os.path.join(folder_path, folder_name + '_flair.nii')
        
        # Load .nii files as nparrays
        seg_img = nib.load(seg_path).get_fdata()
        
        t1_img = nib.load(t1_path).get_fdata() #combine these ones
        t1ce_img = nib.load(t1ce_path).get_fdata()#combine these ones
        t2_img = nib.load(t2_path).get_fdata()#combine these ones
        flair_img = nib.load(flair_path).get_fdata()#combine these ones
        
        # Combine the MRI scans into a single 4-channel image
        combined_mri = np.stack([t1_img, t1ce_img, t2_img, flair_img], axis=0)  
        
        # Convert combined_mri and seg_img to torch tensors
        # combined_mri = torch.tensor(combined_mri, dtype=torch.float32)
        # seg_img = torch.tensor(seg_img, dtype=torch.float32)
        combined_mri = torch.tensor(combined_mri, dtype=torch.int32)
        seg_img = torch.tensor(seg_img, dtype=torch.int32)

        
        #convert to binary problem:
        seg_img[seg_img != 0] = 1
        
        return combined_mri, seg_img

val_dataset = ImageDataset(data_path_dictionary_valid)


# Load the Model

In [7]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        # Encoder (downsampling path)
        self.encoder_conv1 = self.conv_block(in_channels, 64)
        
        self.encoder_conv2 = self.conv_block(64, 128)
        
        self.encoder_conv3 = self.conv_block(128, 256)
        
        self.encoder_conv4 = self.conv_block(256, 512)

        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)

        # Decoder (upsampling path)
        self.decoder_upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.decoder_conv1 = self.conv_block(1024, 512)
        
        self.decoder_upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder_conv2 = self.conv_block(512, 256)
        
        self.decoder_upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder_conv3 = self.conv_block(256, 128)
        
        self.decoder_upconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder_conv4 = self.conv_block(128, 64)

        # Output layer
        self.output_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.encoder_conv1(x)
        enc2 = self.encoder_conv2(F.max_pool2d(enc1, kernel_size=2, stride=2))
        enc3 = self.encoder_conv3(F.max_pool2d(enc2, kernel_size=2, stride=2))
        enc4 = self.encoder_conv4(F.max_pool2d(enc3, kernel_size=2, stride=2))

        # Bottleneck
        bottleneck = self.bottleneck(F.max_pool2d(enc4, kernel_size=2, stride=2))

        # Decoder
        dec1 = self.decoder_upconv1(bottleneck)
        dec1 = self.decoder_conv1(torch.cat([enc4, dec1], dim=1))
        
        dec2 = self.decoder_upconv2(dec1)
        dec2 = self.decoder_conv2(torch.cat([enc3, dec2], dim=1))
        
        dec3 = self.decoder_upconv3(dec2)
        dec3 = self.decoder_conv3(torch.cat([enc2, dec3], dim=1))
        
        dec4 = self.decoder_upconv4(dec3)
        dec4 = self.decoder_conv4(torch.cat([enc1, dec4], dim=1))

        # Output layer
        output = self.output_conv(dec4)

        return torch.sigmoid(output)
        # output = F.softmax(output, dim=1)
        # return output

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    
# Define the number of classes
num_output_classes = 1 #normally should be 5
num_input_channels = 4

# Custom Models
model = UNet(num_input_channels, num_output_classes)

# Load the saved weights
final_project_path 
saved_models_path = os.path.join(final_project_path, 'saved_models')

####### Change the name here #########
saved_model_path = os.path.join(saved_models_path, 'train_unet1_bce_lr1e-5_6e')

# Add the file name
saved_weights_path = 'last_unet.pth'
saved_weights_path = os.path.join(saved_model_path, saved_weights_path)
saved_weights = torch.load(saved_weights_path)

# Load the weights into the model
model.load_state_dict(saved_weights)
model = model.to(device)

# Evaluation on a Valid Patient

In [13]:
patient_number = 5 # range between 1-19

patient_scan = val_dataset.__getitem__(patient_number)
print("patient_scan", patient_scan)

patient_scan (tensor([[[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         ...,

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0,  ..., 0, 0, 0],
   