In [14]:
from pathlib import Path
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
import nibabel as nib
import numpy as np
import os
from sklearn.model_selection import train_test_split
import cv2
import torch
import torch.nn.functional as F

TRAINING_DATA_PATH = "../BrainTumourData/imagesTr/"
TRAINING_SEGMENTATION_PATH = "../BrainTumourData/labelsTr/"

IMG_SIZE = 128


class BrainTumourDataset(Dataset):
    def __init__(
        self,
        file_ids,
        img_dim=(IMG_SIZE, IMG_SIZE),
        transform=None,
    ):
        self.dim = img_dim
        self.file_ids = file_ids
        self.transform = transform
        self.n_channels = 2

    def __len__(self):
        return len(self.file_ids)

    def __getitem__(self, idx):
        file_id = self.file_ids[idx]
        X, y = self.__data_generation(file_id)
        return X, y

    def __data_generation(self, file_id):

        data_path = os.path.join(TRAINING_DATA_PATH, file_id)
        seg_path = os.path.join(TRAINING_SEGMENTATION_PATH, file_id)

        data = nib.load(data_path).get_fdata()
        seg = nib.load(seg_path).get_fdata()

        num_slices = data.shape[2]

        X = np.zeros((num_slices, *self.dim, self.n_channels), dtype=np.float32)
        y = np.zeros((num_slices, *self.dim), dtype=np.float32)

        flair = data[:, :, :, 0]
        t1w = data[:, :, :, 1]

        for i in range(num_slices):
            X[i, :, :, 0] = cv2.resize(flair[:, :, i], self.dim)
            X[i, :, :, 1] = cv2.resize(t1w[:, :, i], self.dim)

            y[i] = cv2.resize(
                seg[:, :, i],
                self.dim,
                interpolation=cv2.INTER_NEAREST,
            )

        y_tensor = torch.from_numpy(y).long()  # Convert to tensor with long type
        Y = F.one_hot(y_tensor, num_classes=4).permute(0, 3, 1, 2).float()
        X = X / np.max(X)
        return X, Y


class BrainTumourDataModule(pl.LightningDataModule):
    def __init__(
        self,
        dir_path,
        batch_size=1,
        num_workers=4,
        img_dim=(IMG_SIZE, IMG_SIZE),
        transform=None,
    ):
        super().__init__()
        self.dir_path = dir_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.img_dim = img_dim
        self.transform = transform
        self.train_ids = self.val_ids = self.test_ids = None

    def prepare_data(self):
        self.training_datas = [f.name for f in os.scandir(self.dir_path) if f.is_file()]

    def setup(self, stage=None):
        train_test_ids, val_ids = train_test_split(self.training_datas, test_size=0.2)
        train_ids, test_ids = train_test_split(train_test_ids, test_size=0.15)

        self.train_ids = train_ids
        self.val_ids = val_ids
        self.test_ids = test_ids

        # Create datasets for each split
        self.train_dataset = BrainTumourDataset(
            self.train_ids,
            img_dim=self.img_dim,
            transform=self.transform,
        )
        self.val_dataset = BrainTumourDataset(
            self.val_ids,
            img_dim=self.img_dim,
            transform=self.transform,
        )
        self.test_dataset = BrainTumourDataset(
            self.test_ids,
            img_dim=self.img_dim,
            transform=self.transform,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers
        )

In [15]:
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from typing import Tuple


class BaselineModel(pl.LightningModule):
    def __init__(self) -> None:
        super(BaselineModel, self).__init__()
        self.mean_intensity = {
            0: 0.0,
            1: 0.0,
            2: 0.0,
            3: 0.0,
        }
        self.class_count = {
            0: 0,
            1: 0,
            2: 0,
            3: 0,
        }

    def forward(self, image_data: torch.Tensor) -> torch.Tensor:
        self._calculate_mean_intensity()

        image_data = image_data.cpu()

        num_classes = len(self.mean_intensity)
        mean_intensities = np.array(list(self.mean_intensity.values()))
        image = image_data[:, :, :, 0].numpy()

        pixel_classes = np.zeros(image.shape, dtype=int)

        for class_index in range(num_classes):
            pixel_classes[image > mean_intensities[class_index]] = class_index

        pixel_classes[image <= np.mean(mean_intensities)] = 0

        pixel_classes_one_hot = np.eye(num_classes)[pixel_classes]
        pixel_classes_one_hot = np.moveaxis(pixel_classes_one_hot, -1, 1)

        return torch.tensor(pixel_classes_one_hot, dtype=torch.float32)

    def training_step(
        self, batch: list[Tuple[torch.Tensor, torch.Tensor]], batch_idx: int
    ) -> None:
        X, y = batch

        flair_image: torch.Tensor = X[:, :, :, :, 0]

        for class_index in self.mean_intensity.keys():
            class_mask = y[:, :, class_index, :, :]

            class_pixels = flair_image[class_mask == 1]

            self.class_count[class_index] += class_pixels.numel()

            self.mean_intensity[class_index] += class_pixels.sum().item()

    def validation_step(
        self, batch: list[Tuple[torch.Tensor, torch.Tensor]], batch_idx: int
    ) -> None:
        return

    def configure_optimizers(self) -> None:
        return

    def _calculate_mean_intensity(self) -> None:
        for class_index in self.mean_intensity.keys():
            if self.class_count[class_index] > 0:
                self.mean_intensity[class_index] /= self.class_count[class_index]

In [16]:
data_module = BrainTumourDataModule("../BrainTumourData/imagesTr/")
data_module.prepare_data()
data_module.setup()

test_data = None
ground_truth = None
for batch in data_module.train_dataloader():
    X, y = batch
    X, y = X[0], y[0]
    test_data = X
    ground_truth = y
    break

In [None]:
plt.imshow(test_data[50, :, :, 0])

In [None]:
trainer = pl.Trainer(max_epochs=1)
model = BaselineModel()

In [None]:
trainer.fit(model, train_dataloaders=data_module)

In [20]:
output = model(test_data)

In [None]:
_, axs = plt.subplots(2, 4, figsize=(15, 7))

index_to_name = {
    0: "Background",
    1: "Edema",
    2: "Non-Enhancing Tumor",
    3: "Enhancing Tumor",
}

for class_index in range(4):
    axs[0, class_index].imshow(output[50, class_index, :, :])
    axs[0, class_index].set_title(f"Prediction - Class {index_to_name[class_index]}")
    axs[0, class_index].axis("off")

    axs[1, class_index].imshow(ground_truth[50, class_index, :, :])
    axs[1, class_index].set_title(f"Ground Truth - Class {index_to_name[class_index]}")
    axs[1, class_index].axis("off")

plt.tight_layout()
plt.show()