# ROI‑Supervised Deep Learning for HER2 Breast Cancer Classification

## Abstract


## Methods overview
- Phase 1 — ROI‑supervised classification: ResNet‑50 backbone; patch extraction strictly within annotated ROIs; wandb metrics and artifacts.
- Phase 2 — MIL fine‑tuning: attention‑based MIL initialized from Phase 1; bag‑level training; attention/Grad‑CAM visualizations.
- Phase 3 — ROI‑derived segmentation: U‑Net with ResNet backbone; ROI polygons converted to masks; Dice/IoU and pixel accuracy.

## Reproducible setup and execution
- Prerequisites: Python 3.8+, PyTorch (CUDA‑enabled if available), GPU ≥8 GB VRAM recommended, 16–32 GB RAM, SVS slides with XML annotations.
- Execution protocol: restart kernel, run cells sequentially, verify data paths and CUDA availability in config, wandb logging is enabled by default.
- Expected outputs: model checkpoints and configs, wandb runs with metrics/artifacts, Grad‑CAM overlays, summary tables/figures.

## Experiment tracking (wandb)
- Metrics: loss (train/val), LR, early‑stopping; accuracy, precision/recall/F1, ROC‑AUC; confusion matrix; segmentation Dice/IoU.
- Artifacts and lineage: versioned checkpoints (by phase/fold), exported configs, selected overlays; dataset fingerprints and code version when available.
- Multi‑phase consistency: links Phase 1→2→3 to analyze feature transfer and ROI utilization; includes attention weights and Grad‑CAM when generated.

## Notes
- Data privacy/IRB compliance is required for clinical datasets.
- Seeds are fixed where supported; minor nondeterminism can remain on CUDA/cuDNN.
- Report hardware specifications and preserve artifacts to facilitate replication.

In [1]:
"""
Environment setup and dependencies
==================================

Configures the runtime environment and imports required libraries for the HER2 classification pipeline.
"""

import os
import sys
import warnings
from pathlib import Path

# Environment configuration for stable, reproducible runs
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'  # OpenMP compatibility
os.environ['OMP_NUM_THREADS'] = '1'          # Threading optimization
warnings.filterwarnings('ignore')             # Suppress non-critical warnings

# Python 3.13 compatibility shim
import collections
import collections.abc
if not hasattr(collections, 'Callable'):
    collections.Callable = collections.abc.Callable

# Core scientific computing libraries
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from IPython.display import clear_output, display

# Jupyter notebook configuration
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Configure matplotlib for publication-quality figures
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.size'] = 12
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['legend.fontsize'] = 10

# Add project root to Python path
sys.path.append('.')

# Import pipeline modules
from scripts.train import (Config, train_phase1, train_phase2, train_segmentation, 
                          explain_predictions, optimize_hyperparameters)
from scripts.augmentations import (get_classification_transforms, get_segmentation_transforms, 
                                 AugmentationConfig)
# NOTE: Configuration factories will be defined in this notebook (no import from scripts.config)

# Verify CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    print(f"CUDA available: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("CUDA not available - using CPU")

CUDA available: NVIDIA GeForce RTX 4060 Laptop GPU
Memory: 8.6 GB


In [2]:
"""
Memory Management and Kernel Cleanup
====================================

Ensures clean execution environment by clearing previous variables and optimizing memory usage.
This is particularly important for reproducible results in deep learning experiments.
"""

import gc

# Print messages only once per kernel session
SENTINEL = "_CLEANUP_MESSAGES_PRINTED"
first_time = not globals().get(SENTINEL, False)

is_torch = 'torch' in globals()
is_gpu = is_torch and hasattr(globals()['torch'], 'cuda') and globals()['torch'].cuda.is_available() if is_torch else False

if first_time:
    print("Cleaning up environment")

# Clear potential variables from previous runs
cleanup_variables = [
    'att_weights', 'attention_weights', 'cam', 'grayscale_cam', 
    'img_mil', 'img_tensor', 'model', 'outputs', 'predicted',
    'target_layer', 'target_layers', 'test_image', 'wrapper_model'
]

cleaned_count = 0
for var_name in cleanup_variables:
    if var_name in globals():
        del globals()[var_name]
        cleaned_count += 1

# Force garbage collection
gc.collect()

# Clear GPU memory cache if available (always execute, only print once)
if is_gpu:
    globals()['torch'].cuda.empty_cache()

if first_time:
    if is_gpu:
        gpu_memory = globals()['torch'].cuda.get_device_properties(0).total_memory / 1e9
        print(f"✅ GPU memory cache cleared ({gpu_memory:.1f} GB available)")
    else:
        print("No GPU available or torch not imported; skipping GPU cache clear.")
    
    print(f"Environment initialized successfully")
    print(f"   Variables cleaned: {cleaned_count}")
    print(f"   Memory optimization: Complete")
    print("Ready for reproducible experiment execution")
    
    # Set sentinel so subsequent runs are silent
    globals()[SENTINEL] = True

Cleaning up environment
✅ GPU memory cache cleared (8.6 GB available)
Environment initialized successfully
   Variables cleaned: 0
   Memory optimization: Complete
Ready for reproducible experiment execution


In [3]:
"""
Notebook-scoped configuration (no external import)
=================================================

Defines dataclasses for pipeline configuration and three presets: quick_test, default, production.
These are adapted by training via coerce_to_train_config.
"""

from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

