# Load Packages and State Paths


In [1]:
import os
import json
import numpy as np
import cv2
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import transforms
import nibabel as nib

final_project_path = '/home/jws2215/e6691-2024spring-project-jwss-jws2215' # vm
data_folder_path = os.path.join(final_project_path, 'BraTS2020')
train_folder_path = os.path.join(data_folder_path, 'train')
valid_folder_path = os.path.join(data_folder_path, 'valid')

if torch.cuda.is_available():
    device = torch.device('cuda')
    print("GPU name: ", torch.cuda.get_device_name(0))
    allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)  # Convert bytes to gigabytes
    cached_memory = torch.cuda.memory_reserved() / (1024 ** 3)  # Convert bytes to gigabytes
    print(f"Allocated Memory: {allocated_memory:.2f} GB")
    print(f"Cached Memory: {cached_memory:.2f} GB")
    total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # Convert bytes to gigabytes
    print(f"Total GPU Memory: {total_memory:.2f} GB")
else:
    print("CUDA is not available. Cannot print memory usage.")
    device = torch.device('cpu')



GPU name:  Tesla T4
Allocated Memory: 0.00 GB
Cached Memory: 0.00 GB
Total GPU Memory: 14.58 GB


# Create the Dataset

In [None]:
## all images are clean and 640 by 640

class ImageDataset(Dataset):
    def __init__(self, data_folder, annotations, image_size=(240, 240), transform=None):
        self.data_folder = data_folder
        self.image_size = image_size
        self.transform = transform

    def __len__(self):
        # print(len(self.annotations["images"]))
        return len(self.annotations["images"])

    def __getitem__(self, idx):
        # Load image
        image_list = self.annotations["images"]
        # print("id", image_list[idx]["id"])
        img_name = os.path.join(self.data_folder, image_list[idx]['file_name'])
        image = cv2.imread(img_name)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Apply transformations if specified
        if self.transform:
            image = self.transform(image)
        
        # Generate mask
        mask = self.generate_mask_from_annotations(idx)
        # print("mask", mask)

        return image, mask

    def generate_mask_from_annotations(self, idx):
        # Generate mask from annotations
        
        mask = np.zeros(self.image_size, dtype=np.uint8)
        #mask = np.zeros((*self.image_size, self.num_classes), dtype=np.uint8)
        
        this_image_annotations = self.annotations["annotations"][idx]
        # print(this_image_annotations)
        category_id = this_image_annotations["category_id"]
        polygons_list = this_image_annotations["segmentation"]
        # print("category_id", category_id)
        # print("polgyons_list", polygons_list)
        
        for points in polygons_list:
            # Extract annotation points and draw rectangle on the mask
            # Last point is the same as the first point
            # print("points", points)
            
            rect_points = np.array(points, dtype=np.int32).reshape((-1, 2))
            # print("rect_points", rect_points)
            mask = cv2.fillPoly(mask, [rect_points], color=(255))
            
        if self.transform:
            mask = self.transform(mask)
        
        #lets do one class only
        mask[mask > 0] = 1
            
        return mask

# Define transforms for data augmentation
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((240, 240)),  # Resize to a standard size 240, 240 works fine, 224 from facebook model
    transforms.ToTensor(),
])

