# UW-Madison GI Tract Segmentation — Plan to Medal

Objectives:
- Build a robust CV pipeline (patient-wise splits) mirroring LB.
- Establish fast, reliable baseline; iterate to medal.

Milestones:
1) Environment + GPU gate
   - Verify GPU (nvidia-smi). Install torch/cu121 stack, smp, albumentations.
   - Add timing/progress logging utilities.

2) Data audit and EDA
   - Inspect train.csv/test.csv schema; parse rle strings by class.
   - Confirm image paths, dimensions, per-case slice counts, empty-mask ratio.
   - Visual sanity checks (few samples with overlays).

3) Validation protocol
   - GroupKFold by patient/case (no leakage across days/slices).
   - 5 folds, deterministic seed; reuse fixed folds throughout.
   - OOF dice per class; track dice-hausdorff proxy (avg Dice + HD95 via medpy/skimage).

4) Baseline data pipeline
   - Load grayscale PNGs; stack 2.5D context (e.g., prev/cur/next slices → 3ch).
   - Resize to 256x256 baseline (keep aspect; pad/crop).
   - Augmentations: flips, small affine, brightness/contrast, elastic light.
   - Convert RLE→mask for 3 classes; mixed empty-slice sampling.

5) Baseline model and loss
   - UNet/UNet++ with ImageNet encoder (ResNet34/EfficientNet-b0) via segmentation_models_pytorch.
   - Loss: 0.5*BCEWithLogits + 0.5*SoftDice; per-class weighting if imbalance observed.
   - Optim: AdamW, LR ~1e-3 with CosineAnnealing, warmup; EMA weights.
   - Mixed precision (amp), gradient clipping.

6) Training strategy
   - Epochs: ~40-60 at 256 res for quick OOF; early stopping on val dice.
   - Save best by val score per fold; log per-epoch dice per class.
   - Cache fold splits, OOF preds (npz) and test logits for later blends.

7) Inference & post-processing
   - TTA (h-flip, v-flip) average.
   - Threshold tuning per class via OOF.
   - Morphology: remove small blobs (class-wise min area), keep largest CC for bowel if helps.
   - RLE encode to submission.csv.

8) Iterations to medal
   - Resolution ablation: 256 → 384/512 if memory allows; compare OOF.
   - 2.5D context window ablation (5ch with [-2,-1,0,+1,+2] via 3ch stride).
   - Encoder sweep: ResNet34 → tf_efficientnet_b3/b4.
   - Loss sweep: add Tversky/FocalDice; class weights.
   - Blend diverse seeds/encoders/resolutions (weighted by OOF).

9) Risk controls
   - Strict fold reuse; no leakage.
   - Sanity checks: mask overlay, non-empty ratio, OOF vs LB tracking.
   - Log progress and time per fold; interrupt if stalled.

Next actions:
- Verify GPU and install torch/cu121 + deps.
- EDA of csvs (schema, counts, empties).
- Implement fold splitter (GroupKFold by case/day).
- Build baseline dataset/loader + UNet(R34, 256) and run 5-fold smoke (few epochs).

In [9]:
# Environment & GPU gate + Torch/cu121 stack install
import os, sys, subprocess, shutil, time, textwrap, json
from pathlib import Path

def run(cmd):
    print("> ", " ".join(cmd), flush=True)
    return subprocess.run(cmd, check=False, capture_output=True, text=True)

print("[GPU CHECK] nvidia-smi:", flush=True)
out = run(["bash","-lc","nvidia-smi || true"])
print(out.stdout)

# Hard reset any prior torch stacks
for pkg in ("torch","torchvision","torchaudio"):
    subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", pkg], check=False)

# Clean stray site dirs that can shadow correct wheels (idempotent)
for d in (
    "/app/.pip-target/torch",
    "/app/.pip-target/torchvision",
    "/app/.pip-target/torchaudio",
    "/app/.pip-target/torchgen",
    "/app/.pip-target/functorch",
):
    if os.path.exists(d):
        print("Removing", d); shutil.rmtree(d, ignore_errors=True)

def pip(*args):
    print("> pip", " ".join(args), flush=True)
    subprocess.run([sys.executable, "-m", "pip", *args], check=True)

# Install exact cu121 torch stack
pip("install",
    "--index-url", "https://download.pytorch.org/whl/cu121",
    "--extra-index-url", "https://pypi.org/simple",
    "torch==2.4.1", "torchvision==0.19.1", "torchaudio==2.4.1")

# Freeze constraints
Path("constraints.txt").write_text("\n".join([
    "torch==2.4.1",
    "torchvision==0.19.1",
    "torchaudio==2.4.1",
]))

# Proactively remove albucore to avoid ABI mismatch with albumentations 1.4.x
subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "albucore"], check=False)

# Install non-torch deps for this competition
deps = [
    "segmentation-models-pytorch==0.3.3",
    "timm==0.9.2",  # SMP 0.3.3 pins timm==0.9.2
    # Use albumentations 1.3.1 (no albucore dependency) to avoid runtime import issues
    "albumentations==1.3.1",
    "opencv-python-headless==4.10.0.84",
    "scikit-image",
    "medpy",
    "scikit-learn",
    "pandas",
    "numpy",
    "matplotlib",
    "pillow",
]
pip("install", "-c", "constraints.txt", *deps, "--upgrade-strategy", "only-if-needed")

import torch
print("torch:", torch.__version__, "built CUDA:", getattr(torch.version, "cuda", None))
print("CUDA available:", torch.cuda.is_available())
assert str(getattr(torch.version, "cuda", "")).startswith("12.1"), f"Wrong CUDA build: {torch.version.cuda}"
assert torch.cuda.is_available(), "CUDA not available"
print("GPU:", torch.cuda.get_device_name(0))
print("[ENV READY]")

[GPU CHECK] nvidia-smi:


>  bash -lc nvidia-smi || true


Thu Sep 25 01:27:44 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.144.06             Driver Version: 550.144.06     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A10-24Q                 On  |   00000002:00:00.0 Off |                    0 |
| N/A   N/A    P0             N/A /  N/A  |     412MiB /  24512MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

Found existing installation: torch 2.4.1+cu121


Uninstalling torch-2.4.1+cu121:
  Successfully uninstalled torch-2.4.1+cu121


Found existing installation: torchvision 0.19.1+cu121
Uninstalling torchvision-0.19.1+cu121:
  Successfully uninstalled torchvision-0.19.1+cu121


Found existing installation: torchaudio 2.4.1+cu121
Uninstalling torchaudio-2.4.1+cu121:
  Successfully uninstalled torchaudio-2.4.1+cu121
> pip install --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1


Looking in indexes: https://download.pytorch.org/whl/cu121, https://pypi.org/simple


