# 1. Setup & Environment

In [1]:
import os, glob, random, logging, sys, time
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm

logging.basicConfig(level=logging.INFO, format="%(message)s", handlers=[logging.StreamHandler(sys.stdout)])
LOG = logging.getLogger("BraTS_Metrics")

SEED, DEVICE = 42, torch.device("cuda" if torch.cuda.is_available() else "cpu")
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

PATCH_SIZE, BATCH_SIZE, EPOCHS, LR_START = (128, 128, 128), 12, 2, 0.01
TRAIN_BASE = "/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"

# 2. Data Discovery

In [2]:
def find_cases(base):
    start_time = time.time()
    cases = []
    found_dirs = sorted([d for d in glob.glob(os.path.join(base, "BraTS20_Training_*")) if os.path.isdir(d)])
    
    for d in found_dirs:
        case_id = os.path.basename(d)
        m = {}
        missing = False
        for k in ["flair", "t1", "t1ce", "t2", "seg"]:
            files = glob.glob(os.path.join(d, f"*{k}.nii*"))
            if files:
                m[k] = files[0]
            else:
                missing = True
                break
        if not missing:
            cases.append((case_id, m))
    
    LOG.info(f"--- Data Discovery ---")
    LOG.info(f"Total Folders Scanned: {len(found_dirs)}")
    LOG.info(f"Valid Cases Found: {len(cases)}")
    LOG.info(f"Discovery Time: {time.time() - start_time:.2f}s")
    return cases

train_cases = find_cases(TRAIN_BASE)
train_split, val_split = train_test_split(train_cases, test_size=0.1, random_state=SEED)

--- Data Discovery ---
Total Folders Scanned: 369
Valid Cases Found: 368
Discovery Time: 2.53s


# 3. Dataset Class

In [3]:
class BraTSDataset(Dataset):
    def __init__(self, cases, samples=1000):
        self.cases = cases
        self.samples = samples

    def __len__(self): return self.samples

    def __getitem__(self, _):
        cid, mf = random.choice(self.cases)
        imgs = []
        for k in ["flair", "t1", "t1ce", "t2"]:
            data = nib.load(mf[k]).get_fdata().astype(np.float32)
            mask = data != 0
            if mask.sum() > 0:
                data[mask] = (data[mask] - data[mask].mean()) / (data[mask].std() + 1e-8)
            imgs.append(data)
        
        x = np.stack(imgs)
        seg = nib.load(mf["seg"]).get_fdata()
        y = np.stack([(seg > 0), ((seg == 1) | (seg == 4)), (seg == 4)]).astype(np.float32)

        d, h, w = x.shape[1:]
        pd, ph, pw = PATCH_SIZE
        if random.random() < 0.66:
            z, i, j = np.where(y[0] > 0)
            idx = random.randint(0, len(z)-1) if len(z) > 0 else 0
            sd, sh, sw = np.clip(z[idx]-pd//2, 0, d-pd), np.clip(i[idx]-ph//2, 0, h-ph), np.clip(j[idx]-pw//2, 0, w-pw)
        else:
            sd, sh, sw = random.randint(0, d-pd), random.randint(0, h-ph), random.randint(0, w-pw)

        x_p, y_p = x[:, sd:sd+pd, sh:sh+ph, sw:sw+pw], y[:, sd:sd+pd, sh:sh+ph, sw:sw+pw]
        
        for axis in [1, 2, 3]:
            if random.random() > 0.5:
                x_p, y_p = np.flip(x_p, axis).copy(), np.flip(y_p, axis).copy()
        
        return torch.from_numpy(x_p), torch.from_numpy(y_p)

# 4. Model Architecture & Loss

In [4]:
class ConvBlock(nn.Module):
    def __init__(self, i, o):
        super().__init__()
        self.c = nn.Sequential(nn.Conv3d(i, o, 3, 1, 1, bias=False), nn.InstanceNorm3d(o), nn.LeakyReLU(0.01, True))
    def forward(self, x): return self.c(x)

class nnUNet(nn.Module):
    def __init__(self):
        super().__init__()
        f = [32, 64, 128, 256]
        self.e1, self.e2, self.e3 = ConvBlock(4, f[0]), ConvBlock(f[0], f[1]), ConvBlock(f[1], f[2])
        self.pool = nn.MaxPool3d(2)
        self.bottleneck = ConvBlock(f[2], f[3])
        self.u3, self.u2, self.u1 = nn.ConvTranspose3d(f[3], f[2], 2, 2), nn.ConvTranspose3d(f[2], f[1], 2, 2), nn.ConvTranspose3d(f[1], f[0], 2, 2)
        self.d3, self.d2, self.d1 = ConvBlock(f[3], f[2]), ConvBlock(f[2], f[1]), ConvBlock(f[1], f[0])
        self.out, self.ds2, self.ds3 = nn.Conv3d(f[0], 3, 1), nn.Conv3d(f[1], 3, 1), nn.Conv3d(f[2], 3, 1)

    def forward(self, x):
        s1 = self.e1(x); s2 = self.e2(self.pool(s1)); s3 = self.e3(self.pool(s2))
        b = self.bottleneck(self.pool(s3))
        d3 = self.d3(torch.cat([self.u3(b), s3], 1))
        d2 = self.d2(torch.cat([self.u2(d3), s2], 1))
        d1 = self.d1(torch.cat([self.u1(d2), s1], 1))
        return self.out(d1), self.ds2(d2), self.ds3(d3)

def hybrid_loss(outputs, target):
    def region_loss(p, y):
        p_sig = torch.sigmoid(p)
        dice = 1 - (2*(p_sig*y).sum()+1e-6)/(p_sig.sum()+y.sum()+1e-6)
        return dice + F.binary_cross_entropy_with_logits(p, y)
    return region_loss(outputs[0], target) + 0.5*region_loss(outputs[1], F.interpolate(target, outputs[1].shape[2:], mode='nearest')) + 0.25*region_loss(outputs[2], F.interpolate(target, outputs[2].shape[2:], mode='nearest'))

# 5. Training & Evaluation

In [5]:
train_loader = DataLoader(BraTSDataset(train_split), batch_size=BATCH_SIZE, num_workers=4, pin_memory=True)
val_loader = DataLoader(BraTSDataset(val_split, samples=40), batch_size=1)

model = nn.DataParallel(nnUNet()).to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=LR_START, momentum=0.99, nesterov=True)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda ep: (1 - ep/EPOCHS)**0.9)
scaler = torch.amp.GradScaler(enabled=True)

