In [None]:
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(os.path.join(project_root, "src"))

In [None]:
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib.axes import Axes
from monai.losses import DiceLoss
from monai.networks.nets import SegResNet, UNETR
from torch import Tensor
from tqdm import tqdm

from dataloader import BrainTumourDataModule

import wandb

In [None]:
wandb.login()

## Wrapping SegResNet

The input data must have a shape of (B, N, H, W, D) format

- B - batch size
- N - number of classes
- H - height
- W - width
- D - depth


In [None]:
class SegResModel(pl.LightningModule):
    def __init__(self, in_channels, out_channels, learning_rate=1e-3):
        super(SegResModel, self).__init__()
        self.model = SegResNet(in_channels=in_channels, out_channels=out_channels)
        self.dice_loss = DiceLoss(softmax=True)
        self.learning_rate = learning_rate

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        loss, _ = self._common_step(batch, batch_idx)

        self.log("train_loss", loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, _ = self._common_step(batch, batch_idx)

        self.log("val_loss", loss, prog_bar=True)
        return loss

    def _common_step(
        self, batch: Tuple[Tensor, Tensor], batch_idx: int
    ) -> Tuple[Tensor, Tensor]:
        images, labels = batch
        preds = self.forward(images)
        loss = self.dice_loss(preds, labels)
        return loss, preds

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

In [None]:
class UNetModel(pl.LightningModule):
    def __init__(self, in_channels, out_channels, learning_rate=1e-3):
        super(UNetModel, self).__init__()
        self.model = UNETR(
            in_channels=4,
            out_channels=4,
            img_size=(128, 128, 128),
        )
        self.dice_loss = DiceLoss(softmax=True)
        self.learning_rate = learning_rate

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        loss, _ = self._common_step(batch, batch_idx)
        return loss

    def _common_step(
        self, batch: Tuple[Tensor, Tensor], batch_idx: int
    ) -> Tuple[Tensor, Tensor]:
        images, labels = batch
        preds = self.forward(images)
        loss = self.dice_loss(preds, labels)
        return loss, preds

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

In [None]:
class EnsembleModel(pl.LightningModule):
    def __init__(self, model_list, num_classes):
        super(EnsembleModel, self).__init__()
        self.models: List[pl.LightningModule] = model_list
        self.num_classes: int = num_classes

    def forward(self, x):
        for model in self.models:
            model.eval()

        predictions = [model(x) for model in self.models]
        averaged_prediction = torch.mean(torch.stack(predictions), dim=0)
        return averaged_prediction

## Data Module Loading


In [None]:
image_path = "../data/BrainTumourData/imagesTr/"
label_path = "../data/BrainTumourData/labelsTr/"
img_dim = (128, 128)
batch_size = 1

data_module = BrainTumourDataModule(
    data_path=image_path, seg_path=label_path, img_dim=img_dim, batch_size=batch_size
)
data_module.prepare_data()
data_module.setup()

## Training the Model


In [None]:
# Instantiate each model
segresnet = SegResModel(in_channels=4, out_channels=4)

# Train each model separately
wandb_logger = pl.loggers.WandbLogger(
    project="medical-image-segmentation", log_model="all"
)
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="val_loss", mode="min")
trainer = pl.Trainer(
    max_epochs=20,
    logger=wandb_logger,
    callbacks=[checkpoint_callback],
)
trainer.fit(segresnet, data_module)
wandb.finish()

In [None]:
# Instantiate each model
unet = UNetModel(in_channels=4, out_channels=4)

# Train each model separately
wandb_logger = pl.loggers.WandbLogger(
    project="medical-image-segmentation", log_model="all"
)
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="val_loss", mode="min")
trainer = pl.Trainer(
    max_epochs=10,
    logger=wandb_logger,
    callbacks=[checkpoint_callback],
)
trainer.fit(unet, data_module)
wandb.finish()

In [None]:
ensemble_model = EnsembleModel([segresnet, unet], num_classes=4)

## Running Tests


In [None]:
image_list: List[Tensor] = []
prediction_list: List[Tensor] = []
label_list: List[Tensor] = []

for images, label in tqdm(data_module.test_dataloader()):
    with torch.no_grad():
        pred: Tensor = ensemble_model(images)
        image_list.append(images)
        prediction_list.append(pred)
        label_list.append(label)

images: Tensor = torch.cat(image_list, dim=0)
predictions: Tensor = torch.cat(prediction_list, dim=0)
labels: Tensor = torch.cat(label_list, dim=0)

print(f"images shape: {images.shape}")
print(f"predictions shape: {predictions.shape}")
print(f"labels shape: {labels.shape}")

## Prediction Visualization


In [None]:
def plot_slices(image: Tensor, label: Tensor, pred: Tensor, slice_index: int) -> None:
    image_slice: np.ndarray = image[0, :, :, slice_index].cpu().numpy()
    labels_map: Tensor = torch.argmax(label, dim=0)
    preds_map: Tensor = torch.argmax(pred, dim=0)

    labels_slice: np.ndarray = labels_map[:, :, slice_index].cpu().numpy()
    preds_slice: np.ndarray = preds_map[:, :, slice_index].cpu().numpy()

    axes: Tuple[Axes, Axes, Axes]
    _, axes = plt.subplots(1, 3, figsize=(15, 5))

    axes_orig: Axes = axes[0]
    axes_truth: Axes = axes[1]
    axes_pred: Axes = axes[2]

    axes_orig.imshow(image_slice, cmap="gray")
    axes_orig.set_title("Original MRI Image")
    axes_orig.axis("off")

    axes_truth.imshow(labels_slice, cmap="jet")
    axes_truth.set_title("Ground Truth")
    axes_truth.axis("off")

    axes_pred.imshow(preds_slice, cmap="jet")
    axes_pred.set_title("Predictions")
    axes_pred.axis("off")

    plt.show()

In [None]:
idx = 0

image = images[idx]
print(f"image shape: {image.shape}")

label = labels[idx]
print(f"labels shape: {label.shape}")

pred = predictions[idx]
print(f"preds shape: {pred.shape}")

In [None]:
plot_slices(image, label, pred, slice_index=image.shape[3] // 2)