Collecting torch==2.4.1
  Downloading https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl (799.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 799.0/799.0 MB 536.9 MB/s eta 0:00:00


Collecting torchvision==0.19.1
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.19.1%2Bcu121-cp311-cp311-linux_x86_64.whl (7.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.1/7.1 MB 525.5 MB/s eta 0:00:00


Collecting torchaudio==2.4.1
  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl (3.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.4/3.4 MB 482.6 MB/s eta 0:00:00
Collecting filelock
  Downloading filelock-3.19.1-py3-none-any.whl (15 kB)


Collecting nvidia-nvtx-cu12==12.1.105
  Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 KB 8.7 MB/s eta 0:00:00
Collecting nvidia-cufft-cu12==11.0.2.54
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.6/121.6 MB 226.1 MB/s eta 0:00:00


Collecting jinja2
  Downloading jinja2-3.1.6-py3-none-any.whl (134 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.9/134.9 KB 478.4 MB/s eta 0:00:00
Collecting typing-extensions>=4.8.0
  Downloading typing_extensions-4.15.0-py3-none-any.whl (44 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 KB 417.1 MB/s eta 0:00:00


Collecting nvidia-cusolver-cu12==11.4.5.107
  Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.2/124.2 MB 240.3 MB/s eta 0:00:00
Collecting nvidia-cusparse-cu12==12.1.0.106
  Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 196.0/196.0 MB 226.2 MB/s eta 0:00:00


Collecting fsspec
  Downloading fsspec-2025.9.0-py3-none-any.whl (199 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 199.3/199.3 KB 521.7 MB/s eta 0:00:00
Collecting nvidia-cuda-runtime-cu12==12.1.105
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.6/823.6 KB 268.6 MB/s eta 0:00:00


Collecting nvidia-curand-cu12==10.3.2.106
  Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.5/56.5 MB 244.3 MB/s eta 0:00:00
Collecting nvidia-cudnn-cu12==9.1.0.70
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 110.2 MB/s eta 0:00:00


Collecting sympy
  Downloading sympy-1.14.0-py3-none-any.whl (6.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.3/6.3 MB 475.2 MB/s eta 0:00:00


Collecting triton==3.0.0
  Downloading triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.4/209.4 MB 32.7 MB/s eta 0:00:00


Collecting nvidia-nccl-cu12==2.20.5
  Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 176.2/176.2 MB 202.4 MB/s eta 0:00:00


Collecting nvidia-cublas-cu12==12.1.3.1
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.6/410.6 MB 226.9 MB/s eta 0:00:00


Collecting nvidia-cuda-nvrtc-cu12==12.1.105
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 222.4 MB/s eta 0:00:00


Collecting nvidia-cuda-cupti-cu12==12.1.105
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.1/14.1 MB 182.9 MB/s eta 0:00:00


Collecting networkx
  Downloading networkx-3.5-py3-none-any.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 510.5 MB/s eta 0:00:00


Collecting pillow!=8.3.*,>=5.3.0
  Downloading pillow-11.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 192.8 MB/s eta 0:00:00


Collecting numpy
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 206.6 MB/s eta 0:00:00


Collecting nvidia-nvjitlink-cu12
  Downloading nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 39.7/39.7 MB 187.7 MB/s eta 0:00:00


Collecting MarkupSafe>=2.0
  Downloading MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (23 kB)
Collecting mpmath<1.4,>=1.1.0
  Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 KB 295.6 MB/s eta 0:00:00


Installing collected packages: mpmath, typing-extensions, sympy, pillow, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, numpy, networkx, MarkupSafe, fsspec, filelock, triton, nvidia-cusparse-cu12, nvidia-cudnn-cu12, jinja2, nvidia-cusolver-cu12, torch, torchvision, torchaudio


Successfully installed MarkupSafe-3.0.2 filelock-3.19.1 fsspec-2025.9.0 jinja2-3.1.6 mpmath-1.3.0 networkx-3.5 numpy-1.26.4 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.9.86 nvidia-nvtx-cu12-12.1.105 pillow-11.3.0 sympy-1.14.0 torch-2.4.1+cu121 torchaudio-2.4.1+cu121 torchvision-0.19.1+cu121 triton-3.0.0 typing-extensions-4.15.0




> pip install -c constraints.txt segmentation-models-pytorch==0.3.3 timm==0.9.2 albumentations==1.3.1 opencv-python-headless==4.10.0.84 scikit-image medpy scikit-learn pandas numpy matplotlib pillow --upgrade-strategy only-if-needed




Collecting segmentation-models-pytorch==0.3.3
  Downloading segmentation_models_pytorch-0.3.3-py3-none-any.whl (106 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 106.7/106.7 KB 6.3 MB/s eta 0:00:00
Collecting timm==0.9.2
  Downloading timm-0.9.2-py3-none-any.whl (2.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 95.9 MB/s eta 0:00:00


Collecting albumentations==1.3.1
  Downloading albumentations-1.3.1-py3-none-any.whl (125 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 125.7/125.7 KB 453.4 MB/s eta 0:00:00
Collecting opencv-python-headless==4.10.0.84
  Downloading opencv_python_headless-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (49.9 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 49.9/49.9 MB 39.5 MB/s eta 0:00:00
Collecting scikit-image
  Downloading scikit_image-0.25.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.8/14.8 MB 37.0 MB/s eta 0:00:00
Collecting medpy
  Downloading medpy-0.5.2.tar.gz (156 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 156.3/156.3 KB 475.9 MB/s eta 0:00:00
  Preparing metadata (setup.py): started


  Preparing metadata (setup.py): finished with status 'done'
Collecting scikit-learn
  Downloading scikit_learn-1.7.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.7/9.7 MB 122.7 MB/s eta 0:00:00
Collecting pandas
  Downloading pandas-2.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.4/12.4 MB 504.6 MB/s eta 0:00:00


Collecting numpy
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 124.1 MB/s eta 0:00:00


Collecting matplotlib
  Downloading matplotlib-3.10.6-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.7/8.7 MB 126.3 MB/s eta 0:00:00


Collecting pillow
  Downloading pillow-11.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 266.3 MB/s eta 0:00:00
Collecting tqdm
  Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.5/78.5 KB 399.0 MB/s eta 0:00:00
Collecting pretrainedmodels==0.7.4
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.8/58.8 KB 405.9 MB/s eta 0:00:00
  Preparing metadata (setup.py): started


  Preparing metadata (setup.py): finished with status 'done'
Collecting torchvision>=0.5.0
  Downloading torchvision-0.19.1-cp311-cp311-manylinux1_x86_64.whl (7.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.0/7.0 MB 171.7 MB/s eta 0:00:00


Collecting efficientnet-pytorch==0.7.1
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py): started


  Preparing metadata (setup.py): finished with status 'done'
Collecting torch>=1.7
  Downloading torch-2.4.1-cp311-cp311-manylinux1_x86_64.whl (797.1 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 797.1/797.1 MB 164.5 MB/s eta 0:00:00


Collecting pyyaml
  Downloading PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (762 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 763.0/763.0 KB 512.5 MB/s eta 0:00:00
Collecting safetensors
  Downloading safetensors-0.6.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (485 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 485.8/485.8 KB 496.1 MB/s eta 0:00:00


Collecting huggingface-hub
  Downloading huggingface_hub-0.35.1-py3-none-any.whl (563 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 563.3/563.3 KB 490.1 MB/s eta 0:00:00
Collecting qudida>=0.0.4
  Downloading qudida-0.0.4-py3-none-any.whl (3.5 kB)


Collecting scipy>=1.1.0
  Downloading scipy-1.16.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (35.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 35.9/35.9 MB 181.7 MB/s eta 0:00:00


Collecting munch
  Downloading munch-4.0.0-py2.py3-none-any.whl (9.9 kB)
Collecting networkx>=3.0
  Downloading networkx-3.5-py3-none-any.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 484.1 MB/s eta 0:00:00
Collecting lazy-loader>=0.4
  Downloading lazy_loader-0.4-py3-none-any.whl (12 kB)
Collecting tifffile>=2022.8.12
  Downloading tifffile-2025.9.20-py3-none-any.whl (230 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 230.1/230.1 KB 447.8 MB/s eta 0:00:00
Collecting imageio!=2.35.0,>=2.33
  Downloading imageio-2.37.0-py3-none-any.whl (315 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 315.8/315.8 KB 505.3 MB/s eta 0:00:00


Collecting packaging>=21
  Downloading packaging-25.0-py3-none-any.whl (66 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.5/66.5 KB 375.7 MB/s eta 0:00:00
Collecting SimpleITK>=2.1
  Downloading simpleitk-2.5.2-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (52.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 52.6/52.6 MB 103.1 MB/s eta 0:00:00
Collecting threadpoolctl>=3.1.0
  Downloading threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Collecting joblib>=1.2.0
  Downloading joblib-1.5.2-py3-none-any.whl (308 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 308.4/308.4 KB 487.9 MB/s eta 0:00:00
Collecting python-dateutil>=2.8.2
  Downloading python_dateutil-2.9.0.post0-py2.py3-none-any.whl (229 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 229.9/229.9 KB 466.6 MB/s eta 0:00:00


Collecting pytz>=2020.1
  Downloading pytz-2025.2-py2.py3-none-any.whl (509 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 509.2/509.2 KB 507.3 MB/s eta 0:00:00
Collecting tzdata>=2022.7
  Downloading tzdata-2025.2-py2.py3-none-any.whl (347 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 347.8/347.8 KB 484.4 MB/s eta 0:00:00
Collecting contourpy>=1.0.1
  Downloading contourpy-1.3.3-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (355 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 355.2/355.2 KB 489.8 MB/s eta 0:00:00


Collecting fonttools>=4.22.0
  Downloading fonttools-4.60.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (5.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.0/5.0 MB 170.7 MB/s eta 0:00:00
Collecting cycler>=0.10
  Downloading cycler-0.12.1-py3-none-any.whl (8.3 kB)
Collecting pyparsing>=2.3.1
  Downloading pyparsing-3.2.5-py3-none-any.whl (113 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 113.9/113.9 KB 426.7 MB/s eta 0:00:00


Collecting kiwisolver>=1.3.1
  Downloading kiwisolver-1.4.9-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (1.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.4/1.4 MB 26.1 MB/s eta 0:00:00


Collecting six>=1.5
  Downloading six-1.17.0-py2.py3-none-any.whl (11 kB)
Collecting typing-extensions
  Downloading typing_extensions-4.15.0-py3-none-any.whl (44 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 KB 342.7 MB/s eta 0:00:00
Collecting sympy
  Downloading sympy-1.14.0-py3-none-any.whl (6.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.3/6.3 MB 225.7 MB/s eta 0:00:00


Collecting nvidia-nvtx-cu12==12.1.105
  Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 KB 411.4 MB/s eta 0:00:00
Collecting filelock
  Downloading filelock-3.19.1-py3-none-any.whl (15 kB)
Collecting nvidia-cusolver-cu12==11.4.5.107
  Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.2/124.2 MB 240.7 MB/s eta 0:00:00
Collecting fsspec
  Downloading fsspec-2025.9.0-py3-none-any.whl (199 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 199.3/199.3 KB 487.3 MB/s eta 0:00:00
Collecting jinja2
  Downloading jinja2-3.1.6-py3-none-any.whl (134 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.9/134.9 KB 457.6 MB/s eta 0:00:00
Collecting nvidia-curand-cu12==10.3.2.106
  Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.5/56.5 MB 267.1 MB/s eta 0:00:00
Collecting nvidia-cusparse-cu12==12.1.0.106
  Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 196.0/196.0 MB 139.6 MB/s eta 0:00:00
Collecting nvidia-nccl-cu12==2.20.5
  Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 176.2/176.2 MB 133.9 MB/s eta 0:00:00
Collecting nvidia-cuda-runtime-cu12==12.1.105
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.6/823.6 KB 210.0 MB/s eta 0:00:00


Collecting nvidia-cudnn-cu12==9.1.0.70
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 179.7 MB/s eta 0:00:00


Collecting nvidia-cublas-cu12==12.1.3.1
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.6/410.6 MB 189.0 MB/s eta 0:00:00


Collecting nvidia-cufft-cu12==11.0.2.54
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.6/121.6 MB 107.1 MB/s eta 0:00:00
Collecting nvidia-cuda-cupti-cu12==12.1.105
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.1/14.1 MB 71.1 MB/s eta 0:00:00
Collecting nvidia-cuda-nvrtc-cu12==12.1.105
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 87.3 MB/s eta 0:00:00
Collecting triton==3.0.0
  Downloading triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.4/209.4 MB 86.3 MB/s eta 0:00:00
Collecting nvidia-nvjitlink-cu12


  Downloading nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 39.7/39.7 MB 195.2 MB/s eta 0:00:00
Collecting requests
  Downloading requests-2.32.5-py3-none-any.whl (64 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64.7/64.7 KB 421.4 MB/s eta 0:00:00


Collecting hf-xet<2.0.0,>=1.1.3
  Downloading hf_xet-1.1.10-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.2/3.2 MB 391.1 MB/s eta 0:00:00


Collecting MarkupSafe>=2.0
  Downloading MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (23 kB)
Collecting idna<4,>=2.5
  Downloading idna-3.10-py3-none-any.whl (70 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 70.4/70.4 KB 432.5 MB/s eta 0:00:00
Collecting charset_normalizer<4,>=2
  Downloading charset_normalizer-3.4.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (150 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 150.3/150.3 KB 448.0 MB/s eta 0:00:00
Collecting certifi>=2017.4.17
  Downloading certifi-2025.8.3-py3-none-any.whl (161 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 161.2/161.2 KB 455.1 MB/s eta 0:00:00


Collecting urllib3<3,>=1.21.1
  Downloading urllib3-2.5.0-py3-none-any.whl (129 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 129.8/129.8 KB 442.6 MB/s eta 0:00:00
Collecting mpmath<1.4,>=1.1.0
  Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 KB 478.2 MB/s eta 0:00:00
Building wheels for collected packages: efficientnet-pytorch, pretrainedmodels, medpy
  Building wheel for efficientnet-pytorch (setup.py): started


  Building wheel for efficientnet-pytorch (setup.py): finished with status 'done'
  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.7.1-py3-none-any.whl size=16446 sha256=5ee883708498716a1dcf136a042fcc635dcd11b8946be8872a355da0c3ecb48f
  Stored in directory: /tmp/pip-ephem-wheel-cache-1_97n9xy/wheels/8b/6f/9b/231a832f811ab6ebb1b32455b177ffc6b8b1cd8de19de70c09
  Building wheel for pretrainedmodels (setup.py): started


  Building wheel for pretrainedmodels (setup.py): finished with status 'done'
  Created wheel for pretrainedmodels: filename=pretrainedmodels-0.7.4-py3-none-any.whl size=60967 sha256=aaf1931beb12d59f6d57131bff2bf08fc9a10f33045cc3cbbef4e707b0d1366e
  Stored in directory: /tmp/pip-ephem-wheel-cache-1_97n9xy/wheels/5f/5b/96/fd94bc35962d7c6b699e8814db545155ac91d2b95785e1b035
  Building wheel for medpy (setup.py): started


  Building wheel for medpy (setup.py): finished with status 'done'
  Created wheel for medpy: filename=MedPy-0.5.2-py3-none-any.whl size=224726 sha256=6956e550f4afec226e35ea593b8c347d1826ab79c15e688564526981ee5487fb
  Stored in directory: /tmp/pip-ephem-wheel-cache-1_97n9xy/wheels/d4/33/ed/aaac5a347fb8d41679ca515b8f5c49dfdf49be15bdbb9a905d
Successfully built efficientnet-pytorch pretrainedmodels medpy


Installing collected packages: SimpleITK, pytz, mpmath, urllib3, tzdata, typing-extensions, tqdm, threadpoolctl, sympy, six, safetensors, pyyaml, pyparsing, pillow, packaging, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, numpy, networkx, munch, MarkupSafe, kiwisolver, joblib, idna, hf-xet, fsspec, fonttools, filelock, cycler, charset_normalizer, certifi, triton, tifffile, scipy, requests, python-dateutil, opencv-python-headless, nvidia-cusparse-cu12, nvidia-cudnn-cu12, lazy-loader, jinja2, imageio, contourpy, scikit-learn, scikit-image, pandas, nvidia-cusolver-cu12, medpy, matplotlib, huggingface-hub, torch, qudida, torchvision, efficientnet-pytorch, albumentations, timm, pretrainedmodels, segmentation-models-pytorch


Successfully installed MarkupSafe-3.0.2 SimpleITK-2.5.2 albumentations-1.3.1 certifi-2025.8.3 charset_normalizer-3.4.3 contourpy-1.3.3 cycler-0.12.1 efficientnet-pytorch-0.7.1 filelock-3.19.1 fonttools-4.60.0 fsspec-2025.9.0 hf-xet-1.1.10 huggingface-hub-0.35.1 idna-3.10 imageio-2.37.0 jinja2-3.1.6 joblib-1.5.2 kiwisolver-1.4.9 lazy-loader-0.4 matplotlib-3.10.6 medpy-0.5.2 mpmath-1.3.0 munch-4.0.0 networkx-3.5 numpy-1.26.4 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.9.86 nvidia-nvtx-cu12-12.1.105 opencv-python-headless-4.10.0.84 packaging-25.0 pandas-2.3.2 pillow-11.3.0 pretrainedmodels-0.7.4 pyparsing-3.2.5 python-dateutil-2.9.0.post0 pytz-2025.2 pyyaml-6.0.2 qudida-0.0.4 requests-2.32.5 safetensors-0.6.2 scik



torch: 2.4.1+cu121 built CUDA: 12.1
CUDA available: True
GPU: NVIDIA A10-24Q
[ENV READY]


In [4]:
# EDA: CSV schema, ID parsing, positive ratios, folds, and path resolver
import pandas as pd, numpy as np, os, re, sys, json, math, time, glob
from pathlib import Path
from sklearn.model_selection import StratifiedGroupKFold

DATA_DIR = Path('.')
TRAIN_CSV = DATA_DIR / 'train.csv'
TEST_CSV = DATA_DIR / 'test.csv'
# Potential roots (local repo mounts first; add common Kaggle-style mounts if present at runtime)
TRAIN_IMG_ROOTS = [
    DATA_DIR / 'train',
    # Kaggle official
    Path('/kaggle/input/uw-madison-gi-tract-image-segmentation/train'),
    Path('/kaggle/input/uw-madison-gi-tract-image-segmentation/train_png'),
    # Common mirrors / alternate mounts
    Path('/kaggle/input/uw-madison-gi-tract-image-segmentation-256x256/train'),
    Path('/kaggle/input/uwmadison-gi-tract-image-segmentation/train'),
    Path('/kaggle/input/uw-madison-gi-tract-image-segmentation-resized/train'),
    Path('/kaggle/temp/uw-madison-gi-tract-image-segmentation/train'),
    Path('/kaggle/working/uw-madison-gi-tract-image-segmentation/train'),
    Path('/content/uw-madison-gi-tract-image-segmentation/train'),
    Path('/mnt/input/uw-madison-gi-tract-image-segmentation/train'),
    Path('/mnt/data/uw-madison-gi-tract-image-segmentation/train'),
    Path('/data/uw-madison-gi-tract-image-segmentation/train'),
    Path('/workspace/uw-madison-gi-tract-image-segmentation/train'),
    Path('/datasets/uw-madison-gi-tract-image-segmentation/train'),
    Path('/opt/data/uw-madison-gi-tract-image-segmentation/train'),
    Path('/app/data/uw-madison-gi-tract-image-segmentation/train'),
]
TEST_IMG_ROOTS = [
    DATA_DIR / 'test',
    # Kaggle official
    Path('/kaggle/input/uw-madison-gi-tract-image-segmentation/test'),
    Path('/kaggle/input/uw-madison-gi-tract-image-segmentation/test_png'),
    # Common mirrors / alternate mounts
    Path('/kaggle/input/uw-madison-gi-tract-image-segmentation-256x256/test'),
    Path('/kaggle/input/uwmadison-gi-tract-image-segmentation/test'),
    Path('/kaggle/input/uw-madison-gi-tract-image-segmentation-resized/test'),
    Path('/kaggle/temp/uw-madison-gi-tract-image-segmentation/test'),
    Path('/kaggle/working/uw-madison-gi-tract-image-segmentation/test'),
    Path('/content/uw-madison-gi-tract-image-segmentation/test'),
    Path('/mnt/input/uw-madison-gi-tract-image-segmentation/test'),
    Path('/mnt/data/uw-madison-gi-tract-image-segmentation/test'),
    Path('/data/uw-madison-gi-tract-image-segmentation/test'),
    Path('/workspace/uw-madison-gi-tract-image-segmentation/test'),
    Path('/datasets/uw-madison-gi-tract-image-segmentation/test'),
    Path('/opt/data/uw-madison-gi-tract-image-segmentation/test'),
    Path('/app/data/uw-madison-gi-tract-image-segmentation/test'),
]

# Inject extracted archive train path (use as both train and test if no separate test dir exists)
EXTERNAL_TRAIN = Path('external_data/uw-madison-gi-tract-image-segmentation/train')
if EXTERNAL_TRAIN.exists():
    TRAIN_IMG_ROOTS.insert(0, EXTERNAL_TRAIN)
    # Also allow resolver to look here for test IDs (many mirrors ship train-only)
    TEST_IMG_ROOTS.insert(0, EXTERNAL_TRAIN)

# Dynamic discovery: scan Kaggle inputs for uw*gi* patterns and append discovered roots
def _append_dynamic_roots(roots_list, split_name):
    try:
        for base in Path('/kaggle/input').glob('*uw*gi*/*'):
            if not base.is_dir():
                continue
            cand = base / split_name
            if cand.exists():
                roots_list.append(cand)
    except Exception:
        pass

# Extra dynamic discovery on multiple prefixes (expanded)
def _append_dynamic_roots_generic(roots_list, split_name, prefixes=('/data', '/mnt', '/opt/data', '/app/data', '/datasets', '/workspace', '/workspace/data')):
    for pref in prefixes:
        try:
            p = Path(pref)
            if not p.exists():
                continue
            for base in p.glob('*uw*gi*/*'):
                if not base.is_dir():
                    continue
                cand = base / split_name
                if cand.exists():
                    roots_list.append(cand)
        except Exception:
            pass

_append_dynamic_roots(TRAIN_IMG_ROOTS, 'train')
_append_dynamic_roots(TEST_IMG_ROOTS, 'test')
_append_dynamic_roots_generic(TRAIN_IMG_ROOTS, 'train', prefixes=('/data','/mnt','/opt/data','/app/data','/datasets','/workspace','/workspace/data'))
_append_dynamic_roots_generic(TEST_IMG_ROOTS, 'test', prefixes=('/data','/mnt','/opt/data','/app/data','/datasets','/workspace','/workspace/data'))

def _unique_existing(paths):
    seen = set(); out = []
    for p in paths:
        ps = str(p)
        if ps in seen:
            continue
        seen.add(ps)
        if Path(p).exists():
            out.append(Path(p))
    return out

TRAIN_IMG_ROOTS = _unique_existing(TRAIN_IMG_ROOTS) or TRAIN_IMG_ROOTS
TEST_IMG_ROOTS = _unique_existing(TEST_IMG_ROOTS) or TEST_IMG_ROOTS
print('[PATH ROOTS] Train roots existing:', [str(p) for p in TRAIN_IMG_ROOTS if Path(p).exists()])
print('[PATH ROOTS] Test roots existing:', [str(p) for p in TEST_IMG_ROOTS if Path(p).exists()])

print('[LOAD] Reading CSVs...')
train_df = pd.read_csv(TRAIN_CSV)
test_df = pd.read_csv(TEST_CSV)
print(train_df.head(3))
print(test_df.head(3))
print(f"train rows={len(train_df)} unique ids={train_df['id'].nunique()} classes={train_df['class'].unique().tolist()}")

# Parse id: case###_day###_slice_####
id_pat = re.compile(r'^case(\d+)_day(\d+)_slice_(\d+)$')
def parse_id(s):
    m = id_pat.match(s)
    if not m:
        return (None, None, None)
    return tuple(int(x) for x in m.groups())

parsed = train_df['id'].apply(parse_id)
train_df[['case','day','slice']] = pd.DataFrame(parsed.tolist(), index=train_df.index)
parsed_t = test_df['id'].apply(parse_id)
test_df[['case','day','slice']] = pd.DataFrame(parsed_t.tolist(), index=test_df.index)

assert train_df['case'].notna().all(), 'ID parse failed for train'
assert test_df['case'].notna().all(), 'ID parse failed for test'

# Basic stats
per_id_any_pos = (train_df.assign(has_pos=train_df['segmentation'].notna())
                           .groupby('id')['has_pos'].any().rename('any_pos'))
pos_ratio = per_id_any_pos.mean()
print(f"[EDA] Positive-slice ratio (any class): {pos_ratio:.3f}")
per_case_ratio = (train_df.assign(has_pos=train_df['segmentation'].notna())
                           .groupby(['case','id'])['has_pos'].any().groupby('case').mean())
per_case_len = train_df.drop_duplicates('id').groupby('case')['id'].count()
print('[EDA] Per-case positive ratio stats:')
print(per_case_ratio.describe())
print('[EDA] Per-case slice-count stats:')
print(per_case_len.describe())

# Build StratifiedGroupKFold by case with combined stratification (pos-ratio bin x len bin)
n_folds = 5
cases = per_case_ratio.index.values
y_cont = per_case_ratio.values
lens = per_case_len.reindex(cases).values
# Bins
n_bins_pos = int(np.minimum(8, max(2, len(y_cont)//10)))
pos_bins = pd.qcut(y_cont, q=n_bins_pos, duplicates='drop', labels=False).astype(int) if len(np.unique(y_cont))>1 else np.zeros_like(y_cont, dtype=int)
n_bins_len = int(np.minimum(5, max(2, len(lens)//15)))
len_bins = pd.qcut(lens, q=n_bins_len, duplicates='drop', labels=False).astype(int) if len(np.unique(lens))>1 else np.zeros_like(lens, dtype=int)
combo_bins = (pos_bins.astype(int) * 10 + len_bins.astype(int)).astype(int)
sgkf = StratifiedGroupKFold(n_splits=n_folds, shuffle=True, random_state=42)
case_to_fold = {}
for fold, (_, val_idx) in enumerate(sgkf.split(cases, combo_bins, groups=cases)):
    for c in cases[val_idx]:
        case_to_fold[int(c)] = fold
print('[CV] Fold distribution (cases per fold):',
      pd.Series(case_to_fold).value_counts().sort_index().to_dict())

# Map id -> fold via case
id_case = train_df.drop_duplicates('id')[['id','case','day','slice']]
id_case['fold'] = id_case['case'].map(case_to_fold)
assert id_case['fold'].notna().all(), 'Some ids missing fold assignment'
id_case.to_csv('folds.csv', index=False)
print('[CV] Saved folds.csv with columns: id, case, day, slice, fold')

# Hardened path resolver with glob and multi-root search
def id_to_rel_candidates(id_str):
    case, day, sl = parse_id(id_str)
    # primary pattern under scans/ (official)
    rel1 = Path(f'case{case}') / f'day{day}' / 'scans' / f'slice_{sl:04d}*'
    # mirrors with case{case}_day{day} folder name
    rel2 = Path(f'case{case}') / f'case{case}_day{day}' / 'scans' / f'slice_{sl:04d}*'
    # fallback without scans/
    rel3 = Path(f'case{case}') / f'day{day}' / f'slice_{sl:04d}*'
    rel4 = Path(f'case{case}') / f'case{case}_day{day}' / f'slice_{sl:04d}*'
    return [rel1, rel2, rel3, rel4]

def resolve_path(id_str, roots):
    for rel_glob in id_to_rel_candidates(id_str):
        for r in roots:
            base = Path(r)
            if not base.exists():
                continue
            matches = sorted(base.glob(str(rel_glob)))
            if matches:
                return Path(os.path.normpath(str(matches[0])))
    # deterministic fallback (expected canonical path under scans with .png)
    case, day, sl = parse_id(id_str)
    return Path(roots[0]) / f'case{case}' / f'day{day}' / 'scans' / f'slice_{sl:04d}.png'

# Quick existence check on a few samples
sample_ids = id_case['id'].sample(min(5, len(id_case)), random_state=0).tolist()
missing = 0
for s in sample_ids:
    p = resolve_path(s, TRAIN_IMG_ROOTS)
    ex = p.exists()
    print(f'[PATH] {s} -> {p} exists={ex}')
    missing += (not ex)
print(f'[PATH] Missing among samples: {missing}/{len(sample_ids)} (expected early if data not mounted)')

print('[EDA DONE]')

# expose resolve_path and parse_id for later cells
globals()['resolve_path'] = resolve_path
globals()['parse_id'] = parse_id

[PATH ROOTS] Train roots existing: ['external_data/uw-madison-gi-tract-image-segmentation/train', 'train']
[PATH ROOTS] Test roots existing: ['external_data/uw-madison-gi-tract-image-segmentation/train', 'test']
[LOAD] Reading CSVs...
                        id        class segmentation
0  case77_day20_slice_0001  large_bowel          NaN
1  case77_day20_slice_0001  small_bowel          NaN
2  case77_day20_slice_0001      stomach          NaN
                         id        class
0  case123_day20_slice_0001  large_bowel
1  case123_day20_slice_0001  small_bowel
2  case123_day20_slice_0001      stomach
train rows=95088 unique ids=31696 classes=['large_bowel', 'small_bowel', 'stomach']


[EDA] Positive-slice ratio (any class): 0.428
[EDA] Per-case positive ratio stats:
count    76.000000
mean      0.430776
std       0.066099
min       0.243056
25%       0.383681
50%       0.438368
75%       0.472222
max       0.570312
Name: has_pos, dtype: float64
[EDA] Per-case slice-count stats:
count     76.000000
mean     417.052632
std      126.290870
min      144.000000
25%      420.000000
50%      432.000000
75%      432.000000
max      576.000000
Name: id, dtype: float64
[CV] Fold distribution (cases per fold): {0: 15, 1: 15, 2: 16, 3: 15, 4: 15}
[CV] Saved folds.csv with columns: id, case, day, slice, fold
[PATH] case20_day24_slice_0084 -> external_data/uw-madison-gi-tract-image-segmentation/train/case20/case20_day24/scans/slice_0084_266_266_1.50_1.50.png exists=True
[PATH] case111_day19_slice_0055 -> external_data/uw-madison-gi-tract-image-segmentation/train/case111/case111_day19/scans/slice_0055_266_266_1.50_1.50.png exists=True
[PATH] case33_day0_slice_0014 -> external_data



In [6]:
# Utilities: RLE encode/decode, image loader with normalization + body crop, 2.5D stack, Dataset
import numpy as np, cv2, math, warnings
from skimage.measure import label, regionprops
import albumentations as A
import torch
from torch.utils.data import Dataset

CLASSES = ['large_bowel','small_bowel','stomach']  # canonical order
IMG_SIZE = 384
CTX_OFFSETS = [-2,-1,0,1,2]

# RLE utils (Kaggle GI: column-major / Fortran order, 1-indexed starts)
def rle_decode(rle, shape):
    if not isinstance(rle, str) or rle.strip() == '':
        return np.zeros(shape, dtype=np.uint8)
    s = list(map(int, rle.split()))
    starts, lengths = s[0::2], s[1::2]
    starts = np.asarray(starts) - 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape, order='F')

def rle_encode(mask):
    # mask: HxW, binary {0,1}; returns 'start length ...' with Fortran order
    pixels = mask.T.flatten()  # Fortran order equivalent
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def decode_row_to_mask(row, shape):
    return rle_decode(row['segmentation'] if isinstance(row['segmentation'], str) else '', shape)

def build_id_mask(train_df, id_str, shape):
    m = np.zeros((len(CLASSES), *shape), dtype=np.uint8)
    sub = train_df[train_df['id']==id_str]
    cls_to_ch = {c:i for i,c in enumerate(CLASSES)}
    for _, r in sub.iterrows():
        ch = cls_to_ch[r['class']]
        m[ch] = decode_row_to_mask(r, shape)
    return m

# Robust intensity normalization and body crop
def robust_norm(img_u16, clip_low=0.5, clip_high=99.5, eps=1e-3):
    img = img_u16.astype(np.float32)
    lo = np.percentile(img, clip_low)
    hi = np.percentile(img, clip_high)
    if hi <= lo:
        hi = lo + 1.0
    img = np.clip(img, lo, hi)
    img = (img - lo) / (hi - lo + eps)
    return img

def body_crop_bbox(image01, thresh=0.1, margin=32):
    # image01 in [0,1], HxW; returns (x1,y1,x2,y2)
    mask = (image01 > thresh).astype(np.uint8)
    if mask.sum() == 0:
        h, w = image01.shape[:2]
        return (0, 0, w, h)
    lbl = label(mask, connectivity=1)
    regions = regionprops(lbl)
    if not regions:
        h, w = image01.shape[:2]
        return (0, 0, w, h)
    rp = max(regions, key=lambda r: r.area)
    minr, minc, maxr, maxc = rp.bbox
    h, w = image01.shape[:2]
    minr = max(0, minr - margin); minc = max(0, minc - margin)
    maxr = min(h, maxr + margin); maxc = min(w, maxc + margin)
    return (minc, minr, maxc, maxr)  # x1,y1,x2,y2

def apply_crop(img, bbox):
    x1,y1,x2,y2 = bbox
    return img[y1:y2, x1:x2]

def resize_to_square(img, size=IMG_SIZE):
    h, w = img.shape[:2]
    scale = min(size / h, size / w) if (h>0 and w>0) else 1.0
    nh, nw = max(1,int(round(h*scale))), max(1,int(round(w*scale)))
    img_r = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_LINEAR)
    out = np.zeros((size, size), dtype=img_r.dtype)
    y0 = (size - nh)//2; x0 = (size - nw)//2
    out[y0:y0+nh, x0:x0+nw] = img_r
    return out, (x0, y0, nw, nh, h, w)  # pad+scale meta for the cropped image

def warp_mask_like(mask, meta):
    x0, y0, nw, nh, h0, w0 = meta
    if mask.size == 0:
        return np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
    mask_r = cv2.resize(mask.astype(np.uint8), (nw, nh), interpolation=cv2.INTER_NEAREST)
    out = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
    out[y0:y0+nh, x0:x0+nw] = mask_r
    return out

def inverse_unwarp_mask(mask_sq, meta, bbox, orig_shape):
    # mask_sq: IMG_SIZExIMG_SIZE; meta=(x0,y0,nw,nh,h_crop,w_crop); bbox=(x1,y1,x2,y2); orig_shape=(H0,W0)
    x0, y0, nw, nh, h_crop, w_crop = meta
    x1, y1, x2, y2 = bbox
    H0, W0 = orig_shape
    crop_space = np.zeros((h_crop, w_crop), dtype=np.uint8)
    if nh>0 and nw>0:
        inner = mask_sq[y0:y0+nh, x0:x0+nw].astype(np.uint8)
        if inner.size > 0:
            crop_space = cv2.resize(inner, (w_crop, h_crop), interpolation=cv2.INTER_NEAREST)
    full = np.zeros((H0, W0), dtype=np.uint8)
    # guard bbox within image
    x1c, y1c = max(0, x1), max(0, y1)
    x2c, y2c = min(W0, x2), min(H0, y2)
    if (y2c>y1c) and (x2c>x1c):
        full[y1c:y2c, x1c:x2c] = crop_space[(y1c - y1):(y2c - y1), (x1c - x1):(x2c - x1)]
    return full

def inverse_unwarp_probs(prob_sq, meta, bbox, orig_shape):
    # prob_sq: IMG_SIZExIMG_SIZE float32 in [0,1]
    x0, y0, nw, nh, h_crop, w_crop = map(int, meta)
    x1, y1, x2, y2 = map(int, bbox)
    H0, W0 = map(int, orig_shape)
    inner = prob_sq[y0:y0+nh, x0:x0+nw].astype(np.float32)
    if inner.size == 0 or h_crop <= 0 or w_crop <= 0:
        crop_prob = np.zeros((h_crop, w_crop), dtype=np.float32)
    else:
        crop_prob = cv2.resize(inner, (w_crop, h_crop), interpolation=cv2.INTER_LINEAR)
    full = np.zeros((H0, W0), dtype=np.float32)
    x1c, y1c = max(0, x1), max(0, y1)
    x2c, y2c = min(W0, x2), min(H0, y2)
    if (y2c > y1c) and (x2c > x1c):
        full[y1c:y2c, x1c:x2c] = crop_prob[(y1c - y1):(y2c - y1), (x1c - x1):(x2c - x1)]
    return full

def read_png_u16(path):
    img = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
    if img is None:
        raise FileNotFoundError(path)
    if img.ndim == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    if img.dtype != np.uint16:
        img = img.astype(np.uint16)
    return img

def get_neighbor_ids(center_id, all_slices_sorted):
    case, day, sl = parse_id(center_id)
    idx = all_slices_sorted.index(sl)
    res = []
    for off in CTX_OFFSETS:
        j = idx + off
        j = min(max(j, 0), len(all_slices_sorted)-1)
        res.append(all_slices_sorted[j])
    return [f"case{case}_day{day}_slice_{s:04d}" for s in res]

class UWGITractDataset(Dataset):
    def __init__(self, df_ids, train_df=None, roots=None, mode='train', aug=None):
        # df_ids: dataframe with columns id, case, day, slice; one row per unique id
        self.df_ids = df_ids.reset_index(drop=True)
        self.train_df = train_df
        self.roots = roots or [Path('train')]
        self.mode = mode
        self.aug = aug
        g = self.df_ids.groupby(['case','day'])['slice'].apply(lambda s: sorted(s.tolist()))
        self.slice_map = {(int(c),int(d)): lst for (c,d), lst in g.items()}

    def __len__(self):
        return len(self.df_ids)

    def _proc_image(self, id_str, bbox=None):
        p = resolve_path(id_str, self.roots)
        img_u16 = read_png_u16(p)
        img01 = robust_norm(img_u16)
        if bbox is None:
            bbox = body_crop_bbox(img01)
        img_crop = apply_crop(img01, bbox)
        img_sq, meta = resize_to_square(img_crop, IMG_SIZE)
        return img_sq.astype(np.float32), bbox, meta, img_u16.shape[:2]

    def __getitem__(self, idx):
        row = self.df_ids.iloc[idx]
        id_str = row['id']
        case, day, sl = int(row['case']), int(row['day']), int(row['slice'])
        # Center first to establish bbox/meta for alignment across neighbors
        center_img, bbox, center_meta, orig_shape_center = self._proc_image(id_str, bbox=None)
        neighbors = get_neighbor_ids(id_str, self.slice_map[(case,day)])
        chans = []
        for nid in neighbors:
            try:
                img_sq, _, _, _ = self._proc_image(nid, bbox=bbox)  # use center bbox
            except FileNotFoundError:
                # Neighbor missing: fallback to center slice to keep channel count/stability
                img_sq = center_img
            chans.append(img_sq)
        img5 = np.stack(chans, axis=0)  # 5xHxW

        if self.mode != 'test':
            # Build center mask aligned to center image using center bbox + meta
            p_center = resolve_path(id_str, self.roots)
            img_u16 = read_png_u16(p_center)
            H0, W0 = img_u16.shape[:2]
            sub = self.train_df[self.train_df['id']==id_str]
            m3 = np.zeros((len(CLASSES), IMG_SIZE, IMG_SIZE), dtype=np.uint8)
            x1,y1,x2,y2 = bbox
            for ci, cls in enumerate(CLASSES):
                r = sub[sub['class']==cls].iloc[0]
                mask0 = decode_row_to_mask(r, (H0, W0))
                mask_crop = mask0[y1:y2, x1:x2]
                m3[ci] = warp_mask_like(mask_crop, center_meta)
            # Albumentations joint augs (geom only) on HxWxC image and list of masks
            if self.aug is not None:
                img_hwk = np.transpose(img5, (1,2,0))  # HxWx5
                masks_list = [m for m in m3]
                out = self.aug(image=img_hwk, masks=masks_list)
                img_hwk = out['image']
                masks_list = out['masks']
                img5 = np.transpose(img_hwk, (2,0,1))
                m3 = np.stack(masks_list, axis=0).astype(np.uint8)
            img_t = torch.from_numpy(img5).float()
            mask_t = torch.from_numpy(m3).float()
            return img_t, mask_t, id_str
        # test mode: return metadata for inverse mapping
        img_t = torch.from_numpy(img5).float()
        return img_t, id_str, bbox, center_meta, orig_shape_center

def get_train_aug():
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=8, p=0.5, border_mode=cv2.BORDER_REFLECT101),
        A.ElasticTransform(alpha=20, sigma=5, alpha_affine=5, p=0.15, border_mode=cv2.BORDER_REFLECT101),
        A.GridDistortion(distort_limit=0.15, p=0.3, border_mode=cv2.BORDER_REFLECT101),
        A.RandomBrightnessContrast(p=0.3),
        A.RandomGamma(gamma_limit=(80,120), p=0.3),
        A.GaussianBlur(blur_limit=3, p=0.2),
        A.GaussNoise(var_limit=(5e-4, 1e-3), p=0.2),
    ])

def get_valid_aug():
    return A.Compose([])

print('[UTILS READY] Dataset aligns neighbors to center crop and warps masks consistently. Includes inverse_unwarp_mask()/inverse_unwarp_probs() and test metadata.')

[UTILS READY] Dataset aligns neighbors to center crop and warps masks consistently. Includes inverse_unwarp_mask()/inverse_unwarp_probs() and test metadata.


In [7]:
# Caching, unit tests, and model/loss skeleton (no training yet)
import os, math, time, json, gc
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import WeightedRandomSampler
import pandas as pd

# Unit test: RLE encode/decode round-trip
def _unit_test_rle():
    rng = np.random.default_rng(0)
    H, W = 64, 64
    m = (rng.random((H,W)) > 0.8).astype(np.uint8)
    r = rle_encode(m)
    m2 = rle_decode(r, (H,W))
    assert np.array_equal(m, m2), 'RLE round-trip failed'
    print('[TEST] RLE round-trip OK')

# Only run unit test if RLE helpers exist in globals (depends on Cell 3)
if 'rle_encode' in globals() and 'rle_decode' in globals():
    try:
        _unit_test_rle()
    except Exception as e:
        print('[TEST] RLE round-trip skipped due to error:', e)
else:
    print('[TEST] Skipping RLE round-trip (helpers not yet defined in kernel)')

# Cache builder: persists preprocessed stacks and metadata to disk
def build_cache(df_ids, train_df=None, roots=None, out_dir='cache/train', mode='train', log_every=200):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    n = len(df_ids)
    t0 = time.time()
    for i, row in df_ids.reset_index(drop=True).iterrows():
        id_str = row['id']
        out_path = out_dir / f"{id_str}.npz"
        if out_path.exists():
            if (i % log_every)==0:
                print(f"[CACHE] ({i}/{n}) skip exists {out_path}")
            continue
        try:
            if mode == 'test':
                # Use test-mode dataset to get metadata directly
                ds = UWGITractDataset(pd.DataFrame([row]), train_df=None, roots=roots, mode='test', aug=None)
                img_t, _id, bbox, meta, orig_shape = ds[0]
                img5 = img_t.numpy().astype(np.float16)
                np.savez_compressed(out_path,
                    img5=img5,
                    bbox=np.array(bbox, np.int32),
                    meta=np.array(meta, np.int32),
                    orig_shape=np.array(orig_shape, np.int32))
            else:
                # Train mode: build image+mask via train-mode; fetch identical metadata via test-mode
                ds_train = UWGITractDataset(pd.DataFrame([row]), train_df=train_df, roots=roots, mode='train', aug=None)
                img_t, mask_t, _id = ds_train[0]
                ds_meta = UWGITractDataset(pd.DataFrame([row]), train_df=train_df, roots=roots, mode='test', aug=None)
                _, _, bbox, meta, orig_shape = ds_meta[0]
                img5 = img_t.numpy().astype(np.float16)
                m3 = mask_t.numpy().astype(np.uint8)
                np.savez_compressed(out_path,
                    img5=img5, m3=m3,
                    bbox=np.array(bbox, np.int32),
                    meta=np.array(meta, np.int32),
                    orig_shape=np.array(orig_shape, np.int32))
        except FileNotFoundError:
            if (i % log_every)==0:
                print(f"[CACHE] ({i}/{n}) MISSING image for {id_str}")
        if (i % log_every)==0 and i>0:
            dt = time.time()-t0
            print(f"[CACHE] {i}/{n} done in {dt/60:.1f} min")
            gc.collect()
            try:
                torch.cuda.empty_cache()
            except Exception:
                pass
    print('[CACHE] Done:', out_dir)

# Sampler weights to target ~60-65% positive slices
def build_pos_oversampler(df_ids, train_df, target_pos_frac=0.62):
    any_pos = (train_df.assign(has_pos=train_df['segmentation'].notna())
                        .groupby('id')['has_pos'].any())
    ids = df_ids['id'].values
    flags = any_pos.reindex(ids).fillna(False).values.astype(np.uint8)
    pos = flags.mean()
    n = len(flags); n_pos = flags.sum(); n_neg = n - n_pos
    if n_pos == 0 or n_neg == 0:
        weights = np.ones(n, dtype=np.float32)
    else:
        w_neg = 1.0
        w_pos = (target_pos_frac * n_neg * w_neg) / ( (1 - target_pos_frac) * n_pos )
        w_pos = float(max(w_pos, 1e-3))
        weights = np.where(flags==1, w_pos, w_neg).astype(np.float32)
    sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
    return sampler

# Model factory: UNet++ tf_efficientnet_b3, in_channels=5, classes=3 (canonical order)
def build_model(device='cuda', encoder='tf_efficientnet_b3', in_ch=5, classes=3):
    # Lazy import to avoid heavy import at cell-exec time
    import segmentation_models_pytorch as smp
    model = smp.UnetPlusPlus(encoder_name=encoder, in_channels=in_ch, classes=classes, activation=None)
    return model.to(device)

# Loss: BCEWithLogits + Tversky(alpha=0.7, beta=0.3) with class weights
_printed_combo_debug = {'done': False}

def _ensure_chw_targets(t):
    # t can be (B,3,H,W) or (B,H,W,3); convert to (B,3,H,W)
    if t.dim() == 3:  # (3,H,W) single sample (unlikely here)
        t = t.unsqueeze(0)
    if t.dim() == 4:
        if t.shape[-1] == 3 and t.shape[1] != 3:
            return t.permute(0, 3, 1, 2).contiguous()
        A = t.shape[1]
        if A not in (1, 3) and t.shape[-1] in (1, 3):
            return t.permute(0, 3, 1, 2).contiguous()
    return t

class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3, eps=1e-6):
        super().__init__(); self.alpha=alpha; self.beta=beta; self.eps=eps
    def forward(self, logits, targets):
        # compute in fp32 to stabilize under amp
        with torch.cuda.amp.autocast(enabled=False):
            logits = logits.float()
            targets = _ensure_chw_targets(targets.float())
            probs = torch.sigmoid(logits)
            dims = (0,2,3)
            tp = (probs*targets).sum(dim=dims)
            fp = (probs*(1-targets)).sum(dim=dims)
            fn = ((1-probs)*targets).sum(dim=dims)
            t = (tp + self.alpha*fp + self.beta*fn + self.eps)
            return 1.0 - (tp + self.eps)/t

class ComboLoss(nn.Module):
    def __init__(self, bce_weight=0.5, tv_weight=0.5, tv_alpha=0.7, tv_beta=0.3, class_weights=(1.1,1.35,1.0)):
        super().__init__()
        # store raw per-class weights
        self.pos_w = nn.Parameter(torch.tensor(class_weights, dtype=torch.float32), requires_grad=False)
        self.tvl = TverskyLoss(alpha=tv_alpha, beta=tv_beta)
        self.bw = bce_weight; self.tw = tv_weight
    def forward(self, logits, targets):
        # Enforce layout to (B,3,H,W) for both
        if logits.dim() == 4 and logits.shape[1] not in (1,3) and logits.shape[-1] in (1,3):
            logits = logits.permute(0,3,1,2).contiguous()
        targets = _ensure_chw_targets(targets)
        # Build per-element weight map: 1 + (pos_w-1)*targets, where pos_w is per-channel
        w = self.pos_w.to(logits.device).reshape(-1)  # (C,)
        pw = w[None, :, None, None]  # (1,C,1,1)
        ew = 1.0 + (pw - 1.0) * targets
        if not _printed_combo_debug['done']:
            try:
                print('[LOSS-DBG] logits', tuple(logits.shape), 'targets', tuple(targets.shape), 'elem_w', tuple(ew.shape))
            finally:
                _printed_combo_debug['done'] = True
        bce = F.binary_cross_entropy_with_logits(logits, targets, weight=ew)
        tv = self.tvl(logits, targets).mean()
        return self.bw*bce + self.tw*tv

print('[CACHE/MODEL UTILS READY] Cache saves img5(float16)+masks+metadata; Tversky computed in fp32 under AMP. Lazy-imported SMP in build_model().')

[TEST] RLE round-trip OK
[CACHE/MODEL UTILS READY] Cache saves img5(float16)+masks+metadata; Tversky computed in fp32 under AMP. Lazy-imported SMP in build_model().


In [None]:
# Filesystem scan for PNG sources (timeboxed)
import os, time, glob, fnmatch
from pathlib import Path

def scan_for_slices(base_dirs, patterns=("**/case*/day*/scans/slice_*.png", "**/case*/day*/slice_*.png"),
                    max_matches=200, timeout_sec=60):
    t0 = time.time()
    found = []
    checked_dirs = []
    for b in base_dirs:
        b = Path(b)
        if not b.exists():
            continue
        checked_dirs.append(str(b))
        for pat in patterns:
            try:
                for p in b.rglob(pat):
                    found.append(str(p))
                    if len(found) >= max_matches or (time.time()-t0) > timeout_sec:
                        return found, checked_dirs
            except Exception as e:
                print(f"[SCAN] Error scanning {b} with {pat}: {e}")
        if (time.time()-t0) > timeout_sec:
            break
    return found, checked_dirs

candidate_dirs = [
    Path('.'),
    Path('./train'), Path('./test'),
    Path('/kaggle/input'),
    Path('/mnt'),
    Path('/data'),
    Path('/workspace'),
]
print('[SCAN] Searching for slice_*.png under candidates (timeboxed)...')
found, checked = scan_for_slices(candidate_dirs, max_matches=50, timeout_sec=30)
print('[SCAN] Checked roots:', checked)
print(f"[SCAN] Found {len(found)} sample files")
if found:
    for p in found[:10]:
        print(' ', p)
else:
    print('[SCAN] No PNGs found. Likely images are not mounted in this environment.')

# If any found under a recognizable uw-madison path, suggest updating TRAIN_IMG_ROOTS/TEST_IMG_ROOTS accordingly.
print('[SCAN DONE]')

In [None]:
# Training & Inference skeleton (5-fold, AMP, cosine, EMA, H-flip TTA + post-proc)
import os, math, time, gc, json, random
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader
from skimage.measure import label, regionprops
import cv2
from scipy.ndimage import binary_fill_holes

# Memory/throughput guards
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')
try:
    cv2.setNumThreads(0)
except Exception:
    pass

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

# Post-processing defaults (order: [large, small, stomach]) per expert advice
PP_THRESH = [0.50, 0.42, 0.47]
PP_MIN_AREA = [1200, 900, 800]
# Optionally override with tuned values if available
try:
    if Path('tuned_pp.json').exists():
        _pp = json.loads(Path('tuned_pp.json').read_text())
        if isinstance(_pp.get('thr'), (list, tuple)) and isinstance(_pp.get('min_area'), (list, tuple)):
            PP_THRESH = [float(x) for x in _pp['thr']]
            PP_MIN_AREA = [int(x) for x in _pp['min_area']]
            print('[PP] Overridden from tuned_pp.json:', PP_THRESH, PP_MIN_AREA)
except Exception as _e:
    print('[PP] tuned_pp.json load failed:', _e)

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def dice_score(pred, targ, eps=1e-6):
    # pred,targ: (H,W) binary
    inter = (pred & targ).sum()
    d = (2*inter + eps) / (pred.sum() + targ.sum() + eps)
    return float(d)

# HD95 proxy helpers (empty-safe). We will use later for OOF tuning.
def _surface_distances(a, b):
    # Simple chessboard distance transform based symmetric approx for speed; not exact hd95
    import scipy.ndimage as ndi
    a = a.astype(bool); b = b.astype(bool)
    if not a.any() and not b.any():
        return np.array([0.0])
    if not a.any() or not b.any():
        # cap by 100 as recommended
        return np.array([100.0])
    a_dt = ndi.distance_transform_cdt(~a, metric='chessboard')
    b_dt = ndi.distance_transform_cdt(~b, metric='chessboard')
    a_b = a_dt[b]
    b_a = b_dt[a]
    if a_b.size == 0: a_b = np.array([0.0])
    if b_a.size == 0: b_a = np.array([0.0])
    return np.concatenate([a_b, b_a]).astype(np.float32)

def hd95_proxy(a, b):
    d = _surface_distances(a, b)
    return float(np.percentile(d, 95)) if d.size else 0.0

class EMA:
    def __init__(self, model, decay=0.9995):
        self.decay = decay
        self.shadow = {}
        for n, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[n] = p.detach().clone()
    def update(self, model):
        for n, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[n].mul_(self.decay).add_(p.detach(), alpha=1-self.decay)
    def apply_to(self, model):
        for n, p in model.named_parameters():
            if p.requires_grad:
                p.data.copy_(self.shadow[n])

def make_loaders(fold, batch_size=10, num_workers=4, target_pos_frac=0.62):
    folds = pd.read_csv('folds.csv')
    tr_ids = folds[folds['fold']!=fold][['id','case','day','slice']].reset_index(drop=True)
    va_ids = folds[folds['fold']==fold][['id','case','day','slice']].reset_index(drop=True)
    train_ds = UWGITractDataset(tr_ids, train_df=train_df, roots=TRAIN_IMG_ROOTS, mode='train', aug=get_train_aug())
    valid_ds = UWGITractDataset(va_ids, train_df=train_df, roots=TRAIN_IMG_ROOTS, mode='valid', aug=get_valid_aug())
    sampler = build_pos_oversampler(tr_ids, train_df, target_pos_frac=target_pos_frac)
    # Safer loader settings to avoid hangs
    pf = None if num_workers == 0 else 2
    train_dl = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=True, drop_last=True, persistent_workers=False, prefetch_factor=pf)
    valid_dl = DataLoader(valid_ds, batch_size=max(1,batch_size//2), shuffle=False, num_workers=num_workers, pin_memory=True, persistent_workers=False, prefetch_factor=pf)
    return train_dl, valid_dl, va_ids

def _find_encoder_stem_conv(enc):
    # Placeholder to keep API compatibility; unused in TinyUNet path
    return None

def _build_tmp_3ch_b3(device='cpu'):
    # Unused in TinyUNet path; keep for interface completeness
    return None

def _force_stem_mean_rgb_mean(model, device='cuda'):
    # Unused in TinyUNet path
    return False

def build_model_b3(device='cuda'):
    # Primary model per expert advice: SMP Unet with ResNet34 backbone (ImageNet), in_ch=5, classes=3
    gc.collect();
    try:
        torch.cuda.empty_cache()
    except Exception:
        pass
    import segmentation_models_pytorch as smp
    model = smp.Unet(encoder_name='resnet34', encoder_weights='imagenet', in_channels=5, classes=3, activation=None)
    return model.to(device)

def _collect_valid_metadata(va_ids):
    # Build id -> (bbox, meta, orig_shape) using test-mode dataset
    ds_meta = UWGITractDataset(va_ids, train_df=None, roots=TRAIN_IMG_ROOTS, mode='test', aug=None)
    dl_meta = DataLoader(ds_meta, batch_size=8, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=False)
    meta_map = {}
    with torch.no_grad():
        for batch in dl_meta:
            imgs, ids, bboxes, metas, orig_shapes = batch
            for i, id_str in enumerate(ids):
                meta_map[id_str] = (tuple(int(x) for x in bboxes[i]), tuple(int(x) for x in metas[i]), tuple(int(x) for x in orig_shapes[i]))
    return meta_map

def _ensure_nchw(t):
    # Convert (B,H,W,C) to (B,C,H,W) if detected
    if t.dim() == 4 and t.shape[1] != 3 and t.shape[-1] == 3:
        return t.permute(0,3,1,2).contiguous()
    return t

def _align_logits_targets(logits, masks):
    # Ensure both are (B,3,H,W). Handle ambiguous NHWC/NCHW cases.
    if logits.dim() == 4 and logits.shape[1] not in (1,3) and logits.shape[-1] in (1,3):
        logits = logits.permute(0,3,1,2).contiguous()
    if masks.dim() == 4 and masks.shape[1] not in (1,3) and masks.shape[-1] in (1,3):
        masks = masks.permute(0,3,1,2).contiguous()
    # If still mismatched, try swapping last and channel dims of logits to match masks
    if logits.shape != masks.shape:
        if logits.dim()==4 and masks.dim()==4 and logits.shape[-1]==3 and masks.shape[1]==3:
            logits = logits.permute(0,3,1,2).contiguous()
        elif logits.dim()==4 and masks.dim()==4 and masks.shape[-1]==3 and logits.shape[1]==3:
            masks = masks.permute(0,3,1,2).contiguous()
    return logits, masks

def train_one_fold(fold, epochs=40, lr=1e-3, wd=1e-4, batch_size=10, num_workers=4, device='cuda', patience=6, min_lr=1e-6):
    print(f"[TRAIN] Fold {fold} start")
    train_dl, valid_dl, va_ids = make_loaders(fold, batch_size=batch_size, num_workers=num_workers)
    # Clear caches before model init to avoid CUDA init errors
    gc.collect();
    try:
        torch.cuda.empty_cache()
    except Exception:
        pass
    model = build_model_b3(device=device)
    loss_fn = ComboLoss(bce_weight=0.5, tv_weight=0.5, tv_alpha=0.7, tv_beta=0.3, class_weights=(1.1,1.45,1.0))
    opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    steps_per_epoch = max(1, len(train_dl))
    total_steps = steps_per_epoch * epochs
    warmup = min(int(0.05*total_steps), max(steps_per_epoch, 1))
    def lr_schedule(step):
        if step < warmup:
            return step / max(1, warmup)
        t = (step - warmup) / max(1, total_steps - warmup)
        return min_lr/lr + (1 - min_lr/lr) * 0.5 * (1 + math.cos(math.pi * t))
    # Disable AMP to avoid kernel death during backward (diagnostic forward succeeded)
    scaler = GradScaler(enabled=False)
    ema = EMA(model, decay=0.9995)
    best_score = -1.0
    best_epoch = 0
    out_dir = Path('oof'); out_dir.mkdir(exist_ok=True, parents=True)
    log_every = 50
    step = 0
    for epoch in range(1, epochs+1):
        t0 = time.time()
        model.train()
        train_loss = 0.0
        for it, batch in enumerate(train_dl):
            imgs, masks, _ids = batch
            imgs = imgs.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            for g in opt.param_groups:
                g['lr'] = lr * lr_schedule(step)
            # Disable autocast to avoid potential mixed precision instability
            with autocast(enabled=False):
                logits = model(imgs)
                logits, masks = _align_logits_targets(logits, masks)
                if it == 0 and epoch == 1:
                    try:
                        print(f"[DBG] imgs={tuple(imgs.shape)} logits={tuple(logits.shape)} masks={tuple(masks.shape)}", flush=True)
                    except Exception:
                        pass
                loss = loss_fn(logits, masks)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            opt.zero_grad(set_to_none=True)
            ema.update(model)
            train_loss += loss.item()
            if (it+1) % log_every == 0:
                print(f"[Fold {fold}] epoch {epoch} it {it+1}/{len(train_dl)} loss {train_loss/(it+1):.4f} lr {opt.param_groups[0]['lr']:.2e}")
            step += 1
        # Validation using EMA weights without a second GPU model to save memory
        backup_sd = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        ema.apply_to(model)
        model.eval()
        dices = []
        with torch.no_grad():
            for imgs, masks, _ids in valid_dl:
                imgs = imgs.to(device, non_blocking=True)
                masks = masks.to(device, non_blocking=True)
                logits = model(imgs)
                logits, masks = _align_logits_targets(logits, masks)
                probs = torch.sigmoid(logits).float().cpu().numpy()
                tgts = masks.float().cpu().numpy()
                for b in range(probs.shape[0]):
                    for c in range(3):
                        p = (probs[b,c] > 0.5).astype(np.uint8)
                        t = (tgts[b,c] > 0.5).astype(np.uint8)
                        dices.append(dice_score(p, t))
        mean_dice = float(np.mean(dices)) if dices else 0.0
        model.load_state_dict(backup_sd, strict=True)
        model.train()
        dt = time.time()-t0
        print(f"[Fold {fold}] epoch {epoch} train_loss {train_loss/max(1,len(train_dl)):.4f} val_dice {mean_dice:.4f} time {dt/60:.1f}m")
        improved = mean_dice > best_score + 1e-5
        if improved:
            best_score = mean_dice
            best_epoch = epoch
            ema.apply_to(model)
            torch.save(model.state_dict(), f"model_fold{fold}.pt")
            model.load_state_dict(backup_sd, strict=True)
            print(f"[Fold {fold}] Saved best EMA model, dice {best_score:.4f}")
        if (epoch - best_epoch) >= patience:
            print(f"[Fold {fold}] Early stopping at epoch {epoch} (best {best_epoch})")
            break
        gc.collect();
        torch.cuda.empty_cache()
    print(f"[TRAIN] Fold {fold} done. Best dice {best_score:.4f} at epoch {best_epoch}")

    # Compute and save OOF square probs + metadata for this fold using best EMA model
    print(f"[OOF] Collecting OOF predictions for fold {fold} ...")
    meta_map = _collect_valid_metadata(va_ids)
    model_best = build_model_b3(device=device)
    model_best.load_state_dict(torch.load(f"model_fold{fold}.pt", map_location=device), strict=True)
    model_best.eval()
    ids_all, probs_all = [], []
    with torch.no_grad():
        for imgs, masks, _ids in valid_dl:
            imgs = imgs.to(device, non_blocking=True)
            logits = model_best(imgs)
            logits, _ = _align_logits_targets(logits, masks)
            probs = torch.sigmoid(logits).float().cpu().numpy()  # Bx3xHxW (square space)
            probs_all.append(probs)
            ids_all += list(_ids)
    probs_all = np.concatenate(probs_all, axis=0).astype(np.float16)
    np.save(f"oof_fold{fold}_ids.npy", np.array(ids_all, dtype=object))
    np.save(f"oof_fold{fold}_probs_sq.npy", probs_all)
    # Save metadata aligned to ids order for later inverse mapping and HD-aware tuning
    bboxes = np.array([meta_map[_id][0] for _id in ids_all], dtype=np.int32)
    metas = np.array([meta_map[_id][1] for _id in ids_all], dtype=np.int32)
    origs = np.array([meta_map[_id][2] for _id in ids_all], dtype=np.int32)
    np.savez_compressed(f"oof_fold{fold}_meta.npz", bbox=bboxes, meta=metas, orig_shape=origs)
    print(f"[OOF] Saved oof_fold{fold}_*.npy/npz")

def tta_hflip_predict(model, imgs):
    # imgs: Bx5xHxW
    logits = model(imgs)
    imgs_h = torch.flip(imgs, dims=[-1])
    logits_h = model(imgs_h)
    logits_h = torch.flip(logits_h, dims=[-1])
    return (logits + logits_h) / 2.0

def post_process_full(mask, cls_index):
    # mask: HxW uint8
    lbl = label(mask)
    if lbl.max() == 0:
        return mask
    areas = [(i, (lbl==i).sum()) for i in range(1, lbl.max()+1)]
    areas.sort(key=lambda x: x[1], reverse=True)
    keep = np.zeros_like(mask)
    kept = 0
    for i, a in areas:
        if a >= PP_MIN_AREA[cls_index]:
            keep[lbl==i] = 1
            kept += 1
            if cls_index==2 and kept>=1: break  # stomach: largest 1
            if cls_index==0 and kept>=3: break  # large: top 3
            if cls_index==1 and kept>=5: break  # small: top 5
    if cls_index == 2:
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
        keep = cv2.morphologyEx(keep, cv2.MORPH_CLOSE, kernel, iterations=1)
        keep = binary_fill_holes(keep.astype(bool)).astype(np.uint8)
    if cls_index == 1:
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
        keep = cv2.morphologyEx(keep, cv2.MORPH_OPEN, kernel, iterations=1)
    return keep

def _z_smooth_groups(id_info, window=3):
    # Smooth probs in square space per (case,day) along slice order
    from collections import defaultdict
    groups = defaultdict(list)
    for id_str in id_info.keys():
        c, d, s = parse_id(id_str)
        groups[(c,d)].append((s, id_str))
    for key, lst in groups.items():
        lst.sort(key=lambda x: x[0])
        ids_sorted = [k for _, k in lst]
        P = [id_info[k]['probs'] for k in ids_sorted]  # T x 3 x H x W
        T = len(P)
        if T >= 2 and window >= 3:
            k = window
            P_pad = [P[0]]*(k//2) + P + [P[-1]]*(k//2)
            for t in range(T):
                acc = None
                for j in range(t, t+k):
                    X = P_pad[j]
                    acc = X if acc is None else acc + X
                sm = acc / float(k)
                id_info[ids_sorted[t]]['probs'] = sm
    return id_info

def _apply_z_consistency(masks_map):
    # masks_map: dict[id_str] -> np array (3,H,W) uint8 after per-slice PP
    from collections import defaultdict
    groups = defaultdict(list)
    for id_str in masks_map.keys():
        c, d, s = parse_id(id_str)
        groups[(c,d)].append((s, id_str))
    for (c,d), lst in groups.items():
        lst.sort(key=lambda x: x[0])
        ids_sorted = [k for _, k in lst]
        T = len(ids_sorted)
        for cls_index in [0,1]:  # bowels only
            for t, id_cur in enumerate(ids_sorted):
                cur = masks_map[id_cur][cls_index].copy()
                if cur.sum() == 0:
                    continue
                prev = masks_map[ids_sorted[t-1]][cls_index] if (t-1) >= 0 else None
                nxt = masks_map[ids_sorted[t+1]][cls_index] if (t+1) < T else None
                support = ((prev is not None and prev.any()) or (nxt is not None and nxt.any()))
                if support:
                    continue
                # drop 1-slice small CCs below 1.2 * min_area
                lbl = label(cur)
                if lbl.max() == 0:
                    continue
                keep = np.zeros_like(cur)
                for i in range(1, lbl.max()+1):
                    a = (lbl==i).sum()
                    if a >= int(1.2 * PP_MIN_AREA[cls_index]):
                        keep[lbl==i] = 1
                masks_map[id_cur][cls_index] = keep
    return masks_map

def infer_test_and_submit(device='cuda'):
    print('[INFER] Loading models...')
    models = []
    for fold in range(5):
        p = Path(f"model_fold{fold}.pt")
        if not p.exists():
            print(f"[INFER] Missing model {p}, skipping fold {fold}")
            continue
        m = build_model_b3(device=device)
        sd = torch.load(p, map_location=device)
        m.load_state_dict(sd, strict=True); m.eval()
        models.append(m)
    assert models, 'No trained models found'
    sub = pd.read_csv('test.csv')
    uniq_ids = sub['id'].unique().tolist()
    df_ids = pd.DataFrame({'id':uniq_ids})
    parsed = df_ids['id'].apply(parse_id)
    df_ids[['case','day','slice']] = pd.DataFrame(parsed.tolist(), index=df_ids.index)
    ds = UWGITractDataset(df_ids, train_df=None, roots=TEST_IMG_ROOTS, mode='test', aug=None)
    dl = DataLoader(ds, batch_size=4, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=False)
    id_info = {}  # id -> dict(probs, bbox, meta, orig_shape)
    print('[INFER] Predicting...')
    with torch.no_grad():
        t0 = time.time()
        for bi, batch in enumerate(dl):
            imgs, ids, bboxes, metas, orig_shapes = batch
            imgs = imgs.to(device)
            logits_sum = None
            for m in models:
                logits = tta_hflip_predict(m, imgs)
                logits_sum = logits if logits_sum is None else (logits_sum + logits)
            probs = torch.sigmoid(logits_sum / len(models)).float().cpu().numpy()
            assert probs.shape[0] == len(ids) == len(bboxes) == len(metas) == len(orig_shapes)
            for i, id_str in enumerate(ids):
                bb = tuple(int(x) for x in bboxes[i])
                me = tuple(int(x) for x in metas[i])
                osz = tuple(int(x) for x in orig_shapes[i])
                id_info[id_str] = {'probs': probs[i], 'bbox': bb, 'meta': me, 'orig_shape': osz}
            if (bi+1) % 25 == 0:
                print(f"[INFER] batch {bi+1}/{len(dl)} elapsed {(time.time()-t0):.1f}s")
    # z-smoothing per (case,day) before thresholding/post-proc
    id_info = _z_smooth_groups(id_info, window=3)
    # Build per-id masks with inverse mapping, threshold, and per-slice post-processing
    print('[INFER] Post-processing and z-consistency...]')
    masks_map = {}  # id -> (3,H,W) uint8
    for id_str, info in id_info.items():
        m3 = []
        for ch in range(3):
            full_prob = inverse_unwarp_probs(info['probs'][ch], info['meta'], info['bbox'], info['orig_shape'])
            full_mask = (full_prob >= PP_THRESH[ch]).astype(np.uint8)
            full_pp = post_process_full(full_mask, ch)
            m3.append(full_pp.astype(np.uint8))
        masks_map[id_str] = np.stack(m3, axis=0)
    # z-consistency for bowels
    masks_map = _apply_z_consistency(masks_map)
    # Encode submission
    rows = []
    for _, r in sub.iterrows():
        id_str = r['id']; cls = r['class']
        if id_str not in masks_map:
            rows.append('')
            continue
        ch = CLASSES.index(cls)
        mm = masks_map[id_str][ch]
        rle = rle_encode(mm.astype(np.uint8)) if mm.sum()>0 else ''
        rows.append(rle)
    sub['segmentation'] = rows
    sub.to_csv('submission.csv', index=False)
    print('[INFER] Saved submission.csv')

print('[TRAIN/INFER SKELETON READY] Defaults set per expert advice. When images are mounted, call train_one_fold(f) per fold, then infer_test_and_submit().')

In [None]:
# Fallback: create empty-mask submission (safety net; replace after real inference)
import pandas as pd
sub = pd.read_csv('test.csv').copy()
sub['segmentation'] = ''
sub.to_csv('submission.csv', index=False)
print('[FALLBACK] Wrote empty submission.csv with', len(sub), 'rows')
print(sub.head())

In [None]:
# Quick fix for albumentations/albucore mismatch
import sys, subprocess
def pip(*args):
    print("> pip", " ".join(args), flush=True)
    subprocess.run([sys.executable, "-m", "pip", *args], check=True)

# Albumentations 1.4.x requires recent albucore; install compatible version explicitly
pip("install", "-c", "constraints.txt", "albucore==0.0.13", "--upgrade-strategy", "only-if-needed")
import albumentations as A
print("albumentations:", A.__version__)
import albucore, inspect
from albucore import utils as _u
print("albucore:", getattr(albucore, "__version__", "?"))
print("has preserve_channel_dim:", hasattr(_u, "preserve_channel_dim"))
print("[ALBU READY]")

In [None]:
# Hard fix: ensure albumentations==1.3.1 is active and no albucore is lingering
import sys, subprocess
def pip(*args):
    print("> pip", " ".join(args), flush=True)
    subprocess.run([sys.executable, "-m", "pip", *args], check=True)

subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "albumentations", "albucore"], check=False)
pip("install", "-c", "constraints.txt", "albumentations==1.3.1", "opencv-python-headless==4.10.0.84", "--upgrade-strategy", "only-if-needed")
import albumentations as A
print("albumentations version:", A.__version__)
print("Has HorizontalFlip:", hasattr(A, "HorizontalFlip"))
print("[ALBU FIXED]")

In [None]:
# Poller: periodically scan for mounted PNG images and stop on first hit
import time
from pathlib import Path

def poll_for_images(interval_sec=90, max_minutes=45, max_show=10):
    start = time.time()
    deadline = start + max_minutes * 60.0
    attempt = 0
    candidates = [
        Path('.'), Path('./train'), Path('./test'),
        Path('/mnt'), Path('/data'), Path('/kaggle/input'), Path('/workspace')
    ]
    patterns = ("**/case*/day*/scans/slice_*.png", "**/case*/day*/slice_*.png")
    print(f"[POLL] Starting image poll: every {interval_sec}s for up to {max_minutes} min")
    while time.time() < deadline:
        attempt += 1
        found = []
        checked = []
        t0 = time.time()
        for b in candidates:
            if not b.exists():
                continue
            checked.append(str(b))
            for pat in patterns:
                try:
                    for p in b.rglob(pat):
                        found.append(str(p))
                        if len(found) >= max_show:
                            break
                except Exception as e:
                    print(f"[POLL] Error scanning {b} with {pat}: {e}")
            if len(found) >= max_show:
                break
        dt = time.time() - t0
        ts = time.strftime('%Y-%m-%d %H:%M:%S')
        if found:
            print(f"[POLL] {ts} attempt {attempt}: FOUND {len(found)} samples (scanned {len(checked)} roots in {dt:.1f}s)")
            for p in found[:max_show]:
                print('  ', p)
            print("[POLL] Images detected. Proceed to build_cache/train.")
            return found
        else:
            remaining = max(0, int(deadline - time.time()))
            print(f"[POLL] {ts} attempt {attempt}: none found (scanned {len(checked)} roots in {dt:.1f}s). Next check in {interval_sec}s. Time left ~{remaining//60}m{remaining%60:02d}s")
            time.sleep(interval_sec)
    print("[POLL] Timeout reached. No images detected.")
    return []

print('[POLL CELL READY] Call poll_for_images(interval_sec=90, max_minutes=45) to wait for data mount.')

In [None]:
# Start polling for data mounts (non-blocking until timeout)
found_samples = poll_for_images(interval_sec=90, max_minutes=45)
print('[POLL RESULT] Found samples:', len(found_samples))

In [None]:
# OOF tuning utilities: per-class threshold/min-area grid search with HD-aware proxy and parity PP
import json, numpy as np, pandas as pd, cv2
from pathlib import Path
from scipy.ndimage import binary_fill_holes

def z_smooth_probs(ids, probs_list, window=3):
    # Moving average along slice order within each (case, day) group, sorted by slice
    from collections import defaultdict
    smoothed = [None]*len(ids)
    by_group = defaultdict(list)
    for i, id_str in enumerate(ids):
        c, d, s = parse_id(id_str)
        by_group[(c, d)].append((s, i))
    k = window
    for (c,d), lst in by_group.items():
        lst.sort(key=lambda x: x[0])
        idxs_sorted = [i for _, i in lst]
        P = np.stack([probs_list[ii] for ii in idxs_sorted], axis=0)
        if len(idxs_sorted) >= 2 and k >= 3:
            P_pad = np.pad(P, ((k//2, k//2), (0,0), (0,0), (0,0)), mode='edge')
            P_ma = np.zeros_like(P)
            for t in range(len(idxs_sorted)):
                P_ma[t] = P_pad[t:t+k].mean(axis=0)
            for j, ii in enumerate(idxs_sorted):
                smoothed[ii] = P_ma[j]
        else:
            for j, ii in enumerate(idxs_sorted):
                smoothed[ii] = P[j]
    return smoothed

def load_all_oof():
    ids_all, probs_all, bbox_all, meta_all, orig_all = [], [], [], [], []
    for f in range(5):
        p_ids = Path("oof_fold{f}_ids.npy".format(f=f))
        p_probs = Path("oof_fold{f}_probs_sq.npy".format(f=f))
        p_meta = Path("oof_fold{f}_meta.npz".format(f=f))
        if not (p_ids.exists() and p_probs.exists() and p_meta.exists()):
            continue
        ids = np.load(p_ids, allow_pickle=True).tolist()
        probs = np.load(p_probs)
        meta = np.load(p_meta)
        ids_all += ids
        probs_all.append(probs)
        bbox_all.append(meta['bbox'])
        meta_all.append(meta['meta'])
        orig_all.append(meta['orig_shape'])
    if not probs_all:
        raise FileNotFoundError('No OOF artifacts found')
    probs_all = np.concatenate(probs_all, axis=0)
    bbox_all = np.concatenate(bbox_all, axis=0)
    meta_all = np.concatenate(meta_all, axis=0)
    orig_all = np.concatenate(orig_all, axis=0)
    return ids_all, probs_all, bbox_all, meta_all, orig_all

def _pp_per_slice(mask, cls_index, min_area, caps=(3,5,1)):
    # mask: HxW uint8, returns post-processed uint8 with class caps and morphology
    from skimage.measure import label
    lbl = label(mask)
    if lbl.max() == 0:
        return mask.astype(np.uint8)
    areas = [(i, (lbl==i).sum()) for i in range(1, lbl.max()+1)]
    areas.sort(key=lambda x: x[1], reverse=True)
    keep = np.zeros_like(mask, dtype=np.uint8)
    kept = 0
    cap = caps[cls_index]
    for i, a in areas:
        if a >= min_area[cls_index]:
            keep[lbl==i] = 1
            kept += 1
            if kept >= cap:
                break
    if cls_index == 2:
        ker = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
        keep = cv2.morphologyEx(keep, cv2.MORPH_CLOSE, ker, iterations=1)
        keep = binary_fill_holes(keep.astype(bool)).astype(np.uint8)
    if cls_index == 1:
        ker = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
        keep = cv2.morphologyEx(keep, cv2.MORPH_OPEN, ker, iterations=1)
    return keep.astype(np.uint8)

def _apply_z_consistency_local(masks_map, min_area):
    # masks_map: id -> (3,H,W) uint8; drop isolated 1-slice CCs for bowels if <1.2*min_area with no ±1 support
    from collections import defaultdict
    from skimage.measure import label
    groups = defaultdict(list)
    for id_str in masks_map.keys():
        c, d, s = parse_id(id_str)
        groups[(c,d)].append((s, id_str))
    for (c,d), lst in groups.items():
        lst.sort(key=lambda x: x[0])
        ids_sorted = [k for _, k in lst]
        T = len(ids_sorted)
        for cls_index in [0,1]:
            for t, id_cur in enumerate(ids_sorted):
                cur = masks_map[id_cur][cls_index].copy()
                if cur.sum() == 0:
                    continue
                prev = masks_map[ids_sorted[t-1]][cls_index] if (t-1) >= 0 else None
                nxt = masks_map[ids_sorted[t+1]][cls_index] if (t+1) < T else None
                support = ((prev is not None and prev.any()) or (nxt is not None and nxt.any()))
                if support:
                    continue
                lbl = label(cur)
                if lbl.max() == 0:
                    continue
                keep = np.zeros_like(cur)
                thr = int(1.2 * min_area[cls_index])
                for i in range(1, lbl.max()+1):
                    a = (lbl==i).sum()
                    if a >= thr:
                        keep[lbl==i] = 1
                masks_map[id_cur][cls_index] = keep
    return masks_map

def oof_proxy_score(thr, min_area, ids, probs_sq, bbox, meta, orig_shape, classes=('large_bowel','small_bowel','stomach')):
    # Build per-id masks with inverse mapping and PP parity, apply z-consistency, then score with Dice+HD95 proxy
    masks_map = {}  # id -> (3,H,W) uint8
    for i, id_str in enumerate(ids):
        mpp = []
        for ci, cls in enumerate(classes):
            prob_sq = probs_sq[i, ci]
            full_prob = inverse_unwarp_probs(prob_sq, meta[i], bbox[i], orig_shape[i])
            pred = (full_prob >= thr[ci]).astype(np.uint8)
            pp = _pp_per_slice(pred, ci, min_area)
            mpp.append(pp.astype(np.uint8))
        masks_map[id_str] = np.stack(mpp, axis=0)
    masks_map = _apply_z_consistency_local(masks_map, min_area)
    per_example = []
    for i, id_str in enumerate(ids):
        sub = train_df[train_df['id']==id_str]
        H0, W0 = orig_shape[i]
        for ci, cls in enumerate(classes):
            predm = masks_map[id_str][ci]
            r = sub[sub['class']==cls].iloc[0]
            tgt = rle_decode(r['segmentation'] if isinstance(r['segmentation'], str) else '', (H0, W0)).astype(np.uint8)
            inter = (predm & tgt).sum()
            dice = (2*inter + 1e-6)/ (predm.sum() + tgt.sum() + 1e-6)
            hd = hd95_proxy(predm, tgt)
            score = 0.6 * dice + 0.4 * (1 - min(hd/100.0, 1.0))
            per_example.append(score)
    return float(np.mean(per_example)) if per_example else 0.0

def grid_tune_oof(z_window=3):
    # Pruned Stage-1 grid per expert advice, then Stage-2 refine
    ids, probs, bbox, meta, orig = load_all_oof()
    probs_list = [probs[i] for i in range(len(ids))]
    probs_sm = z_smooth_probs(ids, probs_list, window=z_window)
    probs_sm = np.stack(probs_sm, axis=0)
    # Stage 1 (pruned): thresholds and min_area candidates
    thr_candidates = [
        [0.45, 0.50, 0.55],  # large bowel
        [0.45, 0.50, 0.55],  # small bowel
        [0.45, 0.50, 0.55],  # stomach
    ]
    area_candidates = [
        [1000, 1400, 1800],  # large
        [800, 1000, 1200],   # small
        [700, 900],          # stomach
    ]
    best = {'score': -1, 'thr': None, 'min_area': None}
    for t0 in thr_candidates[0]:
        for t1 in thr_candidates[1]:
            for t2 in thr_candidates[2]:
                thr = [float(t0), float(t1), float(t2)]
                for a0 in area_candidates[0]:
                    for a1 in area_candidates[1]:
                        for a2 in area_candidates[2]:
                            mins = [int(a0), int(a1), int(a2)]
                            sc = oof_proxy_score(thr, mins, ids, probs_sm, bbox, meta, orig)
                            if sc > best['score'] + 1e-6:
                                best = {'score': float(sc), 'thr': thr, 'min_area': mins}
    print('[TUNE][Stage1] Best:', best)
    # Stage 2 refine around best
    bt = best['thr']; ba = best['min_area']
    thr_ref = []
    for x in bt:
        lo = max(0.0, x - 0.03); hi = min(1.0, x + 0.03)
        thr_ref.append(np.round(np.arange(lo, hi+1e-9, 0.01), 2))
    area_ref = []
    for i, a in enumerate(ba):
        lo = max(0, a - 200); hi = a + 200
        area_ref.append(np.arange(lo, hi+1e-9, 100).astype(int))
    best2 = dict(best)
    for t0 in thr_ref[0]:
        for t1 in thr_ref[1]:
            for t2 in thr_ref[2]:
                thr = [float(t0), float(t1), float(t2)]
                for a0 in area_ref[0]:
                    for a1 in area_ref[1]:
                        for a2 in area_ref[2]:
                            mins = [int(a0), int(a1), int(a2)]
                            sc = oof_proxy_score(thr, mins, ids, probs_sm, bbox, meta, orig)
                            if sc > best2['score'] + 1e-6:
                                best2 = {'score': float(sc), 'thr': thr, 'min_area': mins}
    Path('tuned_pp.json').write_text(json.dumps(best2, indent=2))
    print('[TUNE][Stage2] Best:', best2)
    return best2

print('[OOF TUNING UTILS READY] Parity with inference: stomach close+fill, small-bowel opening, z-smoothing(sorted)=3, z-consistency(edge-safe). Pruned Stage-1 + Stage-2 refine enabled.')

In [None]:
# Synthetic smoke test (optional while waiting for real data mounts)
import numpy as np, cv2, pandas as pd, torch, os, shutil, math, time, re
from pathlib import Path
import torch.nn as nn

# Ensure CLASSES is defined if prior cells weren't executed in this kernel
try:
    CLASSES
except NameError:
    CLASSES = ['large_bowel','small_bowel','stomach']

# Provide a local resolve_path fallback if not defined (for synthetic data only)
if 'resolve_path' not in globals():
    def resolve_path(id_str, roots):
        m = re.match(r'^case(\d+)_day(\d+)_slice_(\d+)$', id_str)
        if not m:
            raise FileNotFoundError(id_str)
        case, day, sl = int(m.group(1)), int(m.group(2)), int(m.group(3))
        roots = roots or [Path('train_syn')]
        for r in roots:
            p = Path(r) / f'case{case}' / f'day{day}' / 'scans' / f'slice_{sl:04d}.png'
            if p.exists():
                return p
        # return canonical path under first root even if missing (upstream will handle)
        return Path(roots[0]) / f'case{case}' / f'day{day}' / 'scans' / f'slice_{sl:04d}.png'

def make_syn_blob(H=512, W=512, center=None, radius=60):
    y,x = np.ogrid[:H, :W]
    if center is None:
        cy, cx = H//2 + np.random.randint(-30,30), W//2 + np.random.randint(-30,30)
    else:
        cy, cx = center
    r2 = (y-cy)**2 + (x-cx)**2
    return (r2 <= radius*radius).astype(np.uint8)

def mask_to_rle_fortran(mask):
    pixels = mask.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

# Local ID parser to avoid dependency on earlier cells
def _parse_id_local(s):
    m = re.match(r'^case(\d+)_day(\d+)_slice_(\d+)$', s)
    if not m:
        return (0,0,0)
    return (int(m.group(1)), int(m.group(2)), int(m.group(3)))

def build_synthetic_dataset(root='train_syn', n_cases=1, n_slices=8, H=512, W=512, classes=None):
    if classes is None:
        classes = ['large_bowel','small_bowel','stomach']
    root = Path(root);
    if root.exists():
        shutil.rmtree(root)
    ids = []
    rows = []
    for case in range(900, 900+n_cases):
        day = 0
        for s in range(1, n_slices+1):
            id_str = f"case{case}_day{day}_slice_{s:04d}"
            ids.append(id_str)
            d = root / f"case{case}" / f"day{day}" / "scans"
            d.mkdir(parents=True, exist_ok=True)
            img = (np.random.rand(H,W)*60000).astype(np.uint16)
            # Add brighter foreground ellipse to simulate body
            body = make_syn_blob(H,W, radius=min(H,W)//2 - 40).astype(bool)
            img[~body] = (img[~body]*0.05).astype(np.uint16)
            cv2.imwrite(str(d / f"slice_{s:04d}.png"), img)
            # simple masks (only some slices positive)
            for cls in classes:
                if (s % 3 == 0) and cls in ('stomach','large_bowel'):
                    mask = make_syn_blob(H,W, radius=40 if cls=='stomach' else 55)
                    rle = mask_to_rle_fortran(mask)
                else:
                    rle = ''
                rows.append({'id': id_str, 'class': cls, 'segmentation': rle})
    train_df_syn = pd.DataFrame(rows)
    df_ids = pd.DataFrame({'id': ids})
    parsed = df_ids['id'].apply(_parse_id_local)
    df_ids[['case','day','slice']] = pd.DataFrame(parsed.tolist(), index=df_ids.index)
    return train_df_syn, df_ids, Path(root)

class TinySegNet(nn.Module):
    def __init__(self, in_ch=5, num_classes=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, num_classes, kernel_size=1)
        )
    def forward(self, x):
        return self.net(x)

def smoke_test_pipeline():
    print('[SMOKE] Building synthetic dataset...')
    train_df_syn, df_ids_syn, root = build_synthetic_dataset()
    print('[SMOKE] Creating Datasets...')
    ds_tr = UWGITractDataset(df_ids_syn.iloc[:6], train_df=train_df_syn, roots=[root], mode='train', aug=get_valid_aug())
    ds_te = UWGITractDataset(df_ids_syn.iloc[:6], train_df=None, roots=[root], mode='test', aug=None)
    x, y, _ = ds_tr[0]
    print('[SMOKE] Train sample img5/mask3 shapes:', tuple(x.shape), tuple(y.shape))
    xt, _id, bbox, meta, orig = ds_te[0]
    print('[SMOKE] Test meta bbox/meta/orig:', bbox, meta, orig)
    # Model forward using a tiny local CNN (no SMP) on CPU to avoid instability
    print('[SMOKE] Model forward...')
    device = 'cpu'
    model = TinySegNet().to(device)
    with torch.no_grad():
        xb = torch.stack([x, x], dim=0).to(device)
        out = model(xb)
    print('[SMOKE] Logits shape:', tuple(out.shape))
    # Loss eval
    loss_fn = ComboLoss()
    loss = loss_fn(out.cpu(), torch.stack([y, y], dim=0).float())
    print('[SMOKE] Loss OK:', float(loss))
    # Inverse unwarp sanity
    probs = torch.sigmoid(out[:1]).cpu().numpy()[0]
    full_prob0 = inverse_unwarp_probs(probs[2], meta, bbox, orig)  # stomach channel
    print('[SMOKE] Inverse unwarp prob shape:', full_prob0.shape, 'range', (float(full_prob0.min()), float(full_prob0.max())))
    print('[SMOKE DONE]')

print('[SMOKE CELL READY] Call smoke_test_pipeline() to validate end-to-end components without real data.')

In [None]:
# Run synthetic smoke test while awaiting data mounts
import re
from pathlib import Path

# Guard: ensure parse_id exists for older function defs
if 'parse_id' not in globals():
    _id_pat = re.compile(r'^case(\d+)_day(\d+)_slice_(\d+)$')
    def parse_id(s):
        m = _id_pat.match(s)
        if not m:
            return (0,0,0)
        return tuple(int(x) for x in m.groups())

# Unconditionally define a simple resolve_path fallback for smoke (ensures availability in UWGITractDataset globals)
def resolve_path(id_str, roots):
    m = re.match(r'^case(\d+)_day(\d+)_slice_(\d+)$', id_str)
    if not m:
        raise FileNotFoundError(id_str)
    case, day, sl = int(m.group(1)), int(m.group(2)), int(m.group(3))
    roots = roots or [Path('train_syn')]
    for r in roots:
        p = Path(r) / f'case{case}' / f'day{day}' / 'scans' / f'slice_{sl:04d}.png'
        if p.exists():
            return p
    return Path(roots[0]) / f'case{case}' / f'day{day}' / 'scans' / f'slice_{sl:04d}.png'

smoke_test_pipeline()

In [None]:
# Poll only official roots (exclude synthetic) for mounted PNG images
import time
from pathlib import Path

def poll_for_official_images(interval_sec=120, max_minutes=60, max_show=10):
    start = time.time()
    deadline = start + max_minutes * 60.0
    attempt = 0
    candidates = [Path('./train'), Path('./test'), Path('/mnt'), Path('/data')]
    patterns = ("**/case*/day*/scans/slice_*.png", "**/case*/day*/slice_*.png")
    print(f"[POLL-OFF] Starting official image poll: every {interval_sec}s for up to {max_minutes} min")
    while time.time() < deadline:
        attempt += 1
        found = []
        checked = []
        t0 = time.time()
        for b in candidates:
            if not b.exists():
                continue
            checked.append(str(b))
            for pat in patterns:
                try:
                    for p in b.rglob(pat):
                        sp = str(p)
                        if 'train_syn' in sp:
                            continue
                        found.append(sp)
                        if len(found) >= max_show:
                            break
                except Exception as e:
                    print(f"[POLL-OFF] Error scanning {b} with {pat}: {e}")
            if len(found) >= max_show:
                break
        dt = time.time() - t0
        ts = time.strftime('%Y-%m-%d %H:%M:%S')
        if found:
            print(f"[POLL-OFF] {ts} attempt {attempt}: FOUND {len(found)} samples (scanned {len(checked)} roots in {dt:.1f}s)")
            for p in found[:max_show]:
                print('  ', p)
            print("[POLL-OFF] Official images detected. Proceed to build_cache/train.")
            return found
        else:
            remaining = max(0, int(deadline - time.time()))
            print(f"[POLL-OFF] {ts} attempt {attempt}: none found (scanned {len(checked)} roots in {dt:.1f}s). Next check in {interval_sec}s. Time left ~{remaining//60}m{remaining%60:02d}s")
            time.sleep(interval_sec)
    print("[POLL-OFF] Timeout reached. No official images detected.")
    return []

print('[POLL-OFF CELL READY] Running official mount poller...')
found_official = poll_for_official_images(interval_sec=120, max_minutes=60)
print('[POLL-OFF RESULT] Found samples:', len(found_official))

In [None]:
# Improved official-only poller (expanded roots per expert advice)
import time
from pathlib import Path

def poll_for_official_images_v2(interval_sec=120, max_minutes=60, max_show=12):
    start = time.time()
    deadline = start + max_minutes * 60.0
    attempt = 0
    candidates = [
        Path('./train'), Path('./test'),
        Path('/mnt'), Path('/data'), Path('/kaggle/input'),
        Path('/opt/data'), Path('/app/data'), Path('/datasets'), Path('/workspace/data')
    ]
    patterns = ("**/case*/day*/scans/slice_*.png", "**/case*/day*/slice_*.png")
    print(f"[POLL-OFF V2] Starting official image poll: every {interval_sec}s for up to {max_minutes} min")
    while time.time() < deadline:
        attempt += 1
        found = []
        checked = []
        t0 = time.time()
        for b in candidates:
            if not b.exists():
                continue
            checked.append(str(b))
            for pat in patterns:
                try:
                    for p in b.rglob(pat):
                        sp = str(p)
                        if 'train_syn' in sp:
                            continue
                        found.append(sp)
                        if len(found) >= max_show:
                            break
                except Exception as e:
                    print(f"[POLL-OFF V2] Error scanning {b} with {pat}: {e}")
            if len(found) >= max_show:
                break
        dt = time.time() - t0
        ts = time.strftime('%Y-%m-%d %H:%M:%S')
        if found:
            print(f"[POLL-OFF V2] {ts} attempt {attempt}: FOUND {len(found)} samples (scanned {len(checked)} roots in {dt:.1f}s)")
            for p in found[:max_show]:
                print('  ', p)
            print("[POLL-OFF V2] Official images detected. Proceed to build_cache/train.")
            return found
        else:
            remaining = max(0, int(deadline - time.time()))
            print(f"[POLL-OFF V2] {ts} attempt {attempt}: none found (scanned {len(checked)} roots in {dt:.1f}s). Next check in {interval_sec}s. Time left ~{remaining//60}m{remaining%60:02d}s")
            time.sleep(interval_sec)
    print("[POLL-OFF V2] Timeout reached. No official images detected.")
    return []

print('[POLL-OFF V2 CELL READY] When ready, interrupt Cell 15 and run: found_official = poll_for_official_images_v2(interval_sec=120, max_minutes=60)')

In [None]:
# Start improved official-only poller
found_official_v2 = poll_for_official_images_v2(interval_sec=120, max_minutes=60)
print('[POLL-OFF V2 RESULT] Found samples:', len(found_official_v2))

In [2]:
# Orchestration helpers: cache -> train -> tune -> infer
import pandas as pd
from pathlib import Path
import json, time, gc

def build_train_test_cache(train_out='cache/train', test_out='cache/test', log_every=200):
    print('[ORCH] Building train cache...')
    df_ids_tr = (train_df.drop_duplicates('id')[['id','case','day','slice']].reset_index(drop=True))
    t0 = time.time()
    build_cache(df_ids_tr, train_df=train_df, roots=TRAIN_IMG_ROOTS, out_dir=train_out, mode='train', log_every=log_every)
    print(f"[ORCH] Train cache done in {(time.time()-t0)/60:.1f}m")
    print('[ORCH] Building test cache...')
    df_ids_te = (test_df.drop_duplicates('id')[['id','case','day','slice']].reset_index(drop=True))
    t1 = time.time()
    build_cache(df_ids_te, train_df=None, roots=TEST_IMG_ROOTS, out_dir=test_out, mode='test', log_every=log_every)
    print(f"[ORCH] Test cache done in {(time.time()-t1)/60:.1f}m")
    gc.collect()

def train_all_folds(epochs=40, batch_size=10, device='cuda'):
    for f in range(5):
        print('='*40); print(f'[ORCH] Training fold {f}'); print('='*40)
        train_one_fold(f, epochs=epochs, batch_size=batch_size, device=device)
        gc.collect()

def tune_pp_and_save(z_window=3):
    print('[ORCH] Running OOF tuning...')
    best = grid_tune_oof(z_window=z_window)
    Path('tuned_pp.json').write_text(json.dumps(best, indent=2))
    print('[ORCH] Saved tuned_pp.json:', best)
    return best

def full_infer():
    print('[ORCH] Inference to submission.csv...')
    infer_test_and_submit()
    print('[ORCH] submission.csv written')

print('[ORCH READY] When mounts appear: 1) interrupt poller, 2) run build_train_test_cache(), 3) train_all_folds(), 4) tune_pp_and_save(), 5) full_infer().')

[ORCH READY] When mounts appear: 1) interrupt poller, 2) run build_train_test_cache(), 3) train_all_folds(), 4) tune_pp_and_save(), 5) full_infer().


In [None]:
# Programmatic data fetch via Kaggle API (optional; requires kaggle.json credentials)
import os, sys, subprocess, shutil, json, time, glob
from pathlib import Path

def _run(cmd):
    print('> ', ' '.join(cmd), flush=True)
    return subprocess.run(cmd, check=False, capture_output=True, text=True)

def try_kaggle_download():
    # Check credentials
    kaggle_json = Path.home()/'.kaggle'/'kaggle.json'
    if not kaggle_json.exists():
        print('[KAGGLE] ~/.kaggle/kaggle.json not found. Skipping Kaggle API download.')
        print('[KAGGLE] If available, place kaggle.json and chmod 600, then re-run this cell.')
        return False
    kaggle_json.chmod(0o600)
    # Ensure kaggle package
    _run([sys.executable, '-m', 'pip', 'install', 'kaggle', '--upgrade', '--quiet'])
    dl_root = Path('/kaggle/working') if Path('/kaggle').exists() else Path('kaggledl')
    dl_root.mkdir(parents=True, exist_ok=True)
    print('[KAGGLE] Download root:', dl_root)
    # Preferred 384x384 PNG mirror preserving structure
    ds_slug = 'andrewmvd/uw-madison-gi-tract-image-segmentation-2d'
    print('[KAGGLE] Downloading dataset:', ds_slug)
    res = _run(['kaggle', 'datasets', 'download', '-d', ds_slug, '-p', str(dl_root), '--unzip'])
    if res.returncode != 0:
        print('[KAGGLE] Download failed:', res.stderr.strip())
        return False
    # Detect train/test dirs within download
    train_cands = []
    test_cands = []
    for p in dl_root.rglob('train'):
        if (p.is_dir() and list(p.rglob('case*/day*/scans/slice_*.png'))[:1]):
            train_cands.append(p)
    for p in dl_root.rglob('test'):
        if (p.is_dir() and list(p.rglob('case*/day*/scans/slice_*.png'))[:1]):
            test_cands.append(p)
    # Fallback names like train_png/test_png
    for p in dl_root.rglob('train_png'):
        if (p.is_dir() and list(p.rglob('case*/day*/scans/slice_*.png'))[:1]):
            train_cands.append(p)
    for p in dl_root.rglob('test_png'):
        if (p.is_dir() and list(p.rglob('case*/day*/scans/slice_*.png'))[:1]):
            test_cands.append(p)
    train_cands = sorted(set(train_cands))
    test_cands = sorted(set(test_cands))
    if not train_cands or not test_cands:
        print('[KAGGLE] Could not find train/test directories after unzip.')
        return False
    # Prepend to resolver roots
    tr0, te0 = train_cands[0], test_cands[0]
    print('[KAGGLE] Using roots:', tr0, te0)
    if 'TRAIN_IMG_ROOTS' in globals():
        TRAIN_IMG_ROOTS.insert(0, tr0)
    if 'TEST_IMG_ROOTS' in globals():
        TEST_IMG_ROOTS.insert(0, te0)
    # Quick sanity: count PNGs
    def _count_pngs(root):
        try:
            return sum(1 for _ in root.rglob('case*/day*/scans/slice_*.png'))
        except Exception:
            return 0
    n_tr = _count_pngs(tr0)
    n_te = _count_pngs(te0)
    print(f'[KAGGLE] train PNGs: {n_tr}, test PNGs: {n_te}')
    # Spot read a few files
    samples = list(tr0.rglob('case*/day*/scans/slice_*.png'))[:3]
    print('[KAGGLE] sample files:')
    for s in samples:
        print(' ', s)
    print('[KAGGLE] Download and path injection complete.')
    return True

ok = try_kaggle_download()
print('[KAGGLE DONE] success=', ok)

In [None]:
# Auto-extract archives (zip/tgz) if present and inject roots
import os, sys, tarfile, zipfile, shutil
from pathlib import Path

def safe_extract_zip(zp, dest):
    with zipfile.ZipFile(zp, 'r') as zf:
        zf.extractall(dest)

def safe_extract_tar(tp, dest):
    mode = 'r:gz' if str(tp).endswith(('.tar.gz', '.tgz')) else 'r:'
    with tarfile.open(tp, mode) as tf:
        def is_within_directory(directory, target):
            abs_directory = os.path.abspath(directory)
            abs_target = os.path.abspath(target)
            return os.path.commonpath([abs_directory]) == os.path.commonpath([abs_directory, abs_target])
        for m in tf.getmembers():
            target = os.path.join(dest, m.name)
            if not is_within_directory(dest, target):
                continue
        tf.extractall(dest)

def scan_and_extract_archives():
    roots = [Path('/kaggle/input'), Path('/mnt'), Path('/data'), Path('.')]
    ex_root = Path('external_data'); ex_root.mkdir(exist_ok=True, parents=True)
    found_archives = []
    for r in roots:
        if not r.exists():
            continue
        for p in r.rglob('*'):
            s = str(p)
            if p.is_file() and (s.endswith('.zip') or s.endswith('.tar.gz') or s.endswith('.tgz')):
                found_archives.append(p)
    if not found_archives:
        print('[EXTRACT] No archives found under candidates')
        return False
    print(f'[EXTRACT] Found {len(found_archives)} archives')
    for a in found_archives:
        out = ex_root / a.stem.replace('.tar','')
        if out.exists() and any(out.iterdir()):
            print('[EXTRACT] Skip existing:', out)
            continue
        out.mkdir(parents=True, exist_ok=True)
        try:
            if str(a).endswith('.zip'):
                print('[EXTRACT] Unzipping', a, '->', out)
                safe_extract_zip(a, out)
            else:
                print('[EXTRACT] Untarring', a, '->', out)
                safe_extract_tar(a, out)
        except Exception as e:
            print('[EXTRACT] Failed for', a, e)
    # After extraction, search for train/test roots and prepend
    train_cands, test_cands = [], []
    for p in ex_root.rglob('train'):
        if p.is_dir() and list(p.rglob('case*/day*/scans/slice_*.png'))[:1]:
            train_cands.append(p)
    for p in ex_root.rglob('test'):
        if p.is_dir() and list(p.rglob('case*/day*/scans/slice_*.png'))[:1]:
            test_cands.append(p)
    for p in ex_root.rglob('train_png'):
        if p.is_dir() and list(p.rglob('case*/day*/scans/slice_*.png'))[:1]:
            train_cands.append(p)
    for p in ex_root.rglob('test_png'):
        if p.is_dir() and list(p.rglob('case*/day*/scans/slice_*.png'))[:1]:
            test_cands.append(p)
    train_cands = sorted(set(train_cands)); test_cands = sorted(set(test_cands))
    if train_cands and test_cands:
        tr0, te0 = train_cands[0], test_cands[0]
        print('[EXTRACT] Using roots:', tr0, te0)
        if 'TRAIN_IMG_ROOTS' in globals():
            TRAIN_IMG_ROOTS.insert(0, tr0)
        if 'TEST_IMG_ROOTS' in globals():
            TEST_IMG_ROOTS.insert(0, te0)
        # Quick counts
        def _count_pngs(root):
            try:
                return sum(1 for _ in root.rglob('case*/day*/scans/slice_*.png'))
            except Exception:
                return 0
        print(f"[EXTRACT] Counts train={_count_pngs(tr0)} test={_count_pngs(te0)}")
        return True
    else:
        print('[EXTRACT] No valid train/test structure found post-extraction')
        return False

ok = scan_and_extract_archives()
print('[EXTRACT DONE] success=', ok)

In [None]:
# Inspect extracted archive structure to locate train/test PNGs
from pathlib import Path
import os, itertools

base = Path('external_data')
print('[INSPECT] Listing immediate subdirs under external_data:')
for p in base.iterdir():
    if p.is_dir():
        print(' -', p, '(', sum(1 for _ in p.iterdir()), 'items)')

root = base / 'uw-madison-gi-tract-image-segmentation'
print('[INSPECT] Root exists:', root.exists(), root)
if root.exists():
    print('[INSPECT] Top-level entries:')
    for p in root.iterdir():
        print('   ', p.name, '(dir)' if p.is_dir() else '(file)')
    # Try common expected structures
    candidates = [
        root / 'train',
        root / 'test',
        root / 'train_png',
        root / 'test_png',
    ]
    for c in candidates:
        print('[INSPECT] Candidate', c, 'exists=', c.exists())
        if c.exists():
            n_png = sum(1 for _ in c.rglob('slice_*.png'))
            print('   -> PNG count:', n_png)
    # Fallback: search for any slice_*.png anywhere under root
    any_pngs = list(itertools.islice(root.rglob('slice_*.png'), 10))
    print('[INSPECT] Found any slice_*.png samples (up to 10):', len(any_pngs))
    for p in any_pngs:
        print('   ', p)
else:
    print('[INSPECT] root not found; check extraction path names.')

In [8]:
# Kick off caching for train and test using discovered roots
print('[RUN] build_train_test_cache start')
build_train_test_cache(train_out='cache/train', test_out='cache/test', log_every=300)
print('[RUN] build_train_test_cache done')

[RUN] build_train_test_cache start
[ORCH] Building train cache...
[CACHE] (0/31696) skip exists cache/train/case77_day20_slice_0001.npz
[CACHE] (300/31696) skip exists cache/train/case77_day18_slice_0013.npz
[CACHE] (600/31696) skip exists cache/train/case133_day25_slice_0025.npz
[CACHE] (900/31696) skip exists cache/train/case129_day20_slice_0037.npz
[CACHE] (1200/31696) skip exists cache/train/case129_day24_slice_0049.npz
[CACHE] (1500/31696) skip exists cache/train/case130_day0_slice_0061.npz
[CACHE] (1800/31696) skip exists cache/train/case130_day22_slice_0073.npz
[CACHE] (2100/31696) skip exists cache/train/case88_day36_slice_0085.npz
[CACHE] (2400/31696) skip exists cache/train/case44_day0_slice_0097.npz
[CACHE] (2700/31696) skip exists cache/train/case44_day19_slice_0109.npz
[CACHE] (3000/31696) skip exists cache/train/case145_day0_slice_0121.npz
[CACHE] (3300/31696) skip exists cache/train/case15_day20_slice_0133.npz
[CACHE] (3600/31696) skip exists cache/train/case42_day17_sli

[CACHE] (6000/31696) skip exists cache/train/case65_day28_slice_0097.npz
[CACHE] (6300/31696) skip exists cache/train/case65_day0_slice_0109.npz
[CACHE] (6600/31696) skip exists cache/train/case122_day18_slice_0121.npz
[CACHE] (6900/31696) skip exists cache/train/case122_day0_slice_0133.npz
[CACHE] (7200/31696) skip exists cache/train/case125_day0_slice_0001.npz
[CACHE] (7500/31696) skip exists cache/train/case117_day0_slice_0077.npz
[CACHE] (7800/31696) skip exists cache/train/case140_day10_slice_0073.npz
[CACHE] (8100/31696) skip exists cache/train/case134_day22_slice_0085.npz
[CACHE] (8400/31696) skip exists cache/train/case134_day21_slice_0097.npz
[CACHE] (8700/31696) skip exists cache/train/case9_day20_slice_0109.npz
[CACHE] (9000/31696) skip exists cache/train/case113_day19_slice_0121.npz
[CACHE] (9300/31696) skip exists cache/train/case113_day16_slice_0133.npz
[CACHE] (9600/31696) skip exists cache/train/case90_day29_slice_0001.npz
[CACHE] (9900/31696) skip exists cache/train/ca

[CACHE] (12300/31696) skip exists cache/train/case154_day16_slice_0109.npz
[CACHE] (12600/31696) skip exists cache/train/case135_day0_slice_0121.npz
[CACHE] (12900/31696) skip exists cache/train/case84_day23_slice_0133.npz
[CACHE] (13200/31696) skip exists cache/train/case147_day0_slice_0001.npz
[CACHE] (13500/31696) skip exists cache/train/case147_day14_slice_0013.npz
[CACHE] (13800/31696) skip exists cache/train/case101_day20_slice_0025.npz
[CACHE] (14100/31696) skip exists cache/train/case101_day26_slice_0037.npz
[CACHE] (14400/31696) skip exists cache/train/case7_day19_slice_0049.npz
[CACHE] (14700/31696) skip exists cache/train/case119_day0_slice_0061.npz
[CACHE] (15000/31696) skip exists cache/train/case119_day19_slice_0073.npz
[CACHE] (15300/31696) skip exists cache/train/case32_day18_slice_0085.npz
[CACHE] (15600/31696) skip exists cache/train/case32_day0_slice_0097.npz
[CACHE] (15900/31696) skip exists cache/train/case24_day0_slice_0109.npz
[CACHE] (16200/31696) skip exists ca

[CACHE] 16500/31696 done in 0.2 min


[CACHE] 16800/31696 done in 0.6 min


[CACHE] 17100/31696 done in 1.1 min


[CACHE] 17400/31696 done in 1.5 min


[CACHE] 17700/31696 done in 2.0 min


[CACHE] 18000/31696 done in 2.4 min


[CACHE] 18300/31696 done in 2.8 min


[CACHE] 18600/31696 done in 3.3 min


[CACHE] 18900/31696 done in 3.8 min


[CACHE] 19200/31696 done in 4.3 min


[CACHE] 19500/31696 done in 4.7 min


[CACHE] 19800/31696 done in 5.2 min


[CACHE] 20100/31696 done in 5.6 min


[CACHE] 20400/31696 done in 6.1 min


[CACHE] 20700/31696 done in 6.5 min


[CACHE] 21000/31696 done in 7.0 min


[CACHE] 21300/31696 done in 7.4 min


[CACHE] 21600/31696 done in 7.8 min


[CACHE] 21900/31696 done in 8.3 min


[CACHE] 22200/31696 done in 8.7 min


[CACHE] 22500/31696 done in 9.2 min


[CACHE] 22800/31696 done in 9.6 min


[CACHE] 23100/31696 done in 10.1 min


[CACHE] 23400/31696 done in 10.5 min


[CACHE] 23700/31696 done in 11.0 min


[CACHE] 24000/31696 done in 11.4 min


[CACHE] 24300/31696 done in 11.9 min


[CACHE] 24600/31696 done in 12.3 min


[CACHE] 24900/31696 done in 12.8 min


[CACHE] 25200/31696 done in 13.3 min


[CACHE] 25500/31696 done in 13.7 min


[CACHE] 25800/31696 done in 14.2 min


[CACHE] 26100/31696 done in 14.7 min


[CACHE] 26400/31696 done in 15.1 min


[CACHE] 26700/31696 done in 15.6 min


[CACHE] 27000/31696 done in 16.0 min


[CACHE] 27300/31696 done in 16.5 min


[CACHE] 27600/31696 done in 16.9 min


[CACHE] 27900/31696 done in 17.4 min


[CACHE] 28200/31696 done in 17.8 min


[CACHE] 28500/31696 done in 18.3 min


[CACHE] 28800/31696 done in 18.8 min


[CACHE] 29100/31696 done in 19.3 min


[CACHE] 29400/31696 done in 19.7 min


[CACHE] 29700/31696 done in 20.2 min


[CACHE] 30000/31696 done in 20.7 min


[CACHE] 30300/31696 done in 21.1 min


[CACHE] 30600/31696 done in 21.6 min


[CACHE] 30900/31696 done in 22.0 min


[CACHE] 31200/31696 done in 22.5 min


[CACHE] 31500/31696 done in 22.9 min


[CACHE] Done: cache/train
[ORCH] Train cache done in 23.2m
[ORCH] Building test cache...


[CACHE] 300/6800 done in 0.3 min


[CACHE] 600/6800 done in 0.6 min


[CACHE] 900/6800 done in 0.9 min


[CACHE] 1200/6800 done in 1.2 min


[CACHE] 1500/6800 done in 1.5 min


[CACHE] 1800/6800 done in 1.9 min


[CACHE] 2100/6800 done in 2.2 min


[CACHE] 2400/6800 done in 2.5 min


[CACHE] 2700/6800 done in 2.8 min


[CACHE] 3000/6800 done in 3.1 min


[CACHE] 3300/6800 done in 3.4 min


[CACHE] 3600/6800 done in 3.7 min


[CACHE] 3900/6800 done in 4.0 min


[CACHE] 4200/6800 done in 4.3 min


[CACHE] 4500/6800 done in 4.7 min


[CACHE] 4800/6800 done in 5.0 min


[CACHE] 5100/6800 done in 5.3 min


[CACHE] 5400/6800 done in 5.6 min


[CACHE] 5700/6800 done in 5.9 min


[CACHE] 6000/6800 done in 6.2 min


[CACHE] 6300/6800 done in 6.5 min


[CACHE] 6600/6800 done in 6.8 min


[CACHE] Done: cache/test
[ORCH] Test cache done in 7.0m
[RUN] build_train_test_cache done


In [None]:
# Start full 5-fold training now (384, bs=10); cache not required
print('[RUN] Starting 5-fold training @384, bs=10, epochs=40')
train_all_folds(epochs=40, batch_size=10, device='cuda')
print('[RUN] Training complete')

In [None]:
# Sanity-run a single fold with 1 epoch to confirm training stability before full 5-fold
import gc, torch
print('[RUN] Sanity training fold 0 for 1 epoch @384, bs=4, workers=0')
gc.collect()
try:
    torch.cuda.empty_cache()
except Exception:
    pass
train_one_fold(0, epochs=1, batch_size=4, num_workers=0, device='cuda')
print('[RUN] Sanity fold 0 done')

In [None]:
# Diagnostic: isolate SMP model build and single forward to find kernel-death root cause
import torch, gc, time
from torch.utils.data import DataLoader
print('[DIAG] CUDA is_available:', torch.cuda.is_available(), 'Device:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu')
try:
    t0 = time.time()
    # Force garbage collection and empty cache before heavy import
    gc.collect();
    torch.cuda.empty_cache()
    # Attempt lazy import + model init
    from segmentation_models_pytorch import UnetPlusPlus
    print('[DIAG] SMP imported OK in', f"{time.time()-t0:.2f}s")
    model = UnetPlusPlus(encoder_name='tf_efficientnet_b3', encoder_weights=None, in_channels=5, classes=3, activation=None).cuda()
    n_params = sum(p.numel() for p in model.parameters())
    print('[DIAG] Model built. Params:', n_params)
    # Build a tiny loader (num_workers=0) and run 1 forward pass
    folds = pd.read_csv('folds.csv')
    va_ids = folds[folds['fold']==0][['id','case','day','slice']].reset_index(drop=True).iloc[:4]
    ds = UWGITractDataset(va_ids, train_df=train_df, roots=TRAIN_IMG_ROOTS, mode='train', aug=get_valid_aug())
    dl = DataLoader(ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
    xb, yb, _ = next(iter(dl))
    xb = xb.cuda(non_blocking=True)
    with torch.cuda.amp.autocast(enabled=True):
        yhat = model(xb)
    print('[DIAG] Forward OK. logits shape:', tuple(yhat.shape))
    del model, xb, yhat; gc.collect(); torch.cuda.empty_cache()
    print('[DIAG DONE]')
except Exception as e:
    print('[DIAG] Exception:', repr(e))

In [None]:
# Lightweight fallback UNet (no SMP/timm) for stability
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class TinyUNet(nn.Module):
    def __init__(self, in_ch=5, num_classes=3, base=32):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ConvBlock(base, base*2)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = ConvBlock(base*2, base*4)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = ConvBlock(base*4, base*8)
        self.pool4 = nn.MaxPool2d(2)
        self.bottleneck = ConvBlock(base*8, base*16)
        self.up4 = nn.ConvTranspose2d(base*16, base*8, 2, stride=2)
        self.dec4 = ConvBlock(base*16, base*8)
        self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.dec3 = ConvBlock(base*8, base*4)
        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.dec2 = ConvBlock(base*4, base*2)
        self.up1 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.dec1 = ConvBlock(base*2, base)
        self.head = nn.Conv2d(base, num_classes, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))
        b = self.bottleneck(self.pool4(e4))
        d4 = self.up4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)
        d3 = self.up3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        return self.head(d1)

print('[FALLBACK MODEL READY] TinyUNet(in_ch=5, classes=3) defined. Modify build_model_b3 to use TinyUNet if SMP is unstable.')

In [None]:
# Diagnostic: inspect first training batch shapes and aligned shapes + forward/loss check
import torch
print('[DIAG-BATCH] Building loaders for fold 0 ...', flush=True)
train_dl, valid_dl, _ = make_loaders(0, batch_size=2, num_workers=2)
batch = next(iter(train_dl))
imgs, masks, ids = batch
print('[DIAG-BATCH] Raw shapes imgs/masks:', tuple(imgs.shape), tuple(masks.shape))

# Local align helper (in case global not defined)
def _align_logits_targets_local(logits, masks):
    if logits.dim()==4 and logits.shape[1] not in (1,3) and logits.shape[-1] in (1,3):
        logits = logits.permute(0,3,1,2).contiguous()
    if masks.dim()==4 and masks.shape[1] not in (1,3) and masks.shape[-1] in (1,3):
        masks = masks.permute(0,3,1,2).contiguous()
    if logits.shape != masks.shape:
        if logits.dim()==4 and masks.dim()==4 and logits.shape[-1]==3 and masks.shape[1]==3:
            logits = logits.permute(0,3,1,2).contiguous()
        elif logits.dim()==4 and masks.dim()==4 and masks.shape[-1]==3 and logits.shape[1]==3:
            masks = masks.permute(0,3,1,2).contiguous()
    return logits, masks

# Dummy logits in NCHW
logits_dummy = torch.zeros((imgs.size(0), 3, imgs.size(-2), imgs.size(-1)))
try:
    logits_a, masks_a = _align_logits_targets_local(logits_dummy, masks)
    print('[DIAG-BATCH] After local align -> logits/masks:', tuple(logits_a.shape), tuple(masks_a.shape))
except Exception as e:
    print('[DIAG-BATCH] Align error:', repr(e))

# Model forward + loss check
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = build_model_b3(device=device)
imgs_dev = imgs.to(device)
masks_dev = masks.to(device)
with torch.no_grad():
    logits = model(imgs_dev)
print('[DIAG-BATCH] Model logits shape:', tuple(logits.shape))
try:
    from math import isnan
    lf = ComboLoss()
    lg, mg = _align_logits_targets_local(logits, masks_dev)
    print('[DIAG-BATCH] Pre-loss shapes lg/mg:', tuple(lg.shape), tuple(mg.shape))
    loss = lf(lg, mg)
    print('[DIAG-BATCH] Loss OK:', float(loss))
except Exception as e:
    print('[DIAG-BATCH] Loss error:', repr(e))
print('[DIAG-BATCH] Done.')

In [None]:
# DIAG: single-batch train step (no AMP) to isolate kernel death
import torch, gc, time, pandas as pd
from torch.utils.data import DataLoader
print('[DIAG-TRAINSTEP] Start')
gc.collect()
try:
    torch.cuda.empty_cache()
except Exception:
    pass
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('[DIAG-TRAINSTEP] CUDA avail:', torch.cuda.is_available(), 'device:', device)
try:
    # Build a tiny dataset/loader directly (avoid sampler dependency)
    folds = pd.read_csv('folds.csv')
    tr_ids = folds[folds['fold']!=0][['id','case','day','slice']].reset_index(drop=True).iloc[:8]
    train_ds = UWGITractDataset(tr_ids, train_df=train_df, roots=TRAIN_IMG_ROOTS, mode='train', aug=get_valid_aug())
    train_dl = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0, pin_memory=True)
    batch = next(iter(train_dl))
    imgs, masks, ids = batch
    imgs = imgs.to(device, non_blocking=True)
    masks = masks.to(device, non_blocking=True)
    # Model + loss + opt
    model = build_model_b3(device=device)
    loss_fn = ComboLoss(bce_weight=0.5, tv_weight=0.5, tv_alpha=0.7, tv_beta=0.3, class_weights=(1.1,1.45,1.0))
    opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    # One train step, AMP disabled
    model.train()
    t0 = time.time()
    logits = model(imgs)
    logits, masks = _align_logits_targets(logits, masks)
    loss = loss_fn(logits, masks)
    print('[DIAG-TRAINSTEP] fwd ok, loss=', float(loss))
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step(); opt.zero_grad(set_to_none=True)
    print('[DIAG-TRAINSTEP] backward/step ok, elapsed', f"{time.time()-t0:.2f}s")
    del model, imgs, masks, logits, loss; gc.collect();
    try:
        torch.cuda.empty_cache()
    except Exception:
        pass
    print('[DIAG-TRAINSTEP] Done')
except Exception as e:
    print('[DIAG-TRAINSTEP] Exception:', repr(e))

In [None]:
# Override: use TinyUNet for stability (avoids SMP-related kernel deaths)
import gc, torch

def build_model_b3(device='cuda'):
    gc.collect()
    try:
        torch.cuda.empty_cache()
    except Exception:
        pass
    # TinyUNet defined in Cell 26
    model = TinyUNet(in_ch=5, num_classes=3, base=32)
    return model.to(device)

print('[MODEL OVERRIDE] build_model_b3 -> TinyUNet(in_ch=5, classes=3)')