In [None]:
import os
import torch
from torch.utils.data import DataLoader
from src.dataset import ISLES2p5DDataset
from src.augmentations import get_train_augmentations
from src.model import get_unet
from src.ensemble import load_ensemble, ensemble_predict
from src.utils import dice_score, plot_sample

# 1. Set up paths and parameters
DATA_DIR = "../data"  # or your actual data folder path
N_MODELS = 3
EPOCHS = 2
BATCH_SIZE = 8
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. Prepare dataset and dataloader
train_dataset = ISLES2p5DDataset(
    root_dir=DATA_DIR,
    modalities=['dwi', 'adc', 'flair'],
    slice_axis=2,
    slice_depth=3,
    transform=get_augmentation_pipeline(),
    resize=(128, 128)  # or (192, 192), but always fixed!
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

# 3. Train multiple base models for ensemble
for i in range(N_MODELS):
    model = get_unet(in_channels=9).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(EPOCHS):
        model.train()
        for x, y in train_loader:
            x, y = x.to(DEVICE, dtype=torch.float), y.to(DEVICE, dtype=torch.float)
            y = y.unsqueeze(1)
            optimizer.zero_grad()
            out = model(x)
            bce = torch.nn.BCELoss()
            loss = 0.5 * bce(out, y) + 0.5 * (1 - dice_score(out, y))
            loss.backward()
            optimizer.step()
    torch.save(model.state_dict(), f"base_model_{i}.pth")
    print(f"Saved base_model_{i}.pth")

# 4. Load ensemble models
model_paths = [f"base_model_{i}.pth" for i in range(N_MODELS)]
ensemble_models = load_ensemble(model_paths, DEVICE)

# 5. Ensemble prediction on a batch
x, y = next(iter(train_loader))
x = x.to(DEVICE, dtype=torch.float)
final_mask, averaged_probs = ensemble_predict(ensemble_models, x)

# 6. Evaluate and visualize
for i in range(min(3, x.shape[0])):  # Show up to 3 samples
    print(f"Dice score (ensemble): {dice_score(final_mask[i], y[i])}")
    plot_sample(x[i].cpu(), y[i].cpu(), final_mask[i].cpu())
