In [None]:
%pip install nibabel
%pip install opencv-python
%pip install albumentations
%pip install segmentation_models_pytorch 



import os
import torch
from torch.utils.data import DataLoader
from src.dataset import ISLESDataset3D
from src.augmentations import get_train_augmentations
from src.model import get_unet
from src.ensemble import load_ensemble, ensemble_predict
from src.utils import dice_score, plot_sample
import torch.nn.functional as F
import warnings

def safe_unsqueeze_mask(y):
    # Ensure mask is [B, 1, D, H, W]
    if y.ndim == 4:
        y = y.unsqueeze(1)
    elif y.ndim == 5 and y.shape[1] != 1:
        # If mask has extra channels, take the first
        y = y[:, :1, ...]
    elif y.ndim < 4:
        raise ValueError(f"Mask shape too small: {y.shape}")
    return y

def safe_pad_or_crop(x, target_shape):
    # Accepts [B, C, D, H, W] or [B, D, H, W] or [C, D, H, W]
    if x.ndim == 4:
        x = x.unsqueeze(0)  # Add batch dim if missing
    if x.ndim == 5:
        _, _, D, H, W = x.shape
        tD, tH, tW = target_shape
        pad_d = max(tD - D, 0)
        pad_h = max(tH - H, 0)
        pad_w = max(tW - W, 0)
        x = F.pad(x, [0, pad_w, 0, pad_h, 0, pad_d])
        x = x[:, :, :tD, :tH, :tW]
    else:
        raise ValueError(f"Input shape not supported: {x.shape}")
    return x

print(f"Imports done")
# 1. Set up paths and parameters
DATA_DIR = "../data"  # or your actual data folder path
N_MODELS = 3
EPOCHS = 2
BATCH_SIZE = 3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Prepare dataset and dataloader",DEVICE)
def pad_collate(batch):
    xs, ys = zip(*batch)
    # Find max height, width, depth in this batch
    max_h = max(x.shape[-3] for x in xs)
    max_w = max(x.shape[-2] for x in xs)
    max_d = max(x.shape[-1] for x in xs)
    xs_padded = []
    ys_padded = []
    for x, y in zip(xs, ys):
        pad_h = max_h - x.shape[-3]
        pad_w = max_w - x.shape[-2]
        pad_d = max_d - x.shape[-1]
        # Pad as (left, right, top, bottom, front, back)
        # F.pad uses (D1, D2, H1, H2, W1, W2) for 5D tensors
        # For 4D tensors: (N, C, H, W, D) or (C, H, W, D)
        # Here, x is likely (C, H, W, D)
        x_padded = F.pad(x, (0, pad_d, 0, pad_w, 0, pad_h))
        y_padded = F.pad(y, (0, pad_d, 0, pad_w, 0, pad_h))
        xs_padded.append(x_padded)
        ys_padded.append(y_padded)
    xs_padded = torch.stack(xs_padded)
    ys_padded = torch.stack(ys_padded)
    return xs_padded, ys_padded

def pad_or_crop_to_shape(x, target_shape):
    # x: [B, C, D, H, W]
    _, _, D, H, W = x.shape
    tD, tH, tW = target_shape
    # Pad
    pad_d = max(tD - D, 0)
    pad_h = max(tH - H, 0)
    pad_w = max(tW - W, 0)
    x = F.pad(x, [0, pad_w, 0, pad_h, 0, pad_d])
    # Crop
    x = x[:, :, :tD, :tH, :tW]
    return x

train_dataset = ISLESDataset3D(
    root_dir=DATA_DIR)
sample_x, sample_y = train_dataset[0]
print("Sample mask unique:", torch.unique(sample_y))
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=0, 
    collate_fn=pad_collate
)
print(f"Number of samples in dataset: {len(train_dataset)}")
print(f"Number of batches in DataLoader: {len(train_loader)}")