@dataclass
class DataConfigNB:
    data_dir: str = "data"
    annotations_dir: str = "Annotations"
    output_dir: str = "output"
    checkpoints_dir: str = "checkpoints"
    patch_size: int = 512

    def __post_init__(self):
        Path(self.output_dir).mkdir(parents=True, exist_ok=True)
        Path(self.checkpoints_dir).mkdir(parents=True, exist_ok=True)

@dataclass
class ModelConfigNB:
    num_classes: int = 2
    batch_size: int = 16
    learning_rate: float = 1e-4
    num_epochs: int = 50
    backbone: str = "resnet50"

@dataclass
class TrainingConfigNB:
    device: str = "auto"  # "cuda" | "cpu" | "auto"
    cross_validation_folds: int = 5
    num_workers: int = 0
    # New: patch sampling controls per phase
    patches_per_slide_phase1: int = 100
    patches_per_slide_phase2: int = 200
    patches_per_slide_seg: int = 50
    # Optional fast-mode overrides
    fast_patches_per_slide_phase1: Optional[int] = 32
    fast_patches_per_slide_phase2: Optional[int] = 64
    fast_patches_per_slide_seg: Optional[int] = 16

@dataclass
class AugmentationConfigNB:
    elastic_deform_prob: float = 0.3
    stain_augment_prob: float = 0.5
    use_otsu_tissue_mask: bool = True

@dataclass
class PipelineConfigNB:
    data: DataConfigNB = field(default_factory=DataConfigNB)
    model: ModelConfigNB = field(default_factory=ModelConfigNB)
    training: TrainingConfigNB = field(default_factory=TrainingConfigNB)
    augment: AugmentationConfigNB = field(default_factory=AugmentationConfigNB)

def create_quick_test_config():
    cfg = PipelineConfigNB()
    cfg.model.batch_size = 4
    cfg.model.num_epochs = 5
    cfg.model.learning_rate = 3e-4
    cfg.training.num_workers = 0  # Windows-friendly
    cfg.data.patch_size = 256
    # Lighter sampling for quick tests
    cfg.training.patches_per_slide_phase1 = 48
    cfg.training.patches_per_slide_phase2 = 96
    cfg.training.patches_per_slide_seg = 24
    return cfg

def create_default_config():
    cfg = PipelineConfigNB()
    cfg.model.batch_size = 8
    cfg.model.num_epochs = 30
    cfg.model.learning_rate = 1e-4
    cfg.training.num_workers = 0
    cfg.data.patch_size = 512
    # Balanced defaults
    cfg.training.patches_per_slide_phase1 = 100
    cfg.training.patches_per_slide_phase2 = 200
    cfg.training.patches_per_slide_seg = 50
    return cfg

def create_production_config():
    cfg = PipelineConfigNB()
    cfg.model.batch_size = 1
    cfg.model.num_epochs = 100
    cfg.model.learning_rate = 5e-5
    cfg.training.num_workers = 4
    cfg.data.patch_size = 512
    # Heavier sampling for robust training
    cfg.training.patches_per_slide_phase1 = 256
    cfg.training.patches_per_slide_phase2 = 512
    cfg.training.patches_per_slide_seg = 128
    return cfg

In [4]:
"""
Fast mode toggle (notebook-friendly)
===================================

Enable this to speed up training in the notebook with fewer batches per epoch and smaller workloads.
"""

# Toggle fast mode for quick iteration
POC_FAST_MODE = True   # Set to False for full runs

# Optional batch caps per epoch when fast mode is enabled
FAST_TRAIN_CAP = 4
FAST_VAL_CAP = 2

print("Fast mode:", POC_FAST_MODE)
print("Batch caps per epoch (train/val):", (FAST_TRAIN_CAP if POC_FAST_MODE else None), "/", (FAST_VAL_CAP if POC_FAST_MODE else None))

Fast mode: True
Batch caps per epoch (train/val): 4 / 2


In [5]:
"""
Dataset metadata generation
===========================

Creates a CSV linking SVS whole-slide images with XML annotations. Enables ROI-supervised training by mapping pathologist annotations to slides.
"""

print("Generating dataset metadata")
print("=" * 32)

# Define data paths
data_dir = Path("data")
svs_dir = data_dir / "SVS"
annotations_dir = data_dir / Path("Annotations")
metadata_file = data_dir / "metadata.csv"

print(f"Data directory: {data_dir}")
print(f"SVS files: {svs_dir}")
print(f"Annotations: {annotations_dir}")

# Validate directory structure
if not svs_dir.exists():
    print(f"Warning: SVS directory not found at {svs_dir}")
if not annotations_dir.exists():
    print(f"Warning: Annotations directory not found at {annotations_dir}")

# Collect slide information
slides_data = []
her2_neg_count = 0
her2_pos_count = 0
annotated_count = 0

# Process all SVS files
svs_files = list(svs_dir.glob("*.svs")) if svs_dir.exists() else []
print(f"\nProcessing {len(svs_files)} SVS files...")

for svs_file in svs_files:
    slide_name = svs_file.stem
    
    # Extract HER2 status from filename
    if slide_name.startswith("Her2Neg"):
        her2_status = 0
        label = "HER2-"
        her2_neg_count += 1
    elif slide_name.startswith("Her2Pos"):
        her2_status = 1
        label = "HER2+"
        her2_pos_count += 1
    else:
        print(f"Unrecognized slide naming: {slide_name}")
        continue
    
    # Check for corresponding annotation
    annotation_file = annotations_dir / f"{slide_name}.xml"
    has_annotation = annotation_file.exists()
    if has_annotation:
        annotated_count += 1
    
    slide_info = {
        'slide_name': slide_name,
        'slide_path': str(svs_file),
        'her2_status': her2_status,
        'label': label,
        'has_annotation': has_annotation,
        'annotation_path': str(annotation_file) if has_annotation else None
    }
    slides_data.append(slide_info)

