# 0) Environment Setup

In [None]:
!pip -q install nibabel tqdm torch

# 1) Imports and Global State

In [None]:
import os, glob, random, logging, sys, time
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.model_selection import train_test_split

os.makedirs("/kaggle/working/logs", exist_ok=True)
logging.basicConfig(
    level=logging.INFO, 
    format="%(asctime)s | %(levelname)s | %(message)s", 
    handlers=[
        logging.FileHandler("/kaggle/working/logs/training.log"),
        logging.StreamHandler(sys.stdout)
    ]
)
LOG = logging.getLogger("BraTS_Fast_nnUNet")

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ROOT = "/kaggle/input/brats20-dataset-training-validation"
TRAIN_BASE = f"{ROOT}/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"
VAL_BASE = f"{ROOT}/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData"

LOG.info(f"Initialized: {DEVICE}")

# 2) Data Discovery

In [None]:
def find_cases(base, is_val=False):
    cases = []
    pattern = "BraTS20_Validation_*" if is_val else "BraTS20_Training_*"
    found_dirs = sorted(glob.glob(os.path.join(base, pattern)))
    
    for d in found_dirs:
        def check_file(p):
            found = glob.glob(os.path.join(d, p))
            return found[0] if found else None
            
        m = {
            "flair": check_file("*_flair.nii*"), 
            "t1": check_file("*_t1.nii*"),
            "t1ce": check_file("*_t1ce.nii*"), 
            "t2": check_file("*_t2.nii*"),
            "seg": check_file("*_seg.nii*")
        }
        
        if not is_val and all(v is not None for v in m.values()):
            cases.append((os.path.basename(d), m))
        elif is_val and all(m[k] is not None for k in ["flair", "t1", "t1ce", "t2"]):
            cases.append((os.path.basename(d), m))
            
    return cases

train_cases = find_cases(TRAIN_BASE, is_val=False)
val_cases = find_cases(VAL_BASE, is_val=True)
LOG.info(f"Dataset Structure: {len(train_cases)} Train | {len(val_cases)} Val")

# 3) Training Configuration

In [None]:
PATCH_SIZE = (128, 128, 128)
BATCH_SIZE = 12
EPOCHS = 4
LR_START = 0.01

# 4) Dataset with Foreground-Biased Sampling

