This notebook trains a binary classifier (benign vs. malignant) on BCN20000 and reports Accuracy, Recall, AUROC.
We also apply validation-time test-time augmentation (TTA) using horizontal flips and average probabilities to stabilize AUROC.
A trained model checkpoint is saved to ../artifacts/checkpoints/ for reuse.

In [12]:
import sys, os
sys.path.append(os.path.abspath(".."))

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.models as models
from torchvision import transforms
import torchvision.transforms.functional as TF

from sklearn.metrics import (
    accuracy_score,
    recall_score,
    roc_auc_score,
    precision_recall_curve,
    roc_curve,
    confusion_matrix,
    average_precision_score,
    brier_score_loss
)

import matplotlib.pyplot as plt
import numpy as np
import csv, json, datetime

from src import load_bcn20000, get_transforms

import json
import csv
import datetime

---
Helper classes and utilities: device selection, dataset wrapper, and dataloaders with the binary mapping (malignant vs benign).

In [13]:
def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

class TorchImageDataset(Dataset):
    def __init__(self, hf_ds, transform, has_labels=True):
        self.ds = hf_ds
        self.tf = transform
        self.has_labels = has_labels
        self.label_feature = hf_ds.features["label"] if has_labels else None
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        ex = self.ds[idx]
        img = ex["image"]
        if not isinstance(img, Image.Image):
            img = Image.open(img).convert("RGB")
        x = self.tf(img)
        if self.has_labels:
            y = ex["label"]
            if isinstance(y, str):
                y = self.label_feature.str2int(y)
            else:
                y = int(y)
            return x, y
        return x

def get_binary_mapping():
    malignant = {"MEL","SCC","BCC"}
    all_labels = ["MEL","SCC","NV","BCC","BKL","AK","DF","VASC"]
    return {lbl: ("malignant" if lbl in malignant else "benign") for lbl in all_labels}

def make_loaders(batch_size=64):
    label_mapping = get_binary_mapping()
    train_hf = load_bcn20000(split="train", filename_column="bcn_filename", label_column="diagnosis", label_mapping=label_mapping)
    val_hf   = load_bcn20000(split="validation", filename_column="bcn_filename", label_column="diagnosis", label_mapping=label_mapping)
    label_names = train_hf.features["label"].names
    train_tf = get_transforms(train=True)
    eval_tf  = get_transforms(train=False)
    train_ds = TorchImageDataset(train_hf, train_tf, has_labels=True)
    val_ds   = TorchImageDataset(val_hf,   eval_tf,  has_labels=True)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=0)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=0)
    return train_loader, val_loader, label_names

---
Model: ResNet18 with a 2-class head. Training loop runs for a few epochs and returns the model plus its history.

In [14]:
def build_resnet18_binary(num_classes=2, use_pretrained=True):
    try:
        if use_pretrained:
            weights = models.ResNet18_Weights.DEFAULT
            model = models.resnet18(weights=weights)
        else:
            model = models.resnet18(weights=None)
    except Exception:
        model = models.resnet18(weights=None)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model

def evaluate_basic(model, data_loader, device, criterion):
    model.eval()
    total, correct, running = 0, 0, 0.0
    with torch.no_grad():
        for xb, yb in data_loader:
            xb, yb = xb.to(device), yb.to(device)
            out = model(xb)
            loss = criterion(out, yb)
            running += loss.item()
            pred = out.argmax(dim=1)
            correct += (pred == yb).sum().item()
            total += yb.size(0)
    return running / max(1, len(data_loader)), correct / max(1, total)

def train_model(model, train_loader, val_loader, device, epochs=15, lr=1e-3):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    history = {"loss": [], "val_loss": [], "val_acc": []}
    from tqdm.auto import tqdm
    for epoch in range(1, epochs + 1):
        model.train()
        running = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=True)
        for xb, yb in pbar:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            out = model(xb)
            loss = criterion(out, yb)
            loss.backward()
            optimizer.step()
            running += loss.item()
            pbar.set_postfix(batch_loss=f"{loss.item():.4f}")
        tr_loss = running / max(1, len(train_loader))
        vl_loss, vl_acc = evaluate_basic(model, val_loader, device, criterion)
        history["loss"].append(tr_loss)
        history["val_loss"].append(vl_loss)
        history["val_acc"].append(vl_acc)
    return model, history