# Create structured DataFrame
df = pd.DataFrame(slides_data)
df = df.sort_values('slide_name').reset_index(drop=True)

# Save metadata
df.to_csv(metadata_file, index=False)

# Display dataset statistics
print(f"\nDataset statistics:")
print(f"  Total slides: {len(df)}")
print(f"  HER2- cases: {her2_neg_count}")
print(f"  HER2+ cases: {her2_pos_count}")
print(f"  Annotated slides: {annotated_count}")
print(f"  Annotation coverage: {(annotated_count/len(df)*100):.1f}%" if len(df) > 0 else "  Annotation coverage: 0%")

# Display sample data
if len(df) > 0:
    print(f"\nSample metadata (first 3 rows):")
    print(df[['slide_name', 'label', 'has_annotation']].head(3).to_string(index=False))

# Annotation distribution analysis
if annotated_count > 0:
    annotated_df = df[df['has_annotation'] == True]
    roi_neg = len(annotated_df[annotated_df['her2_status'] == 0])
    roi_pos = len(annotated_df[annotated_df['her2_status'] == 1])
    
    print(f"\nROI annotation distribution:")
    print(f"  HER2- with ROIs: {roi_neg}")
    print(f"  HER2+ with ROIs: {roi_pos}")
    print(f"  Class balance: {(roi_pos/(roi_neg+roi_pos)*100):.1f}% HER2+" if (roi_neg+roi_pos) > 0 else "  Class balance: No data")

print(f"\nMetadata saved: {metadata_file}")
print("Dataset preparation complete")

Generating dataset metadata
Data directory: data
SVS files: data\SVS
Annotations: data\Annotations

Processing 192 SVS files...

Dataset statistics:
  Total slides: 192
  HER2- cases: 99
  HER2+ cases: 93
  Annotated slides: 187
  Annotation coverage: 97.4%

Sample metadata (first 3 rows):
     slide_name label  has_annotation
Her2Neg_Case_01 HER2-            True
Her2Neg_Case_02 HER2-            True
Her2Neg_Case_03 HER2-            True

ROI annotation distribution:
  HER2- with ROIs: 97
  HER2+ with ROIs: 90
  Class balance: 48.1% HER2+

Metadata saved: data\metadata.csv
Dataset preparation complete


In [6]:
# Diagnostics: preflight checks (optional to run before training)
print("Running diagnostics (preflight checks)…")
from scripts import diagnostics as diag
import importlib

# Reload to pick up recent changes to diagnostics module
try:
    diag = importlib.reload(diag)
except Exception as e:
    print(f"[diagnostics] reload note: {e}")

# Resolve a config object robustly even if prior cells weren't run
try:
    cfg = legacy_config  # from earlier cells, if present
except NameError:
    try:
        cfg = config  # pipeline config from earlier cells
    except NameError:
        from scripts.train import Config as TrainConfig
        cfg = TrainConfig()
        print("[diagnostics] No prior config found; using default TrainConfig()")

# Persist to legacy_config so downstream cells can rely on it
legacy_config = cfg

try:
    # Enable a CPU-only compile probe on Windows to surface compile_probe results
    results = diag.run_all_checks(legacy_config, try_compile_probe_on_windows=True)
    diag._print_human(results)
    from pathlib import Path
    import json
    out_dir = Path("output") / "logs"
    out_dir.mkdir(parents=True, exist_ok=True)
    with open(out_dir / "diagnostics.json", "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, default=str)
    print(f"Saved diagnostics to {out_dir / 'diagnostics.json'}")
except Exception as e:
    print(f"Diagnostics failed: {e}")

Running diagnostics (preflight checks)…
[diagnostics] No prior config found; using default TrainConfig()


Diagnostics: 100%|██████████| 8/8 [00:26<00:00,  3.36s/it]

Diagnostics Summary
Python: 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:20:11) [MSC v.1938 64 bit (AMD64)]
PyTorch: 2.8.0+cu128 | CUDA: Yes
  GPU: NVIDIA GeForce RTX 4060 Laptop GPU | VRAM: 8.59 GB | bf16: Yes
torch.compile: Yes | Triton: Yes | compile_probe: Yes
Optional deps:
  - openslide: Yes
  - monai: Yes
  - wandb: Yes
  - tensorboard: Yes
  - pytorch-gradcam: Yes
  - sklearn: Yes
  - opencv: Yes
  - triton: Yes
bf16->NumPy safe path: Yes
Data paths: data=Yes, ann=Yes, svs=Yes (count=192)
Tiny forward probes:
  AttentionMIL: Yes
  SegmentationUNet: Yes
Grad-CAM smoke:
  explain fn: Yes | pkg: Yes
  checkpoint: checkpoints\best_model.pth | executed: Yes
Overall OK: Yes
Saved diagnostics to output\logs\diagnostics.json





In [7]:
# Enable and display performance optimization toggles for training
from scripts.train import Config as TrainConfig
import importlib.util

# Adopt legacy_config if present, otherwise create a fresh TrainConfig
cfg = legacy_config if 'legacy_config' in globals() else TrainConfig()

