# Experiment: Derm Foundation SCIN Classifier (Implementation Plan Version)

Objective:
- Train a deployable multi-label skin-condition classifier on top of Derm Foundation embeddings.
- Export artifacts required by the Momnetrix/MamaGuard pipeline.

What this notebook does:
- Loads SCIN metadata (labels + image paths).
- Loads precomputed Derm Foundation embeddings (`scin_dataset_precomputed_embeddings.npz`).
- Trains a 10-condition multi-label sklearn classifier.
- Evaluates with hamming loss + per-label ROC-AUC.
- Exports deployment artifacts: `derm_classifier.pkl`, `derm_labels.json`, `derm_config.json`, `sample_prediction.json`.

Deployment context:
- Runtime endpoint receives a photo.
- Derm Foundation (TensorFlow/Keras) creates a 6144-d embedding.
- This classifier maps embedding -> `{condition: probability}`.
- Gemini then applies pregnancy-specific context and escalation guidance.


In [None]:
from __future__ import annotations

import ast
import json
import pickle
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import hamming_loss, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.multioutput import MultiOutputClassifier
from sklearn.preprocessing import MultiLabelBinarizer

SEED = 42

# Top-10 conditions from the implementation plan / audited notebook.
CONDITIONS_TO_PREDICT = [
    "Eczema",
    "Allergic Contact Dermatitis",
    "Insect Bite",
    "Urticaria",
    "Psoriasis",
    "Folliculitis",
    "Irritant Contact Dermatitis",
    "Tinea",
    "Herpes Zoster",
    "Drug Rash",
]

# Paths can be changed for local vs Colab usage.
WORK_DIR = Path(".")
ARTIFACT_DIR = WORK_DIR / "artifacts" / "derm"
CACHE_DIR = WORK_DIR / "data" / "scin_cache"
CACHE_DIR.mkdir(parents=True, exist_ok=True)
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)

print({
    "seed": SEED,
    "artifact_dir": str(ARTIFACT_DIR.resolve()),
    "cache_dir": str(CACHE_DIR.resolve()),
    "num_conditions": len(CONDITIONS_TO_PREDICT),
})


## Data Loading Strategy (Local or Colab)

- This notebook avoids `google.colab.*` imports so it runs locally and in Colab.
- We use:
  - `huggingface_hub` for precomputed embeddings (`google/derm-foundation` repo).
  - public SCIN CSVs (`scin_cases.csv`, `scin_labels.csv`) via direct URL.
- If direct URL fetching fails in your environment, download the 2 CSVs manually and point paths in the next cell.

Expected columns used:
- `case_id`
- `image_1_path`, `image_2_path`, `image_3_path`
- `dermatologist_gradable_for_skin_condition_1`
- `dermatologist_skin_condition_on_label_name`
- `dermatologist_skin_condition_confidence`


In [None]:
from huggingface_hub import hf_hub_download


@dataclass
class ScinSources:
    cases_csv_url: str = "https://storage.googleapis.com/dx-scin-public-data/dataset/scin_cases.csv"
    labels_csv_url: str = "https://storage.googleapis.com/dx-scin-public-data/dataset/scin_labels.csv"
    hf_repo_id: str = "google/derm-foundation"
    hf_embeddings_filename: str = "scin_dataset_precomputed_embeddings.npz"


SOURCES = ScinSources()


def _parse_listlike(value: Any) -> list[Any]:
    if pd.isna(value):
        return []
    if isinstance(value, list):
        return value
    if isinstance(value, str):
        return list(ast.literal_eval(value))
    return []


def load_scin_dataframe(sources: ScinSources) -> pd.DataFrame:
    cases_df = pd.read_csv(sources.cases_csv_url, dtype={"case_id": str})
    labels_df = pd.read_csv(sources.labels_csv_url, dtype={"case_id": str})
    merged = pd.merge(cases_df, labels_df, on="case_id")
    merged["case_id"] = merged["case_id"].astype(str)
    merged = merged.set_index("case_id")
    return merged