for epoch in range(EPOCHS):
    model.train()
    epoch_start = time.time()
    train_losses = []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    for x, y in pbar:
        optimizer.zero_grad()
        with torch.amp.autocast(device_type="cuda"):
            outputs = model(x.to(DEVICE))
            loss = hybrid_loss(outputs, y.to(DEVICE))
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        train_losses.append(loss.item())
        pbar.set_postfix({"Loss": f"{loss.item():.4f}", "LR": f"{optimizer.param_groups[0]['lr']:.5f}"})

    model.eval()
    metrics = {"WT": [], "TC": [], "ET": []}
    with torch.no_grad():
        for vx, vy in val_loader:
            v_out, _, _ = model(vx.to(DEVICE))
            p = (torch.sigmoid(v_out) > 0.5).float().cpu()
            for i, label in enumerate(["WT", "TC", "ET"]):
                dice = (2.*(p[:,i]*vy[:,i]).sum()/(p[:,i].sum()+vy[:,i].sum()+1e-6)).item()
                metrics[label].append(dice)

    LOG.info(f"\n--- Epoch {epoch+1} Summary ---")
    LOG.info(f"Mean Loss: {np.mean(train_losses):.4f} | Time: {time.time()-epoch_start:.1f}s")
    LOG.info(f"Dice Scores -> WT: {np.mean(metrics['WT']):.3f} | TC: {np.mean(metrics['TC']):.3f} | ET: {np.mean(metrics['ET']):.3f}")
    scheduler.step()

Epoch 1: 100%|██████████| 84/84 [10:06<00:00,  7.22s/it, Loss=0.4744, LR=0.01000]



--- Epoch 1 Summary ---
Mean Loss: 1.4074 | Time: 674.1s
Dice Scores -> WT: 0.832 | TC: 0.646 | ET: 0.636


Epoch 2: 100%|██████████| 84/84 [10:51<00:00,  7.76s/it, Loss=0.5717, LR=0.00536] 



--- Epoch 2 Summary ---
Mean Loss: 0.4405 | Time: 720.2s
Dice Scores -> WT: 0.847 | TC: 0.753 | ET: 0.761


# 6. Save Metrics

In [6]:
save_path = "/kaggle/working/nnunet_brats.pth"
save_start = time.time()
torch.save(model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(), save_path)

LOG.info(f"--- Model Saving ---")
LOG.info(f"Saved to: {save_path}")
LOG.info(f"File Size: {os.path.getsize(save_path)/1024**2:.2f} MB")
LOG.info(f"Save Latency: {time.time()-save_start:.2f}s")

--- Model Saving ---
Saved to: /kaggle/working/nnunet_brats.pth
File Size: 10.20 MB
Save Latency: 0.18s