# Detect Triton availability (required for torch.compile on CUDA/Inductor)
def _has_triton():
    try:
        return importlib.util.find_spec("triton") is not None
    except Exception:
        return False

# Decide whether to enable torch.compile
if torch.cuda.is_available() and not _has_triton():
    cfg.USE_TORCH_COMPILE = False
    print("[compile] Triton not found; disabling torch.compile on CUDA to avoid Inductor error")
else:
    cfg.USE_TORCH_COMPILE = True

# Other performance toggles
cfg.USE_CHANNELS_LAST = True
cfg.USE_FUSED_ADAMW = True
cfg.ZERO_SET_TO_NONE = True

# Prefer bf16 when supported, else fp16
try:
    if hasattr(torch.cuda, 'is_bf16_supported') and torch.cuda.is_bf16_supported():
        cfg.AMP_DTYPE = torch.bfloat16
    else:
        cfg.AMP_DTYPE = torch.float16
except Exception:
    cfg.AMP_DTYPE = torch.float16

# Print the active optimization settings for visibility
print("Performance settings:")
print({
    'USE_TORCH_COMPILE': cfg.USE_TORCH_COMPILE,
    'USE_CHANNELS_LAST': cfg.USE_CHANNELS_LAST,
    'AMP_DTYPE': str(cfg.AMP_DTYPE),
    'USE_FUSED_ADAMW': cfg.USE_FUSED_ADAMW,
    'ZERO_SET_TO_NONE': cfg.ZERO_SET_TO_NONE,
})

# Persist back to legacy_config so downstream cells use the same object
legacy_config = cfg

Performance settings:
{'USE_TORCH_COMPILE': True, 'USE_CHANNELS_LAST': True, 'AMP_DTYPE': 'torch.bfloat16', 'USE_FUSED_ADAMW': True, 'ZERO_SET_TO_NONE': True}


In [8]:
# Weights & Biases: links and monitored signals (concise)
print("Weights & Biases: ROI-supervised training runs")
print("=" * 48)

# Replace with your project URL if different
print("Project: https://wandb.ai/thanakornbua/her2-breast-cancer")
print("Runs (illustrative):")
print("  • Phase 1 (ROI classification): Phase1_ROI_Supervised_fold0")
print("  • Phase 2 (MIL): Phase2_MIL_FineTuning_fold0")
print("  • Phase 3 (Segmentation): Phase3_Segmentation_fold0")

print("\nMonitored signals:")
print("  • Loss (train/val), LR schedule, early stopping")
print("  • Accuracy, Precision/Recall/F1, ROC-AUC")
print("  • Confusion matrix and classification report")
print("  • ROI coverage and patch sampling stats")
print("  • Segmentation Dice/IoU (if enabled)")

Weights & Biases: ROI-supervised training runs
Project: https://wandb.ai/thanakornbua/her2-breast-cancer
Runs (illustrative):
  • Phase 1 (ROI classification): Phase1_ROI_Supervised_fold0
  • Phase 2 (MIL): Phase2_MIL_FineTuning_fold0
  • Phase 3 (Segmentation): Phase3_Segmentation_fold0

Monitored signals:
  • Loss (train/val), LR schedule, early stopping
  • Accuracy, Precision/Recall/F1, ROC-AUC
  • Confusion matrix and classification report
  • ROI coverage and patch sampling stats
  • Segmentation Dice/IoU (if enabled)


## Phase 1 — ROI-supervised classification
Run Phase 1 to train a ResNet-50 AttentionMIL strictly on annotated ROIs.

In [None]:
# Phase 1: ROI-supervised classification
print("Phase 1: ROI-supervised classification")
print("=" * 48)