def load_embeddings(sources: ScinSources, cache_dir: Path) -> dict[str, np.ndarray]:
    file_path = hf_hub_download(
        repo_id=sources.hf_repo_id,
        filename=sources.hf_embeddings_filename,
        local_dir=str(cache_dir),
    )
    data = np.load(file_path, allow_pickle=True)
    return {k: v for k, v in data.items()}


def prepare_xy(
    scin_df: pd.DataFrame,
    embeddings: dict[str, np.ndarray],
    conditions_to_predict: list[str],
    min_confidence: int = 0,
) -> tuple[np.ndarray, np.ndarray, MultiLabelBinarizer, dict[str, int]]:
    X: list[np.ndarray] = []
    y_labels: list[list[str]] = []

    stats = {
        "rows_total": 0,
        "rows_poor_quality": 0,
        "missing_embedding": 0,
        "label_not_tracked": 0,
        "label_low_confidence": 0,
    }

    for row in scin_df.itertuples():
        stats["rows_total"] += 1

        if row.dermatologist_gradable_for_skin_condition_1 != "DEFAULT_YES_IMAGE_QUALITY_SUFFICIENT":
            stats["rows_poor_quality"] += 1
            continue

        labels = _parse_listlike(row.dermatologist_skin_condition_on_label_name)
        confidences = _parse_listlike(row.dermatologist_skin_condition_confidence)

        row_labels: list[str] = []
        for label, conf in zip(labels, confidences):
            if label not in conditions_to_predict:
                stats["label_not_tracked"] += 1
                continue
            if conf < min_confidence:
                stats["label_low_confidence"] += 1
                continue
            row_labels.append(label)

        image_paths = [row.image_1_path, row.image_2_path, row.image_3_path]
        for image_path in image_paths:
            if pd.isna(image_path):
                continue
            if image_path not in embeddings:
                stats["missing_embedding"] += 1
                continue
            X.append(embeddings[image_path])
            y_labels.append(row_labels)

    mlb = MultiLabelBinarizer(classes=conditions_to_predict)
    y = mlb.fit_transform(y_labels)
    X_np = np.asarray(X)

    return X_np, y, mlb, stats


scin_df = load_scin_dataframe(SOURCES)
embeddings = load_embeddings(SOURCES, CACHE_DIR)
X, y, mlb, prep_stats = prepare_xy(
    scin_df=scin_df,
    embeddings=embeddings,
    conditions_to_predict=CONDITIONS_TO_PREDICT,
    min_confidence=0,
)

print("Dataset prep stats:", prep_stats)
print("X shape:", X.shape, "y shape:", y.shape)
print("Class order:", list(mlb.classes_))

X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=0.2,
    random_state=SEED,
)

classifier = MultiOutputClassifier(
    LogisticRegression(max_iter=300, random_state=SEED)
)
classifier.fit(X_train, y_train)

# Convert sklearn's list-of-arrays into matrix [n_samples, n_labels]
proba_cols = [classifier.estimators_[i].predict_proba(X_test)[:, 1] for i in range(len(mlb.classes_))]
y_pred_proba = np.column_stack(proba_cols)
y_pred_binary = (y_pred_proba >= 0.5).astype(int)

print("Hamming loss:", round(float(hamming_loss(y_test, y_pred_binary)), 4))

auc_by_label: dict[str, float] = {}
for idx, label in enumerate(mlb.classes_):
    if len(np.unique(y_test[:, idx])) < 2:
        auc_by_label[label] = float("nan")
        continue
    auc_by_label[label] = float(roc_auc_score(y_test[:, idx], y_pred_proba[:, idx]))

print("ROC-AUC by label:")
for label in mlb.classes_:
    print(f"  {label}: {auc_by_label[label]:.3f}")


## Export + Inference Contract

This is the deployment-facing output section.