In [None]:
class FastBraTSDataset(Dataset):
    def __init__(self, cases, samples=1000):
        self.cases = cases
        self.samples = samples

    def __len__(self): return self.samples

    def __getitem__(self, _):
        cid, mf = random.choice(self.cases)
        imgs = []
        for k in ["flair","t1","t1ce","t2"]:
            data = nib.load(mf[k]).get_fdata().astype(np.float32)
            mask = data != 0
            if mask.sum() > 10:
                data[mask] = (data[mask] - data[mask].mean()) / (data[mask].std() + 1e-8)
            imgs.append(data)
        x = np.stack(imgs)
        
        seg = nib.load(mf["seg"]).get_fdata()
        y = np.stack([(seg > 0), ((seg == 1) | (seg == 4)), (seg == 4)]).astype(np.float32)

        d, h, w = x.shape[1:]
        pd, ph, pw = PATCH_SIZE
        
        if random.random() < 0.66:
            z, i, j = np.where(y[0] > 0)
            if len(z) > 0:
                idx = random.randint(0, len(z)-1)
                sd = np.clip(z[idx] - pd//2, 0, d-pd)
                sh = np.clip(i[idx] - ph//2, 0, h-ph)
                sw = np.clip(j[idx] - pw//2, 0, w-pw)
            else:
                sd, sh, sw = random.randint(0, d-pd), random.randint(0, h-ph), random.randint(0, w-pw)
        else:
            sd, sh, sw = random.randint(0, d-pd), random.randint(0, h-ph), random.randint(0, w-pw)

        x_p, y_p = x[:, sd:sd+pd, sh:sh+ph, sw:sw+pw], y[:, sd:sd+pd, sh:sh+ph, sw:sw+pw]
        
        for axis in [1, 2, 3]:
            if random.random() > 0.5:
                x_p, y_p = np.flip(x_p, axis).copy(), np.flip(y_p, axis).copy()
                
        return torch.from_numpy(x_p), torch.from_numpy(y_p)

# 5) Model Architecture (nnU-Net with Deep Supervision)

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, i, o):
        super().__init__()
        self.c = nn.Sequential(
            nn.Conv3d(i, o, 3, 1, 1, bias=False),
            nn.InstanceNorm3d(o),
            nn.LeakyReLU(0.01, True),
            nn.Dropout3d(0.2)
        )
    def forward(self, x): return self.c(x)

class nnUNet(nn.Module):
    def __init__(self):
        super().__init__()
        f = [32, 64, 128, 256]
        self.e1, self.e2, self.e3 = ConvBlock(4, f[0]), ConvBlock(f[0], f[1]), ConvBlock(f[1], f[2])
        self.pool = nn.MaxPool3d(2)
        self.bottleneck = ConvBlock(f[2], f[3])
        self.u3, self.u2, self.u1 = nn.ConvTranspose3d(f[3], f[2], 2, 2), nn.ConvTranspose3d(f[2], f[1], 2, 2), nn.ConvTranspose3d(f[1], f[0], 2, 2)
        self.d3, self.d2, self.d1 = ConvBlock(f[3], f[2]), ConvBlock(f[2], f[1]), ConvBlock(f[1], f[0])
        self.ds3 = nn.Conv3d(f[2], 3, 1)
        self.ds2 = nn.Conv3d(f[1], 3, 1)
        self.out = nn.Conv3d(f[0], 3, 1)

    def forward(self, x):
        s1 = self.e1(x); s2 = self.e2(self.pool(s1)); s3 = self.e3(self.pool(s2))
        b = self.bottleneck(self.pool(s3))
        d3 = self.d3(torch.cat([self.u3(b), s3], 1))
        d2 = self.d2(torch.cat([self.u2(d3), s2], 1))
        d1 = self.d1(torch.cat([self.u1(d2), s1], 1))
        return self.out(d1), self.ds2(d2), self.ds3(d3)

# 6) Hybrid Loss Function

In [None]:
def hybrid_loss(outputs, target):
    def region_loss(p, y):
        p_sig = torch.sigmoid(p)
        dice = 1 - (2*(p_sig*y).sum()+1e-6)/(p_sig.sum()+y.sum()+1e-6)
        bce = F.binary_cross_entropy_with_logits(p, y)
        return dice + bce
    
    l1 = region_loss(outputs[0], target)
    l2 = region_loss(outputs[1], F.interpolate(target, outputs[1].shape[2:], mode='nearest'))
    l3 = region_loss(outputs[2], F.interpolate(target, outputs[2].shape[2:], mode='nearest'))
    return l1 + 0.5*l2 + 0.25*l3

# 7) Training Loop with Poly LR Scheduler

In [None]:
if len(train_cases) > 0:
    train_split, internal_val = train_test_split(train_cases, test_size=0.1, random_state=SEED)
    train_loader = DataLoader(FastBraTSDataset(train_split), batch_size=BATCH_SIZE, num_workers=4, pin_memory=True)
    val_loader = DataLoader(FastBraTSDataset(internal_val, samples=20), batch_size=2)

    net = nn.DataParallel(nnUNet()).to(DEVICE)
    optimizer = torch.optim.SGD(net.parameters(), lr=LR_START, momentum=0.99, nesterov=True)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda ep: (1 - ep/EPOCHS)**0.9)
    scaler = torch.amp.GradScaler(enabled=True)

    for epoch in range(EPOCHS):
        net.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        
        for x, y in pbar:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            with torch.amp.autocast(device_type="cuda"):
                outputs = net(x)
                loss = hybrid_loss(outputs, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        net.eval()
        torch.cuda.empty_cache()
        d_scores = []
        with torch.no_grad():
            for vx, vy in val_loader:
                vx = vx.to(DEVICE)
                with torch.amp.autocast(device_type="cuda"):
                    v_out, _, _ = net(vx)
                p = (torch.sigmoid(v_out) > 0.5).float().cpu()
                d_scores.append((2.*(p*vy).sum()/(p.sum()+vy.sum()+1e-6)).item())
        
        LOG.info(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Val Dice: {np.mean(d_scores):.4f}")
        scheduler.step()

# 8) Uncertainty Inference Function

In [None]:
@torch.no_grad()
def infer_uncertainty(x, model, mc_iter=8):
    model.train()
    preds = [torch.sigmoid(model(x)[0]) for _ in range(mc_iter)]
    mean_p = torch.stack(preds).mean(0)
    uncertainty = - (mean_p * torch.log(mean_p + 1e-8) + (1-mean_p) * torch.log(1-mean_p + 1e-8))
    return mean_p, (uncertainty / np.log(2) * 100).clamp(0, 100)

# 9) Model Export

In [None]:
SAVE_FILE = "/kaggle/working/weights/nnunet_brats2020.pth"
os.makedirs(os.path.dirname(SAVE_FILE), exist_ok=True)
try:
    if isinstance(net, nn.DataParallel):
        torch.save(net.module.state_dict(), SAVE_FILE)
    else:
        torch.save(net.state_dict(), SAVE_FILE)
    LOG.info(f"Exported: {SAVE_FILE} ({os.path.getsize(SAVE_FILE)/1024**2:.1f}MB)")
except Exception as e:
    LOG.error(f"Save Error: {e}")

# 10) Visualization Helper Function

In [None]:
import matplotlib.pyplot as plt

def visualize_comparison(image, ground_truth, prediction, slice_idx=None):
    if slice_idx is None:
        slice_idx = image.shape[1] // 2 
        
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    axes[0].imshow(image[0, slice_idx, :, :], cmap='gray')
    axes[0].set_title(f'FLAIR MRI (Slice {slice_idx})')
    axes[0].axis('off')
    
    axes[1].imshow(ground_truth[0, slice_idx, :, :], cmap='jet', alpha=0.8)
    axes[1].set_title('Ground Truth (Whole Tumor)')
    axes[1].axis('off')
    
    axes[2].imshow(prediction[0, slice_idx, :, :], cmap='jet', alpha=0.8)
    axes[2].set_title('Model Prediction (Whole Tumor)')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

# 11) Run Inference and Visualize

In [None]:
net.eval()

with torch.no_grad():
    images, targets = next(iter(val_loader))
    images = images.to(DEVICE)
    
    outputs, _, _ = net(images)
    
    preds = (torch.sigmoid(outputs) > 0.5).float().cpu()

sample_img = images[0].cpu().numpy()
sample_gt = targets[0].numpy()
sample_pred = preds[0].numpy()

visualize_comparison(sample_img, sample_gt, sample_pred)

# 12) Multimodal Visualizer

In [None]:
def plot_tumor_regions(gt, pred, slice_idx=64):
    regions = ['Whole Tumor', 'Tumor Core', 'Enhancing Tumor']
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    for i in range(3):
        axes[0, i].imshow(gt[i, slice_idx, :, :], cmap='gray')
        axes[0, i].set_title(f'GT: {regions[i]}')
        axes[0, i].axis('off')
        
        axes[1, i].imshow(pred[i, slice_idx, :, :], cmap='Reds')
        axes[1, i].set_title(f'Pred: {regions[i]}')
        axes[1, i].axis('off')
    
    plt.suptitle(f"Segmentation Comparison - Slice {slice_idx}")
    plt.show()

plot_tumor_regions(sample_gt, sample_pred)