# Fine-tune MedSigLIP for Surgical Wound Assessment — Run 2

**MamaGuard — Wound Assessment Component (v2: Expanded Unfreeze + Differential LR)**

Fine-tunes `google/medsiglip-448` vision encoder on the [SurgWound dataset](https://huggingface.co/datasets/xuxuxuxuxu/SurgWound)
for **6 binary clinical labels**: healing status, erythema, edema, infection risk, urgency, and exudate.

The fine-tuned model serves as a structured clinical signal provider in a two-stage pipeline:
```
Wound photo → MedSigLIP (this model) → structured scores → Gemini/MedGemma orchestrator → empathetic response
```

### Changes from Run 1 → Run 2
Run 1 (40 total optimizer steps, N_UNFREEZE=4) showed clear underfitting: inference scores clustered 0.45–0.59 and the model was still improving at epoch 5 with no plateau.

| Parameter | Run 1 | Run 2 | Effect |
|---|---|---|---|
| `N_UNFREEZE` | 4 | **8** | ~2× trainable capacity (~14% → ~28%) |
| `GRAD_ACCUM` | 16 | **4** | 8 → 30 optimizer steps/epoch |
| `EPOCHS` | 5 | **10** | 40 → **300** total optimizer steps (7.5×) |
| Learning rate | single 5e-5 | **differential** backbone=1.5e-5 / head=8e-5 | Preserves pretrained features, fast head convergence |
| Threshold | fixed 0.5 | **per-label (Youden's J)** | Corrects miscalibration (e.g., healing sens=0.84/spec=0.26) |

### Key Design Decisions
- **Expanded selective freezing**: Last **8** encoder blocks + classification head are trainable (~28% of params); deeper layers give model more expressive capacity without saturating T4 VRAM
- **Differential learning rate**: Backbone blocks update at `BACKBONE_LR=1.5e-5` (gentle, preserves pretrained SigLIP features); classifier head at `HEAD_LR=8e-5` (fast learning from random initialization)
- **Masked BCE loss**: 3 of 6 labels have MISSING values — loss is zeroed out for those entries instead of dropping entire samples
- **Light augmentation**: Horizontal flip + rotation + color jitter to compensate for small dataset (480 train images)
- **eval_loss for model selection**: Val set has only 69 images — per-label AUC too noisy for checkpoint comparison
- **Per-label threshold tuning**: Youden's J (J = sensitivity + specificity − 1) on validation set replaces a fixed threshold=0.5 after training

### Dataset
- **Source**: SurgWound (686 images: 480 train / 69 val / 137 test)
- **Upload as Kaggle dataset** at `surgwound-dataset` for instant `/kaggle/input/` access

### References
- [MedSigLIP model](https://huggingface.co/google/medsiglip-448)
- [Google's fine-tuning notebook](https://github.com/google-health/medsiglip/blob/main/notebooks/fine_tune_for_image_classification.ipynb)
- [SurgWound dataset](https://huggingface.co/datasets/xuxuxuxuxu/SurgWound)

## 1. Setup

### GPU Requirements
This notebook is designed for **T4 (16GB)** on Kaggle free tier.
It will also work on P100 (16GB), L4 (24GB), or A100 (40GB+).

### Dataset Setup
Before running, upload the `data/surgwound/` folder (containing `labels.csv` + `images/` with 686 JPGs)
as a Kaggle dataset named `surgwound-dataset`. It will be accessible at `/kaggle/input/surgwound-dataset/`.

In [None]:
# ── Install dependencies ──────────────────────────────────────────────────────
# Pin transformers >= 4.46.0 for Trainer compatibility
!pip install --upgrade --quiet \
    "transformers>=4.46.0" \
    accelerate \
    datasets \
    evaluate \
    tensorboard \
    scikit-learn \
    tqdm \
    pillow

In [None]:
# ── Authenticate with Hugging Face ────────────────────────────────────────────
# Required to download gated model: google/medsiglip-448
#
# On Kaggle: Add your HF token as a Kaggle Secret named "HF_TOKEN"
# On Colab:  Add it to Colab Secrets, or it will prompt notebook_login()

import os
import sys

if "kaggle_secrets" in dir() or os.path.exists("/kaggle"):
    # Running on Kaggle
    try:
        from kaggle_secrets import UserSecretsClient
        secrets = UserSecretsClient()
        os.environ["HF_TOKEN"] = secrets.get_secret("HF_TOKEN")
        print("✓ HF_TOKEN loaded from Kaggle Secrets")
    except Exception as e:
        print(f"⚠ Could not load HF_TOKEN from Kaggle Secrets: {e}")
        print("  Falling back to huggingface_hub login...")
        from huggingface_hub import notebook_login
        notebook_login()
elif "google.colab" in sys.modules:
    # Running on Colab
    try:
        from google.colab import userdata
        os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
        print("✓ HF_TOKEN loaded from Colab Secrets")
    except Exception:
        from huggingface_hub import notebook_login
        notebook_login()
else:
    # Local / other environment
    from huggingface_hub import get_token
    if get_token() is None:
        from huggingface_hub import notebook_login
        notebook_login()
    else:
        print("✓ HF token already configured")

## 2. Configuration

All hyperparameters and label definitions in one place.

In [None]:
# ── Force single-GPU mode on multi-GPU Kaggle runtimes ──────────────────────
# Must run BEFORE importing torch.
import os

if os.environ.get("CUDA_VISIBLE_DEVICES") != "0":
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
print(f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")

import torch
import numpy as np

# ── Paths ─────────────────────────────────────────────────────────────────────
# Adjust BASE_PATH depending on your environment:
#   Kaggle:  /kaggle/input/surgwound-dataset
#   Local:   ./data/surgwound

if os.path.exists("/kaggle/input/datasets/kkfkmf/surgwound-dataset"):
    BASE_PATH = "/kaggle/input/datasets/kkfkmf/surgwound-dataset"  # confirmed working in Run 1
elif os.path.exists("/kaggle/input/surgwound-dataset"):
    BASE_PATH = "/kaggle/input/surgwound-dataset"
elif os.path.exists("/kaggle/input/surgwound"):
    BASE_PATH = "/kaggle/input/surgwound"
else:
    BASE_PATH = "./data/surgwound"  # Local development

LABELS_CSV = os.path.join(BASE_PATH, "labels.csv")
IMAGES_DIR = os.path.join(BASE_PATH, "images")

# If Kaggle dataset was uploaded with directory mode=zip, images may arrive as images.zip
IMAGES_ZIP = os.path.join(BASE_PATH, "images.zip")
if not os.path.isdir(IMAGES_DIR) and os.path.isfile(IMAGES_ZIP):
    import zipfile
    print(f"Extracting {IMAGES_ZIP} ...")
    with zipfile.ZipFile(IMAGES_ZIP, "r") as zf:
        zf.extractall(BASE_PATH)
    print(f"✓ Extracted images to {IMAGES_DIR}")

# ── Model ─────────────────────────────────────────────────────────────────────
MODEL_ID   = "google/medsiglip-448"
OUTPUT_DIR = "medsiglip-448-surgwound-v2"   # separate dir so Run 1 checkpoint is preserved

# ── Label definitions ─────────────────────────────────────────────────────────
# 6 binary labels predicted from wound images
LABEL_NAMES = [
    "healing_status",   # 0: Healed, 1: Not Healed
    "erythema",         # 0: Non-existent, 1: Existent       (has MISSING)
    "edema",            # 0: Non-existent, 1: Existent       (has MISSING)
    "infection_risk",   # 0: Low, 1: Medium+High
    "urgency",          # 0: Green (home care), 1: Yellow+Red (needs attention)
    "exudate",          # 0: Non-existent, 1: Any exudate    (has MISSING)
]
NUM_LABELS = len(LABEL_NAMES)

id2label = {i: name for i, name in enumerate(LABEL_NAMES)}
label2id = {name: i for i, name in enumerate(LABEL_NAMES)}

# Precomputed from training split (non-MISSING samples only):
#   healing:   neg=282, pos=198  → 282/198 = 1.42
#   erythema:  neg=334, pos=129  → 334/129 = 2.59
#   edema:     neg=328, pos=50   → 328/50  = 6.56
#   infection: neg=402, pos=78   → 402/78  = 5.15
#   urgency:   neg=423, pos=57   → 423/57  = 7.42
#   exudate:   neg=367, pos=70   → 367/70  = 5.24
POS_WEIGHT = torch.tensor([1.42, 2.59, 6.56, 5.15, 7.42, 5.24])

# ── Freezing strategy ────────────────────────────────────────────────────────
N_UNFREEZE = 8   # Unfreeze last N encoder blocks + classification head (Run 1: 4)
                 # ~2× trainable capacity; differential LR mitigates overfit risk

# ── Training hyperparameters ─────────────────────────────────────────────────
BATCH_SIZE   = 4
GRAD_ACCUM   = 4         # Effective batch = 4 × 4 = 16; gives 30 steps/epoch (Run 1: 16, 8 steps/epoch)
EPOCHS       = 10        # 300 total optimizer steps vs Run 1's 40
BACKBONE_LR  = 1.5e-5    # Gentle updates to preserve pretrained SigLIP features
HEAD_LR      = 8e-5      # Fast learning for randomly initialized classifier head
LR           = HEAD_LR   # TrainingArguments base LR — cosine scheduler scales both groups proportionally
WARMUP_STEPS = 15        # 15/300 = 5% of total steps (Run 1: 10/40 = 25% — too aggressive)
WEIGHT_DECAY = 0.015     # Slightly stronger regularisation for ~2× trainable capacity
SCHEDULER    = "cosine"
FP16         = True      # T4 doesn't support bf16; fp16 saves VRAM

# ── Device ───────────────────────────────────────────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
    visible_gpu_count = torch.cuda.device_count()
    print(f"Visible GPU count: {visible_gpu_count}")
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem  = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"✓ GPU: {gpu_name} ({gpu_mem:.1f} GB)")
    # For larger GPUs: increase BATCH_SIZE and reduce GRAD_ACCUM proportionally
    # so that the effective batch size AND optimizer step count stay identical to T4.
    # T4:  batch=4,  accum=4, eff=16, forward=120, steps/epoch=30
    # A100: batch=16, accum=1, eff=16, forward=30,  steps/epoch=30 ← same!
    if gpu_mem >= 30:  # A100 / L4
        BATCH_SIZE = 16
        GRAD_ACCUM = 1
        print(f"  → Adjusted: batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM} (effective={BATCH_SIZE*GRAD_ACCUM}, same step count as T4)")
else:
    print("⚠ No GPU detected — training will be extremely slow")
    FP16 = False

# ── Step count verification (critical — print and sanity-check) ───────────────
forward_passes  = (480 + BATCH_SIZE - 1) // BATCH_SIZE
optimizer_steps = (forward_passes + GRAD_ACCUM - 1) // GRAD_ACCUM
total_steps     = EPOCHS * optimizer_steps
warmup_pct      = 100 * WARMUP_STEPS / total_steps

print(f"\nEffective batch size:    {BATCH_SIZE * GRAD_ACCUM}  (Run 1: 64)")
print(f"Forward passes / epoch:  {forward_passes}")
print(f"Optimizer steps / epoch: {optimizer_steps}  (Run 1: 8)")
print(f"Total optimizer steps:   {total_steps}  (Run 1: 40)")
print(f"Warmup: {WARMUP_STEPS} steps = {warmup_pct:.1f}% of total  (Run 1: 25.0%)")
print(f"Backbone LR: {BACKBONE_LR}  |  Head LR: {HEAD_LR}")
print(f"Dataset path: {BASE_PATH}")

## 3. Load & Validate Dataset

Read `labels.csv`, verify split counts (480/69/137), and print label distributions.
This cell fails fast if the dataset is corrupted or missing.

In [None]:
import pandas as pd
from pathlib import Path

# ── Load CSV ──────────────────────────────────────────────────────────────────
df = pd.read_csv(LABELS_CSV)
print(f"Loaded {len(df)} rows from {LABELS_CSV}")
print(f"Columns: {list(df.columns)}\n")

# ── Verify split counts ──────────────────────────────────────────────────────
split_counts = df["split"].value_counts().to_dict()
print("Split counts:", split_counts)

assert split_counts.get("train", 0) == 480, f"Expected 480 train, got {split_counts.get('train', 0)}"
assert split_counts.get("validation", 0) == 69, f"Expected 69 val, got {split_counts.get('validation', 0)}"
assert split_counts.get("test", 0) == 137, f"Expected 137 test, got {split_counts.get('test', 0)}"
print("✓ Split counts verified: 480 train / 69 val / 137 test\n")

# ── Verify all images exist on disk ──────────────────────────────────────────
missing_images = []
for _, row in df.iterrows():
    img_path = os.path.join(BASE_PATH, row["image_path"])
    if not os.path.exists(img_path):
        missing_images.append(img_path)

if missing_images:
    print(f"✗ {len(missing_images)} images missing! First 5:")
    for p in missing_images[:5]:
        print(f"  {p}")
    raise FileNotFoundError(f"{len(missing_images)} images not found on disk")
else:
    print(f"✓ All {len(df)} images verified on disk\n")

# ── Print label distributions for training split ─────────────────────────────
train_df = df[df["split"] == "train"]
print("=" * 60)
print("TRAINING SPLIT LABEL DISTRIBUTIONS")
print("=" * 60)
for col in ["healing_status", "erythema", "edema", "infection_risk", "urgency_level", "exudate_type"]:
    counts = train_df[col].value_counts()
    print(f"\n{col}:")
    for val, cnt in counts.items():
        pct = 100 * cnt / len(train_df)
        marker = " ← MISSING" if val == "MISSING" else ""
        print(f"  {val:55s} {cnt:4d} ({pct:5.1f}%){marker}")

print("\n" + "=" * 60)

## 4. Label Encoding

Convert raw CSV labels into a 6-dimensional binary vector per image.

| # | Label | 0 (negative) | 1 (positive) | MISSING → -1 |
|---|---|---|---|---|
| 0 | healing_status | Healed | Not Healed | never |
| 1 | erythema | Non-existent | Existent | yes (17 in train) |
| 2 | edema | Non-existent | Existent | yes (102 in train) |
| 3 | infection_risk | Low | Medium or High | never |
| 4 | urgency | Home Care (Green) | Clinic Visit or Emergency | never |
| 5 | exudate | Non-existent | Any type present | yes (43 in train) |

In [None]:
def encode_labels(row: pd.Series) -> list[float]:
    """
    Convert a single CSV row into a 6-dim label vector.

    Returns:
        List of 6 floats: 0.0 (negative), 1.0 (positive), or -1.0 (MISSING).
        The masked loss function will ignore -1.0 entries.
    """
    labels = []

    # 0. healing_status: "Not Healed" → 1 (positive), "Healed" → 0
    labels.append(1.0 if row["healing_status"] == "Not Healed" else 0.0)

    # 1. erythema: "Existent" → 1, "Non-existent" → 0, "MISSING" → -1
    if row["erythema"] == "MISSING":
        labels.append(-1.0)
    else:
        labels.append(1.0 if row["erythema"] == "Existent" else 0.0)

    # 2. edema: "Existent" → 1, "Non-existent" → 0, "MISSING" → -1
    if row["edema"] == "MISSING":
        labels.append(-1.0)
    else:
        labels.append(1.0 if row["edema"] == "Existent" else 0.0)

    # 3. infection_risk: "Medium" or "High" → 1, "Low" → 0
    labels.append(1.0 if row["infection_risk"] in ("Medium", "High") else 0.0)

    # 4. urgency: anything other than "Home Care (Green)..." → 1
    labels.append(0.0 if row["urgency_level"].startswith("Home Care") else 1.0)

    # 5. exudate: "Non-existent" → 0, "MISSING" → -1, anything else → 1
    if row["exudate_type"] == "MISSING":
        labels.append(-1.0)
    elif row["exudate_type"] == "Non-existent":
        labels.append(0.0)
    else:
        labels.append(1.0)  # Serous, Sanguineous, Purulent, Seropurulent

    return labels


# ── Verify encoding on a few known examples ──────────────────────────────────
sample_row = train_df.iloc[0]
sample_labels = encode_labels(sample_row)
print(f"Sample row (img_id={sample_row['img_id']}):")
print(f"  healing_status = {str(sample_row['healing_status']):20s} → {sample_labels[0]}")
print(f"  erythema       = {str(sample_row['erythema']):20s} → {sample_labels[1]}")
print(f"  edema          = {str(sample_row['edema']):20s} → {sample_labels[2]}")
print(f"  infection_risk = {str(sample_row['infection_risk']):20s} → {sample_labels[3]}")
print(f"  urgency_level  = {str(sample_row['urgency_level'])[:30]:30s} → {sample_labels[4]}")
print(f"  exudate_type   = {str(sample_row['exudate_type']):20s} → {sample_labels[5]}")
print(f"\n  Encoded vector: {sample_labels}")

## 5. Create HuggingFace Datasets

Build `Dataset` objects for train, validation, and test splits.
Each sample has an `image` (PIL) and `label` (6-dim float list).

In [None]:
from datasets import Dataset, Features, Value, Sequence, Image as HFImage
import PIL.Image  # Ensure PIL.Image is in sys.modules for datasets internals
from tqdm.auto import tqdm


def build_dataset_from_split(split_df: pd.DataFrame, split_name: str) -> Dataset:
    """
    Build a HuggingFace Dataset from a pandas DataFrame for one split.

    Loads images from disk and encodes labels as 6-dim float vectors.
    Uses explicit Features schema to avoid datasets library PIL type-inference bugs.
    """
    image_paths = []
    labels = []
    skipped = 0

    for _, row in tqdm(split_df.iterrows(), total=len(split_df), desc=f"Loading {split_name}"):
        img_path = os.path.join(BASE_PATH, row["image_path"])
        if not os.path.exists(img_path):
            skipped += 1
            continue
        image_paths.append(img_path)
        labels.append(encode_labels(row))

    if skipped > 0:
        print(f"  ⚠ Skipped {skipped} missing images in {split_name}")

    # Explicit schema avoids PIL.Image.Image isinstance check inside datasets
    features = Features({
        "image": HFImage(),
        "label": Sequence(Value("float32"), length=6),
    })

    ds = Dataset.from_dict(
        {"image": image_paths, "label": labels},
        features=features,
    )

    print(f"  ✓ {split_name}: {len(ds)} samples")
    return ds


# ── Build all three splits ────────────────────────────────────────────────────
train_ds_raw = build_dataset_from_split(df[df["split"] == "train"], "train")
val_ds_raw   = build_dataset_from_split(df[df["split"] == "validation"], "validation")
test_ds_raw  = build_dataset_from_split(df[df["split"] == "test"], "test")

print(f"\nDataset sizes: train={len(train_ds_raw)}, val={len(val_ds_raw)}, test={len(test_ds_raw)}")

## 6. Image Preprocessing

Following Google's MedSigLIP preprocessing exactly:
1. **Zero-pad to square** — pads shorter dimension with black pixels to preserve aspect ratio
2. **Resize** to 448×448 (bilinear)
3. **Normalize** — scale to [-1, 1] with mean=0.5, std=0.5

Training adds light augmentation (horizontal flip, rotation, color jitter) to compensate for the small dataset.

> **Note**: Uses pure PIL + numpy instead of torchvision to avoid Kaggle's torch/torchvision version conflict.

In [None]:
import random
import numpy as np
import torch
from PIL import Image as PILImage, ImageEnhance
from transformers import AutoImageProcessor

# ── Load processor to get canonical image size and normalization ──────────────
image_processor = AutoImageProcessor.from_pretrained(MODEL_ID)
IMG_SIZE = image_processor.size["height"]   # 448
IMG_MEAN = image_processor.image_mean       # [0.5, 0.5, 0.5]
IMG_STD  = image_processor.image_std        # [0.5, 0.5, 0.5]
print(f"✓ Image size: {IMG_SIZE}, mean: {IMG_MEAN}, std: {IMG_STD}")


# ── Pure PIL/numpy transform functions (no torchvision dependency) ────────────

def _pil_to_tensor(img: PILImage.Image) -> torch.Tensor:
    """Convert PIL RGB image → float tensor (C, H, W) in [-1, 1]."""
    arr = np.array(img, dtype=np.float32) / 255.0                        # [0, 1]
    arr = (arr - np.array(IMG_MEAN, dtype=np.float32)) / \
          np.array(IMG_STD, dtype=np.float32)                             # [-1, 1]
    return torch.from_numpy(arr).permute(2, 0, 1)                        # (C, H, W)


def _zero_pad_to_square(img: PILImage.Image) -> PILImage.Image:
    """
    Zero-pad shorter dimension to match longer — replicates Google's
    CenterCrop(max(image.size)) trick exactly.
    """
    w, h = img.size
    max_dim = max(w, h)
    if w == h:
        return img
    padded = PILImage.new("RGB", (max_dim, max_dim), (0, 0, 0))
    padded.paste(img, ((max_dim - w) // 2, (max_dim - h) // 2))
    return padded


def _augment(img: PILImage.Image) -> PILImage.Image:
    """Light training augmentation: flip, rotation, brightness/contrast jitter."""
    if random.random() < 0.5:
        img = img.transpose(PILImage.Transpose.FLIP_LEFT_RIGHT)
    angle = random.uniform(-10, 10)
    img = img.rotate(angle, resample=PILImage.Resampling.BILINEAR,
                     fillcolor=(0, 0, 0))
    img = ImageEnhance.Brightness(img).enhance(random.uniform(0.9, 1.1))
    img = ImageEnhance.Contrast(img).enhance(random.uniform(0.9, 1.1))
    return img


def _process_image(img: PILImage.Image, augment: bool) -> torch.Tensor:
    """Full pipeline: RGB → pad-to-square → [augment] → resize → normalise."""
    img = img.convert("RGB")
    img = _zero_pad_to_square(img)
    if augment:
        img = _augment(img)
    img = img.resize((IMG_SIZE, IMG_SIZE), PILImage.Resampling.BILINEAR)
    return _pil_to_tensor(img)


def preprocess_train(examples: dict) -> dict:
    """Preprocess training examples with augmentation."""
    examples["pixel_values"] = [
        _process_image(img, augment=True) for img in examples["image"]
    ]
    return examples


def preprocess_eval(examples: dict) -> dict:
    """Preprocess validation/test examples without augmentation."""
    examples["pixel_values"] = [
        _process_image(img, augment=False) for img in examples["image"]
    ]
    return examples


# ── Apply preprocessing ───────────────────────────────────────────────────────
print("Preprocessing training data (with augmentation)...")
train_ds = train_ds_raw.map(preprocess_train, batched=True, remove_columns=["image"])

print("Preprocessing validation data...")
val_ds = val_ds_raw.map(preprocess_eval, batched=True, remove_columns=["image"])

print("Preprocessing test data...")
test_ds = test_ds_raw.map(preprocess_eval, batched=True, remove_columns=["image"])

# ── Sanity check ──────────────────────────────────────────────────────────────
sample_pv = torch.tensor(train_ds[0]["pixel_values"])
assert sample_pv.shape == torch.Size([3, IMG_SIZE, IMG_SIZE]), \
    f"Unexpected shape: {sample_pv.shape}"
assert sample_pv.min() >= -1.5 and sample_pv.max() <= 1.5, \
    f"Pixel range unexpected: [{sample_pv.min():.2f}, {sample_pv.max():.2f}]"

print(f"\n✓ Preprocessed: train={len(train_ds)}, val={len(val_ds)}, test={len(test_ds)}")
print(f"  pixel_values shape : {list(sample_pv.shape)}")
print(f"  pixel value range  : [{sample_pv.min():.3f}, {sample_pv.max():.3f}]")
print(f"  label              : {train_ds[0]['label']}")

## 7. Visual Sanity Check

Display sample images with their encoded labels to verify the pipeline.

In [None]:
import matplotlib.pyplot as plt


def show_sample(ds_raw, ds_processed, idx, title_prefix=""):
    """Display an image alongside its encoded label vector."""
    raw_img = ds_raw[idx]["image"]
    label_vec = ds_processed[idx]["label"]

    fig, ax = plt.subplots(1, 1, figsize=(4, 4))
    ax.imshow(raw_img)
    ax.set_title(f"{title_prefix} sample {idx}", fontsize=10)
    ax.axis("off")

    label_str = "\n".join(
        f"  {LABEL_NAMES[i]}: {label_vec[i]:+.0f}"
        + (" (MISSING)" if label_vec[i] == -1.0 else "")
        for i in range(NUM_LABELS)
    )
    ax.text(
        1.05, 0.5, label_str,
        transform=ax.transAxes, fontsize=9, verticalalignment="center",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="lightyellow", alpha=0.8),
    )
    plt.tight_layout()
    plt.show()


# Show 4 samples: 2 train, 2 val
for i in [0, 50]:
    show_sample(train_ds_raw, train_ds, i, title_prefix="Train")
for i in [0, 30]:
    show_sample(val_ds_raw, val_ds, i, title_prefix="Val")

# ── Verify MISSING encoding ──────────────────────────────────────────────────
missing_found = False
for idx in range(len(train_ds)):
    label_vec = train_ds[idx]["label"]
    if -1.0 in label_vec:
        missing_indices = [i for i, v in enumerate(label_vec) if v == -1.0]
        missing_names = [LABEL_NAMES[i] for i in missing_indices]
        print(f"✓ Found MISSING in train[{idx}]: {missing_names}")
        print(f"  Full label vector: {label_vec}")
        missing_found = True
        break

if not missing_found:
    print("⚠ No MISSING values found in training data — check encoding!")

## 8. Load Model with Selective Freezing

Load `google/medsiglip-448` with a 6-label classification head.
Freeze all parameters except the last 4 encoder blocks + the new head.

This reduces trainable params from ~400M to ~50-60M, fitting comfortably in T4 16GB VRAM.

In [None]:
from transformers import AutoModelForImageClassification

# ── Load pretrained model with new classification head ────────────────────────
model = AutoModelForImageClassification.from_pretrained(
    MODEL_ID,
    problem_type="multi_label_classification",
    num_labels=NUM_LABELS,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,  # Head size mismatch: pretrained != 6
)

# ── Freeze all parameters ────────────────────────────────────────────────────
for param in model.parameters():
    param.requires_grad = False

# ── Unfreeze classification head (randomly initialized — MUST be trainable) ──
for param in model.classifier.parameters():
    param.requires_grad = True

# ── Unfreeze last N encoder blocks ───────────────────────────────────────────
encoder_layers = model.vision_model.encoder.layers
total_layers = len(encoder_layers)
print(f"Encoder has {total_layers} layers. Unfreezing last {N_UNFREEZE}...")

for layer in encoder_layers[-N_UNFREEZE:]:
    for param in layer.parameters():
        param.requires_grad = True

# ── Print parameter counts ───────────────────────────────────────────────────
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen_params = total_params - trainable_params

print(f"\nParameter summary:")
print(f"  Total:     {total_params:>12,}")
print(f"  Trainable: {trainable_params:>12,} ({100*trainable_params/total_params:.1f}%)")
print(f"  Frozen:    {frozen_params:>12,} ({100*frozen_params/total_params:.1f}%)")

# ── Move to device and check VRAM ────────────────────────────────────────────
model = model.to(device)

if torch.cuda.is_available():
    vram_used = torch.cuda.memory_allocated() / 1e9
    vram_total = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"\nVRAM after model load: {vram_used:.2f} / {vram_total:.1f} GB")
    if vram_used > vram_total * 0.7:
        print("⚠ VRAM usage is high — consider reducing N_UNFREEZE or BATCH_SIZE")

## 9. Masked BCE Loss Function

Three labels (erythema, edema, exudate) have MISSING values encoded as `-1.0`.
Instead of dropping entire rows (losing signal for other valid labels),
the masked loss zeros out gradient contributions for MISSING entries.

Every image contributes to every label it has data for.

In [None]:
from torch.nn import BCEWithLogitsLoss


def masked_bce_loss(
    outputs: dict,
    labels: torch.Tensor,
) -> torch.Tensor:
    """
    BCE loss with per-element masking for MISSING values and class-imbalance weighting.

    Args:
        outputs: Model output dict containing 'logits' of shape (batch, 6).
        labels:  Tensor of shape (batch, 6) with values in {0.0, 1.0, -1.0}.
                 -1.0 indicates MISSING — loss for that entry is masked out.

    Returns:
        Scalar loss tensor (mean over valid entries only).
    """
    logits = outputs.get("logits")             # (batch, 6)
    mask = (labels >= 0).float()                # 1 where valid, 0 where MISSING
    safe_labels = labels.clamp(min=0.0)         # Replace -1 with 0 for BCE math

    pos_weight = POS_WEIGHT.to(logits.device)
    loss_fct = BCEWithLogitsLoss(pos_weight=pos_weight, reduction="none")
    per_element_loss = loss_fct(logits, safe_labels)  # (batch, 6)
    masked_loss = per_element_loss * mask               # Zero out MISSING

    # Mean over valid entries only (prevents bias from batch MISSING counts)
    num_valid = mask.sum()
    if num_valid == 0:
        return torch.tensor(0.0, device=logits.device, requires_grad=True)

    return masked_loss.sum() / num_valid


# ── Unit test: verify MISSING entries are masked ─────────────────────────────
print("Running masked loss unit test...")

# Fake logits and labels with known MISSING entries
_test_logits = torch.tensor([[0.5, 0.3, -0.2, 0.1, 0.4, 0.6]])
_test_labels_valid = torch.tensor([[1.0, 0.0, 1.0, 0.0, 1.0, 0.0]])   # All valid
_test_labels_missing = torch.tensor([[1.0, 0.0, -1.0, 0.0, 1.0, -1.0]])  # Entries 2,5 MISSING

# Compute loss with all valid
_loss_valid = masked_bce_loss({"logits": _test_logits}, _test_labels_valid)

# Compute loss with MISSING
_loss_masked = masked_bce_loss({"logits": _test_logits}, _test_labels_missing)

print(f"  Loss (all valid):     {_loss_valid.item():.4f}  (6/6 entries contribute)")
print(f"  Loss (2 MISSING):     {_loss_masked.item():.4f}  (4/6 entries contribute)")
print(f"  Losses differ: {_loss_valid.item() != _loss_masked.item()} (expected: True)")

# Verify gradient flows for valid entries but not MISSING
_test_logits_grad = torch.tensor([[0.5, 0.3, -0.2, 0.1, 0.4, 0.6]], requires_grad=True)
_loss_for_grad = masked_bce_loss({"logits": _test_logits_grad}, _test_labels_missing)
_loss_for_grad.backward()
grad = _test_logits_grad.grad[0]
print(f"  Gradients: {[round(g, 4) for g in grad.tolist()]}")
print(f"  MISSING positions (2,5) have zero grad: {grad[2].item() == 0.0 and grad[5].item() == 0.0}")
print("✓ Masked loss unit test passed!")

## 10. Evaluation Metrics

Macro-averaged One-vs-Rest ROC AUC + per-label sensitivity/specificity.
For labels with MISSING values in the eval set, metrics are computed only on non-MISSING samples.

In [None]:
from sklearn.metrics import roc_auc_score


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def compute_metrics(eval_pred):
    """
    Compute evaluation metrics, handling MISSING values (-1) in labels.

    Returns:
        Dict with 'roc_auc_macro' and per-label AUC.
    """
    logits, labels = eval_pred
    scores = sigmoid(logits)

    results = {}
    per_label_auc = []

    for i, name in enumerate(LABEL_NAMES):
        # Mask out MISSING entries for this label
        valid_mask = labels[:, i] >= 0
        if valid_mask.sum() == 0:
            continue

        y_true = labels[valid_mask, i]
        y_score = scores[valid_mask, i]

        # Need at least one sample of each class for AUC
        if len(np.unique(y_true)) < 2:
            results[f"auc_{name}"] = float("nan")
            continue

        try:
            auc = roc_auc_score(y_true, y_score)
            per_label_auc.append(auc)
            results[f"auc_{name}"] = auc
        except ValueError:
            results[f"auc_{name}"] = float("nan")

    # Macro-averaged AUC (only over labels with valid AUC)
    if per_label_auc:
        results["roc_auc_macro"] = np.mean(per_label_auc)
    else:
        results["roc_auc_macro"] = float("nan")

    return results


def compute_full_metrics(logits, labels, threshold=0.5):
    """
    Compute detailed metrics including sensitivity/specificity per label.
    Used for final test evaluation (not during training).
    """
    scores = sigmoid(logits)
    results = {}

    print("\n" + "=" * 70)
    print(f"{'Label':<20} {'AUC':>8} {'Sens':>8} {'Spec':>8} {'N valid':>8}")
    print("=" * 70)

    all_aucs = []
    for i, name in enumerate(LABEL_NAMES):
        valid_mask = labels[:, i] >= 0
        n_valid = valid_mask.sum()
        y_true = labels[valid_mask, i]
        y_score = scores[valid_mask, i]
        y_pred = (y_score > threshold).astype(int)

        # AUC
        try:
            auc = roc_auc_score(y_true, y_score) if len(np.unique(y_true)) >= 2 else float("nan")
        except ValueError:
            auc = float("nan")

        # Sensitivity and specificity
        tp = ((y_pred == 1) & (y_true == 1)).sum()
        tn = ((y_pred == 0) & (y_true == 0)).sum()
        fn = ((y_pred == 0) & (y_true == 1)).sum()
        fp = ((y_pred == 1) & (y_true == 0)).sum()
        sens = tp / (tp + fn) if (tp + fn) > 0 else float("nan")
        spec = tn / (tn + fp) if (tn + fp) > 0 else float("nan")

        if not np.isnan(auc):
            all_aucs.append(auc)

        results[f"auc_{name}"] = auc
        results[f"sens_{name}"] = sens
        results[f"spec_{name}"] = spec

        print(f"{name:<20} {auc:>8.4f} {sens:>8.4f} {spec:>8.4f} {n_valid:>8d}")

    macro_auc = np.mean(all_aucs) if all_aucs else float("nan")
    results["roc_auc_macro"] = macro_auc
    print("=" * 70)
    print(f"{'Macro AUC':<20} {macro_auc:>8.4f}")
    print("=" * 70)

    return results


print("✓ Metrics functions defined")

## 11. Data Collator

In [None]:
def collate_fn(examples):
    """
    Collate function for Trainer.

    Stacks pixel_values into a (batch, 3, 448, 448) tensor
    and labels into a (batch, 6) float tensor (with -1.0 for MISSING).
    """
    pixel_values = torch.stack([torch.tensor(ex["pixel_values"]) for ex in examples])
    labels = torch.tensor([ex["label"] for ex in examples], dtype=torch.float)
    return {"pixel_values": pixel_values, "labels": labels}


# ── Quick verification ────────────────────────────────────────────────────────
_test_batch = collate_fn([train_ds[0], train_ds[1]])
print(f"Collated batch shapes:")
print(f"  pixel_values: {_test_batch['pixel_values'].shape}")   # (2, 3, 448, 448)
print(f"  labels:       {_test_batch['labels'].shape}")          # (2, 6)
print(f"  labels dtype: {_test_batch['labels'].dtype}")          # float32
print(f"  labels[0]:    {_test_batch['labels'][0].tolist()}")

## 12. Trainer Setup

Uses a **subclassed Trainer** to inject the masked BCE loss and **differential learning rates** via `create_optimizer`.
This approach works across all `transformers` versions (safer than `compute_loss_func`
parameter which was introduced in ~4.46.0 and may be unavailable on some Kaggle kernels).

### Training Configuration Summary
| Parameter | Value | Rationale |
|---|---|---|
| Effective batch size | 16 (4×4) | More optimizer steps per epoch (Run 1 used 64) |
| Epochs | 10 | 300 total steps; `load_best_model_at_end` guards against overfit |
| Backbone LR | 1.5e-5 | Gentle updates to preserve pretrained SigLIP features |
| Head LR | 8e-5 | Fast learning for randomly initialized classifier head |
| Warmup steps | 15 | 5% of 300 total steps (Run 1: 25% — severely under-warmed) |
| Scheduler | Cosine | Scales both param groups proportionally |
| fp16 | True | T4 VRAM optimization |
| Model selection | eval_loss | Val set too small (69) for reliable AUC |

In [None]:
from transformers import Trainer, TrainingArguments


class WoundClassificationTrainer(Trainer):
    """
    Custom Trainer with two enhancements over a plain HF Trainer:
      1. Masked BCE loss for MISSING label handling (compute_loss).
      2. Differential learning rates (create_optimizer):
           - backbone encoder blocks (unfrozen): BACKBONE_LR = 1.5e-5
           - classifier head (random-init):      HEAD_LR     = 8e-5
    """

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        loss = masked_bce_loss(outputs, labels)
        return (loss, outputs) if return_outputs else loss

    def create_optimizer(self):
        """
        Build AdamW with 3 param groups for differential learning rates.

        Groups:
          1. backbone trainable params that receive weight decay  → BACKBONE_LR + WEIGHT_DECAY
          2. backbone trainable params exempt from weight decay   → BACKBONE_LR + 0.0
          3. classifier head params                               → HEAD_LR + WEIGHT_DECAY

        The HF Trainer's cosine scheduler uses LambdaLR, which multiplies each
        group's stored base_lr by the same schedule factor → warmup and cosine
        decay apply proportionally to both backbone and head groups.
        """
        no_decay  = {"bias", "layer_norm.weight", "LayerNorm.weight"}
        head_params = list(self.model.classifier.parameters())
        head_ids    = {id(p) for p in head_params}

        backbone_decay, backbone_nodecay = [], []
        for name, param in self.model.named_parameters():
            if not param.requires_grad or id(param) in head_ids:
                continue
            if any(nd in name for nd in no_decay):
                backbone_nodecay.append(param)
            else:
                backbone_decay.append(param)

        param_groups = [
            {"params": backbone_decay,   "lr": BACKBONE_LR, "weight_decay": WEIGHT_DECAY},
            {"params": backbone_nodecay, "lr": BACKBONE_LR, "weight_decay": 0.0},
            {"params": head_params,      "lr": HEAD_LR,     "weight_decay": WEIGHT_DECAY},
        ]

        self.optimizer = torch.optim.AdamW(param_groups)

        print(f"  AdamW param groups:")
        print(f"    backbone (w/ decay):  {len(backbone_decay):>4} tensors @ lr={BACKBONE_LR:.1e}")
        print(f"    backbone (no decay):  {len(backbone_nodecay):>4} tensors @ lr={BACKBONE_LR:.1e}")
        print(f"    classifier head:      {len(head_params):>4} tensors @ lr={HEAD_LR:.1e}")
        return self.optimizer


# ── Training arguments ────────────────────────────────────────────────────────
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE * 2,  # Larger batch for eval (no grad)
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR,                           # = HEAD_LR; scheduler scales both groups proportionally
    weight_decay=WEIGHT_DECAY,
    warmup_steps=WARMUP_STEPS,
    lr_scheduler_type=SCHEDULER,
    fp16=FP16,
    logging_steps=1,                    # 30 steps/epoch — log every step
    save_strategy="epoch",
    eval_strategy="epoch",              # renamed from evaluation_strategy in transformers >=4.46
    metric_for_best_model="eval_loss",  # Val set too small for reliable AUC
    greater_is_better=False,            # Lower loss = better
    load_best_model_at_end=True,
    report_to="tensorboard",
    push_to_hub=False,                  # Push manually after evaluation
    remove_unused_columns=False,        # Keep our custom columns
    dataloader_num_workers=2,
    save_total_limit=3,                 # Keep only 3 best checkpoints
)

# ── Create Trainer ───────────────────────────────────────────────────────────
trainer = WoundClassificationTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)

forward_passes  = (len(train_ds) + BATCH_SIZE - 1) // BATCH_SIZE
optimizer_steps = (forward_passes + GRAD_ACCUM - 1) // GRAD_ACCUM
print(f"✓ Trainer initialized")
print(f"  Effective batch size:    {BATCH_SIZE * GRAD_ACCUM}")
print(f"  Forward passes / epoch:  {forward_passes}")
print(f"  Optimizer steps / epoch: {optimizer_steps}")
print(f"  Total optimizer steps:   {EPOCHS * optimizer_steps}")
print(f"  Model selection: best eval_loss (lower is better)")

## 13. Train

Expected training time: **25–35 minutes on T4**.
Increased from Run 1 (12.8 min) due to: (a) 10 epochs vs 5, (b) 8 unfrozen blocks vs 4 (heavier backward pass per step), (c) 30 optimizer steps/epoch vs 8.

In [None]:
import time

print("Starting training...\n")
start_time = time.time()

train_result = trainer.train()

elapsed = time.time() - start_time
print(f"\n{'='*60}")
print(f"Training complete in {elapsed/60:.1f} minutes")
print(f"{'='*60}")

# ── Check for NaN (fp16 safety) ──────────────────────────────────────────────
final_loss = train_result.training_loss
if np.isnan(final_loss):
    print("\n⚠ WARNING: Training loss is NaN!")
    print("  This likely means fp16 caused numerical instability.")
    print("  Try re-running with FP16 = False in the config cell.")
else:
    print(f"Final training loss: {final_loss:.4f}")
    print(f"Best model loaded from: {trainer.state.best_model_checkpoint}")

## 14. Evaluate on Test Set

Run inference on the 137 test images and compute per-label metrics:
AUC, sensitivity, and specificity.

In [None]:
# ── Run prediction on test set ───────────────────────────────────────────────
print("Running inference on test set (137 images)...")
predictions = trainer.predict(test_ds)

test_logits = predictions.predictions  # (137, 6)
test_labels = predictions.label_ids     # (137, 6)

print(f"Logits shape: {test_logits.shape}")
print(f"Labels shape: {test_labels.shape}")

# ── Compute detailed metrics ─────────────────────────────────────────────────
test_metrics = compute_full_metrics(test_logits, test_labels, threshold=0.5)

# ── Summary ──────────────────────────────────────────────────────────────────
print(f"\nTest set macro-averaged ROC AUC: {test_metrics['roc_auc_macro']:.4f}")

## 15. Per-Label Threshold Tuning

Run 1 showed clear threshold miscalibration at the fixed 0.5 cutoff (all labels use the same decoder boundary):
- `healing_status`: sens=0.84 / spec=0.26 → threshold too low (model over-predicts "not healed")
- `infection_risk`: sens=0.35 / spec=0.68 → threshold too high (25 of 40 infection cases missed)

**Youden's J statistic** (`J = sensitivity + specificity − 1`) finds the operating point that maximises true detection rate minus false alarm rate. It is equivalent to maximising the vertical distance from the ROC diagonal, and is computed independently per label.

Thresholds are tuned on the **validation set** (69 images) and then applied to the test set and inference demo — no data leakage. Note that Macro AUC is threshold-independent (area under the full ROC curve), so it stays identical regardless of threshold choice.

In [None]:
from sklearn.metrics import roc_auc_score

# ── Get validation set predictions for threshold tuning ──────────────────────
print("Running inference on validation set for threshold tuning (69 images)...")
val_preds        = trainer.predict(val_ds)
val_logits_t     = val_preds.predictions    # (69, 6)
val_labels_t     = val_preds.label_ids      # (69, 6)
val_scores_t     = sigmoid(val_logits_t)

# ── Sweep thresholds per label, maximise Youden's J = sens + spec - 1 ────────
# Tuned on val set only — applied to test (no leakage).
PRED_THRESHOLDS = {}
threshold_grid  = np.arange(0.10, 0.91, 0.01)

print(f"\n{'Label':<20} {'J@0.50':>7} {'Best-J':>7} {'Thresh':>7} {'Δ sens':>8} {'Δ spec':>8}")
print("-" * 62)

for i, name in enumerate(LABEL_NAMES):
    valid_mask = val_labels_t[:, i] >= 0
    y_true     = val_labels_t[valid_mask, i]
    y_score    = val_scores_t[valid_mask, i]

    # Baseline at threshold=0.5
    y_pred05 = (y_score > 0.5).astype(int)
    tp05 = ((y_pred05 == 1) & (y_true == 1)).sum()
    tn05 = ((y_pred05 == 0) & (y_true == 0)).sum()
    fn05 = ((y_pred05 == 0) & (y_true == 1)).sum()
    fp05 = ((y_pred05 == 1) & (y_true == 0)).sum()
    sens05 = tp05 / (tp05 + fn05) if (tp05 + fn05) > 0 else 0.0
    spec05 = tn05 / (tn05 + fp05) if (tn05 + fp05) > 0 else 0.0
    j05    = sens05 + spec05 - 1.0

    # Sweep
    best_j, best_thresh       = j05, 0.5
    best_sens, best_spec      = sens05, spec05
    for thresh in threshold_grid:
        y_pred = (y_score > thresh).astype(int)
        tp = ((y_pred == 1) & (y_true == 1)).sum()
        tn = ((y_pred == 0) & (y_true == 0)).sum()
        fn = ((y_pred == 0) & (y_true == 1)).sum()
        fp = ((y_pred == 1) & (y_true == 0)).sum()
        sens = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        spec = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        j    = sens + spec - 1.0
        if j > best_j:
            best_j, best_thresh = j, thresh
            best_sens, best_spec = sens, spec

    PRED_THRESHOLDS[name] = float(round(best_thresh, 2))
    delta_sens = best_sens - sens05
    delta_spec = best_spec - spec05
    print(f"{name:<20} {j05:>7.3f} {best_j:>7.3f} {best_thresh:>7.2f} {delta_sens:>+8.3f} {delta_spec:>+8.3f}")

print(f"\nTuned thresholds: {PRED_THRESHOLDS}")

# ── Re-evaluate test set with tuned per-label thresholds ─────────────────────
print("\n" + "=" * 70)
print("TEST SET — TUNED PER-LABEL THRESHOLDS (for comparison with fixed-0.5 above)")
print("=" * 70)

test_scores_tuned = sigmoid(test_logits)   # test_logits from previous cell
tuned_aucs        = []

print(f"\n{'Label':<20} {'AUC':>8} {'Thresh':>7} {'Sens':>8} {'Spec':>8} {'N valid':>8}")
print("=" * 70)

for i, name in enumerate(LABEL_NAMES):
    thresh     = PRED_THRESHOLDS[name]
    valid_mask = test_labels[:, i] >= 0
    n_valid    = valid_mask.sum()
    y_true     = test_labels[valid_mask, i]
    y_score    = test_scores_tuned[valid_mask, i]
    y_pred     = (y_score > thresh).astype(int)

    try:
        auc = roc_auc_score(y_true, y_score) if len(np.unique(y_true)) >= 2 else float("nan")
    except ValueError:
        auc = float("nan")

    tp = ((y_pred == 1) & (y_true == 1)).sum()
    tn = ((y_pred == 0) & (y_true == 0)).sum()
    fn = ((y_pred == 0) & (y_true == 1)).sum()
    fp = ((y_pred == 1) & (y_true == 0)).sum()
    sens = tp / (tp + fn) if (tp + fn) > 0 else float("nan")
    spec = tn / (tn + fp) if (tn + fp) > 0 else float("nan")

    if not np.isnan(auc):
        tuned_aucs.append(auc)

    print(f"{name:<20} {auc:>8.4f} {thresh:>7.2f} {sens:>8.4f} {spec:>8.4f} {n_valid:>8d}")

macro_auc = np.mean(tuned_aucs) if tuned_aucs else float("nan")
print("=" * 70)
print(f"{'Macro AUC':<20} {macro_auc:>8.4f}  (AUC is threshold-independent)")
print("=" * 70)
print("\nNote: Macro AUC identical to the fixed-threshold run above — thresholds only shift the sens/spec balance.")

## 16. Save Model

Save the fine-tuned model and image processor locally.
Optionally push to Hugging Face Hub.

In [None]:
# ── Save locally ─────────────────────────────────────────────────────────────
trainer.save_model(OUTPUT_DIR)
image_processor.save_pretrained(OUTPUT_DIR)
print(f"✓ Model and image processor saved to {OUTPUT_DIR}/")

# ── Optionally push to Hugging Face Hub ──────────────────────────────────────
# Uncomment the following lines to push to your HF account:
#
# HUB_MODEL_ID = "<your-username>/medsiglip-448-mamaguard-wound"
# trainer.push_to_hub(HUB_MODEL_ID)
# image_processor.push_to_hub(HUB_MODEL_ID)
# print(f"✓ Pushed to https://huggingface.co/{HUB_MODEL_ID}")

## 17. Inference Demo

Load the saved model and run inference on 3 test images
to show what the downstream Gemini orchestrator would receive.
Per-label tuned thresholds from Section 15 are used by default.

In [None]:
# ── Load the fine-tuned model ────────────────────────────────────────────────
ft_model = AutoModelForImageClassification.from_pretrained(
    OUTPUT_DIR,
    problem_type="multi_label_classification",
    num_labels=NUM_LABELS,
    id2label=id2label,
    label2id=label2id,
    device_map="auto",
)
ft_model.eval()

# ── Human-readable label mapping ─────────────────────────────────────────────
LABEL_DISPLAY = {
    "healing_status": ("progressing", "not progressing"),
    "erythema":       ("absent", "present (redness)"),
    "edema":          ("absent", "present (swelling)"),
    "infection_risk": ("low", "elevated"),
    "urgency":        ("home care OK", "needs professional attention"),
    "exudate":        ("absent", "present (drainage)"),
}

# Fall back to 0.5 per label if threshold tuning cell was not run
_default_thresholds = {name: 0.5 for name in LABEL_NAMES}
_tuned = "PRED_THRESHOLDS" in dir() or "PRED_THRESHOLDS" in globals()
_active_thresholds = PRED_THRESHOLDS if _tuned else _default_thresholds  # type: ignore[name-defined]
print(f"Using {'tuned per-label' if _tuned else 'default (0.5)'} thresholds: {_active_thresholds}")


def predict_wound(image, model, thresholds: dict | None = None):
    """
    Run wound assessment on a single image using per-label thresholds.

    Args:
        image:      PIL image (any size/mode — will be preprocessed internally).
        model:      Fine-tuned SiglipForImageClassification.
        thresholds: Dict mapping label name → decision threshold.
                    Defaults to per-label tuned thresholds (or 0.5 if not available).

    Returns:
        Dict: label name → {"score": float, "prediction": str}
    """
    if thresholds is None:
        thresholds = _active_thresholds

    # Apply same preprocessing as eval (pure PIL pipeline, no augmentation)
    pixel_values = _process_image(image, augment=False)
    pixel_values = pixel_values.unsqueeze(0)  # Add batch dimension

    model_device = next(model.parameters()).device
    pixel_values = pixel_values.to(model_device)

    with torch.no_grad():
        outputs = model(pixel_values=pixel_values)

    probs = torch.sigmoid(outputs.logits[0]).cpu().numpy()

    results = {}
    for i, name in enumerate(LABEL_NAMES):
        neg_label, pos_label = LABEL_DISPLAY[name]
        is_positive = probs[i] > thresholds.get(name, 0.5)
        results[name] = {
            "score":      float(probs[i]),
            "threshold":  thresholds.get(name, 0.5),
            "prediction": pos_label if is_positive else neg_label,
        }

    return results


# ── Demo on 3 test images ────────────────────────────────────────────────────
test_df     = df[df["split"] == "test"]
demo_indices = [0, 50, 100]

for idx in demo_indices:
    if idx >= len(test_df):
        continue

    row      = test_df.iloc[idx]
    img_path = os.path.join(BASE_PATH, row["image_path"])
    img      = PILImage.open(img_path).convert("RGB")

    results  = predict_wound(img, ft_model)

    print(f"\n{chr(9472)*65}")
    print(f"Image: {row['image_path']} (img_id={row['img_id']})")
    print(f"{chr(9472)*65}")
    print(f"\n{'Label':<20} {'Score bar':^22} {'Score':>6}  {'Thresh':>6}   Prediction")
    for name, info in results.items():
        confidence = info["score"]
        thresh     = info["threshold"]
        prediction = info["prediction"]
        filled  = int(confidence * 20)
        bar     = chr(9608) * filled + chr(9617) * (20 - filled)
        marker  = " ◀" if confidence > thresh else ""
        print(f"  {name:<18} {bar} {confidence:.2f}   {thresh:.2f}   {prediction}{marker}")

    # Show the image
    fig, ax = plt.subplots(1, 1, figsize=(3, 3))
    ax.imshow(img)
    ax.set_title(f"img_id={row['img_id']}")
    ax.axis("off")
    plt.tight_layout()
    plt.show()

print("\n✓ Inference demo complete")
print("  These structured scores feed into the Gemini/MedGemma orchestrator")
print("  to generate empathetic, contextual responses for MamaGuard users.")

## Appendix: Dynamic pos_weight Verification

Recompute `pos_weight` from the actual training data to verify the hardcoded constants.
Run this cell to double-check if you've modified the dataset or label encoding.

In [None]:
# ── Dynamically compute pos_weight from training labels ──────────────────────
train_labels = torch.tensor(train_ds["label"])  # (480, 6)

print(f"Training label tensor shape: {train_labels.shape}")
print(f"\n{'Label':<20s} {'Neg':>6} {'Pos':>6} {'Miss':>6} {'pos_weight':>10} {'Hardcoded':>10} {'Match':>6}")
print("-" * 70)

computed_pw = []
for i, name in enumerate(LABEL_NAMES):
    col = train_labels[:, i]
    n_missing = (col == -1).sum().item()
    n_pos = (col == 1).sum().item()
    n_neg = (col == 0).sum().item()
    pw = n_neg / n_pos if n_pos > 0 else float("inf")
    computed_pw.append(pw)
    hardcoded = POS_WEIGHT[i].item()
    match = abs(pw - hardcoded) < 0.05
    print(f"{name:<20s} {n_neg:>6d} {n_pos:>6d} {n_missing:>6d} {pw:>10.2f} {hardcoded:>10.2f} {'✓' if match else '✗':>6}")

print(f"\nHardcoded POS_WEIGHT: {POS_WEIGHT.tolist()}")
print(f"Computed POS_WEIGHT:  {[round(x, 2) for x in computed_pw]}")