In [1]:
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))

sys.path.append(os.path.join(project_root, "src"))

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from typing import List, Tuple


class BaselineModel(pl.LightningModule):
    def __init__(self) -> None:
        super(BaselineModel, self).__init__()
        self.mean_intensity = {
            0: 0.0,
            1: 0.0,
            2: 0.0,
            3: 0.0,
        }
        self.class_count = {
            0: 0,
            1: 0,
            2: 0,
            3: 0,
        }

    def forward(self, image_data: torch.Tensor) -> torch.Tensor:
        image_data = image_data.cpu()

        num_classes = len(self.mean_intensity)
        mean_intensities = np.array(list(self.mean_intensity.values()))
        image = image_data[:, :, :, 0].numpy()

        pixel_classes = np.zeros(image.shape, dtype=int)

        for class_index in range(num_classes):
            pixel_classes[image > mean_intensities[class_index]] = class_index

        pixel_classes[image <= np.mean(mean_intensities)] = 0

        pixel_classes_one_hot = np.eye(num_classes)[pixel_classes]
        pixel_classes_one_hot = np.moveaxis(pixel_classes_one_hot, -1, 1)

        return torch.tensor(pixel_classes_one_hot, dtype=torch.float32)

    def training_step(
        self, batch: list[Tuple[torch.Tensor, torch.Tensor]], batch_idx: int
    ) -> None:
        X, y = batch

        flair_image: torch.Tensor = X[:, :, :, :, 0]

        for class_index in self.mean_intensity.keys():
            class_mask = y[:, :, class_index, :, :]

            class_pixels = flair_image[class_mask == 1]

            self.class_count[class_index] += class_pixels.numel()

            self.mean_intensity[class_index] += class_pixels.sum().item()

    def validation_step(
        self, batch: list[Tuple[torch.Tensor, torch.Tensor]], batch_idx: int
    ) -> None:
        return

    def configure_optimizers(self) -> None:
        return

    def on_train_end(self) -> None:
        self._calculate_mean_intensity()

    def _calculate_mean_intensity(self) -> None:
        for class_index in self.mean_intensity.keys():
            if self.class_count[class_index] > 0:
                self.mean_intensity[class_index] /= self.class_count[class_index]

In [3]:
from dataloader import BrainTumourDataModule

data_module = BrainTumourDataModule("../BrainTumourData/imagesTr/")
data_module.prepare_data()
data_module.setup()

test_data = None
ground_truth = None
for batch in data_module.train_dataloader():
    X, y = batch
    X, y = X[0], y[0]
    test_data = X
    ground_truth = y
    break

In [None]:
plt.imshow(test_data[50, :, :, 0])

In [None]:
trainer = pl.Trainer(logger=False, enable_checkpointing=False, max_epochs=1)
model = BaselineModel()

In [None]:
trainer.fit(model, train_dataloaders=data_module)

In [7]:
output = model(test_data)

In [None]:
_, axs = plt.subplots(2, 4, figsize=(15, 7))

index_to_name = {
    0: "Background",
    1: "Edema",
    2: "Non-Enhancing Tumor",
    3: "Enhancing Tumor",
}

for class_index in range(4):
    axs[0, class_index].imshow(output[50, class_index, :, :])
    axs[0, class_index].set_title(f"Prediction - Class {index_to_name[class_index]}")
    axs[0, class_index].axis("off")

    axs[1, class_index].imshow(ground_truth[50, class_index, :, :])
    axs[1, class_index].set_title(f"Ground Truth - Class {index_to_name[class_index]}")
    axs[1, class_index].axis("off")

plt.tight_layout()
plt.show()

In [None]:
import torch

model.eval()

total_batches = len(data_module.test_dataloader())
num_classes = 4
height, width = 128, 128
slices_per_sample = 155

all_predictions = torch.empty(
    (total_batches, slices_per_sample, num_classes, height, width), dtype=torch.float32
)
all_targets = torch.empty(
    (total_batches, slices_per_sample, num_classes, height, width), dtype=torch.float32
)

for current_index, (images, targets) in enumerate(data_module.test_dataloader()):
    with torch.no_grad():
        predictions = model(images[0])
        all_predictions[current_index] = predictions
        all_targets[current_index] = targets[0]

print("Shape of all_predictions:", all_predictions.shape)
print("Shape of all_targets:", all_targets.shape)

In [None]:
from metrics import *

dice_scores = dice_score(all_predictions, all_targets)
mean_dsc = mean_dice_score(dice_scores)
recall_scores, precision_scores = recall_precision(all_predictions, all_targets)
weighted_recall_score = weighted_recall(recall_scores, alpha=[0.2, 0.3, 0.5])

print("Dice Scores for each class:", dice_scores)
print("Mean Dice Score:", mean_dsc)
print("Recall Scores for each class:", recall_scores)
print("Precision Scores for each class:", precision_scores)
print("Weighted Recall:", weighted_recall_score)

In [None]:
confusion_matrix = compute_confusion_matrix(all_predictions, all_targets)
print("Confusion Matrix:", confusion_matrix)