try:
    # Coerce config for training functions (accepts PipelineConfig)
    from scripts.train import coerce_to_train_config
    legacy_config = coerce_to_train_config(legacy_config)

    # Ensure strict ROI flags are enabled
    legacy_config.REQUIRE_ROI_FOR_PHASE1 = True
    legacy_config.ALLOW_FALLBACK_OUTSIDE_ROI = False
    legacy_config.ROI_MAX_SAMPLING_ATTEMPTS = 20
    
    # Use fast mode toggle from notebook
    from scripts.train import apply_fast_mode_overrides
    if 'POC_FAST_MODE' in globals() and POC_FAST_MODE:
        legacy_config.FAST_MODE = True
        legacy_config.FAST_IGNORE_CHECKPOINTS = True
        # Cap batches if provided
        if 'FAST_TRAIN_CAP' in globals():
            legacy_config.FAST_MAX_TRAIN_BATCHES_PER_EPOCH = FAST_TRAIN_CAP
        if 'FAST_VAL_CAP' in globals():
            legacy_config.FAST_MAX_VAL_BATCHES_PER_EPOCH = FAST_VAL_CAP
        apply_fast_mode_overrides(legacy_config)
    
    # Safety: disable torch.compile on CUDA if Triton is missing
    import importlib.util
    if getattr(legacy_config, 'USE_TORCH_COMPILE', False) and torch.cuda.is_available() and importlib.import_module:
        if importlib.util.find_spec('triton') is None:
            print("[compile] Triton missing at runtime; disabling torch.compile for this session")
            legacy_config.USE_TORCH_COMPILE = False
    
    # Display key settings (robust to both config styles)
    print("Config summary:")
    summary_fields = {
        'PATCH_SIZE': getattr(legacy_config, 'PATCH_SIZE', getattr(getattr(legacy_config, 'data', object()), 'patch_size', 'n/a')),
        'BATCH_SIZE': getattr(legacy_config, 'BATCH_SIZE', getattr(getattr(legacy_config, 'model', object()), 'batch_size', 'n/a')),
        'N_FOLDS': getattr(legacy_config, 'N_FOLDS', getattr(getattr(legacy_config, 'training', object()), 'cross_validation_folds', 'n/a')),
        'PATCHES_PER_SLIDE_P1': getattr(legacy_config, 'PATCHES_PER_SLIDE_PHASE1', getattr(getattr(legacy_config, 'training', object()), 'patches_per_slide_phase1', 'n/a')),
        'PATCHES_PER_SLIDE_P2': getattr(legacy_config, 'PATCHES_PER_SLIDE_PHASE2', getattr(getattr(legacy_config, 'training', object()), 'patches_per_slide_phase2', 'n/a')),
        'PATCHES_PER_SLIDE_SEG': getattr(legacy_config, 'PATCHES_PER_SLIDE_SEG', getattr(getattr(legacy_config, 'training', object()), 'patches_per_slide_seg', 'n/a')),
        'USE_TORCH_COMPILE': getattr(legacy_config, 'USE_TORCH_COMPILE', False),
        'USE_CHANNELS_LAST': getattr(legacy_config, 'USE_CHANNELS_LAST', False),
        'AMP_DTYPE': str(getattr(legacy_config, 'AMP_DTYPE', 'n/a')),
        'USE_FUSED_ADAMW': getattr(legacy_config, 'USE_FUSED_ADAMW', False),
        'ZERO_SET_TO_NONE': getattr(legacy_config, 'ZERO_SET_TO_NONE', True),
        'REQUIRE_ROI_FOR_PHASE1': getattr(legacy_config, 'REQUIRE_ROI_FOR_PHASE1', True),
        'ALLOW_FALLBACK_OUTSIDE_ROI': getattr(legacy_config, 'ALLOW_FALLBACK_OUTSIDE_ROI', False),
        'FAST_MAX_TRAIN_BATCHES_PER_EPOCH': getattr(legacy_config, 'MAX_TRAIN_BATCHES_PER_EPOCH', None),
        'FAST_MAX_VAL_BATCHES_PER_EPOCH': getattr(legacy_config, 'MAX_VAL_BATCHES_PER_EPOCH', None),
    }
    print(summary_fields)
    
    # Run Phase 1 for a chosen fold (default 0)
    fold = 0
    from torch.utils.tensorboard import SummaryWriter
    from pathlib import Path
    tb_dir = Path(legacy_config.LOG_DIR) / "tensorboard"
    tb_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(log_dir=str(tb_dir / f"phase1_fold{fold}"))
    
    best_auc = train_phase1(legacy_config, fold=fold, writer=writer)
    print(f"Phase 1 completed. Best AUC: {best_auc:.4f}")
    
except Exception as e:
    print(f"Phase 1 failed: {e}")
    import traceback
    traceback.print_exc()

Phase 1: ROI-supervised classification
[FAST] Fast mode enabled: INPUT_SIZE=256, EPOCHS(P1/P2)=5/5, PATCHES/SLIDE(P1/P2/Seg)=32/64/16, SEG_PATCH=192, OTSU=False, CKPT=False
Config summary:
{'PATCH_SIZE': 512, 'BATCH_SIZE': 16, 'N_FOLDS': 5, 'PATCHES_PER_SLIDE_P1': 32, 'PATCHES_PER_SLIDE_P2': 64, 'PATCHES_PER_SLIDE_SEG': 16, 'USE_TORCH_COMPILE': True, 'USE_CHANNELS_LAST': True, 'AMP_DTYPE': 'torch.bfloat16', 'USE_FUSED_ADAMW': True, 'ZERO_SET_TO_NONE': True, 'REQUIRE_ROI_FOR_PHASE1': True, 'ALLOW_FALLBACK_OUTSIDE_ROI': False, 'FAST_MAX_TRAIN_BATCHES_PER_EPOCH': 4, 'FAST_MAX_VAL_BATCHES_PER_EPOCH': 2}
[Data] Slides with ROI annotations: 187/192
[ROI] Phase 1 ROI-only mode: filtered train 153->148, val 39->39
[Data] Slides with ROI annotations: 187/192
[ROI] Phase 1 ROI-only mode: filtered train 153->148, val 39->39


wandb: You can find your API key in your browser here: http://localhost:8080/authorize?ref=models
wandb: Appending key for localhost:8080 to your netrc file: C:\Users\tanth\_netrc
wandb: Currently logged in as: thanakornbua to http://localhost:8080. Use `wandb login --relogin` to force relogin


## Phase 2 — Multiple instance learning (MIL) fine-tuning

Objective: fine-tune an attention-based MIL classifier using slide-level labels, initialized from Phase 1 ROI-supervised features.

- Initialization: load the best Phase 1 checkpoint; freeze early layers; tune the attention head and classifier.
- Data: group ROI-derived patches into bags; optional tissue masking to exclude background.
- Outputs: best-performing MIL checkpoint, attention visualizations, and tracked metrics.
- Metrics: bag-level ROC-AUC, precision/recall/F1, confusion matrix; learning dynamics.
- Reproducibility: all configs and artifacts are versioned; seeds fixed where supported.

In [9]:
# Phase 2: Multiple instance learning (MIL) fine-tuning

print("Phase 2: MIL fine-tuning with frozen backbone")
print("=" * 48)

