# BraTS Improved Submission Pipeline

This notebook implements key improvements to increase Dice score:

1. **Correct RLE Encoding**: C-order (row-major), 1-indexed
2. **Proper Orientation**: RAS for inference ‚Üí LPS for submission
3. **Threshold Optimization**: Per-class thresholds
4. **Post-processing**: Remove small components, fill holes, enforce hierarchy
5. **Test-Time Augmentation (TTA)**: Average predictions from flipped volumes

Run all cells in order.

In [None]:
# Cell 1: Install dependencies
!pip install -q monai nibabel scipy scikit-image tqdm pandas numpy torch

In [None]:
import os
import sys
import glob
import logging
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
import nibabel as nib
from tqdm.auto import tqdm
from scipy.ndimage import zoom, label as scipy_label, binary_fill_holes, binary_dilation, binary_erosion
from skimage.morphology import remove_small_objects
from dataclasses import dataclass
from typing import Tuple, List, Dict

from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, ConcatItemsd, DeleteItemsd,
    Orientationd, Spacingd, NormalizeIntensityd, CropForegroundd, SpatialPadd,
)
from monai.inferers import SlidingWindowInferer

In [None]:
def setup_logger(name: str = "BraTS_Inference", log_file: str = None) -> logging.Logger:
    """Setup logger that prints to stdout (visible on Kaggle) and optionally to file."""
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    logger.handlers = []  # Clear existing handlers
    logger.propagate = False  # Prevent duplicate logging with INFO:BraTS prefix

    # Console handler - prints to stdout (visible on Kaggle)
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.INFO)
    console_format = logging.Formatter(
        '%(asctime)s | %(levelname)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    console_handler.setFormatter(console_format)
    logger.addHandler(console_handler)

    # File handler (optional)
    if log_file:
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(console_format)
        logger.addHandler(file_handler)

    return logger


# Initialize logger
logger = setup_logger("BraTS_Inference", "/kaggle/working/inference.log")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
sys.stdout.flush()


In [None]:

@dataclass
class Config:
    # Paths - UPDATE THESE FOR KAGGLE
    test_dir: str = "/kaggle/input/instant-odc-ai-hackathon/test"
    model_path: str = "/kaggle/input/best-5-epoch-model/best_model.pth"
    output_csv: str = "/kaggle/working/submission.csv"

    # Model architecture
    patch_size: Tuple[int, int, int] = (128, 128, 128)
    in_channels: int = 4
    out_channels: int = 3

    # Inference settings
    sw_batch_size: int = 2
    overlap: float = 0.6  # Good for boundary consistency
    use_amp: bool = True
    target_spacing: Tuple[float, float, float] = (1.0, 1.0, 1.0)

    # BALANCED thresholds (between 0.7137 and 0.6465 runs)
    threshold_tc: float = 0.52  # Slightly higher than original 0.5
    threshold_wt: float = 0.47  # Between 0.45 and 0.50
    threshold_et: float = 0.57  # Between 0.55 and 0.60

    # Post-processing (MODERATE)
    min_component_size: int = 150  # Between 100 and 200
    fill_holes: bool = True
    enforce_hierarchy: bool = True
    apply_erosion: bool = False  # DISABLED - was too aggressive
    erosion_iterations: int = 0

    # Test-Time Augmentation
    use_tta: bool = True
    tta_flips: List[int] = None

    # Evaluation / Validation (optional - for datasets with ground truth)
    enable_evaluation: bool = True  # Set to True to calculate Dice scores
    ground_truth_dir: str = "/kaggle/input/brain-tumor-segmentation-hackathon"  # Path to hackathon dataset with labels


config = Config()
config.tta_flips = [0, 1, 2]

MODALITY_KEYS = ["flair", "t1", "t1ce", "t2"]
ORIGINAL_SHAPE = (240, 240, 155)

logger.info(f"Config loaded (BALANCED):")
logger.info(f"  Thresholds: TC={config.threshold_tc}, WT={config.threshold_wt}, ET={config.threshold_et}")
logger.info(f"  TTA enabled: {config.use_tta}")
logger.info(f"  Post-processing: min_size={config.min_component_size}, fill_holes={config.fill_holes}")
logger.info(f"  Erosion: {config.apply_erosion}")
logger.info(f"  Overlap: {config.overlap}")
sys.stdout.flush()

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv3d(in_ch, out_ch, 3, padding=1, bias=False)
        self.norm1 = nn.InstanceNorm3d(out_ch, affine=True)
        self.conv2 = nn.Conv3d(out_ch, out_ch, 3, padding=1, bias=False)
        self.norm2 = nn.InstanceNorm3d(out_ch, affine=True)
        self.act = nn.LeakyReLU(0.01, True)

    def forward(self, x):
        return self.act(self.norm2(self.conv2(self.act(self.norm1(self.conv1(x))))))


class DownBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = ConvBlock(in_ch, out_ch)
        self.pool = nn.MaxPool3d(2)

    def forward(self, x):
        skip = self.conv(x)
        return self.pool(skip), skip


class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        # Renamed 'up' to 'upsample' to match checkpoint
        self.upsample = nn.ConvTranspose3d(in_ch, in_ch, 2, stride=2)
        self.conv = ConvBlock(in_ch + skip_ch, out_ch)

    def forward(self, x, skip):
        # Use renamed layer
        x = self.upsample(x)
        if x.shape != skip.shape:
            d = [skip.shape[i + 2] - x.shape[i + 2] for i in range(3)]
            x = F.pad(x, [d[2] // 2, d[2] - d[2] // 2, d[1] // 2, d[1] - d[1] // 2, d[0] // 2, d[0] - d[0] // 2])
        return self.conv(torch.cat([x, skip], 1))


class UNet3D(nn.Module):
    def __init__(self):
        super().__init__()
        f = [64, 96, 128, 192, 256, 384]
        self.encoders = nn.ModuleList([DownBlock(4 if i == 0 else f[i - 1], c) for i, c in enumerate(f[:-1])])
        self.bottleneck = ConvBlock(f[-2], f[-1])
        self.decoders = nn.ModuleList(
            [UpBlock(f[-1] if i == 0 else f[-2 - i + 1], f[-2 - i], f[-2 - i]) for i in range(len(f) - 1)])
        # Renamed 'out' to 'output_head' to match checkpoint
        self.output_head = nn.Conv3d(f[0], 3, 1)

        # Deep Supervision heads expected by checkpoint (ignored in forward)
        # We don't necessarily need to define them if we use strict=False,
        # but defining them prevents "unexpected key" messages if we cared.
        # We will use strict=False.

    def forward(self, x):
        skips = []
        for e in self.encoders:
            x, s = e(x)
            skips.append(s)
        x = self.bottleneck(x)
        for d, s in zip(self.decoders, skips[::-1]):
            x = d(x, s)
        # Use renamed layer
        return self.output_head(x)


# Load model
logger.info("=" * 70)
logger.info("MODEL LOADING")
logger.info("=" * 70)
model = UNet3D().to(device)
if os.path.exists(config.model_path):
    logger.info(f"Loading model from {config.model_path}...")
    sys.stdout.flush()
    try:
        state_dict = torch.load(config.model_path, map_location=device)
        # Check if state_dict is inside a key (e.g., 'model_state_dict')
        if 'model_state_dict' in state_dict:
            state_dict = state_dict['model_state_dict']

        # Use strict=False to ignore ds_head weights (Auxiliary heads)
        model.load_state_dict(state_dict, strict=False)
        logger.info(f"‚úÖ Model loaded successfully (strict=False)")
        sys.stdout.flush()
    except Exception as e:
        logger.error(f"‚ùå Failed to load model: {e}")
        sys.stdout.flush()
else:
    logger.error(f"‚ùå Model not found at {config.model_path}")
    sys.stdout.flush()

model.eval()
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
logger.info("=" * 70)
sys.stdout.flush()


In [None]:
def rle_encode_c_order(mask: np.ndarray) -> str:
    """
    RLE encoding using C-order (row-major), 1-indexed.
    THIS IS THE CORRECT FORMAT FOR THE COMPETITION!

    Args:
        mask: 3D binary numpy array (240, 240, 155)

    Returns:
        RLE string: "start1 length1 start2 length2 ..."
    """
    # CRITICAL: Use C-order (row-major) flattening
    flat = mask.flatten(order='C')

    if flat.sum() == 0:
        return ""

    # Pad with zeros to detect runs at boundaries
    flat = np.concatenate([[0], flat, [0]])
    runs = np.where(flat[1:] != flat[:-1])[0]

    runs = runs.reshape(-1, 2)
    starts = runs[:, 0] + 1  # Convert to 1-indexed
    lengths = runs[:, 1] - runs[:, 0]

    rle_pairs = [f"{s} {l}" for s, l in zip(starts, lengths)]
    return " ".join(rle_pairs)


def rle_decode_c_order(rle_string: str, shape: tuple = (240, 240, 155)) -> np.ndarray:
    """Decode RLE string to 3D mask (C-order, 1-indexed)."""
    if not rle_string or rle_string.strip() == '':
        return np.zeros(shape, dtype=np.uint8)

    mask = np.zeros(np.prod(shape), dtype=np.uint8)
    rle_pairs = rle_string.strip().split()

    for i in range(0, len(rle_pairs), 2):
        start = int(rle_pairs[i]) - 1  # Convert to 0-indexed
        length = int(rle_pairs[i + 1])
        mask[start:start + length] = 1

    return mask.reshape(shape, order='C')


logger.info("‚úÖ RLE encoding functions (C-order) loaded")


# Cell 5.5: Dice Score Calculation (for per-sample diagnostics)
def calculate_dice_score(pred: np.ndarray, gt: np.ndarray) -> float:
    """
    Calculate Dice coefficient between prediction and ground truth.

    Dice = 2 * |X ‚à© Y| / (|X| + |Y|)

    Args:
        pred: Binary prediction mask
        gt: Binary ground truth mask

    Returns:
        Dice score (0.0 to 1.0). Returns 1.0 if both masks are empty.
    """
    pred = pred.astype(bool)
    gt = gt.astype(bool)

    # Handle case where both are empty (perfect score)
    if pred.sum() == 0 and gt.sum() == 0:
        return 1.0

    # Handle case where only one is empty
    if pred.sum() == 0 or gt.sum() == 0:
        return 0.0

    intersection = np.logical_and(pred, gt).sum()
    return 2.0 * intersection / (pred.sum() + gt.sum())


def calculate_per_class_dice(pred_masks: Dict[int, np.ndarray], gt_masks: Dict[int, np.ndarray]) -> Dict[int, float]:
    """
    Calculate Dice scores for each class.

    Args:
        pred_masks: Dict mapping class (1, 2, 4) to predicted binary mask
        gt_masks: Dict mapping class (1, 2, 4) to ground truth binary mask

    Returns:
        Dict mapping class to Dice score
    """
    dice_scores = {}
    for cls in [1, 2, 4]:
        pred = pred_masks.get(cls, np.zeros_like(list(gt_masks.values())[0]))
        gt = gt_masks.get(cls, np.zeros_like(list(pred_masks.values())[0]))
        dice_scores[cls] = calculate_dice_score(pred, gt)
    return dice_scores


def load_ground_truth_masks(patient_folder: str, target_shape: tuple = (240, 240, 155)) -> Dict[int, np.ndarray]:
    """
    Load ground truth segmentation mask and convert to per-class binary masks.

    The ground truth is loaded in LPS orientation to match the submission format.

    BraTS label convention:
    - 0: Background
    - 1: NCR/NET (Necrotic/Non-enhancing Tumor Core)
    - 2: ED (Peritumoral Edema)
    - 4: ET (Enhancing Tumor)

    Args:
        patient_folder: Path to patient folder containing seg.nii.gz
        target_shape: Expected output shape (should match submission format)

    Returns:
        Dict mapping class (1, 2, 4) to binary mask in LPS orientation
    """
    # Find segmentation file
    seg_path = find_nifti_file(patient_folder, "seg")
    seg_nii = nib.load(seg_path)
    seg_data = seg_nii.get_fdata().astype(np.uint8)

    # The ground truth is typically in LPS orientation already
    # But we need to ensure it matches our submission orientation
    # Load with nibabel and check orientation
    orig_ornt = nib.orientations.io_orientation(seg_nii.affine)
    target_ornt = nib.orientations.axcodes2ornt(('L', 'P', 'S'))

    if not np.array_equal(orig_ornt, target_ornt):
        # Reorient to LPS
        transform = nib.orientations.ornt_transform(orig_ornt, target_ornt)
        seg_data = nib.orientations.apply_orientation(seg_data, transform)

    # Resize if needed
    if seg_data.shape != target_shape:
        zoom_factors = [t / s for t, s in zip(target_shape, seg_data.shape)]
        seg_data = zoom(seg_data.astype(np.float32), zoom_factors, order=0).astype(np.uint8)
        # Ensure exact shape
        seg_data = seg_data[:target_shape[0], :target_shape[1], :target_shape[2]]

    # Extract per-class masks
    masks = {
        1: (seg_data == 1).astype(np.uint8),  # NCR/NET
        2: (seg_data == 2).astype(np.uint8),  # Edema
        4: (seg_data == 4).astype(np.uint8),  # Enhancing Tumor
    }

    return masks


def find_ground_truth_folder(patient_id: str, gt_dir: str) -> str:
    """
    Find the ground truth folder for a given patient ID.

    Handles potential naming differences between test set and hackathon dataset.

    Args:
        patient_id: Patient ID from test set (e.g., "BraTS2021_00000")
        gt_dir: Root directory of ground truth dataset

    Returns:
        Path to ground truth patient folder, or None if not found
    """
    # Try exact match first
    exact_path = os.path.join(gt_dir, patient_id)
    if os.path.isdir(exact_path):
        return exact_path

    # Try to find by numeric ID
    try:
        num_id = patient_id.split("_")[-1]
        candidates = glob.glob(os.path.join(gt_dir, f"*{num_id}*"))
        folders = [c for c in candidates if os.path.isdir(c)]
        if folders:
            return folders[0]
    except Exception:
        pass

    return None


logger.info("‚úÖ Dice score calculation functions loaded")


In [None]:
def find_nifti_file(base_path: str, keyword: str) -> str:
    """Find NIfTI file containing keyword."""
    candidates = glob.glob(os.path.join(base_path, "**", f"*{keyword}*.nii*"), recursive=True)
    real_files = [f for f in candidates if os.path.isfile(f)]
    if not real_files:
        raise FileNotFoundError(f"No file found for '{keyword}' in {base_path}")
    return max(real_files, key=len)


def get_inference_transforms(config: Config) -> Compose:
    """
    Inference transforms - uses RAS orientation (like training).
    Includes CropForeground to match training pipeline.
    """
    return Compose([
        LoadImaged(keys=MODALITY_KEYS, image_only=False),
        EnsureChannelFirstd(keys=MODALITY_KEYS),
        ConcatItemsd(keys=MODALITY_KEYS, name="image", dim=0),
        DeleteItemsd(keys=MODALITY_KEYS),
        Orientationd(keys=["image"], axcodes="RAS"),  # RAS for inference (like training)
        Spacingd(keys=["image"], pixdim=config.target_spacing, mode="bilinear"),
        NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
        CropForegroundd(keys=["image"], source_key="image"),
        SpatialPadd(keys=["image"], spatial_size=config.patch_size),
    ])


def get_pre_crop_transforms() -> Compose:
    """Transforms to get pre-crop shape for coordinate mapping."""
    return Compose([
        LoadImaged(keys=MODALITY_KEYS, image_only=False),
        EnsureChannelFirstd(keys=MODALITY_KEYS),
        ConcatItemsd(keys=MODALITY_KEYS, name="image", dim=0),
        DeleteItemsd(keys=MODALITY_KEYS),
        Orientationd(keys=["image"], axcodes="RAS"),
        Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode="bilinear"),
        NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
    ])


def get_crop_info_transforms() -> Compose:
    """Transforms to get exact crop shape."""
    return Compose([
        LoadImaged(keys=MODALITY_KEYS, image_only=False),
        EnsureChannelFirstd(keys=MODALITY_KEYS),
        ConcatItemsd(keys=MODALITY_KEYS, name="image", dim=0),
        DeleteItemsd(keys=MODALITY_KEYS),
        Orientationd(keys=["image"], axcodes="RAS"),
        Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode="bilinear"),
        NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
        CropForegroundd(keys=["image"], source_key="image"),
    ])


logger.info("‚úÖ Transform functions loaded")


In [None]:
def post_process_mask(mask: np.ndarray, min_size: int = 100, fill_holes: bool = True,
                      apply_erosion: bool = False, erosion_iterations: int = 1) -> np.ndarray:
    """
    Post-process a binary mask:
    1. Remove small connected components
    2. Fill holes
    3. Optional erosion to shrink boundaries

    Args:
        mask: 3D binary mask
        min_size: Minimum component size to keep
        fill_holes: Whether to fill holes in the mask
        apply_erosion: Whether to apply erosion
        erosion_iterations: Number of erosion iterations

    Returns:
        Cleaned mask
    """
    if mask.sum() == 0:
        return mask

    cleaned = mask.copy().astype(np.uint8)

    # Remove small components
    labeled, num_features = scipy_label(cleaned)
    if num_features > 0:
        component_sizes = np.bincount(labeled.ravel())
        component_sizes[0] = 0  # Ignore background
        large_mask = component_sizes >= min_size
        cleaned = large_mask[labeled].astype(np.uint8)

    # Fill holes
    if fill_holes and cleaned.sum() > 0:
        cleaned = binary_fill_holes(cleaned).astype(np.uint8)

    # Apply erosion to shrink boundaries (reduces false positives)
    if apply_erosion and cleaned.sum() > 0 and erosion_iterations > 0:
        cleaned = binary_erosion(cleaned, iterations=erosion_iterations).astype(np.uint8)

    return cleaned


def enforce_tumor_hierarchy(tc: np.ndarray, wt: np.ndarray, et: np.ndarray) -> Tuple[
    np.ndarray, np.ndarray, np.ndarray]:
    """
    Enforce anatomical constraints: ET ‚äÜ TC ‚äÜ WT

    The tumor regions should follow:
    - Enhancing Tumor (ET) is inside Tumor Core (TC)
    - Tumor Core (TC) is inside Whole Tumor (WT)

    Args:
        tc: Tumor Core mask
        wt: Whole Tumor mask
        et: Enhancing Tumor mask

    Returns:
        Corrected (tc, wt, et)
    """
    # ET must be inside TC
    et_corrected = (et & tc).astype(np.uint8)

    # TC must be inside WT
    tc_corrected = (tc & wt).astype(np.uint8)

    # WT stays as is (it's the outermost region)
    wt_corrected = wt.astype(np.uint8)

    # Final check: ET must be in corrected TC
    et_corrected = (et_corrected & tc_corrected).astype(np.uint8)

    return tc_corrected, wt_corrected, et_corrected


def keep_largest_component(mask: np.ndarray) -> np.ndarray:
    """
    Keep only the largest connected component.
    Useful for removing isolated false positive regions.
    """
    if mask.sum() == 0:
        return mask

    labeled, num_features = scipy_label(mask)
    if num_features <= 1:
        return mask

    component_sizes = np.bincount(labeled.ravel())
    component_sizes[0] = 0  # Ignore background
    largest_label = component_sizes.argmax()

    return (labeled == largest_label).astype(np.uint8)


logger.info("‚úÖ Post-processing functions loaded (with erosion support)")


In [None]:
def run_inference_with_tta(model, image_tensor: torch.Tensor, inferer,
                           flip_dims: List[int] = [2, 3, 4]) -> torch.Tensor:
    """
    Run inference with test-time augmentation (flip augmentations).

    Args:
        model: PyTorch model
        image_tensor: Input tensor [B, C, D, H, W]
        inferer: MONAI SlidingWindowInferer
        flip_dims: Dimensions to flip (2=D, 3=H, 4=W for 3D)

    Returns:
        Averaged prediction probabilities
    """
    model.eval()
    predictions = []

    with torch.no_grad():
        # Original
        with autocast():
            out = inferer(image_tensor, model)
            if isinstance(out, tuple):
                out = out[0]
            predictions.append(torch.sigmoid(out))

        # Flipped versions
        for dim in flip_dims:
            img_flipped = torch.flip(image_tensor, dims=[dim])
            with autocast():
                out = inferer(img_flipped, model)
                if isinstance(out, tuple):
                    out = out[0]
                # Flip prediction back
                pred_flipped = torch.flip(torch.sigmoid(out), dims=[dim])
                predictions.append(pred_flipped)

    # Average all predictions
    avg_pred = torch.stack(predictions, dim=0).mean(dim=0)
    return avg_pred


logger.info("‚úÖ TTA function loaded")

In [None]:
def process_single_case(
        model,
        patient_folder: str,
        config: Config,
        inferer,
        inference_transforms,
        pre_crop_transforms,
        crop_info_transforms,
        return_masks: bool = False  # [EVAL] Return decoded masks for evaluation
) -> Dict[int, str]:
    """
    Process a single patient case and return RLE for each class.

    Args:
        model: PyTorch model
        patient_folder: Path to patient folder
        config: Configuration object
        inferer: MONAI SlidingWindowInferer
        inference_transforms: Transform pipeline
        pre_crop_transforms: Pre-crop transform pipeline
        crop_info_transforms: Crop info transform pipeline
        return_masks: If True, also return decoded masks for evaluation

    Returns:
        If return_masks=False: Dict mapping class (1, 2, 4) to RLE string
        If return_masks=True: Tuple of (RLE dict, masks dict) where masks dict
                              maps class (1, 2, 4) to binary numpy arrays
    """
    patient_id = os.path.basename(patient_folder)

    # Load original file for shape info
    flair_path = find_nifti_file(patient_folder, "flair")
    original_nii = nib.load(flair_path)
    original_shape = original_nii.shape  # (240, 240, 155)

    # Create data dict
    data = {k: find_nifti_file(patient_folder, k) for k in MODALITY_KEYS}

    # Get pre-crop info (RAS space before crop)
    pre_data = pre_crop_transforms(dict(data))
    pre_img = pre_data["image"].numpy()
    pre_shape = list(pre_img.shape[1:])

    # Find crop box (where non-zero voxels are)
    nz = np.where(np.any(pre_img != 0, axis=0))
    crop_start = [int(nz[i].min()) for i in range(3)]

    # Get actual cropped shape
    crop_data = crop_info_transforms({k: find_nifti_file(patient_folder, k) for k in MODALITY_KEYS})
    actual_crop_shape = list(crop_data["image"].shape[1:])
    crop_end = [crop_start[i] + actual_crop_shape[i] for i in range(3)]

    # Full inference transform (includes padding)
    full_data = inference_transforms({k: find_nifti_file(patient_folder, k) for k in MODALITY_KEYS})
    img = full_data["image"]

    # Calculate pad offsets
    pad_offset = [(config.patch_size[i] - actual_crop_shape[i]) // 2
                  if actual_crop_shape[i] < config.patch_size[i] else 0
                  for i in range(3)]

    # Run inference
    x = img.unsqueeze(0).to(device)

    if config.use_tta:
        pred_probs = run_inference_with_tta(model, x, inferer, flip_dims=[2, 3, 4])
    else:
        with torch.no_grad():
            with autocast():
                out = inferer(x, model)
                if isinstance(out, tuple):
                    out = out[0]
                pred_probs = torch.sigmoid(out)

    pred_probs = pred_probs.cpu().numpy()[0]  # [3, D, H, W]

    # Apply per-class thresholds
    tc = (pred_probs[0] > config.threshold_tc).astype(np.uint8)  # Channel 0 = TC
    wt = (pred_probs[1] > config.threshold_wt).astype(np.uint8)  # Channel 1 = WT
    et = (pred_probs[2] > config.threshold_et).astype(np.uint8)  # Channel 2 = ET

    # Unpad (remove padding added by SpatialPadd)
    o = pad_offset
    s = actual_crop_shape
    tc_crop = tc[o[0]:o[0] + s[0], o[1]:o[1] + s[1], o[2]:o[2] + s[2]]
    wt_crop = wt[o[0]:o[0] + s[0], o[1]:o[1] + s[1], o[2]:o[2] + s[2]]
    et_crop = et[o[0]:o[0] + s[0], o[1]:o[1] + s[1], o[2]:o[2] + s[2]]

    # Place back into full RAS volume
    tc_ras = np.zeros(pre_shape, dtype=np.uint8)
    wt_ras = np.zeros(pre_shape, dtype=np.uint8)
    et_ras = np.zeros(pre_shape, dtype=np.uint8)

    c, e = crop_start, crop_end
    tc_ras[c[0]:e[0], c[1]:e[1], c[2]:e[2]] = tc_crop
    wt_ras[c[0]:e[0], c[1]:e[1], c[2]:e[2]] = wt_crop
    et_ras[c[0]:e[0], c[1]:e[1], c[2]:e[2]] = et_crop

    # Enforce hierarchy FIRST (before post-processing)
    if config.enforce_hierarchy:
        tc_ras, wt_ras, et_ras = enforce_tumor_hierarchy(tc_ras, wt_ras, et_ras)

    # Post-process each mask (now with erosion option)
    tc_ras = post_process_mask(tc_ras, config.min_component_size, config.fill_holes,
                               config.apply_erosion, config.erosion_iterations)
    wt_ras = post_process_mask(wt_ras, config.min_component_size, config.fill_holes,
                               config.apply_erosion, config.erosion_iterations)
    et_ras = post_process_mask(et_ras, config.min_component_size, config.fill_holes,
                               config.apply_erosion, config.erosion_iterations)

    # Re-enforce hierarchy AFTER post-processing (erosion can break hierarchy)
    if config.enforce_hierarchy:
        tc_ras, wt_ras, et_ras = enforce_tumor_hierarchy(tc_ras, wt_ras, et_ras)

    # RAS ‚Üí LPS conversion (flip X and Y axes)
    tc_lps = np.flip(np.flip(tc_ras, 0), 1).copy()
    wt_lps = np.flip(np.flip(wt_ras, 0), 1).copy()
    et_lps = np.flip(np.flip(et_ras, 0), 1).copy()

    # Resample to original shape if needed (spacing might differ)
    if tc_lps.shape != original_shape:
        zoom_factors = [o / c for o, c in zip(original_shape, tc_lps.shape)]
        tc_lps = zoom(tc_lps.astype(np.float32), zoom_factors, order=0).astype(np.uint8)
        wt_lps = zoom(wt_lps.astype(np.float32), zoom_factors, order=0).astype(np.uint8)
        et_lps = zoom(et_lps.astype(np.float32), zoom_factors, order=0).astype(np.uint8)

        # Ensure exact shape
        tc_lps = tc_lps[:original_shape[0], :original_shape[1], :original_shape[2]]
        wt_lps = wt_lps[:original_shape[0], :original_shape[1], :original_shape[2]]
        et_lps = et_lps[:original_shape[0], :original_shape[1], :original_shape[2]]

    # Derive BraTS submission classes:
    # Class 1 (NCR/NET): TC - ET (Necrotic/Non-enhancing tumor core)
    # Class 2 (ED): WT - TC (Peritumoral edema)
    # Class 4 (ET): ET directly (Enhancing tumor)
    class_1 = ((tc_lps > 0) & (et_lps == 0)).astype(np.uint8)  # NCR/NET = TC - ET
    class_2 = ((wt_lps > 0) & (tc_lps == 0)).astype(np.uint8)  # Edema = WT - TC
    class_4 = et_lps.astype(np.uint8)  # ET directly

    # [DIAGNOSTIC LOGGING] Log final voxel counts after RAS‚ÜíLPS conversion and resampling
    logger.info(f"  Final voxel counts: Class1={class_1.sum():,}, Class2={class_2.sum():,}, Class4={class_4.sum():,}")
    sys.stdout.flush()

    # Generate RLE (C-order)
    rle_1 = rle_encode_c_order(class_1)
    rle_2 = rle_encode_c_order(class_2)
    rle_4 = rle_encode_c_order(class_4)

    rle_dict = {1: rle_1, 2: rle_2, 4: rle_4}

    if return_masks:
        # Return both RLE and decoded masks for evaluation
        masks_dict = {1: class_1, 2: class_2, 4: class_4}
        return rle_dict, masks_dict

    return rle_dict


logger.info("‚úÖ Main processing pipeline loaded (with enhanced post-processing)")


In [None]:
def generate_submission(config: Config, target_id: str = None) -> pd.DataFrame:
    """
    Generate submission CSV for all test cases.
    [EVAL] Optional Dice score evaluation when ground truth is available.

    Returns:
        Tuple of (submission_df, stats_df, eval_df) where eval_df is None if evaluation disabled
    """
    # Setup inferer
    inferer = SlidingWindowInferer(
        roi_size=config.patch_size,
        sw_batch_size=config.sw_batch_size,
        overlap=config.overlap,
        mode="gaussian"
    )

    # Setup transforms
    inference_transforms = get_inference_transforms(config)
    pre_crop_transforms = get_pre_crop_transforms()
    crop_info_transforms = get_crop_info_transforms()

    # Find all test cases
    test_folders = sorted(glob.glob(os.path.join(config.test_dir, "BraTS*")))

    # Filter for specific case if requested
    if target_id:
        test_folders = [f for f in test_folders if target_id in os.path.basename(f)]
        logger.info(f"üîç DEBUG MODE: Processing ONLY case '{target_id}'")

    logger.info(f"Found {len(test_folders)} test cases")
    logger.info(
        f"Settings: TTA={config.use_tta}, Thresholds=[{config.threshold_tc}, {config.threshold_wt}, {config.threshold_et}]")
    sys.stdout.flush()

    # [EVAL] Check if evaluation is enabled and ground truth is available
    evaluation_enabled = config.enable_evaluation and os.path.isdir(config.ground_truth_dir)
    if config.enable_evaluation:
        if evaluation_enabled:
            logger.info(f"[EVAL] ‚úÖ Evaluation ENABLED - Ground truth dir: {config.ground_truth_dir}")
        else:
            logger.warning(f"[EVAL] ‚ö†Ô∏è Evaluation requested but ground truth dir not found: {config.ground_truth_dir}")
            evaluation_enabled = False
    sys.stdout.flush()

    all_rows = []
    stats = []
    eval_results = []  # [EVAL] Store per-case Dice scores

    # [EVAL] Track best and worst cases
    best_case = {"patient": None, "mean_dice": -1}
    worst_case = {"patient": None, "mean_dice": 2}

    logger.info("=" * 70)
    logger.info("STARTING INFERENCE ON TEST CASES")
    logger.info("=" * 70)
    sys.stdout.flush()

    total_cases = len(test_folders)
    for idx, patient_folder in enumerate(test_folders, 1):
        patient_id = os.path.basename(patient_folder)
        case_start_time = time.time()

        logger.info(f"Processing case {idx}/{total_cases}: {patient_id}")
        sys.stdout.flush()

        try:
            # [EVAL] Request masks if evaluation is enabled
            result = process_single_case(
                model=model,
                patient_folder=patient_folder,
                config=config,
                inferer=inferer,
                inference_transforms=inference_transforms,
                pre_crop_transforms=pre_crop_transforms,
                crop_info_transforms=crop_info_transforms,
                return_masks=evaluation_enabled
            )

            if evaluation_enabled:
                rle_dict, pred_masks = result
            else:
                rle_dict = result
                pred_masks = None

            # Count voxels for stats
            voxel_counts = {}
            for cls, rle in rle_dict.items():
                if pred_masks is not None:
                    # Use already-decoded masks
                    voxel_counts[f"c{cls}"] = int(pred_masks[cls].sum())
                else:
                    mask = rle_decode_c_order(rle, ORIGINAL_SHAPE)
                    voxel_counts[f"c{cls}"] = int(mask.sum())
                all_rows.append({"id": f"{patient_id}_{cls}", "rle": rle if rle else ""})

            stats.append({"patient": patient_id, **voxel_counts})

            # Log processing time
            case_time = time.time() - case_start_time
            logger.info(f"  ‚úì Completed in {case_time:.2f}s")
            sys.stdout.flush()

            # [EVAL] Calculate Dice scores if evaluation is enabled
            if evaluation_enabled:
                gt_folder = find_ground_truth_folder(patient_id, config.ground_truth_dir)
                if gt_folder:
                    try:
                        gt_masks = load_ground_truth_masks(gt_folder, ORIGINAL_SHAPE)
                        dice_scores = calculate_per_class_dice(pred_masks, gt_masks)
                        mean_dice = np.mean(list(dice_scores.values()))

                        eval_results.append({
                            "patient": patient_id,
                            "dice_class1": dice_scores[1],
                            "dice_class2": dice_scores[2],
                            "dice_class4": dice_scores[4],
                            "mean_dice": mean_dice
                        })

                        # Track best and worst
                        if mean_dice > best_case["mean_dice"]:
                            best_case = {"patient": patient_id, "mean_dice": mean_dice}
                        if mean_dice < worst_case["mean_dice"]:
                            worst_case = {"patient": patient_id, "mean_dice": mean_dice}

                        # Log per-case Dice score
                        logger.info(f"  [EVAL] Dice Scores: Class1={dice_scores[1]:.4f}, Class2={dice_scores[2]:.4f}, Class4={dice_scores[4]:.4f}, Mean={mean_dice:.4f}")
                        sys.stdout.flush()

                    except Exception as eval_ex:
                        logger.warning(f"  [EVAL ERROR] {eval_ex}")
                        sys.stdout.flush()
                        eval_results.append({
                            "patient": patient_id,
                            "dice_class1": np.nan,
                            "dice_class2": np.nan,
                            "dice_class4": np.nan,
                            "mean_dice": np.nan
                        })
                else:
                    logger.warning(f"  [EVAL] Ground truth not found")
                    sys.stdout.flush()
                    eval_results.append({
                        "patient": patient_id,
                        "dice_class1": np.nan,
                        "dice_class2": np.nan,
                        "dice_class4": np.nan,
                        "mean_dice": np.nan
                    })

        except Exception as ex:
            logger.error(f"  ‚ùå Error processing {patient_id}: {ex}")
            sys.stdout.flush()
            for cls in [1, 2, 4]:
                all_rows.append({"id": f"{patient_id}_{cls}", "rle": ""})
            stats.append({"patient": patient_id, "c1": 0, "c2": 0, "c4": 0})
            if evaluation_enabled:
                eval_results.append({
                    "patient": patient_id,
                    "dice_class1": np.nan,
                    "dice_class2": np.nan,
                    "dice_class4": np.nan,
                    "mean_dice": np.nan
                })

    # Create and save submission
    logger.info("=" * 70)
    logger.info("SAVING SUBMISSION")
    logger.info("=" * 70)
    submission_df = pd.DataFrame(all_rows).sort_values("id").reset_index(drop=True)
    submission_df.to_csv(config.output_csv, index=False)

    logger.info(f"‚úÖ Submission saved to: {config.output_csv}")
    logger.info(f"Total rows: {len(submission_df)}")
    sys.stdout.flush()

    # Print statistics
    stats_df = pd.DataFrame(stats)
    logger.info("")
    logger.info("üìä Prediction Statistics:")
    logger.info(
        f"  Cases with tumor (any class > 0): {((stats_df['c1'] > 0) | (stats_df['c2'] > 0) | (stats_df['c4'] > 0)).sum()}/{len(stats_df)}")
    logger.info(f"  Mean Class 1 (NCR/NET) voxels: {stats_df['c1'].mean():.0f}")
    logger.info(f"  Mean Class 2 (Edema) voxels:   {stats_df['c2'].mean():.0f}")
    logger.info(f"  Mean Class 4 (ET) voxels:      {stats_df['c4'].mean():.0f}")
    sys.stdout.flush()

    # [EVAL] Print evaluation summary
    eval_df = None
    if evaluation_enabled and eval_results:
        eval_df = pd.DataFrame(eval_results)
        valid_evals = eval_df.dropna()

        logger.info("")
        logger.info("=" * 70)
        logger.info("üìä EVALUATION RESULTS (Dice Scores)")
        logger.info("=" * 70)
        logger.info(f"  Cases evaluated: {len(valid_evals)}/{len(eval_results)}")

        if len(valid_evals) > 0:
            logger.info("")
            logger.info("  üìà MEAN DICE SCORES:")
            logger.info(f"     Class 1 (NCR/NET):     {valid_evals['dice_class1'].mean():.4f}")
            logger.info(f"     Class 2 (Edema):       {valid_evals['dice_class2'].mean():.4f}")
            logger.info(f"     Class 4 (ET):          {valid_evals['dice_class4'].mean():.4f}")
            logger.info(f"     ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ")
            logger.info(f"     OVERALL MEAN DICE:     {valid_evals['mean_dice'].mean():.4f}")

            logger.info("")
            logger.info(f"  üèÜ BEST CASE:  {best_case['patient']} (Mean Dice: {best_case['mean_dice']:.4f})")
            logger.info(f"  ‚ö†Ô∏è  WORST CASE: {worst_case['patient']} (Mean Dice: {worst_case['mean_dice']:.4f})")

            # Per-class statistics
            logger.info("")
            logger.info("  üìä PER-CLASS STATISTICS:")
            for cls, col in [(1, 'dice_class1'), (2, 'dice_class2'), (4, 'dice_class4')]:
                cls_data = valid_evals[col]
                logger.info(f"     Class {cls}: Mean={cls_data.mean():.4f}, Std={cls_data.std():.4f}, "
                      f"Min={cls_data.min():.4f}, Max={cls_data.max():.4f}")

        logger.info("=" * 70)

        # Save evaluation results
        eval_csv_path = config.output_csv.replace('.csv', '_evaluation.csv')
        eval_df.to_csv(eval_csv_path, index=False)
        logger.info(f"  Evaluation results saved to: {eval_csv_path}")
        sys.stdout.flush()

    return submission_df, stats_df, eval_df

In [None]:
# Cell 11: Run Full Submission Generation
logger.info("=" * 70)
logger.info("GENERATING FINAL SUBMISSION (ALL TEST CASES)")
logger.info("=" * 70)
logger.info("")
logger.info("Pipeline Settings:")
logger.info(f"  ‚úÖ C-order RLE encoding (correct format)")
logger.info(f"  ‚úÖ RAS ‚Üí LPS orientation conversion")
logger.info(f"  ‚úÖ Per-class thresholds: TC={config.threshold_tc}, WT={config.threshold_wt}, ET={config.threshold_et}")
logger.info(f"  ‚úÖ Post-processing: min_size={config.min_component_size}, fill_holes={config.fill_holes}")
logger.info(f"  ‚úÖ Tumor hierarchy enforcement: {config.enforce_hierarchy}")
logger.info(f"  ‚úÖ Test-Time Augmentation: {config.use_tta}")
logger.info(f"  ‚úÖ Overlap: {config.overlap}")
logger.info(f"  ‚úÖ Evaluation: {config.enable_evaluation}")
logger.info("")
sys.stdout.flush()

# Generate submission for ALL test cases
submission_df, stats_df, eval_df = generate_submission(config)

# Preview
logger.info("")
logger.info("üìã Submission Preview (first 15 rows):")
logger.info("\n" + submission_df.head(15).to_string(index=False))
logger.info("")
logger.info("=" * 70)
logger.info("‚úÖ SUBMISSION GENERATION COMPLETE")
logger.info("=" * 70)
sys.stdout.flush()

## ‚úÖ Submission Complete

Your submission file has been saved to: `/kaggle/working/submission.csv`

### Pipeline Summary
| Component | Setting |
|-----------|---------|
| RLE Encoding | C-order (row-major), 1-indexed |
| Orientation | RAS inference ‚Üí LPS submission |
| Thresholds | TC=0.52, WT=0.47, ET=0.57 |
| Post-processing | min_size=150, fill_holes=True |
| Hierarchy | ET ‚äÜ TC ‚äÜ WT enforced |
| TTA | Flip augmentation (X, Y, Z) |
| Overlap | 0.6 |

### Download and Submit
Run this to download:
```python
from IPython.display import FileLink
FileLink('submission.csv')
```