---
Validation-time TTA: average probabilities from original and horizontally flipped images to stabilize AUROC and recall.

In [15]:
@torch.no_grad()
def predict_with_tta(model, xb, device):
    model.eval()
    xb = xb.to(device)
    out = model(xb)
    prob = torch.softmax(out, dim=1)
    xb_flip = TF.hflip(xb)
    out_flip = model(xb_flip)
    prob_flip = torch.softmax(out_flip, dim=1)
    prob_avg = (prob + prob_flip) / 2.0
    return prob_avg

@torch.no_grad()
def collect_outputs(model, loader, device, use_tta=True):
    ys, preds, probs = [], [], []
    for xb, yb in loader:
        if use_tta:
            pr = predict_with_tta(model, xb, device).cpu()
        else:
            out = model(xb.to(device))
            pr = torch.softmax(out, dim=1).cpu()
        pb = pr[:, 1]
        preds.extend(pr.argmax(dim=1).tolist())
        probs.extend(pb.tolist())
        ys.extend(yb.tolist())
    return ys, preds, probs

def compute_core_metrics(y_true, y_pred, y_prob, pos_label_idx):
    acc = accuracy_score(y_true, y_pred)
    rec = recall_score(y_true, y_pred, pos_label=pos_label_idx, zero_division=0)
    try:
        auc = roc_auc_score(y_true, y_prob)
    except Exception:
        auc = float("nan")
    return {"accuracy": float(acc), "recall": float(rec), "auroc": float(auc)}

---
Plotting utilities

In [16]:
def ensure_dir(path):
    os.makedirs(path, exist_ok=True)
    return path

def plot_loss_acc(history, outdir="../artifacts/plots", tag="binary_resnet18"):
    ensure_dir(outdir)
    fig = plt.figure(figsize=(7,4))
    plt.plot(history["loss"], label="train_loss")
    plt.plot(history["val_loss"], label="val_loss")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.legend()
    plt.title(f"Loss: {tag}")
    p1 = os.path.join(outdir, f"{tag}_loss.png")
    fig.savefig(p1, bbox_inches="tight")
    plt.close(fig)
    fig = plt.figure(figsize=(7,4))
    plt.plot(history["val_acc"], label="val_acc")
    plt.xlabel("epoch")
    plt.ylabel("accuracy")
    plt.legend()
    plt.title(f"Val Acc: {tag}")
    p2 = os.path.join(outdir, f"{tag}_val_acc.png")
    fig.savefig(p2, bbox_inches="tight")
    plt.close(fig)
    return p1, p2