try:
    # Ensure fast mode carries over to phase 2 as well
    try:
        from scripts.train import apply_fast_mode_overrides
        if 'POC_FAST_MODE' in globals() and POC_FAST_MODE:
            legacy_config.FAST_MODE = True
            legacy_config.FAST_IGNORE_CHECKPOINTS = True
            apply_fast_mode_overrides(legacy_config)
    except Exception as e:
        print(f"[FAST] Note: could not re-apply fast mode for Phase 2: {e}")

    print("Configuration:")
    print(f"  Epochs: {legacy_config.EPOCHS_PHASE2}")
    print(f"  Learning rate: {legacy_config.LR_PHASE2}")
    print(f"  Patches per slide: {legacy_config.PATCHES_PER_SLIDE_PHASE2}")
    print(f"  Otsu tissue masking: {'Enabled' if legacy_config.USE_OTSU_TISSUE_MASK else 'Disabled'}")
    print("  Strategy: Freeze early layers; fine-tune attention and classifier")

    # Check if Phase 1 model exists
    phase1_model_path = legacy_config.CHECKPOINT_DIR / "phase1_fold0_best.pth"
    if not phase1_model_path.exists():
        print("Phase 1 model not found. Skipping Phase 2.")
        print("Run Phase 1 training first to generate the initialization model.")
    else:
        print("Starting MIL fine-tuning…")
        print("Using ROI-trained Phase 1 model for initialization.")

        # Run Phase 2 training (without TensorBoard writer)
        train_phase2(legacy_config, fold=0, writer=None)

        print("MIL fine-tuning completed.")

        # Generate explanations for Phase 2 model
        phase2_model_path = legacy_config.CHECKPOINT_DIR / "phase2_fold0_best.pth"
        if phase2_model_path.exists():
            print("Generating MIL attention/Grad-CAM visualizations…")
            explain_predictions(legacy_config, str(phase2_model_path), fold=0, num_samples=3)
            print("MIL explanations generated.")

except KeyboardInterrupt:
    print("Phase 2 training interrupted by user.")
except Exception as e:
    print(f"Phase 2 training failed: {e}")
    import traceback
    traceback.print_exc()
finally:
    print("Phase 2 training session complete.")


Phase 2: MIL fine-tuning with frozen backbone
Configuration:
Phase 2 training failed: 'PipelineConfig' object has no attribute 'EPOCHS_PHASE2'
Phase 2 training session complete.


Traceback (most recent call last):
  File "C:\Users\tanth\AppData\Local\Temp\ipykernel_49696\3930901177.py", line 18, in <module>
    print(f"  Epochs: {legacy_config.EPOCHS_PHASE2}")
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'PipelineConfig' object has no attribute 'EPOCHS_PHASE2'


## Phase 3 — ROI-derived segmentation

Objective: train a U-Net–based segmentation model to delineate tissue regions using masks derived from ROI annotations.

- Architecture: U-Net (ResNet backbone).
- Data: convert ROI polygons to binary masks; use patch-based sampling.
- Augmentation: elastic deformation and other MONAI transforms (configurable).
- Outputs: best-performing segmentation checkpoint and selected overlays.
- Metrics: Dice, IoU, pixel accuracy, and per-class breakdown.
- Notes: Segmentation complements classification by localizing ROI-consistent patterns.

In [10]:
# Phase 3: Segmentation training (ROI-derived masks)
print("Phase 3: Segmentation training with U-Net")
print("Segmentation masks are derived from ROI annotations")
print("=" * 48)

try:
    # Ensure fast mode also affects segmentation stage
    try:
        from scripts.train import apply_fast_mode_overrides
        if 'POC_FAST_MODE' in globals() and POC_FAST_MODE:
            legacy_config.FAST_MODE = True
            legacy_config.FAST_IGNORE_CHECKPOINTS = True
            apply_fast_mode_overrides(legacy_config)
    except Exception as e:
        print(f"[FAST] Note: could not re-apply fast mode for Segmentation: {e}")

    # Safety: disable torch.compile if Triton is missing on CUDA
    import importlib.util
    if getattr(legacy_config, 'USE_TORCH_COMPILE', False) and torch.cuda.is_available() and importlib.util.find_spec('triton') is None:
        print("[compile] Triton missing at runtime; disabling torch.compile for Segmentation")
        legacy_config.USE_TORCH_COMPILE = False

    # Enforce ROI-only segmentation patches (positive-mask requirement)
    legacy_config.REQUIRE_ROI_FOR_SEGMENTATION = True
    legacy_config.REQUIRE_POSITIVE_MASK_PATCHES = True
    legacy_config.POS_MASK_MAX_ATTEMPTS = max(getattr(legacy_config, 'POS_MASK_MAX_ATTEMPTS', 20), 20)
    print("Segmentation ROI-only: Enabled (patches must contain positive mask pixels)")

    print("Configuration:")
    print(f"  Architecture: U-Net with ResNet backbone")
    print(f"  Patch size: {legacy_config.PATCH_SIZE_SEG}")
    print(f"  Batch size: {legacy_config.BATCH_SIZE}")
    print(f"  Elastic deformation: {legacy_config.ELASTIC_DEFORM_PROB * 100}% probability")
    print("  Augmentations: MONAI transforms")
    print("  Task: Binary segmentation from ROI-derived masks")

    print("\nStarting segmentation training…")
    print("Using ROI annotations to generate ground truth masks.")

    # Run Segmentation training (without TensorBoard writer)
    train_segmentation(legacy_config, fold=0, writer=None)

    print("Segmentation training completed.")

    # Check for segmentation model
    seg_model_path = legacy_config.CHECKPOINT_DIR / "segmentation_fold0_best.pth"
    if seg_model_path.exists():
        print("Segmentation model saved.")
        print(f"  Model path: {seg_model_path}")
        print("  Model trained to segment ROI-consistent regions.")

