In [None]:
import os
import sys

import pytorch_lightning as pl
import torch
from monai.losses import DiceLoss
from monai.networks.nets import UNETR, SegResNet

from dataloader import BrainTumourDataModule

In [None]:
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(os.path.join(project_root, "src"))

In [None]:
class SegResModel(pl.LightningModule):
    def __init__(self, in_channels, out_channels, learning_rate=1e-3):
        super(SegResModel, self).__init__()
        self.model = SegResNet(in_channels=in_channels, out_channels=out_channels)
        self.loss_fn = DiceLoss(
            smooth_nr=0,
            smooth_dr=1e-5,
            squared_pred=True,
            to_onehot_y=False,
            sigmoid=True,
        )
        self.learning_rate = learning_rate

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self.model(images)
        loss = self.loss_fn(outputs, labels)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

In [None]:
class EnsembleModel(pl.LightningModule):
    def __init__(self, model_list, num_classes):
        super(EnsembleModel, self).__init__()
        self.models = model_list
        self.num_classes = num_classes

    def forward(self, x):
        # Collect predictions from each model in the ensemble
        predictions = [model(x) for model in self.models]
        # Average predictions
        averaged_prediction = torch.mean(torch.stack(predictions), dim=0)
        return averaged_prediction

In [None]:
image_path = "../data/BrainTumourData/imagesTr/"
label_path = "../data/BrainTumourData/labelsTr/"
data_module = BrainTumourDataModule(
    data_path=image_path, seg_path=label_path, img_dim=(8, 8)
)
data_module.prepare_data()
data_module.setup()

In [None]:
# Instantiate each model
model1 = SegResModel(in_channels=2, out_channels=4)
# model2 = SegResModel(in_channels=2, out_channels=4)

# Train each model separately
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model1, data_module)
# trainer.fit(model2, data_module)

# Create the ensemble model using the trained models
ensemble_model = EnsembleModel([model1], num_classes=4)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type      | Params | Mode 
----------------------------------------------
0 | model   | SegResNet | 1.2 M  | train
1 | loss_fn | DiceLoss  | 0      | train
----------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.706     Total estimated model params size (MB)
137       Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|██████████| 328/328 [01:50<00:00,  2.97it/s, v_num=27]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 328/328 [01:50<00:00,  2.97it/s, v_num=27]


In [None]:
# Evaluate on validation or test data
test_dataloader = data_module.test_dataloader()
ensemble_predictions = trainer.predict(ensemble_model, dataloaders=test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0:   0%|          | 0/59 [02:58<?, ?it/s]


TypeError: conv3d() received an invalid combination of arguments - got (list, Parameter, NoneType, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, tuple of ints padding = 0, tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!list of [Tensor, Tensor]!, !Parameter!, !NoneType!, !tuple of (int, int, int)!, !tuple of (int, int, int)!, !tuple of (int, int, int)!, !int!)
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, str padding = "valid", tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!list of [Tensor, Tensor]!, !Parameter!, !NoneType!, !tuple of (int, int, int)!, !tuple of (int, int, int)!, !tuple of (int, int, int)!, !int!)