def plot_roc_curve(y_true, y_prob, outdir="../artifacts/plots", tag="binary_resnet18"):
    ensure_dir(outdir)
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    auc = roc_auc_score(y_true, y_prob)
    fig = plt.figure(figsize=(6,6))
    plt.plot(fpr, tpr, label=f"AUROC={auc:.3f}")
    plt.plot([0,1],[0,1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve")
    plt.legend()
    path = os.path.join(outdir, f"{tag}_roc.png")
    fig.savefig(path, bbox_inches="tight")
    plt.close(fig)
    return path

def plot_pr_curve(y_true, y_prob, outdir="../artifacts/plots", tag="binary_resnet18"):
    ensure_dir(outdir)
    precision, recall, _ = precision_recall_curve(y_true, y_prob)
    ap = average_precision_score(y_true, y_prob)
    fig = plt.figure(figsize=(6,6))
    plt.plot(recall, precision, label=f"AP={ap:.3f}")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("Precision-Recall Curve")
    plt.legend()
    path = os.path.join(outdir, f"{tag}_pr.png")
    fig.savefig(path, bbox_inches="tight")
    plt.close(fig)
    return path

def plot_confusion(y_true, y_pred, label_names, outdir="../artifacts/plots", tag="binary_resnet18"):
    ensure_dir(outdir)
    cm = confusion_matrix(y_true, y_pred, labels=[0,1])
    fig = plt.figure(figsize=(6,5))
    plt.imshow(cm, cmap="Blues")
    plt.xticks([0,1], label_names)
    plt.yticks([0,1], label_names)
    for i in range(2):
        for j in range(2):
            plt.text(j, i, str(cm[i,j]), ha="center", va="center")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    path = os.path.join(outdir, f"{tag}_confusion.png")
    fig.savefig(path, bbox_inches="tight")
    plt.close(fig)
    return path

def plot_calibration(y_true, y_prob, outdir="../artifacts/plots", tag="binary_resnet18", bins=10):
    ensure_dir(outdir)
    y_true = np.array(y_true)
    y_prob = np.array(y_prob)
    edges = np.linspace(0,1,bins+1)
    idx = np.digitize(y_prob, edges) - 1
    acc = []
    conf = []
    for b in range(bins):
        m = idx == b
        if m.sum() == 0:
            continue
        acc.append(y_true[m].mean())
        conf.append(y_prob[m].mean())
    fig = plt.figure(figsize=(6,6))
    plt.plot([0,1],[0,1], linestyle="--")
    plt.scatter(conf, acc)
    plt.xlabel("Predicted probability")
    plt.ylabel("Empirical accuracy")
    plt.title("Calibration Plot")
    path = os.path.join(outdir, f"{tag}_calibration.png")
    fig.savefig(path, bbox_inches="tight")
    plt.close(fig)
    return path

def plot_threshold_sweep(y_true, y_prob, label_names, outdir="../artifacts/plots", tag="binary_resnet18"):
    ensure_dir(outdir)
    y_true = np.array(y_true)
    y_prob = np.array(y_prob)
    ts = np.linspace(0.0, 1.0, 101)
    recalls = []
    accuracies = []
    for t in ts:
        y_pred = (y_prob >= t).astype(int)
        recalls.append(recall_score(y_true, y_pred, pos_label=1, zero_division=0))
        accuracies.append(accuracy_score(y_true, y_pred))
    fig = plt.figure(figsize=(7,4))
    plt.plot(ts, recalls, label=f"Recall({label_names[1]})")
    plt.plot(ts, accuracies, label="Accuracy")
    plt.xlabel("Threshold")
    plt.ylabel("Score")
    plt.title("Threshold Sweep")
    plt.legend()
    path = os.path.join(outdir, f"{tag}_threshold_sweep.png")
    fig.savefig(path, bbox_inches="tight")
    plt.close(fig)
    return path

---
Save a reusable checkpoint and a JSON of the metrics.

In [17]:
def save_checkpoint(model, path="../artifacts/checkpoints/binary_resnet18.pt"):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({"state_dict": model.state_dict()}, path)
    return path

def save_metrics_json(metrics, path="../artifacts/binary_metrics.json"):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w") as f:
        json.dump(metrics, f, indent=2)
    return path

def append_metrics_log(metrics, log_path="../artifacts/metrics_log.csv", run_name="binary_resnet18"):
    os.makedirs(os.path.dirname(log_path), exist_ok=True)
    file_exists = os.path.isfile(log_path)
    with open(log_path, "a", newline="") as f:
        writer = csv.writer(f)
        if not file_exists:
            writer.writerow(["timestamp", "run_name", "accuracy", "recall", "auroc", "epochs", "batch_size", "lr"])
        writer.writerow([
            datetime.datetime.now().isoformat(timespec="seconds"),
            run_name,
            metrics.get("accuracy", ""),
            metrics.get("recall", ""),
            metrics.get("auroc", ""),
            metrics.get("epochs", ""),
            metrics.get("batch_size", ""),
            metrics.get("lr", "")
        ])
    return log_path

def save_history_csv(history, path="../artifacts/history_binary_resnet18.csv"):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["epoch", "train_loss", "val_loss", "val_acc"])
        for i, (tr, vl, va) in enumerate(zip(history["loss"], history["val_loss"], history["val_acc"]), start=1):
            writer.writerow([i, tr, vl, va])
    return path

def save_predictions_csv(y_true, y_pred, y_prob, label_names, path="../artifacts/val_predictions_binary.csv"):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["y_true", "y_pred", "p_malignant"])
        for t, p, s in zip(y_true, y_pred, y_prob):
            writer.writerow([label_names[t], label_names[p], s])
    return path

