# BraTS Brain Tumor Segmentation - Final Training Pipeline

This notebook implements the complete training pipeline for BraTS brain tumor segmentation with:

# Features
- **Early Stopping** with configurable patience
- **Checkpoint Saving** after each epoch
- **Best Model Tracking** saved separately
- **Comprehensive Logging** visible on Kaggle
- **Learning Rate Scheduling** with Cosine Annealing
- **Mixed Precision Training** for faster training
- **Deep Supervision** for better gradients

# Model Architecture
- Modified 3D U-Net with 6-level encoder
- Deep supervision heads for auxiliary losses
- Instance normalization + LeakyReLU

# Output Files
After training, use `improved_submission.ipynb` for inference with:
- `/kaggle/working/best_model.pth` - Best model weights
- `/kaggle/working/checkpoints/` - All epoch checkpoints

---

# 1: Install Dependencies

In [1]:
# Install required libraries
!pip install -q monai nibabel scipy scikit-image tqdm matplotlib

[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/2.7 MB[0m [31m?[0m eta [36m-:--:--[0m

[2K   [91m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[90m‚ï∫[0m[90m‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m2.4/2.7 MB[0m [31m73.7 MB/s[0m eta [36m0:00:01[0m

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m2.7/2.7 MB[0m [31m48.6 MB/s[0m eta [36m0:00:00[0m
[?25h

# 2: Imports and Setup

In [2]:
# =============================================================================
# IMPORTS AND ENVIRONMENT SETUP
# =============================================================================
import os
import sys
import glob
import time
import logging
import warnings
from pathlib import Path
from dataclasses import dataclass
from typing import Tuple, List, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# MONAI imports
from monai.utils import set_determinism
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Orientationd, Spacingd,
    NormalizeIntensityd, CropForegroundd, RandCropByPosNegLabeld,
    MapLabelValued, RandFlipd, RandRotate90d, RandShiftIntensityd,
    ConvertToMultiChannelBasedOnBratsClassesd, SpatialPadd,
    ConcatItemsd, DeleteItemsd,
)
from monai.inferers import SlidingWindowInferer
from monai.losses import DiceLoss
from monai.metrics import DiceMetric

# Suppress warnings
warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# =============================================================================
# LOGGER SETUP - Visible on Kaggle
# =============================================================================
def setup_logger(name: str = "BraTS", log_file: str = None) -> logging.Logger:
    """Setup logger that prints to stdout (visible on Kaggle) and optionally to file."""
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    logger.handlers = []  # Clear existing handlers
    logger.propagate = False  # Prevent duplicate logging with INFO:BraTS prefix
    
    # Console handler - prints to stdout (visible on Kaggle)
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.INFO)
    console_format = logging.Formatter(
        '%(asctime)s | %(levelname)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    console_handler.setFormatter(console_format)
    logger.addHandler(console_handler)
    
    # File handler (optional)
    if log_file:
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(console_format)
        logger.addHandler(file_handler)
    
    return logger

# Initialize logger
logger = setup_logger("BraTS", "/kaggle/working/training.log")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Device: {device}")
if torch.cuda.is_available():
    logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
    logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

2026-02-03 23:24:34.471327: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1770161074.650452      24 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770161074.703898      24 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770161075.131112      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770161075.131154      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770161075.131157      24 computation_placer.cc:177] computation placer alr

2026-02-03 23:24:48 | INFO | Device: cuda


2026-02-03 23:24:48 | INFO | GPU: Tesla P100-PCIE-16GB


2026-02-03 23:24:48 | INFO | GPU Memory: 17.1 GB


# 3: Configuration

In [3]:
# =============================================================================
# CONFIGURATION - All hyperparameters in one place
# =============================================================================

@dataclass
class Config:
    """Central configuration for the BraTS segmentation pipeline."""
    
    # ----- Paths -----
    data_root: str = "/kaggle/input/brain-tumor-segmentation-hackathon"
    output_dir: str = "/kaggle/working"
    checkpoint_dir: str = "/kaggle/working/checkpoints"
    
    # ----- Training Hyperparameters -----
    epochs: int = 50                    # Maximum epochs (early stopping will likely trigger before)
    batch_size: int = 1                 # Batch size (1 for P100 16GB with 128^3 patches)
    learning_rate: float = 1e-4         # Initial learning rate
    weight_decay: float = 1e-5          # AdamW weight decay
    grad_clip: float = 12.0             # Gradient clipping max norm
    
    # ----- Early Stopping -----
    early_stopping_patience: int = 10   # Stop if no improvement for N epochs
    early_stopping_min_delta: float = 0.001  # Minimum improvement to reset patience
    
    # ----- Model Architecture -----
    patch_size: Tuple[int, int, int] = (128, 128, 128)
    in_channels: int = 4                # FLAIR, T1, T1ce, T2
    out_channels: int = 3               # TC, WT, ET
    features: List[int] = None          # Will default to [64, 96, 128, 192, 256, 384]
    deep_supervision: bool = True       # Enable auxiliary loss heads
    
    # ----- Data -----
    val_split: float = 0.2              # Validation split ratio
    target_spacing: Tuple[float, float, float] = (1.0, 1.0, 1.0)
    use_augmentation: bool = True
    cache_rate: float = 0.1             # Fraction of data to cache
    num_workers: int = 4
    pos_neg_ratio: Tuple[int, int] = (1, 1)
    num_samples_per_image: int = 1
    
    # ----- Validation -----
    val_interval: int = 1               # Validate every N epochs
    
    # ----- Inference -----
    sw_batch_size: int = 2              # Sliding window batch size
    overlap: float = 0.5                # Sliding window overlap
    
    # ----- Mixed Precision -----
    use_amp: bool = True                # Use automatic mixed precision
    
    # ----- Reproducibility -----
    seed: int = 42
    
    # ----- Checkpointing -----
    save_every_epoch: bool = True       # Save checkpoint after each epoch
    keep_last_n_checkpoints: int = 3    # Keep only last N checkpoints (0 = keep all)


# Initialize config
config = Config()

# Create directories
os.makedirs(config.checkpoint_dir, exist_ok=True)
os.makedirs(config.output_dir, exist_ok=True)

# Set determinism
set_determinism(seed=config.seed)

# Log configuration
logger.info("="*60)
logger.info("CONFIGURATION")
logger.info("="*60)
logger.info(f"Data root: {config.data_root}")
logger.info(f"Output dir: {config.output_dir}")
logger.info(f"Max epochs: {config.epochs}")
logger.info(f"Batch size: {config.batch_size}")
logger.info(f"Patch size: {config.patch_size}")
logger.info(f"Learning rate: {config.learning_rate}")
logger.info(f"Early stopping patience: {config.early_stopping_patience}")
logger.info(f"Deep supervision: {config.deep_supervision}")
logger.info(f"Mixed precision: {config.use_amp}")
logger.info("="*60)



2026-02-03 23:24:48 | INFO | CONFIGURATION




2026-02-03 23:24:48 | INFO | Data root: /kaggle/input/brain-tumor-segmentation-hackathon


2026-02-03 23:24:48 | INFO | Output dir: /kaggle/working


2026-02-03 23:24:48 | INFO | Max epochs: 50


2026-02-03 23:24:48 | INFO | Batch size: 1


2026-02-03 23:24:48 | INFO | Patch size: (128, 128, 128)


2026-02-03 23:24:48 | INFO | Learning rate: 0.0001


2026-02-03 23:24:48 | INFO | Early stopping patience: 10


2026-02-03 23:24:48 | INFO | Deep supervision: True


2026-02-03 23:24:48 | INFO | Mixed precision: True




# 4: Data Loading Utilities

In [4]:
# =============================================================================
# FILE DISCOVERY
# =============================================================================

MODALITY_KEYS = ["flair", "t1", "t1ce", "t2"]

def find_nifti_file(base_path: str, keyword: str) -> str:
    """Find a NIfTI file containing the keyword in nested folders."""
    candidates = glob.glob(os.path.join(base_path, "**", f"*{keyword}*.nii*"), recursive=True)
    real_files = [f for f in candidates if os.path.isfile(f)]
    
    if not real_files:
        raise FileNotFoundError(f"No valid file found for '{keyword}' in {base_path}")
    
    return max(real_files, key=len)


def get_file_list(data_dir: str) -> list:
    """Scan the dataset directory and build a list of file dictionaries."""
    data_dicts = []
    patient_folders = sorted(glob.glob(os.path.join(data_dir, "BraTS*")))
    
    logger.info(f"Scanning {len(patient_folders)} patient folders...")
    
    for pat_path in patient_folders:
        patient_id = os.path.basename(pat_path)
        try:
            flair = find_nifti_file(pat_path, "flair")
            t1 = find_nifti_file(pat_path, "t1")
            t1ce = find_nifti_file(pat_path, "t1ce")
            t2 = find_nifti_file(pat_path, "t2")
            seg = find_nifti_file(pat_path, "seg")

            data_dicts.append({
                "flair": flair,
                "t1": t1,
                "t1ce": t1ce,
                "t2": t2,
                "label": seg,
                "patient_id": patient_id
            })
        except FileNotFoundError as e:
            logger.warning(f"Skipping {patient_id}: {e}")
    
    logger.info(f"Successfully loaded {len(data_dicts)} valid patient volumes")
    return data_dicts

# 5: Transform Pipelines

In [5]:
# =============================================================================
# TRANSFORM BUILDERS
# =============================================================================

def get_train_transforms(config: Config) -> Compose:
    """Build the complete training transform pipeline."""
    return Compose([
        # Load each modality separately
        LoadImaged(keys=MODALITY_KEYS + ["label"]),
        EnsureChannelFirstd(keys=MODALITY_KEYS + ["label"]),
        
        # Concatenate modalities into single 4-channel image
        ConcatItemsd(keys=MODALITY_KEYS, name="image", dim=0),
        DeleteItemsd(keys=MODALITY_KEYS),
        
        # Label preprocessing
        # MapLabelValued(keys=["label"], orig_labels=[4], target_labels=[3]), # REMOVED: FIX FOR ET CLASS
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        
        # Spacing and normalization
        Spacingd(
            keys=["image", "label"],
            pixdim=config.target_spacing,
            mode=("bilinear", "nearest")
        ),
        ConvertToMultiChannelBasedOnBratsClassesd(keys=["label"]),
        NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        
        # Ensure minimum size
        SpatialPadd(keys=["image", "label"], spatial_size=config.patch_size),
        
        # Patch extraction
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=config.patch_size,
            pos=config.pos_neg_ratio[0],
            neg=config.pos_neg_ratio[1],
            num_samples=config.num_samples_per_image,
            image_key="image",
            image_threshold=0,
        ),
        
        # Augmentation
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3),
        RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
    ])


def get_val_transforms(config: Config) -> Compose:
    """Build the validation transform pipeline."""
    return Compose([
        LoadImaged(keys=MODALITY_KEYS + ["label"]),
        EnsureChannelFirstd(keys=MODALITY_KEYS + ["label"]),
        ConcatItemsd(keys=MODALITY_KEYS, name="image", dim=0),
        DeleteItemsd(keys=MODALITY_KEYS),
        # MapLabelValued(keys=["label"], orig_labels=[4], target_labels=[3]), # REMOVED: FIX FOR ET CLASS
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=config.target_spacing,
            mode=("bilinear", "nearest")
        ),
        ConvertToMultiChannelBasedOnBratsClassesd(keys=["label"]),
        NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        SpatialPadd(keys=["image", "label"], spatial_size=config.patch_size),
    ])

logger.info("Transform functions defined")

2026-02-03 23:24:49 | INFO | Transform functions defined


# 6: Model Architecture

In [6]:
# =============================================================================
# U-NET BUILDING BLOCKS
# =============================================================================

class ConvBlock(nn.Module):
    """Basic convolutional block: Conv3D -> InstanceNorm -> LeakyReLU (x2)"""
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
        super().__init__()
        padding = kernel_size // 2
        
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=False)
        self.norm1 = nn.InstanceNorm3d(out_channels, affine=True)
        self.act1 = nn.LeakyReLU(negative_slope=0.01, inplace=True)
        
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size, padding=padding, bias=False)
        self.norm2 = nn.InstanceNorm3d(out_channels, affine=True)
        self.act2 = nn.LeakyReLU(negative_slope=0.01, inplace=True)
    
    def forward(self, x):
        x = self.act1(self.norm1(self.conv1(x)))
        x = self.act2(self.norm2(self.conv2(x)))
        return x


class DownBlock(nn.Module):
    """Encoder block: ConvBlock followed by MaxPool3D"""
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.conv = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
    
    def forward(self, x):
        skip = self.conv(x)
        pooled = self.pool(skip)
        return pooled, skip


class UpBlock(nn.Module):
    """Decoder block: Upsample -> Concatenate skip -> ConvBlock"""
    def __init__(self, in_channels: int, skip_channels: int, out_channels: int):
        super().__init__()
        # Named 'upsample' to match improved_submission.ipynb
        self.upsample = nn.ConvTranspose3d(in_channels, in_channels, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_channels + skip_channels, out_channels)
    
    def forward(self, x, skip):
        x = self.upsample(x)
        
        # Handle size mismatch
        if x.shape != skip.shape:
            diff_d = skip.shape[2] - x.shape[2]
            diff_h = skip.shape[3] - x.shape[3]
            diff_w = skip.shape[4] - x.shape[4]
            x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2,
                         diff_h // 2, diff_h - diff_h // 2,
                         diff_d // 2, diff_d - diff_d // 2])
        
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)

In [7]:
# =============================================================================
# U-NET MODEL (Compatible with improved_submission.ipynb)
# =============================================================================

class UNet3D(nn.Module):
    """
    Modified 3D U-Net for BraTS segmentation with deep supervision.
    
    Architecture matches improved_submission.ipynb for seamless inference:
    - Layer names: 'upsample' (not 'up'), 'output_head' (not 'out')
    - Deep supervision heads saved but only used during training
    """
    
    def __init__(
        self,
        in_channels: int = 4,
        out_channels: int = 3,
        features: list = None,
        deep_supervision: bool = True,
    ):
        super().__init__()
        
        # 6-level encoder: 128 -> 64 -> 32 -> 16 -> 8 -> 4
        if features is None:
            features = [64, 96, 128, 192, 256, 384]
        
        self.deep_supervision = deep_supervision
        self.features = features
        
        # Encoder
        self.encoders = nn.ModuleList()
        prev_channels = in_channels
        for feat in features[:-1]:
            self.encoders.append(DownBlock(prev_channels, feat))
            prev_channels = feat
        
        # Bottleneck
        self.bottleneck = ConvBlock(prev_channels, features[-1])
        
        # Decoder
        self.decoders = nn.ModuleList()
        prev_channels = features[-1]
        for feat in reversed(features[:-1]):
            self.decoders.append(UpBlock(prev_channels, feat, feat))
            prev_channels = feat
        
        # Output head (named to match improved_submission.ipynb)
        self.output_head = nn.Conv3d(features[0], out_channels, kernel_size=1)
        
        # Deep supervision heads
        if deep_supervision and len(features) >= 5:
            self.ds_head1 = nn.Conv3d(features[1], out_channels, kernel_size=1)
            self.ds_head2 = nn.Conv3d(features[2], out_channels, kernel_size=1)
        else:
            self.ds_head1 = None
            self.ds_head2 = None
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv3d, nn.ConvTranspose3d)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            elif isinstance(m, nn.InstanceNorm3d):
                if m.weight is not None:
                    nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Encoder
        skips = []
        for encoder in self.encoders:
            x, skip = encoder(x)
            skips.append(skip)
        
        # Bottleneck
        x = self.bottleneck(x)
        
        # Decoder
        ds_outputs = []
        skips = skips[::-1]
        
        for i, (decoder, skip) in enumerate(zip(self.decoders, skips)):
            x = decoder(x, skip)
            
            # Deep supervision outputs (only during training)
            if self.deep_supervision and self.training:
                if i == 2 and self.ds_head2 is not None:
                    ds_outputs.append(self.ds_head2(x))
                elif i == 3 and self.ds_head1 is not None:
                    ds_outputs.append(self.ds_head1(x))
        
        output = self.output_head(x)
        
        if self.deep_supervision and self.training and len(ds_outputs) > 0:
            return output, ds_outputs
        
        return output
    
    def get_num_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


def create_model(config: Config) -> UNet3D:
    """Create and initialize the U-Net model."""
    model = UNet3D(
        in_channels=config.in_channels,
        out_channels=config.out_channels,
        features=config.features or [64, 96, 128, 192, 256, 384],
        deep_supervision=config.deep_supervision,
    )
    logger.info(f"Model created with {model.get_num_parameters():,} parameters")
    return model

# 7: Loss Function

In [8]:
# =============================================================================
# LOSS FUNCTIONS
# =============================================================================

class BraTSLoss(nn.Module):
    """Combined Dice + BCE loss for BraTS segmentation."""
    
    def __init__(self):
        super().__init__()
        self.dice = DiceLoss(sigmoid=True, smooth_nr=1e-5, smooth_dr=1e-5, batch=True)
        self.ce = nn.BCEWithLogitsLoss()
    
    def _loss(self, prediction, target):
        return self.dice(prediction, target) + self.ce(prediction, target.float())
    
    def forward(self, predictions, targets):
        # Per-region losses: TC, WT, ET
        loss_tc = self._loss(predictions[:, 0:1], targets[:, 0:1])
        loss_wt = self._loss(predictions[:, 1:2], targets[:, 1:2])
        loss_et = self._loss(predictions[:, 2:3], targets[:, 2:3])
        
        return loss_tc + loss_wt + loss_et


class DeepSupervisionLoss(nn.Module):
    """Loss wrapper for deep supervision."""
    
    def __init__(self, weights=None):
        super().__init__()
        self.base_loss = BraTSLoss()
        self.weights = weights or [1.0, 0.5, 0.25]
    
    def forward(self, outputs, targets):
        if isinstance(outputs, tuple):
            main_output, aux_outputs = outputs
        else:
            return self.base_loss(outputs, targets)
        
        # Main loss
        total_loss = self.weights[0] * self.base_loss(main_output, targets)
        
        # Auxiliary losses with downsampled targets
        for i, aux_output in enumerate(aux_outputs):
            if i + 1 < len(self.weights):
                aux_target = F.interpolate(targets.float(), size=aux_output.shape[2:], mode='nearest')
                total_loss = total_loss + self.weights[i + 1] * self.base_loss(aux_output, aux_target)
        
        return total_loss


def create_loss(config: Config) -> nn.Module:
    """Create the appropriate loss function."""
    if config.deep_supervision:
        return DeepSupervisionLoss()
    return BraTSLoss()

logger.info("Loss functions defined")

2026-02-03 23:24:49 | INFO | Loss functions defined


# 8: Early Stopping

In [9]:
# =============================================================================
# EARLY STOPPING
# =============================================================================

class EarlyStopping:
    """
    Early stopping to terminate training when validation metric stops improving.
    
    Args:
        patience: Number of epochs with no improvement after which training will be stopped.
        min_delta: Minimum change in monitored metric to qualify as an improvement.
        mode: 'max' for metrics where higher is better (e.g., Dice), 'min' for loss.
    """
    
    def __init__(self, patience: int = 10, min_delta: float = 0.001, mode: str = 'max'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_epoch = 0
    
    def __call__(self, score: float, epoch: int) -> bool:
        """
        Check if training should stop.
        
        Args:
            score: Current validation metric
            epoch: Current epoch number
        
        Returns:
            True if training should stop, False otherwise
        """
        if self.best_score is None:
            self.best_score = score
            self.best_epoch = epoch
            return False
        
        if self.mode == 'max':
            improved = score > (self.best_score + self.min_delta)
        else:
            improved = score < (self.best_score - self.min_delta)
        
        if improved:
            self.best_score = score
            self.best_epoch = epoch
            self.counter = 0
            return False
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                return True
            return False
    
    def status(self) -> str:
        """Return current early stopping status."""
        return f"Patience: {self.counter}/{self.patience}, Best: {self.best_score:.4f} @ epoch {self.best_epoch}"

logger.info("Early stopping class defined")

2026-02-03 23:24:49 | INFO | Early stopping class defined


# 9: Trainer Class

In [10]:
# =============================================================================
# TRAINER CLASS WITH EARLY STOPPING, CHECKPOINT SAVING AND ENHANCED LOGGING
# =============================================================================

class Trainer:
    """Complete training loop with early stopping, checkpointing, and logging."""
    
    def __init__(self, model, train_loader, val_loader, config: Config):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        
        # Optimizer
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
        
        # Scheduler
        self.scheduler = CosineAnnealingLR(
            self.optimizer,
            T_max=config.epochs,
            eta_min=1e-7
        )
        
        # Loss
        self.criterion = create_loss(config)
        
        # Mixed precision
        self.scaler = torch.amp.GradScaler('cuda') if config.use_amp else None
        
        # Metrics - Mean Dice (for checking early stopping)
        self.dice_metric = DiceMetric(include_background=True, reduction="mean", num_classes=3)
        # Metrics - Per-Class Dice (for detailed logging)
        # Channels: 0=TC, 1=WT, 2=ET
        self.dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch", num_classes=3)

        
        # Early stopping
        self.early_stopping = EarlyStopping(
            patience=config.early_stopping_patience,
            min_delta=config.early_stopping_min_delta,
            mode='max'
        )
        
        # Tracking
        self.best_metric = -1
        self.best_epoch = -1
        self.train_losses = []
        self.val_metrics = []
        self.learning_rates = []
        self.saved_checkpoints = []
    
    def save_checkpoint(self, epoch: int, val_metric: float, is_best: bool = False):
        """Save model checkpoint."""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_metric': self.best_metric,
            'train_loss': self.train_losses[-1] if self.train_losses else None,
            'val_metric': val_metric,
            'config': {
                'epochs': self.config.epochs,
                'learning_rate': self.config.learning_rate,
                'patch_size': self.config.patch_size,
            }
        }
        
        # Save epoch checkpoint
        if self.config.save_every_epoch:
            checkpoint_path = os.path.join(
                self.config.checkpoint_dir, 
                f"checkpoint_epoch{epoch:02d}_dice{val_metric:.4f}.pth"
            )
            torch.save(checkpoint, checkpoint_path)
            self.saved_checkpoints.append(checkpoint_path)
            logger.info(f"üíæ Checkpoint saved: {checkpoint_path}")
            
            # Cleanup old checkpoints
            if self.config.keep_last_n_checkpoints > 0:
                while len(self.saved_checkpoints) > self.config.keep_last_n_checkpoints:
                    old_ckpt = self.saved_checkpoints.pop(0)
                    if os.path.exists(old_ckpt):
                        os.remove(old_ckpt)
                        logger.info(f"üóëÔ∏è Removed old checkpoint: {os.path.basename(old_ckpt)}")
        
        # Always save best model separately
        if is_best:
            best_path = os.path.join(self.config.output_dir, "best_model.pth")
            torch.save(self.model.state_dict(), best_path)
            logger.info(f"‚≠ê NEW BEST MODEL SAVED! Dice: {val_metric:.4f}")
            
            # Also save full checkpoint for best
            best_ckpt_path = os.path.join(self.config.checkpoint_dir, "best_checkpoint.pth")
            torch.save(checkpoint, best_ckpt_path)
    
    def train_epoch(self, epoch: int) -> float:
        """Train for one epoch."""
        self.model.train()
        epoch_loss = 0
        
        pbar = tqdm(
            self.train_loader, 
            desc=f"Epoch {epoch}/{self.config.epochs} [Train]",
            file=sys.stdout,
            ncols=100,
            ascii=False,
            leave=True,
            bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}'
        )
        for batch_data in pbar:
            inputs = batch_data["image"].to(device)
            labels = batch_data["label"].to(device)
            
            self.optimizer.zero_grad()
            
            if self.config.use_amp:
                with torch.amp.autocast('cuda'):
                    outputs = self.model(inputs)
                    loss = self.criterion(outputs, labels)
                
                self.scaler.scale(loss).backward()
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
                self.optimizer.step()
            
            epoch_loss += loss.item()
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})
        
        print()  # New line after progress bar for clean logging
        return epoch_loss / len(self.train_loader)
    
    @torch.no_grad()
    def validate(self) -> Tuple[float, List[float]]:
        """Run validation and compute metrics (Mean, [TC, WT, ET])."""
        self.model.eval()
        self.dice_metric.reset()
        self.dice_metric_batch.reset()
        
        inferer = SlidingWindowInferer(
            roi_size=self.config.patch_size,
            sw_batch_size=self.config.sw_batch_size,
            overlap=self.config.overlap,
            mode="gaussian"
        )
        
        for batch_data in tqdm(
            self.val_loader, 
            desc="Validating",
            file=sys.stdout,
            ncols=100,
            ascii=False,
            leave=True,
            bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}'
        ):
            inputs = batch_data["image"].to(device)
            labels = batch_data["label"].to(device)
            
            if self.config.use_amp:
                with torch.amp.autocast('cuda'):
                    outputs = inferer(inputs, self.model)
            else:
                outputs = inferer(inputs, self.model)
            
            # Apply sigmoid and threshold
            outputs = (torch.sigmoid(outputs) >= 0.5).float()
            
            # Update metrics
            self.dice_metric(y_pred=outputs, y=labels)
            self.dice_metric_batch(y_pred=outputs, y=labels)
        
        print()  # New line after progress bar for clean logging
        
        # Aggregate results
        mean_dice = self.dice_metric.aggregate().item()
        
        # Get per-channel dice: Channel 0=TC, 1=WT, 2=ET (Standard BraTS channels from Transform)
        # Note: If ConvertToMultiChannelBasedOnBratsClassesd is used, channels are [TC, WT, ET]
        class_dice = self.dice_metric_batch.aggregate().cpu().numpy()
        
        return mean_dice, class_dice
    
    def fit(self) -> float:
        """Main training loop with early stopping."""
        logger.info("="*70)
        logger.info("üöÄ STARTING TRAINING with Enhanced Logging")
        logger.info("="*70)
        logger.info(f"üìÅ Checkpoints: {self.config.checkpoint_dir}")
        logger.info(f"üõë Early stopping patience: {self.config.early_stopping_patience}")
        logger.info(f"üîÑ Epochs: {self.config.epochs}")
        logger.info(f"üì¶ Batch Size: {self.config.batch_size}")
        logger.info(f"üìè Patch Size: {self.config.patch_size}")
        logger.info(f"üß† Initial LR: {self.config.learning_rate}")
        logger.info("="*70)
        sys.stdout.flush()
        
        training_start_time = time.time()
        
        for epoch in range(1, self.config.epochs + 1):
            epoch_start_time = time.time()
            
            # Train
            train_loss = self.train_epoch(epoch)
            self.train_losses.append(train_loss)
            
            # Update learning rate
            self.scheduler.step()
            current_lr = self.scheduler.get_last_lr()[0]
            self.learning_rates.append(current_lr)
            
            # Validate
            val_metric = 0.0
            class_metrics = [0.0, 0.0, 0.0]
            
            if epoch % self.config.val_interval == 0:
                val_metric, class_metrics = self.validate()
                self.val_metrics.append(val_metric)
                
                # Check if best
                is_best = val_metric > self.best_metric
                if is_best:
                    self.best_metric = val_metric
                    self.best_epoch = epoch
                
                # Save checkpoint
                self.save_checkpoint(epoch, val_metric, is_best=is_best)
                
                # Early stopping check
                should_stop = self.early_stopping(val_metric, epoch)
                
                # Log epoch results with detailed per-class breakdown
                epoch_time = time.time() - epoch_start_time
                # Handle class_metrics safely - it may have 2 or 3 elements depending on configuration
                if len(class_metrics) >= 3:
                    tc_dice, wt_dice, et_dice = class_metrics[0], class_metrics[1], class_metrics[2]
                elif len(class_metrics) == 2:
                    # If only 2 metrics, assume they are WT and ET (TC might be missing or combined)
                    tc_dice, wt_dice, et_dice = 0.0, class_metrics[0], class_metrics[1]
                else:
                    # Fallback if unexpected number of metrics
                    tc_dice, wt_dice, et_dice = 0.0, 0.0, 0.0
                
                logger.info("-" * 70)
                logger.info(f"üìä EPOCH {epoch}/{self.config.epochs} RESULTS:")
                logger.info(f"   Train Loss: {train_loss:.4f}")
                logger.info(f"   Validation: Mean Dice = {val_metric:.4f} {'üî• BEST!' if is_best else ''}")
                if len(class_metrics) >= 3:
                    logger.info(f"     ‚îú‚îÄ TC (Tumor Core):    {tc_dice:.4f}")
                    logger.info(f"     ‚îú‚îÄ WT (Whole Tumor):   {wt_dice:.4f}")
                    logger.info(f"     ‚îî‚îÄ ET (Enhancing):     {et_dice:.4f} {'‚ö†Ô∏è LOW' if et_dice < 0.1 else ''}")
                elif len(class_metrics) == 2:
                    logger.info(f"     ‚îú‚îÄ Class 0:            {wt_dice:.4f}")
                    logger.info(f"     ‚îî‚îÄ Class 1:            {et_dice:.4f} {'‚ö†Ô∏è LOW' if et_dice < 0.1 else ''}")
                logger.info(f"   LR: {current_lr:.2e} | Time: {epoch_time/60:.1f} min")
                logger.info(f"   Status: {self.early_stopping.status()}")
                logger.info("-" * 70)
                sys.stdout.flush()
                
                # Check early stopping
                if should_stop:
                    logger.info("="*70)
                    logger.info(f"üõë EARLY STOPPING TRIGGERED at epoch {epoch}")
                    logger.info(f"   No improvement for {self.config.early_stopping_patience} epochs")
                    logger.info("="*70)
                    break
            else:
                epoch_time = time.time() - epoch_start_time
                logger.info(f"Epoch {epoch}: Loss={train_loss:.4f}, LR={current_lr:.2e}, Time={epoch_time/60:.1f}min")
            
            # Clear CUDA cache
            torch.cuda.empty_cache()
        
        # Training complete
        total_time = time.time() - training_start_time
        logger.info("="*70)
        logger.info("‚úÖ TRAINING COMPLETE!")
        logger.info("="*70)
        logger.info(f"üèÜ Best Mean Dice: {self.best_metric:.4f}")
        logger.info(f"üìÖ Best Epoch:     {self.best_epoch}")
        logger.info(f"‚è±Ô∏è Total Time:     {total_time/3600:.2f} hours")
        logger.info(f"üíæ Best Model:     {os.path.join(self.config.output_dir, 'best_model.pth')}")
        logger.info("="*70)
        
        # Save training history
        self._save_training_history()
        
        return self.best_metric
    
    def _save_training_history(self):
        """Save training history to file."""
        history = {
            'train_losses': self.train_losses,
            'val_metrics': self.val_metrics,
            'learning_rates': self.learning_rates,
            'best_metric': self.best_metric,
            'best_epoch': self.best_epoch,
        }
        history_path = os.path.join(self.config.output_dir, 'training_history.npy')
        np.save(history_path, history)
        logger.info(f"Training history saved to {history_path}")

logger.info("Trainer class defined with enhanced logging")

2026-02-03 23:24:49 | INFO | Trainer class defined with enhanced logging


# 10: Data Preparation

In [11]:
# =============================================================================
# DATA PREPARATION - Manual Split (IDs 1333-1666 for Val, others for Train)
# =============================================================================

logger.info("="*60)
logger.info("LOADING DATA")
logger.info("="*60)

# Get file list
all_files = get_file_list(config.data_root)

# Manual split based on patient ID
train_files = []
val_files = []

for f in all_files:
    patient_id = f["patient_id"]
    # Extract numeric ID (e.g., "BraTS2021_01333" -> 1333)
    try:
        num_id = int(patient_id.split("_")[-1])
    except ValueError:
        # If can't parse, put in training
        train_files.append(f)
        continue
    
    # IDs 1333-1666 go to validation, all others to training
    if 1333 <= num_id <= 1666:
        val_files.append(f)
    else:
        train_files.append(f)

logger.info(f"Training samples:   {len(train_files)}")
logger.info(f"Validation samples: {len(val_files)} (full set)")

# Use only 50% of validation files to save time
np.random.seed(config.seed)
val_subset_size = len(val_files) // 2
val_indices = np.random.choice(len(val_files), size=val_subset_size, replace=False)
val_files = [val_files[i] for i in sorted(val_indices)]

logger.info(f"Validation subset:  {len(val_files)} (50% for faster training)")
logger.info("="*60)



2026-02-03 23:24:49 | INFO | LOADING DATA




2026-02-03 23:24:49 | INFO | Scanning 1251 patient folders...


2026-02-03 23:25:09 | INFO | Successfully loaded 1251 valid patient volumes


2026-02-03 23:25:09 | INFO | Training samples:   917


2026-02-03 23:25:09 | INFO | Validation samples: 334 (full set)


2026-02-03 23:25:09 | INFO | Validation subset:  167 (50% for faster training)




# 11: Create Datasets and DataLoaders

In [12]:
# =============================================================================
# CREATE DATASETS AND DATALOADERS
# =============================================================================

logger.info("Creating datasets...")

# Transforms
train_transforms = get_train_transforms(config)
val_transforms = get_val_transforms(config)

# Datasets
train_ds = CacheDataset(
    data=train_files,
    transform=train_transforms,
    cache_rate=config.cache_rate,
    num_workers=config.num_workers,
)

val_ds = CacheDataset(
    data=val_files,
    transform=val_transforms,
    cache_rate=config.cache_rate,
    num_workers=config.num_workers,
)

# DataLoaders
train_loader = DataLoader(
    train_ds,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
)

val_loader = DataLoader(
    val_ds,
    batch_size=1,
    shuffle=False,
    num_workers=config.num_workers,
)

logger.info(f"Datasets created")
logger.info(f"   Train batches: {len(train_loader)}")
logger.info(f"   Val batches:   {len(val_loader)}")

2026-02-03 23:25:09 | INFO | Creating datasets...


Loading dataset:   0%|          | 0/91 [00:00<?, ?it/s]

Loading dataset:   1%|          | 1/91 [00:03<04:41,  3.13s/it]

Loading dataset:   5%|‚ñå         | 5/91 [00:06<01:35,  1.11s/it]

Loading dataset:  10%|‚ñâ         | 9/91 [00:09<01:14,  1.10it/s]

Loading dataset:  13%|‚ñà‚ñé        | 12/91 [00:09<00:47,  1.67it/s]

Loading dataset:  14%|‚ñà‚ñç        | 13/91 [00:12<01:14,  1.05it/s]

Loading dataset:  18%|‚ñà‚ñä        | 16/91 [00:12<00:44,  1.68it/s]

Loading dataset:  20%|‚ñà‚ñâ        | 18/91 [00:15<01:01,  1.19it/s]

Loading dataset:  23%|‚ñà‚ñà‚ñé       | 21/91 [00:17<00:59,  1.18it/s]

Loading dataset:  24%|‚ñà‚ñà‚ñç       | 22/91 [00:18<00:53,  1.28it/s]

Loading dataset:  25%|‚ñà‚ñà‚ñå       | 23/91 [00:18<00:45,  1.48it/s]

Loading dataset:  27%|‚ñà‚ñà‚ñã       | 25/91 [00:20<00:56,  1.17it/s]

Loading dataset:  29%|‚ñà‚ñà‚ñä       | 26/91 [00:21<00:46,  1.40it/s]

Loading dataset:  30%|‚ñà‚ñà‚ñâ       | 27/91 [00:21<00:41,  1.56it/s]

Loading dataset:  32%|‚ñà‚ñà‚ñà‚ñè      | 29/91 [00:23<00:52,  1.18it/s]

Loading dataset:  33%|‚ñà‚ñà‚ñà‚ñé      | 30/91 [00:24<00:44,  1.36it/s]

Loading dataset:  35%|‚ñà‚ñà‚ñà‚ñå      | 32/91 [00:24<00:29,  1.99it/s]

Loading dataset:  36%|‚ñà‚ñà‚ñà‚ñã      | 33/91 [00:26<00:47,  1.21it/s]

Loading dataset:  37%|‚ñà‚ñà‚ñà‚ñã      | 34/91 [00:27<00:45,  1.25it/s]

Loading dataset:  40%|‚ñà‚ñà‚ñà‚ñâ      | 36/91 [00:27<00:29,  1.85it/s]

Loading dataset:  41%|‚ñà‚ñà‚ñà‚ñà      | 37/91 [00:28<00:37,  1.42it/s]

Loading dataset:  42%|‚ñà‚ñà‚ñà‚ñà‚ñè     | 38/91 [00:29<00:42,  1.24it/s]

Loading dataset:  44%|‚ñà‚ñà‚ñà‚ñà‚ñç     | 40/91 [00:30<00:26,  1.90it/s]

Loading dataset:  45%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 41/91 [00:31<00:36,  1.38it/s]

Loading dataset:  46%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 42/91 [00:32<00:39,  1.23it/s]

Loading dataset:  48%|‚ñà‚ñà‚ñà‚ñà‚ñä     | 44/91 [00:32<00:25,  1.85it/s]

Loading dataset:  49%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 45/91 [00:33<00:30,  1.53it/s]

Loading dataset:  51%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 46/91 [00:35<00:35,  1.26it/s]

Loading dataset:  53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 48/91 [00:35<00:24,  1.74it/s]

Loading dataset:  54%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç    | 49/91 [00:36<00:24,  1.69it/s]

Loading dataset:  55%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç    | 50/91 [00:37<00:32,  1.24it/s]

Loading dataset:  56%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå    | 51/91 [00:38<00:27,  1.47it/s]

Loading dataset:  57%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã    | 52/91 [00:38<00:24,  1.61it/s]

Loading dataset:  58%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä    | 53/91 [00:38<00:20,  1.86it/s]

Loading dataset:  59%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ    | 54/91 [00:40<00:26,  1.40it/s]

Loading dataset:  60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 55/91 [00:40<00:23,  1.55it/s]

Loading dataset:  62%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè   | 56/91 [00:41<00:23,  1.48it/s]

Loading dataset:  63%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé   | 57/91 [00:41<00:18,  1.88it/s]

Loading dataset:  64%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé   | 58/91 [00:42<00:21,  1.55it/s]

Loading dataset:  65%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç   | 59/91 [00:42<00:19,  1.61it/s]

Loading dataset:  66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå   | 60/91 [00:43<00:21,  1.44it/s]

Loading dataset:  67%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 61/91 [00:43<00:15,  1.90it/s]

Loading dataset:  68%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä   | 62/91 [00:44<00:19,  1.48it/s]

Loading dataset:  69%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ   | 63/91 [00:45<00:18,  1.54it/s]

Loading dataset:  70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 64/91 [00:46<00:18,  1.49it/s]

Loading dataset:  71%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 65/91 [00:46<00:14,  1.77it/s]

Loading dataset:  73%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 66/91 [00:47<00:16,  1.54it/s]

Loading dataset:  74%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 67/91 [00:47<00:12,  1.90it/s]

Loading dataset:  75%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç  | 68/91 [00:48<00:16,  1.38it/s]

Loading dataset:  77%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã  | 70/91 [00:49<00:13,  1.54it/s]

Loading dataset:  78%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä  | 71/91 [00:50<00:10,  1.91it/s]

Loading dataset:  79%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ  | 72/91 [00:51<00:13,  1.40it/s]

Loading dataset:  80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 73/91 [00:51<00:10,  1.70it/s]

Loading dataset:  81%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè | 74/91 [00:52<00:10,  1.64it/s]

Loading dataset:  82%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè | 75/91 [00:52<00:08,  1.90it/s]

Loading dataset:  84%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé | 76/91 [00:53<00:10,  1.39it/s]

Loading dataset:  85%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç | 77/91 [00:54<00:08,  1.75it/s]

Loading dataset:  86%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå | 78/91 [00:54<00:08,  1.57it/s]

Loading dataset:  87%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã | 79/91 [00:55<00:06,  1.95it/s]

Loading dataset:  88%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä | 80/91 [00:56<00:08,  1.29it/s]

Loading dataset:  89%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ | 81/91 [00:56<00:06,  1.62it/s]

Loading dataset:  90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 82/91 [00:57<00:05,  1.59it/s]

Loading dataset:  92%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè| 84/91 [00:59<00:05,  1.22it/s]

Loading dataset:  95%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç| 86/91 [00:59<00:02,  1.81it/s]

Loading dataset:  96%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 87/91 [00:59<00:01,  2.07it/s]

Loading dataset:  97%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 88/91 [01:02<00:02,  1.09it/s]

Loading dataset:  99%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ| 90/91 [01:02<00:00,  1.70it/s]

Loading dataset: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [01:02<00:00,  1.97it/s]

Loading dataset: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [01:02<00:00,  1.45it/s]




Loading dataset:   0%|          | 0/16 [00:00<?, ?it/s]

Loading dataset:   6%|‚ñã         | 1/16 [00:03<00:51,  3.43s/it]

Loading dataset:  31%|‚ñà‚ñà‚ñà‚ñè      | 5/16 [00:06<00:11,  1.07s/it]

Loading dataset:  38%|‚ñà‚ñà‚ñà‚ñä      | 6/16 [00:06<00:09,  1.10it/s]

Loading dataset:  44%|‚ñà‚ñà‚ñà‚ñà‚ñç     | 7/16 [00:06<00:06,  1.41it/s]

Loading dataset:  56%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã    | 9/16 [00:08<00:06,  1.16it/s]

Loading dataset:  62%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé   | 10/16 [00:09<00:05,  1.15it/s]

Loading dataset:  81%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè | 13/16 [00:11<00:02,  1.43it/s]

Loading dataset:  88%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä | 14/16 [00:12<00:01,  1.27it/s]

Loading dataset: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16/16 [00:12<00:00,  1.30it/s]

2026-02-03 23:26:24 | INFO | Datasets created


2026-02-03 23:26:24 | INFO |    Train batches: 917


2026-02-03 23:26:24 | INFO |    Val batches:   167





# 12: Initialize Model and Start Training

In [13]:
# =============================================================================
# TRAINING
# =============================================================================

# Create model
model = create_model(config)

# Create trainer
trainer = Trainer(model, train_loader, val_loader, config)

# Start training!
best_dice = trainer.fit()

2026-02-03 23:26:24 | INFO | Model created with 27,186,057 parameters




2026-02-03 23:26:24 | INFO | üöÄ STARTING TRAINING with Enhanced Logging




2026-02-03 23:26:24 | INFO | üìÅ Checkpoints: /kaggle/working/checkpoints


2026-02-03 23:26:24 | INFO | üõë Early stopping patience: 10


2026-02-03 23:26:24 | INFO | üîÑ Epochs: 50


2026-02-03 23:26:24 | INFO | üì¶ Batch Size: 1


2026-02-03 23:26:24 | INFO | üìè Patch Size: (128, 128, 128)


2026-02-03 23:26:24 | INFO | üß† Initial LR: 0.0001




Epoch 1/50 [Train]:   0%|          | 0/917 [00:00<?, ?it/s]                                         




Validating:   0%|          | 0/167 [00:00<?, ?it/s]                                                 




2026-02-04 00:18:11 | INFO | üíæ Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch01_dice0.7345.pth


2026-02-04 00:18:12 | INFO | ‚≠ê NEW BEST MODEL SAVED! Dice: 0.7345


2026-02-04 00:18:12 | INFO | ----------------------------------------------------------------------


2026-02-04 00:18:12 | INFO | üìä EPOCH 1/50 RESULTS:


2026-02-04 00:18:12 | INFO |    Train Loss: 2.5739


2026-02-04 00:18:12 | INFO |    Validation: Mean Dice = 0.7345 üî• BEST!


2026-02-04 00:18:12 | INFO |      ‚îú‚îÄ TC (Tumor Core):    0.6943


2026-02-04 00:18:12 | INFO |      ‚îú‚îÄ WT (Whole Tumor):   0.8198


2026-02-04 00:18:12 | INFO |      ‚îî‚îÄ ET (Enhancing):     0.6979 


2026-02-04 00:18:12 | INFO |    LR: 9.99e-05 | Time: 51.8 min


2026-02-04 00:18:12 | INFO |    Status: Patience: 0/10, Best: 0.7345 @ epoch 1


2026-02-04 00:18:12 | INFO | ----------------------------------------------------------------------


Epoch 2/50 [Train]:   0%|          | 0/917 [00:00<?, ?it/s]                                         




Validating:   0%|          | 0/167 [00:00<?, ?it/s]                                                 




2026-02-04 01:09:56 | INFO | üíæ Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch02_dice0.7776.pth


2026-02-04 01:09:56 | INFO | ‚≠ê NEW BEST MODEL SAVED! Dice: 0.7776


2026-02-04 01:09:56 | INFO | ----------------------------------------------------------------------


2026-02-04 01:09:56 | INFO | üìä EPOCH 2/50 RESULTS:


2026-02-04 01:09:56 | INFO |    Train Loss: 1.2727


2026-02-04 01:09:56 | INFO |    Validation: Mean Dice = 0.7776 üî• BEST!


2026-02-04 01:09:56 | INFO |      ‚îú‚îÄ TC (Tumor Core):    0.7466


2026-02-04 01:09:56 | INFO |      ‚îú‚îÄ WT (Whole Tumor):   0.8553


2026-02-04 01:09:56 | INFO |      ‚îî‚îÄ ET (Enhancing):     0.7411 


2026-02-04 01:09:56 | INFO |    LR: 9.96e-05 | Time: 51.7 min


2026-02-04 01:09:56 | INFO |    Status: Patience: 0/10, Best: 0.7776 @ epoch 2


2026-02-04 01:09:56 | INFO | ----------------------------------------------------------------------


Epoch 3/50 [Train]:   0%|          | 0/917 [00:00<?, ?it/s]                                         




Validating:   0%|          | 0/167 [00:00<?, ?it/s]                                                 




2026-02-04 02:03:08 | INFO | üíæ Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch03_dice0.7648.pth


2026-02-04 02:03:08 | INFO | ----------------------------------------------------------------------


2026-02-04 02:03:08 | INFO | üìä EPOCH 3/50 RESULTS:


2026-02-04 02:03:08 | INFO |    Train Loss: 1.0050


2026-02-04 02:03:08 | INFO |    Validation: Mean Dice = 0.7648 


2026-02-04 02:03:08 | INFO |      ‚îú‚îÄ TC (Tumor Core):    0.7240


2026-02-04 02:03:08 | INFO |      ‚îú‚îÄ WT (Whole Tumor):   0.8627


2026-02-04 02:03:08 | INFO |      ‚îî‚îÄ ET (Enhancing):     0.7083 


2026-02-04 02:03:08 | INFO |    LR: 9.91e-05 | Time: 53.2 min


2026-02-04 02:03:08 | INFO |    Status: Patience: 1/10, Best: 0.7776 @ epoch 2


2026-02-04 02:03:08 | INFO | ----------------------------------------------------------------------


Epoch 4/50 [Train]:   0%|          | 0/917 [00:00<?, ?it/s]                                         




Validating:   0%|          | 0/167 [00:00<?, ?it/s]                                                 




2026-02-04 02:54:54 | INFO | üíæ Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch04_dice0.8183.pth


2026-02-04 02:54:54 | INFO | üóëÔ∏è Removed old checkpoint: checkpoint_epoch01_dice0.7345.pth


2026-02-04 02:54:54 | INFO | ‚≠ê NEW BEST MODEL SAVED! Dice: 0.8183


2026-02-04 02:54:55 | INFO | ----------------------------------------------------------------------


2026-02-04 02:54:55 | INFO | üìä EPOCH 4/50 RESULTS:


2026-02-04 02:54:55 | INFO |    Train Loss: 0.9120


2026-02-04 02:54:55 | INFO |    Validation: Mean Dice = 0.8183 üî• BEST!


2026-02-04 02:54:55 | INFO |      ‚îú‚îÄ TC (Tumor Core):    0.8044


2026-02-04 02:54:55 | INFO |      ‚îú‚îÄ WT (Whole Tumor):   0.8865


2026-02-04 02:54:55 | INFO |      ‚îî‚îÄ ET (Enhancing):     0.7636 


2026-02-04 02:54:55 | INFO |    LR: 9.84e-05 | Time: 51.8 min


2026-02-04 02:54:55 | INFO |    Status: Patience: 0/10, Best: 0.8183 @ epoch 4


2026-02-04 02:54:55 | INFO | ----------------------------------------------------------------------


Epoch 5/50 [Train]:   0%|          | 0/917 [00:00<?, ?it/s]                                         




Validating:   0%|          | 0/167 [00:00<?, ?it/s]                                                 




2026-02-04 03:46:46 | INFO | üíæ Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch05_dice0.8172.pth


2026-02-04 03:46:46 | INFO | üóëÔ∏è Removed old checkpoint: checkpoint_epoch02_dice0.7776.pth


2026-02-04 03:46:46 | INFO | ----------------------------------------------------------------------


2026-02-04 03:46:46 | INFO | üìä EPOCH 5/50 RESULTS:


2026-02-04 03:46:46 | INFO |    Train Loss: 0.8605


2026-02-04 03:46:46 | INFO |    Validation: Mean Dice = 0.8172 


2026-02-04 03:46:46 | INFO |      ‚îú‚îÄ TC (Tumor Core):    0.7993


2026-02-04 03:46:46 | INFO |      ‚îú‚îÄ WT (Whole Tumor):   0.8771


2026-02-04 03:46:46 | INFO |      ‚îî‚îÄ ET (Enhancing):     0.7818 


2026-02-04 03:46:46 | INFO |    LR: 9.76e-05 | Time: 51.9 min


2026-02-04 03:46:46 | INFO |    Status: Patience: 1/10, Best: 0.8183 @ epoch 4


2026-02-04 03:46:46 | INFO | ----------------------------------------------------------------------


Epoch 6/50 [Train]:   0%|          | 0/917 [00:00<?, ?it/s]                                         

Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




Exception ignored in: 

 

 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




 

 

Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


 

 

 

 

    

 

self._shutdown_workers()

 




 

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


^

    

if w.is_alive():

^




^

 

^

 

^

 

^

 

^

 

^

 

^

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




^

  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


^

    

^

assert self._parent_pid == os.getpid(), 'can only test a child process'

^

^

^




^

 

^

 

^

 




 

 

AssertionError

: 

 

can only test a child process

 




 

 

 

Exception ignored in: 

 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>

^

^




^

Traceback (most recent call last):


^

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


^

    

^

self._shutdown_workers()

^




^

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


^

    

^

if w.is_alive():

^




 

^

 

^

 

^

 

^

 

^

 

^

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError




: 

  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


can only test a child process

    




assert self._parent_pid == os.getpid(), 'can only test a child process'




 

Exception ignored in: 

 

 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>

 




 

Traceback (most recent call last):


 

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


 

    

 

self._shutdown_workers()

 




 

 

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


^

    

^

if w.is_alive():

^




^

^

 

^

 

^

^

 

 

^

^

 

^

 

^

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




^

  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

^

assert self._parent_pid == os.getpid(), 'can only test a child process'

^




 

^

 

^




 

AssertionError

: 

 

can only test a child process




 

 

Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>

 




 

Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


 

 

    

self._shutdown_workers()

 




^

^

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


^

    

^

if w.is_alive():

^




^

 

^

 

 

^

 

^

 

^

 

^

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




^

  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


^

^

    

^




assert self._parent_pid == os.getpid(), 'can only test a child process'

AssertionError




: 

 

can only test a child process




 

 

Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>

 




 

Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


 

 

    

self._shutdown_workers()

 




 

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


 

    

 

^

if w.is_alive():

^




^

^

 

^

 

^

^

 

^

 

^

^

 

^

^

 

^

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




^

  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

^

assert self._parent_pid == os.getpid(), 'can only test a child process'

^




 

^

 

^

 

 

^

 

^

 




 

AssertionError

 

 

: 

 

can only test a child process

 




^

^

^

^

^

Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>

^

^

^




^

Traceback (most recent call last):


^

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


^

    

^

^

self._shutdown_workers()

^




^

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


^

    

^

if w.is_alive():

^




^

 

^

 

^

^

 

^

 

^

 

^

^

 

^

 

^

^

^

^

^

^

^

^

^




^

AssertionError

^

: 

^

can only test a child process




^

^

Exception ignored in: 

^

^

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>







Traceback (most recent call last):


  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

    

self._shutdown_workers()

assert self._parent_pid == os.getpid(), 'can only test a child process'




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers





 

    

 

if w.is_alive():




 

 

 

 

 

 

 

 

 

 

 

 

 

 

^

 

^

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




^

  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

^

assert self._parent_pid == os.getpid(), 'can only test a child process'

^

^




 

^

 

^

 

 

^

^

 

 

^

 

^

^

 

^

 

^

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




^

AssertionError

^

: 

^

can only test a child process

^




^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e0ef41ba8e0>




Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__


    

self._shutdown_workers()




  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers


    

if w.is_alive():




 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^




  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




 

 

 

 

 

 

 

 

 

 

 

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^

^




AssertionError

: 

can only test a child process







Validating:   0%|          | 0/167 [00:00<?, ?it/s]                                                 




2026-02-04 04:40:36 | INFO | üíæ Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch06_dice0.8178.pth


2026-02-04 04:40:36 | INFO | üóëÔ∏è Removed old checkpoint: checkpoint_epoch03_dice0.7648.pth


2026-02-04 04:40:36 | INFO | ----------------------------------------------------------------------


2026-02-04 04:40:36 | INFO | üìä EPOCH 6/50 RESULTS:


2026-02-04 04:40:36 | INFO |    Train Loss: 0.7933


2026-02-04 04:40:36 | INFO |    Validation: Mean Dice = 0.8178 


2026-02-04 04:40:36 | INFO |      ‚îú‚îÄ TC (Tumor Core):    0.8008


2026-02-04 04:40:36 | INFO |      ‚îú‚îÄ WT (Whole Tumor):   0.8790


2026-02-04 04:40:36 | INFO |      ‚îî‚îÄ ET (Enhancing):     0.7784 


2026-02-04 04:40:36 | INFO |    LR: 9.65e-05 | Time: 53.8 min


2026-02-04 04:40:36 | INFO |    Status: Patience: 2/10, Best: 0.8183 @ epoch 4


2026-02-04 04:40:36 | INFO | ----------------------------------------------------------------------


Epoch 7/50 [Train]:   0%|          | 0/917 [00:00<?, ?it/s]                                         




Validating:   0%|          | 0/167 [00:00<?, ?it/s]                                                 




2026-02-04 05:32:22 | INFO | üíæ Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch07_dice0.7988.pth


2026-02-04 05:32:22 | INFO | üóëÔ∏è Removed old checkpoint: checkpoint_epoch04_dice0.8183.pth


2026-02-04 05:32:22 | INFO | ----------------------------------------------------------------------


2026-02-04 05:32:22 | INFO | üìä EPOCH 7/50 RESULTS:


2026-02-04 05:32:22 | INFO |    Train Loss: 0.7519


2026-02-04 05:32:22 | INFO |    Validation: Mean Dice = 0.7988 


2026-02-04 05:32:22 | INFO |      ‚îú‚îÄ TC (Tumor Core):    0.7630


2026-02-04 05:32:22 | INFO |      ‚îú‚îÄ WT (Whole Tumor):   0.8835


2026-02-04 05:32:22 | INFO |      ‚îî‚îÄ ET (Enhancing):     0.7601 


2026-02-04 05:32:22 | INFO |    LR: 9.52e-05 | Time: 51.8 min


2026-02-04 05:32:22 | INFO |    Status: Patience: 3/10, Best: 0.8183 @ epoch 4


2026-02-04 05:32:22 | INFO | ----------------------------------------------------------------------


Epoch 8/50 [Train]:   0%|          | 0/917 [00:00<?, ?it/s]                                         




Validating:   0%|          | 0/167 [00:00<?, ?it/s]                                                 




2026-02-04 06:25:29 | INFO | üíæ Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch08_dice0.8212.pth


2026-02-04 06:25:29 | INFO | üóëÔ∏è Removed old checkpoint: checkpoint_epoch05_dice0.8172.pth


2026-02-04 06:25:29 | INFO | ‚≠ê NEW BEST MODEL SAVED! Dice: 0.8212


2026-02-04 06:25:30 | INFO | ----------------------------------------------------------------------


2026-02-04 06:25:30 | INFO | üìä EPOCH 8/50 RESULTS:


2026-02-04 06:25:30 | INFO |    Train Loss: 0.7224


2026-02-04 06:25:30 | INFO |    Validation: Mean Dice = 0.8212 üî• BEST!


2026-02-04 06:25:30 | INFO |      ‚îú‚îÄ TC (Tumor Core):    0.8013


2026-02-04 06:25:30 | INFO |      ‚îú‚îÄ WT (Whole Tumor):   0.8866


2026-02-04 06:25:30 | INFO |      ‚îî‚îÄ ET (Enhancing):     0.7849 


2026-02-04 06:25:30 | INFO |    LR: 9.38e-05 | Time: 53.1 min


2026-02-04 06:25:30 | INFO |    Status: Patience: 0/10, Best: 0.8212 @ epoch 8


2026-02-04 06:25:30 | INFO | ----------------------------------------------------------------------


Epoch 9/50 [Train]:   0%|          | 0/917 [00:00<?, ?it/s]                                         




Validating:   0%|          | 0/167 [00:00<?, ?it/s]                                                 




2026-02-04 07:17:17 | INFO | üíæ Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch09_dice0.8351.pth


2026-02-04 07:17:17 | INFO | üóëÔ∏è Removed old checkpoint: checkpoint_epoch06_dice0.8178.pth


2026-02-04 07:17:17 | INFO | ‚≠ê NEW BEST MODEL SAVED! Dice: 0.8351


2026-02-04 07:17:17 | INFO | ----------------------------------------------------------------------


2026-02-04 07:17:17 | INFO | üìä EPOCH 9/50 RESULTS:


2026-02-04 07:17:17 | INFO |    Train Loss: 0.7032


2026-02-04 07:17:17 | INFO |    Validation: Mean Dice = 0.8351 üî• BEST!


2026-02-04 07:17:17 | INFO |      ‚îú‚îÄ TC (Tumor Core):    0.8187


2026-02-04 07:17:17 | INFO |      ‚îú‚îÄ WT (Whole Tumor):   0.8966


2026-02-04 07:17:17 | INFO |      ‚îî‚îÄ ET (Enhancing):     0.7956 


2026-02-04 07:17:17 | INFO |    LR: 9.22e-05 | Time: 51.8 min


2026-02-04 07:17:17 | INFO |    Status: Patience: 0/10, Best: 0.8351 @ epoch 9


2026-02-04 07:17:17 | INFO | ----------------------------------------------------------------------


Epoch 10/50 [Train]:   0%|          | 0/917 [00:00<?, ?it/s]                                        




Validating:   0%|          | 0/167 [00:00<?, ?it/s]                                                 




2026-02-04 08:09:09 | INFO | üíæ Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch10_dice0.8142.pth


2026-02-04 08:09:09 | INFO | üóëÔ∏è Removed old checkpoint: checkpoint_epoch07_dice0.7988.pth


2026-02-04 08:09:09 | INFO | ----------------------------------------------------------------------


2026-02-04 08:09:09 | INFO | üìä EPOCH 10/50 RESULTS:


2026-02-04 08:09:09 | INFO |    Train Loss: 0.6995


2026-02-04 08:09:09 | INFO |    Validation: Mean Dice = 0.8142 


2026-02-04 08:09:09 | INFO |      ‚îú‚îÄ TC (Tumor Core):    0.7931


2026-02-04 08:09:09 | INFO |      ‚îú‚îÄ WT (Whole Tumor):   0.8962


2026-02-04 08:09:09 | INFO |      ‚îî‚îÄ ET (Enhancing):     0.7575 


2026-02-04 08:09:09 | INFO |    LR: 9.05e-05 | Time: 51.9 min


2026-02-04 08:09:09 | INFO |    Status: Patience: 1/10, Best: 0.8351 @ epoch 9


2026-02-04 08:09:09 | INFO | ----------------------------------------------------------------------


Epoch 11/50 [Train]:   0%|          | 0/917 [00:00<?, ?it/s]                                        

# 13: Visualize Training Progress

In [None]:
# =============================================================================
# TRAINING VISUALIZATION
# =============================================================================

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Training loss
axes[0].plot(trainer.train_losses, 'b-', linewidth=2)
axes[0].set_xlabel("Epoch", fontsize=12)
axes[0].set_ylabel("Loss", fontsize=12)
axes[0].set_title("Training Loss", fontsize=14)
axes[0].grid(True, alpha=0.3)

# Validation Dice
val_epochs = list(range(config.val_interval, len(trainer.val_metrics) * config.val_interval + 1, config.val_interval))
axes[1].plot(val_epochs, trainer.val_metrics, 'g-', linewidth=2, marker='o')
axes[1].axhline(y=trainer.best_metric, color='r', linestyle='--', linewidth=2, label=f'Best: {trainer.best_metric:.4f}')
axes[1].axvline(x=trainer.best_epoch, color='r', linestyle=':', alpha=0.5)
axes[1].set_xlabel("Epoch", fontsize=12)
axes[1].set_ylabel("Mean Dice", fontsize=12)
axes[1].set_title("Validation Dice Score", fontsize=14)
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)

# Learning rate
axes[2].plot(trainer.learning_rates, 'm-', linewidth=2)
axes[2].set_xlabel("Epoch", fontsize=12)
axes[2].set_ylabel("Learning Rate", fontsize=12)
axes[2].set_title("Learning Rate Schedule", fontsize=14)
axes[2].set_yscale('log')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(config.output_dir, "training_curves.png"), dpi=150, bbox_inches='tight')
plt.show()

logger.info(f"Training curves saved to {config.output_dir}/training_curves.png")

# 14: Summary

In [None]:
# =============================================================================
# TRAINING SUMMARY
# =============================================================================

print("="*60)
print(" TRAINING SUMMARY")
print("="*60)
print(f"")
print(f" Best Validation Dice: {trainer.best_metric:.4f}")
print(f" Best Epoch:           {trainer.best_epoch}")
print(f" Final Training Loss:  {trainer.train_losses[-1]:.4f}")
print(f"")
print(" OUTPUT FILES:")
print(f"   ‚Ä¢ best_model.pth     - Use this for inference")
print(f"   ‚Ä¢ checkpoints/       - All epoch checkpoints")
print(f"   ‚Ä¢ training_curves.png")
print(f"   ‚Ä¢ training_history.npy")
print(f"   ‚Ä¢ training.log")
print(f"")
print("="*60)

---

# Quick Reference

## Model Weights Location
```
/kaggle/working/best_model.pth
```

## For Inference (improved_submission.ipynb)
Update the config path:
```python
config.model_path = "/kaggle/input/YOUR_DATASET_NAME/best_model.pth"
```

## Model Architecture Compatibility
This training notebook uses the **same architecture** as `improved_submission.ipynb`:
- Layer names: `upsample`, `output_head`
- Deep supervision heads: `ds_head1`, `ds_head2` (loaded with `strict=False`)

## Early Stopping
- Patience: 10 epochs
- Min delta: 0.001
- Metric: Validation Dice (higher is better)

---