print(f"#Train multiple base models for ensemble")
for i in range(N_MODELS): 
    model = get_unet(in_channels=2).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(EPOCHS):
        model.train()
        print(f"Training model {i}, epoch {epoch+1}/{EPOCHS}")
       # ...existing code...
    # In your training loop:
    for batch_idx, (x, y) in enumerate(train_loader):
        target_shape = (128, 128, 64)
        # Ensure correct dims and type
        if x.ndim == 4:
            x = x.unsqueeze(0)
        if x.ndim == 5 and x.shape[1] > 2:
            x = x[:, :2, ...]
        x = safe_pad_or_crop(x, target_shape)
        x = x.to(DEVICE, dtype=torch.float)
        # Mask handling
        y = safe_unsqueeze_mask(y)
        y = safe_pad_or_crop(y, target_shape)
        y = y.to(DEVICE, dtype=torch.float)
        y = y.clamp(0, 1)  # Ensure mask is binary/float

        print("x shape:", x.shape)
        print("y shape:", y.shape)      # [B, 2, H, W]
        # No further slicing needed!  # [B, 2, H, W]

        # For y (mask), shape: [B, D, H, W]
        optimizer.zero_grad()
        out = model(x)
        bce = torch.nn.BCELoss()
        loss = 0.5 * bce(out, y) + 0.5 * (1 - dice_score(out, y))
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f"  Model {i} Epoch {epoch+1} Batch {batch_idx}/{len(train_loader)} Loss: {loss.item():.4f}")
# ...existing code...
    torch.save(model.state_dict(), f"base_model_{i}.pth")
    print(f"Saved base_model_{i}.pth")

print(f"#Load ensemble models")
model_paths = [f"base_model_{i}.pth" for i in range(N_MODELS)]
ensemble_models = load_ensemble(model_paths, DEVICE)

print(f"#Ensemble prediction on a batch")
x, y = next(iter(train_loader)._drop_last)
x, y = next(iter(train_loader)._drop_last)
print("Input min/max:", x.min(), x.max())
print("Mask unique values:", torch.unique(y))
x = x.to(DEVICE, dtype=torch.float)
y = y.to(DEVICE, dtype=torch.float)
x = x[..., x.shape[-1] // 2]  # [B, 2, H, W]
y = y[..., y.shape[-1] // 2]  # [B, H, W]
if x.shape[1] > 2:
    x = x[:, :2, ...]
final_mask, averaged_probs = ensemble_predict(ensemble_models, x)

print(f"#Evaluate and visualize")
for i in range(min(3, x.shape[0])):  # Show up to 3 samples
    print(f"Dice score (ensemble): {dice_score(final_mask[i], y[i])}")
    plot_sample(x[i].cpu().squeeze(), y[i].cpu(), final_mask[i].cpu().squeeze(), channel=0)  # DWI
    continue





[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.

[1m[[0m[3

  from .autonotebook import tqdm as notebook_tqdm


Imports done
Prepare dataset and dataloader cpu
entering 3D samples
Total 3D samples: 248
Sample mask unique: tensor([0., 1.])
Number of samples in dataset: 248
Number of batches in DataLoader: 83
#Train multiple base models for ensemble
Training model 0, epoch 1/2
Training model 0, epoch 2/2
x shape: torch.Size([3, 2, 128, 128, 64])
y shape: torch.Size([3, 1, 128, 128, 64])
skip shape: torch.Size([3, 256, 16, 16, 8])
x shape: torch.Size([3, 256, 16, 16, 8])
skip shape: torch.Size([3, 128, 32, 32, 16])
x shape: torch.Size([3, 128, 32, 32, 16])
skip shape: torch.Size([3, 64, 64, 64, 32])
x shape: torch.Size([3, 64, 64, 64, 32])
skip shape: torch.Size([3, 32, 128, 128, 64])
x shape: torch.Size([3, 32, 128, 128, 64])
  Model 0 Epoch 2 Batch 0/83 Loss: 0.8681
x shape: torch.Size([3, 2, 128, 128, 64])
y shape: torch.Size([3, 1, 128, 128, 64])
skip shape: torch.Size([3, 256, 16, 16, 8])
x shape: torch.Size([3, 256, 16, 16, 8])
skip shape: torch.Size([3, 128, 32, 32, 16])
x shape: torch.Size(