---
Train, setup and metrics

In [18]:
def run_all(epochs=15, batch_size=64, lr=1e-3, run_name="binary_resnet18"):
    device = get_device()
    train_loader, val_loader, label_names = make_loaders(batch_size=batch_size)
    model = build_resnet18_binary(num_classes=2, use_pretrained=True).to(device)
    model, history = train_model(model, train_loader, val_loader, device, epochs=epochs, lr=lr)
    y_true, y_pred, y_prob = collect_outputs(model, val_loader, device, use_tta=True)
    metrics = compute_core_metrics(y_true, y_pred, y_prob, pos_label_idx=label_names.index("malignant"))
    metrics["epochs"] = epochs
    metrics["batch_size"] = batch_size
    metrics["lr"] = lr
    ckpt = save_checkpoint(model, "../artifacts/checkpoints/binary_resnet18.pt")
    mjson = save_metrics_json(metrics, "../artifacts/binary_metrics.json")
    mlog = append_metrics_log(metrics, "../artifacts/metrics_log.csv", run_name=run_name)
    hcsv = save_history_csv(history, "../artifacts/history_binary_resnet18.csv")
    pcsv = save_predictions_csv(y_true, y_pred, y_prob, label_names, "../artifacts/val_predictions_binary.csv")
    plots_dir = "../artifacts/plots"
    p1, p2 = plot_loss_acc(history, plots_dir, run_name)
    prc = plot_pr_curve(y_true, y_prob, plots_dir, run_name)
    roc = plot_roc_curve(y_true, y_prob, plots_dir, run_name)
    cmx = plot_confusion(y_true, y_pred, label_names, plots_dir, run_name)
    cal = plot_calibration(y_true, y_prob, plots_dir, run_name)
    thr = plot_threshold_sweep(y_true, y_prob, label_names, plots_dir, run_name)
    return {"metrics": metrics, "checkpoint": ckpt, "metrics_json": mjson, "metrics_log": mlog, "history_csv": hcsv, "predictions_csv": pcsv, "plots": [p1,p2,prc,roc,cmx,cal,thr]}

In [19]:
results = run_all(epochs=15, batch_size=64, lr=1e-3, run_name="binary_resnet18_tta15e")
results

Casting the dataset:   0%|          | 0/12413 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/12413 [00:00<?, ? examples/s]

Epoch 1/15:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 2/15:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 3/15:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 4/15:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 5/15:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 6/15:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 7/15:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 8/15:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 9/15:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 10/15:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 11/15:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 12/15:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 13/15:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 14/15:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 15/15:   0%|          | 0/194 [00:00<?, ?it/s]

{'metrics': {'accuracy': 0.8661290322580645,
  'recall': 0.9540229885057471,
  'auroc': 0.9561776729927994,
  'epochs': 15,
  'batch_size': 64,
  'lr': 0.001},
 'checkpoint': '../artifacts/checkpoints/binary_resnet18.pt',
 'metrics_json': '../artifacts/binary_metrics.json',
 'metrics_log': '../artifacts/metrics_log.csv',
 'history_csv': '../artifacts/history_binary_resnet18.csv',
 'predictions_csv': '../artifacts/val_predictions_binary.csv',
 'plots': ['../artifacts/plots/binary_resnet18_tta15e_loss.png',
  '../artifacts/plots/binary_resnet18_tta15e_val_acc.png',
  '../artifacts/plots/binary_resnet18_tta15e_pr.png',
  '../artifacts/plots/binary_resnet18_tta15e_roc.png',
  '../artifacts/plots/binary_resnet18_tta15e_confusion.png',
  '../artifacts/plots/binary_resnet18_tta15e_calibration.png',
  '../artifacts/plots/binary_resnet18_tta15e_threshold_sweep.png']}