except KeyboardInterrupt:
    print("Segmentation training interrupted by user.")
except Exception as e:
    print(f"Segmentation training failed: {e}")
    import traceback
    traceback.print_exc()
finally:
    print("Segmentation training session complete.")


Phase 3: Segmentation training with U-Net
Segmentation masks are derived from ROI annotations
Segmentation ROI-only: Enabled (patches must contain positive mask pixels)
Configuration:
  Architecture: U-Net with ResNet backbone
Segmentation training failed: 'PipelineConfig' object has no attribute 'PATCH_SIZE_SEG'
Segmentation training session complete.


Traceback (most recent call last):
  File "C:\Users\tanth\AppData\Local\Temp\ipykernel_49696\3275890057.py", line 31, in <module>
    print(f"  Patch size: {legacy_config.PATCH_SIZE_SEG}")
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'PipelineConfig' object has no attribute 'PATCH_SIZE_SEG'


In [11]:
# Hyperparameter Optimization with Optuna
print("🔧 Hyperparameter Optimization with Bayesian Optimization")
print("=" * 60)

# Configuration for optimization
ENABLE_OPTIMIZATION = False  # Set to True to run optimization
N_TRIALS = 20  # Reduced for notebook demo

if ENABLE_OPTIMIZATION:
    try:
        print("🎯 Optuna configuration:")
        print(f"   - Number of trials: {N_TRIALS}")
        print("   - Sampler: TPE (Tree-structured Parzen Estimator)")
        print("   - Objective: Maximize F1-score")
        print("   - Pruning: Enabled for early stopping")
        
        print("🚀 Starting hyperparameter optimization...")
        
        # Run Optuna optimization
        optimized_config = optimize_hyperparameters(legacy_config, n_trials=N_TRIALS)
        
        print("✅ Hyperparameter optimization completed!")
        
        # Display optimized parameters
        print("\n📊 Optimized Hyperparameters:")
        print(f"   - Learning Rate: {optimized_config.LR_PHASE1}")
        print(f"   - Batch Size: {optimized_config.BATCH_SIZE}")
        print(f"   - Weight Decay: {optimized_config.WEIGHT_DECAY}")
        
        # Save optimized configuration
        import json
        optimized_params = {
            'learning_rate': optimized_config.LR_PHASE1,
            'batch_size': optimized_config.BATCH_SIZE,
            'weight_decay': optimized_config.WEIGHT_DECAY,
            'trials_completed': N_TRIALS
        }
        
        with open('optimized_config_notebook.json', 'w') as f:
            json.dump(optimized_params, f, indent=2, default=str)
        
        print("💾 Optimized configuration saved to: optimized_config_notebook.json")
        
    except Exception as e:
        print(f"❌ Hyperparameter optimization failed: {e}")
        import traceback
        traceback.print_exc()
else:
    print("⚠️ Hyperparameter optimization is disabled")
    print("   Set ENABLE_OPTIMIZATION = True to run optimization")
    print("   Note: This may take considerable time depending on N_TRIALS")

🔧 Hyperparameter Optimization with Bayesian Optimization
⚠️ Hyperparameter optimization is disabled
   Set ENABLE_OPTIMIZATION = True to run optimization
   Note: This may take considerable time depending on N_TRIALS


In [12]:
# Results analysis and visualization (ROI-focused)
print("Results analysis and visualization")
print("ROI-supervised training results")
print("=" * 36)

import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import numpy as np
import pandas as pd

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")

# Check for training results
checkpoint_dir = Path(legacy_config.CHECKPOINT_DIR)
log_dir = Path(legacy_config.LOG_DIR)

print("Checking available model checkpoints…")

# Display available models
model_files = list(checkpoint_dir.glob("*.pth"))
if model_files:
    print(f"Found {len(model_files)} checkpoint(s):")
    for model_file in model_files:
        size_mb = model_file.stat().st_size / (1024 * 1024)
        print(f"  - {model_file.name} ({size_mb:.1f} MB)")
        if "phase1" in model_file.name:
            print("    Type: ROI-supervised classification model")
        elif "phase2" in model_file.name:
            print("    Type: MIL model (initialized from ROI-trained base)")
        elif "segmentation" in model_file.name:
            print("    Type: Segmentation model (ROI-derived masks)")
else:
    print("No checkpoints found. Run training phases first.")

# Check for training logs
if log_dir.exists():
    log_files = list(log_dir.glob("*.log"))
    if log_files:
        print(f"\nFound {len(log_files)} training log file(s)")
        for log_file in log_files:
            print(f"  - {log_file.name}")
    else:
        print("\nNo log files found.")

