# Classification Training: ResNet-18 for Benign vs Malignant

**Goals:**
- Train ResNet-18 with 5-fold CV
- Optimize Focal loss for class imbalance
- Track metrics: ROC-AUC, balanced accuracy, sensitivity, specificity
- Apply temperature scaling for calibration
- Measure ECE before/after calibration

In [None]:
# Imports
import sys
sys.path.append('..')

import torch
import lightning as L
from pathlib import Path
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import roc_curve, auc, confusion_matrix

from src.utils import seed_everything
from src.datamodules.bus_uc import BusUcClsDataModule
from src.models.cls_resnet18 import ResNet18Classifier, LightningClsModel
from src.losses import FocalLoss
from src.metrics import compute_roc_auc, expected_calibration_error
from src.calibration import ModelWithTemperature

## 1. Load Configuration

In [None]:
# TODO: Load config from ../configs/cls_resnet18.yaml
# config = yaml.safe_load(open('../configs/cls_resnet18.yaml'))
# print(config)

## 2. Data Preparation

In [None]:
# TODO: Initialize DataModule with WeightedRandomSampler
# datamodule = BusUcClsDataModule(...)
# datamodule.setup()

# TODO: Verify class distribution in train/val/test
# TODO: Visualize sample images from each class

## 3. Model Initialization

In [None]:
# TODO: Initialize ResNet-18 model
# model = ResNet18Classifier(pretrained=True, freeze_epochs=5)
# loss_fn = FocalLoss(gamma=1.5)
# lightning_model = LightningClsModel(model, loss_fn, ...)

## 4. Training (Single Fold)

In [None]:
# TODO: Set up Trainer with callbacks
# - ModelCheckpoint (save best by val_roc_auc)
# - EarlyStopping
# - LearningRateMonitor

# TODO: Train with frozen layers first, then unfreeze
# trainer.fit(lightning_model, datamodule)

## 5. Evaluation

In [None]:
# TODO: Evaluate on test set
# - Compute predictions and probabilities
# - Plot ROC curve
# - Compute confusion matrix
# - Calculate sensitivity, specificity, balanced accuracy

## 6. Calibration

In [None]:
# TODO: Temperature scaling
# model_with_temp = ModelWithTemperature(model)
# model_with_temp.fit_temperature(val_loader)

# TODO: Measure ECE before/after calibration
# TODO: Plot reliability diagram before/after

## 7. Cross-Validation (All Folds)

In [None]:
# TODO: Run training for all 5 folds
# - Collect metrics per fold
# - Aggregate: mean ± std
# - Save results to CSV

## 8. Error Analysis

In [None]:
# TODO: Analyze misclassified samples
# - Visualize false positives and false negatives
# - Examine prediction confidence distribution
# - Check if errors correlate with lesion size or other factors

## Summary

TODO: Print final results:
- ROC-AUC (mean ± std)
- Balanced accuracy
- Sensitivity & Specificity
- ECE before/after calibration
- Quality gate: ROC-AUC ≥ 0.90, ECE ≤ 0.05?