In [None]:
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(os.path.join(project_root, "src"))

In [None]:
from typing import Tuple

import pytorch_lightning as pl
import torch
from monai.losses import DiceLoss
from monai.networks.nets import UNETR, SegResNet
from torch import Tensor


class SegResModel(pl.LightningModule):
    def __init__(self, in_channels, out_channels, learning_rate=1e-3):
        super(SegResModel, self).__init__()
        self.model = SegResNet(in_channels=in_channels, out_channels=out_channels)
        self.loss_fn = DiceLoss(
            smooth_nr=0,
            smooth_dr=1e-5,
            squared_pred=True,
            to_onehot_y=False,
            sigmoid=True,
        )
        self.learning_rate = learning_rate

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        loss, _ = self._common_step(batch, batch_idx)
        return loss

    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        loss, _ = self._common_step(batch, batch_idx)
        return loss

    def _common_step(
        self, batch: Tuple[Tensor, Tensor], batch_idx: int
    ) -> Tuple[Tensor, Tensor]:
        images, labels = batch
        preds = self.forward(images)
        loss = self.loss_fn(preds, labels)
        return loss, preds

    def predict_step(self, batch: Tensor, batch_idx: int) -> Tuple[Tensor, Tensor]:
        images, labels = batch
        preds = self.forward(images)
        return preds, labels

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

In [None]:
class EnsembleModel(pl.LightningModule):
    def __init__(self, model_list, num_classes):
        super(EnsembleModel, self).__init__()
        self.models = model_list
        self.num_classes = num_classes

    def forward(self, x):
        # Collect predictions from each model in the ensemble
        predictions = [model(x) for model in self.models]
        # Average predictions
        averaged_prediction = torch.mean(torch.stack(predictions), dim=0)
        return averaged_prediction

In [None]:
from dataloader import BrainTumourDataModule, BrainTumourDataset

image_path = "../data/BrainTumourData/imagesTr/"
label_path = "../data/BrainTumourData/labelsTr/"
data_module = BrainTumourDataModule(
    data_path=image_path, seg_path=label_path, img_dim=(8, 8)
)
data_module.prepare_data()
data_module.setup()

In [None]:
# Instantiate each model
model1 = SegResModel(in_channels=2, out_channels=4)
# model2 = SegResModel(in_channels=2, out_channels=4)

# Train each model separately
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model1, data_module)
# trainer.fit(model2, data_module)

# Create the ensemble model using the trained models
ensemble_model = EnsembleModel([model1], num_classes=4)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/kristof/Melytanulas/medical-image-segmentation/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/home/kristof/Melytanulas/medical-image-segmentation/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:68: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type      | Params | Mode 


Epoch 0: 100%|██████████| 328/328 [02:25<00:00,  2.25it/s, v_num=38]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 328/328 [02:26<00:00,  2.25it/s, v_num=38]


In [None]:
test_ds = BrainTumourDataset(
    "../data/BrainTumourData/imagesTr/",
    "../data/BrainTumourData/labelsTr/",
    ["BRATS_001.nii.gz"],
    img_dim=(8, 8),
)

In [None]:
test_data_example = test_ds[0]
image, label = test_data_example

In [None]:
import numpy as np

In [None]:
image = torch.tensor(np.array([image]))
image.shape

torch.Size([1, 2, 8, 8, 128])

In [None]:
label = torch.tensor(np.array([label]))
label.shape

torch.Size([1, 4, 8, 8, 128])

In [None]:
pred = model1.forward(image)

In [None]:
from metrics import *

In [None]:
pred.shape

torch.Size([1, 4, 8, 8, 128])

In [None]:
pred = pred.permute(0, 4, 1, 2, 3)

In [None]:
pred.shape

torch.Size([1, 128, 4, 8, 8])

In [None]:
label.shape

torch.Size([1, 4, 8, 8, 128])

In [None]:
label = label.permute(0, 4, 1, 2, 3)

In [None]:
dice_scores = dice_score(pred, label)
mean_dsc = mean_dice_score(dice_scores)
recall_scores, precision_scores = recall_precision(pred, label)
weighted_recall_score = weighted_recall(recall_scores, alpha=[0.2, 0.3, 0.5])
confusion_matrix = compute_confusion_matrix(pred, label)

print("Dice Scores for each class:", dice_scores)
print("Mean Dice Score:", mean_dsc)
print("Recall Scores for each class:", recall_scores)
print("Precision Scores for each class:", precision_scores)
print("Weighted Recall:", weighted_recall_score)
print(confusion_matrix)

Dice Scores for each class: [1.192777156829834, -0.006856401450932026, -0.009154651314020157, -0.000811282021459192]
Mean Dice Score: 0.29398870551085565
Recall Scores for each class: [1.4979944229125977, 1.9840309619903564, 1.4429582357406616, 0.200379878282547]
Precision Scores for each class: [0.9908838272094727, -0.0034222870599478483, -0.004562851507216692, -0.0004048215050715953]
Weighted Recall: 0.9298836022615433
[[7999    0   90    0]
 [   8    0   26    0]
 [  18    0   21    0]
 [  24    0    6    0]]