# Display augmentation examples (synthetic ROI-like image)
print("\nTesting augmentation pipeline (synthetic example)…")
try:
    from scripts.augmentations import get_classification_transforms
    try:
        from scripts.augmentations import elastic_deformation, NativeStainNormalizer
        native_stain_available = True
    except ImportError:
        print("Native stain normalization not available")
        native_stain_available = False

    # Create a sample ROI-like image
    sample_image = np.random.randint(50, 200, (256, 256, 3), dtype=np.uint8)
    for _ in range(15):
        x, y = np.random.randint(20, 236, 2)
        size = np.random.randint(4, 12)
        Y, X = np.ogrid[:256, :256]
        mask = (X - x)**2 + (Y - y)**2 <= size**2
        sample_image[mask, 0] = np.clip(120 + np.random.randint(-20, 20), 0, 255)
        sample_image[mask, 1] = np.clip(100 + np.random.randint(-15, 15), 0, 255)
        sample_image[mask, 2] = np.clip(180 + np.random.randint(-15, 15), 0, 255)

    results = [sample_image]
    titles = ["Original ROI-like image"]

    if native_stain_available:
        try:
            elastic_result = elastic_deformation(sample_image, alpha=100, sigma=10)
            results.append(elastic_result)
            titles.append("Elastic deformation")

            normalizer = NativeStainNormalizer()
            stain_result = normalizer.fit_transform([sample_image])[0]
            results.append(stain_result)
            titles.append("Native H&E normalization")
        except Exception as e:
            print(f"Augmentation error: {e}")

    n_images = len(results)
    fig, axes = plt.subplots(1, n_images, figsize=(5*n_images, 5))
    if n_images == 1:
        axes = [axes]
    for i, (img, title) in enumerate(zip(results, titles)):
        axes[i].imshow(img)
        axes[i].set_title(title, fontweight='bold')
        axes[i].axis('off')
    plt.tight_layout()
    plt.show()
    print("Augmentation pipeline ran without errors.")

except Exception as e:
    print(f"Augmentation test failed: {e}")

# ROI training summary
print("\nROI-focused pipeline summary")
print("=" * 28)
print("Phase 1: ROI-supervised classification — patches from XML annotations")
print("Phase 2: MIL fine-tuning — initialized from Phase 1")
print("Phase 3: Segmentation — ROI regions converted to masks")
print("Augmentation pipeline — native H&E and elastic deformation")
print("Interpretability — Grad-CAM/attention overlays on selected samples")

# ROI annotation coverage
metadata_file = legacy_config.DATA_DIR / "metadata.csv"
if metadata_file.exists():
    df = pd.read_csv(metadata_file)
    total_slides = len(df)
    roi_slides = len(df[df['has_annotation'] == True])
    coverage = (roi_slides / total_slides) * 100 if total_slides > 0 else 0
    print("\nROI annotation coverage:")
    print(f"  Total slides: {total_slides}")
    print(f"  Slides with ROI annotations: {roi_slides}")
    print(f"  Coverage: {coverage:.1f}%")

# Next steps
print("\nNext steps:")
print("1. Review wandb runs for metrics and artifacts.")
print("2. Run inference on held-out test data.")
print("3. Generate Grad-CAM or attention visualizations for qualitative review.")
print("4. Validate performance on ROI-annotated regions.")
print("5. Compare ROI-focused vs. random patch sampling if applicable.")

Results analysis and visualization
ROI-supervised training results


AttributeError: 'PipelineConfig' object has no attribute 'CHECKPOINT_DIR'

In [None]:
"""
Experiment completion and resource management
============================================

Final cleanup and structured summary of the experimental pipeline. Ensures proper resource management for reproducible runs.
"""

print("ROI-supervised HER2 classification pipeline complete")
print("=" * 48)

# Perform systematic cleanup
print("Performing resource cleanup…")

try:
    # Memory optimization
    cleanup_variables = ['fig', 'axes', 'sample_image', 'results', 'titles', 'df', 'roi_slides']
    cleaned_count = 0
    
    for var_name in cleanup_variables:
        if var_name in locals():
            del locals()[var_name]
            cleaned_count += 1
        if var_name in globals():
            del globals()[var_name]
            cleaned_count += 1
    
    # Force garbage collection
    import gc
    gc.collect()
    
    # GPU memory management
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"GPU memory cache cleared ({gpu_memory:.1f} GB total)")
    
    print(f"Resource cleanup complete ({cleaned_count} variables cleared)")
    
except Exception as e:
    print(f"Cleanup note: {e}")

# Pipeline status summary
print("\nPipeline execution summary")
print("  ROI-supervised training framework initialized")
print("  Native H&E stain normalization pipeline available")
print("  Multi-phase training architecture configured")
print("  Experiment tracking enabled (wandb)")
print("  Model interpretability integrated (Grad-CAM)")
print("  Analysis tools prepared for publication")

# Output directory summary
try:
    from pathlib import Path
    checkpoints_dir = Path("checkpoints")
    output_dir = Path("output")
    
    print("\nOutput locations:")
    print(f"  Model checkpoints: {checkpoints_dir}")
    print(f"  Training logs: {output_dir / 'logs'}")
    print(f"  Visualizations: {output_dir / 'logs' / 'explanations'}")
    print(f"  Experiment data: Weights & Biases dashboard")
    
except Exception:
    print("\nStandard output directories: checkpoints/, output/")

# Research and publication guidelines
print("\nResearch guidelines:")
print("  • Ensure proper data licensing and ethical approval")
print("  • Validate results with independent test datasets")
print("  • Report confidence intervals and statistical significance")
print("  • Document hardware specifications and runtime requirements")
print("  • Preserve experiment artifacts for reproducibility")

print("\nClinical validation recommendations:")
print("  • Collaborate with board-certified pathologists")
print("  • Perform inter-observer agreement studies")
print("  • Validate across multiple institutions")
print("  • Consider prospective studies when appropriate")

print("\nSystem prepared for peer review and publication submission")