Artifacts produced:
- `derm_classifier.pkl`: trained sklearn classifier head.
- `derm_labels.json`: exact label order expected by classifier output.
- `derm_config.json`: thresholds and metadata.
- `sample_prediction.json`: reference API-shaped output for integration tests.

Runtime contract (for your Modal endpoint):
1. Endpoint receives image bytes.
2. Derm Foundation model converts image -> 6144-d embedding.
3. Load `derm_classifier.pkl` and `derm_labels.json` at startup.
4. Return sorted `{condition: probability}` and summary fields (`top_k`, `max_confidence`, `low_confidence`).


In [None]:
# 1) Save classifier
classifier_path = ARTIFACT_DIR / "derm_classifier.pkl"
with classifier_path.open("wb") as f:
    pickle.dump(classifier, f)

# 2) Save label order (critical for mapping index -> condition)
labels_path = ARTIFACT_DIR / "derm_labels.json"
labels_payload = {"labels": list(mlb.classes_)}
labels_path.write_text(json.dumps(labels_payload, indent=2), encoding="utf-8")

# 3) Save config
config = {
    "model_name": "google/derm-foundation",
    "classifier_type": "sklearn.multioutput.MultiOutputClassifier(LogisticRegression)",
    "embedding_dim": int(X.shape[1]),
    "threshold": 0.5,
    "top_k": 5,
    "seed": SEED,
}
config_path = ARTIFACT_DIR / "derm_config.json"
config_path.write_text(json.dumps(config, indent=2), encoding="utf-8")


def predict_embedding_proba(
    embedding: np.ndarray,
    clf: MultiOutputClassifier,
    labels: list[str],
    threshold: float = 0.5,
    top_k: int = 5,
) -> dict[str, Any]:
    if embedding.ndim == 1:
        embedding = np.expand_dims(embedding, axis=0)

    probs = [clf.estimators_[i].predict_proba(embedding)[:, 1][0] for i in range(len(labels))]
    scores = {label: float(prob) for label, prob in zip(labels, probs)}
    sorted_scores = dict(sorted(scores.items(), key=lambda kv: kv[1], reverse=True))

    top = list(sorted_scores.items())[:top_k]
    max_confidence = top[0][1] if top else 0.0

    return {
        "scores": sorted_scores,
        "top_k": [{"condition": k, "score": round(v, 4)} for k, v in top],
        "max_confidence": round(float(max_confidence), 4),
        "low_confidence": bool(max_confidence < threshold),
        "threshold": threshold,
    }


# 4) Build one sample prediction payload for integration tests
sample_payload = predict_embedding_proba(
    embedding=X_test[0],
    clf=classifier,
    labels=list(mlb.classes_),
    threshold=config["threshold"],
    top_k=config["top_k"],
)

sample_prediction = {
    "model": "derm-foundation-classifier-v1",
    "case": "scin_test_sample_0",
    "prediction": sample_payload,
}
sample_path = ARTIFACT_DIR / "sample_prediction.json"
sample_path.write_text(json.dumps(sample_prediction, indent=2), encoding="utf-8")

print("Saved artifacts:")
print("-", classifier_path)
print("-", labels_path)
print("-", config_path)
print("-", sample_path)
print("\nSample top-k:", sample_prediction["prediction"]["top_k"])


## Next Steps (Handoff)

1. Run this notebook fully and confirm artifacts exist in `artifacts/derm/`.
2. In Modal endpoint code, load `derm_classifier.pkl` and `derm_labels.json` at startup.
3. Ensure Derm Foundation inference returns a `(6144,)` embedding for every image.
4. Feed embedding into classifier and return the `sample_prediction.json` schema.
5. Pass prediction + gestational age + vitals + meds + original photo to Gemini for pregnancy-contextual guidance.
6. Persist final output in Visit Prep Summary.

Notes:
- This classifier is a screening head, not a diagnosis engine.
- Keep label order fixed; changing order without retraining breaks outputs.
