In [4]:
import os
import torch
from torch.utils.data import DataLoader
from src.dataset import ISLESDataset3D
from src.augmentations import get_train_augmentations
from src.model import get_unet
from src.ensemble import load_ensemble, ensemble_predict
from src.utils import dice_score, plot_sample

print(f"Imports done")
# 1. Set up paths and parameters
DATA_DIR = "../data"  # or your actual data folder path
N_MODELS = 1
EPOCHS = 1
BATCH_SIZE = 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Prepare dataset and dataloader")
train_dataset = ISLESDataset3D(
    root_dir=DATA_DIR)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
print(f"Number of samples in dataset: {len(train_dataset)}")
print(f"Number of batches in DataLoader: {len(train_loader)}")

print(f"#Train multiple base models for ensemble")
for i in range(N_MODELS):
    model = get_unet(in_channels=9).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(EPOCHS):
        model.train()
        print(f"Training model {i}, epoch {epoch+1}/{EPOCHS}")
        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(DEVICE, dtype=torch.float), y.to(DEVICE, dtype=torch.float)
            y = y.unsqueeze(1)
            optimizer.zero_grad()
            out = model(x)
            bce = torch.nn.BCELoss()
            loss = 0.5 * bce(out, y) + 0.5 * (1 - dice_score(out, y))
            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print(f"  Model {i} Epoch {epoch+1} Batch {batch_idx}/{len(train_loader)} Loss: {loss.item():.4f}")
    torch.save(model.state_dict(), f"base_model_{i}.pth")
    print(f"Saved base_model_{i}.pth")

print(f"#Load ensemble models")
model_paths = [f"base_model_{i}.pth" for i in range(N_MODELS)]
ensemble_models = load_ensemble(model_paths, DEVICE)

print(f"#Ensemble prediction on a batch")
x, y = next(iter(train_loader))
x = x.to(DEVICE, dtype=torch.float)
final_mask, averaged_probs = ensemble_predict(ensemble_models, x)

print(f"#Evaluate and visualize")
for i in range(min(3, x.shape[0])):  # Show up to 3 samples
    print(f"Dice score (ensemble): {dice_score(final_mask[i], y[i])}")
    plot_sample(x[i].cpu(), y[i].cpu(), final_mask[i].cpu())


  from .autonotebook import tqdm as notebook_tqdm


Imports done
Prepare dataset and dataloader
entering 3D samples
Total 3D samples: 248
Number of samples in dataset: 248
Number of batches in DataLoader: 248
#Train multiple base models for ensemble
Training model 0, epoch 1/1


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 2, 64, 64, 16]