# Load Packages and State Paths


In [2]:
import os
import json
import numpy as np
import cv2
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import transforms
import nibabel as nib

final_project_path = '/home/jws2215/e6691-2024spring-project-jwss-jws2215' # vm
data_folder_path = os.path.join(final_project_path, 'BraTS2020')
train_folder_path = os.path.join(data_folder_path, 'train')
valid_folder_path = os.path.join(data_folder_path, 'valid')


if torch.cuda.is_available():
    device = torch.device('cuda')
    print("GPU name: ", torch.cuda.get_device_name(0))
    allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)  # Convert bytes to gigabytes
    cached_memory = torch.cuda.memory_reserved() / (1024 ** 3)  # Convert bytes to gigabytes
    print(f"Allocated Memory: {allocated_memory:.2f} GB")
    print(f"Cached Memory: {cached_memory:.2f} GB")
    total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # Convert bytes to gigabytes
    print(f"Total GPU Memory: {total_memory:.2f} GB")
else:
    print("CUDA is not available. Cannot print memory usage.")
    device = torch.device('cpu')



GPU name:  Tesla T4
Allocated Memory: 0.00 GB
Cached Memory: 0.00 GB
Total GPU Memory: 14.58 GB


# Create the Dataset

In [3]:
def create_data_dictionary(folder_path):
    data_dict = {}
    subfolders = [f for f in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, f))]
    
    for idx, subfolder in enumerate(subfolders):
        abs_path = os.path.join(folder_path, subfolder)
        data_dict[idx] = {'absolute_path': abs_path, 'folder_name': subfolder}
    
    return data_dict

data_path_dictionary_train = create_data_dictionary(train_folder_path)
data_path_dictionary_valid = create_data_dictionary(valid_folder_path)

print("Train Data Dictionary:")
for idx, data in data_path_dictionary_train.items():
    print(f"Index: {idx}, Absolute Path: {data['absolute_path']}, Folder Name: {data['folder_name']}")

print("\nValid Data Dictionary:")
for idx, data in data_path_dictionary_valid.items():
    print(f"Index: {idx}, Absolute Path: {data['absolute_path']}, Folder Name: {data['folder_name']}")


Train Data Dictionary:
Index: 0, Absolute Path: /home/jws2215/e6691-2024spring-project-jwss-jws2215/BraTS2020/train/BraTS20_Training_099, Folder Name: BraTS20_Training_099
Index: 1, Absolute Path: /home/jws2215/e6691-2024spring-project-jwss-jws2215/BraTS2020/train/BraTS20_Training_036, Folder Name: BraTS20_Training_036
Index: 2, Absolute Path: /home/jws2215/e6691-2024spring-project-jwss-jws2215/BraTS2020/train/BraTS20_Training_004, Folder Name: BraTS20_Training_004
Index: 3, Absolute Path: /home/jws2215/e6691-2024spring-project-jwss-jws2215/BraTS2020/train/BraTS20_Training_029, Folder Name: BraTS20_Training_029
Index: 4, Absolute Path: /home/jws2215/e6691-2024spring-project-jwss-jws2215/BraTS2020/train/BraTS20_Training_021, Folder Name: BraTS20_Training_021
Index: 5, Absolute Path: /home/jws2215/e6691-2024spring-project-jwss-jws2215/BraTS2020/train/BraTS20_Training_058, Folder Name: BraTS20_Training_058
Index: 6, Absolute Path: /home/jws2215/e6691-2024spring-project-jwss-jws2215/BraTS2

In [None]:
## all data is stored in (240,240,155) 

class ImageDataset(Dataset):
    def __init__(self, data_folder, data_path_dictionary, image_size=(240, 240), transform=None):
        self.data_folder = data_folder
        self.image_size = image_size
        self.transform = transform

    def __len__(self):
        # print(len(self.annotations["images"]))
        return len(self.data_path_dictionary)

    def __getitem__(self, idx):
        # Load image
        image_list = self.annotations["images"]
        # print("id", image_list[idx]["id"])
        img_name = os.path.join(self.data_folder, image_list[idx]['file_name'])
        image = cv2.imread(img_name)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        folder_name = data_path_dictionary["folder_name"]
        folder_path = data_path_dictionary["absolute_path"]
        
        seg_path = os.path.join(folder_path, folder_name + '_seg.nii')
        t1_path = os.path.join(folder_path, folder_name + '_t1.nii')
        t1ce_path = os.path.join(folder_path, folder_name + '_t1ce.nii')
        t2_path = os.path.join(folder_path, folder_name + '_t2.nii')
        flair_path = os.path.join(folder_path, folder_name + '_flair.nii')
        
        # Load .nii files as nparrays
        seg_img = nib.load(seg_path).get_fdata()
        
        t1_img = nib.load(t1_path).get_fdata() #combine these ones
        t1ce_img = nib.load(t1ce_path).get_fdata()#combine these ones
        t2_img = nib.load(t2_path).get_fdata()#combine these ones
        flair_img = nib.load(flair_path).get_fdata()#combine these ones
        
        # Combine the MRI scans into a single 4-channel image
        combined_mri = np.stack([t1_img, t1ce_img, t2_img, flair_img], axis=2)
        
        # Apply transformations if specified
        if self.transform:
            image = self.transform(image)


        return image, seg_img

    

# Define transforms for data augmentation
transform = transforms.Compose([
    transforms.ToTensor(),
])

