1 – Imports

In [2]:
import os
import pandas as pd
import numpy as np

from PIL import Image

from collections import Counter

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    roc_auc_score,
    roc_curve
)

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from torch.nn.functional import sigmoid

import matplotlib.pyplot as plt


 2 –  paths

In [3]:
# Adjust these paths according to your repo structure
DATA_DIR = "../data/HAM10000_images"
META_CSV = "../data/HAM10000_metadata.csv" 

assert os.path.isdir(DATA_DIR), "DATA_DIR not found"
assert os.path.isfile(META_CSV), "META_CSV not found"


In [4]:
df = pd.read_csv(META_CSV)
print("Original rows:", len(df))
print(df.head())


Original rows: 10015
     lesion_id      image_id   dx dx_type   age   sex localization
0  HAM_0000118  ISIC_0027419  bkl   histo  80.0  male        scalp
1  HAM_0000118  ISIC_0025030  bkl   histo  80.0  male        scalp
2  HAM_0002730  ISIC_0026769  bkl   histo  80.0  male        scalp
3  HAM_0002730  ISIC_0025661  bkl   histo  80.0  male        scalp
4  HAM_0001466  ISIC_0031633  bkl   histo  75.0  male          ear


In [5]:
#Remove data with no pic

def image_exists(image_id):
    jpg_name = image_id + ".jpg"
    return os.path.isfile(os.path.join(DATA_DIR, jpg_name))

df["image_exists"] = df["image_id"].apply(image_exists)
df = df[df["image_exists"] == True].copy()
df = df.drop(columns=["image_exists"])

print("After removing missing images:", len(df))

After removing missing images: 0


In [6]:
# remove duplicate image_id

before = len(df)
df = df.drop_duplicates(subset=["image_id"])
after = len(df)

print(f"Removed {before - after} duplicate rows. Final rows: {after}")


Removed 0 duplicate rows. Final rows: 0


4 – Map High / Low risk

In [7]:
risk_map = {
    "mel": 1,   # melanoma - high risk
    "bcc": 1,   # basal cell carcinoma - high risk
    "akiec": 1, # actinic keratoses - high risk
    "nv": 0,    # nevus - low risk
    "bkl": 0,   # benign keratosis - low risk
    "df": 0,    # dermatofibroma - low risk
    "vasc": 0   # vascular lesions - low risk
}

df["risk_label"] = df["dx"].map(risk_map)

print(df["risk_label"].value_counts())
print(df["risk_label"].value_counts(normalize=True))


Series([], Name: count, dtype: int64)
Series([], Name: proportion, dtype: float64)


 5 – Train / Val / Test split (stratified)

In [9]:
train_df, temp_df = train_test_split(
    df,
    test_size=0.30,
    stratify=df["risk_label"],
    random_state=42
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.50,
    stratify=temp_df["risk_label"],
    random_state=42
)

print("Train size:", len(train_df))
print("Val size:", len(val_df))
print("Test size:", len(test_df))

print("Train dist:\n", train_df["risk_label"].value_counts(normalize=True))
print("Val dist:\n",   val_df["risk_label"].value_counts(normalize=True))
print("Test dist:\n",  test_df["risk_label"].value_counts(normalize=True))


ValueError: With n_samples=0, test_size=0.3 and train_size=None, the resulting train set will be empty. Adjust any of the aforementioned parameters.

6 – Transforms & Dataset / DataLoader

In [None]:
IMG_SIZE = 160
BATCH_SIZE = 64

train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet
        std=[0.229, 0.224, 0.225]
    ),
])

eval_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])


In [None]:
class HamDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_id = row["image_id"]
        label = row["risk_label"]

        img_path = os.path.join(self.img_dir, image_id + ".jpg")
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        label = torch.tensor(label, dtype=torch.float32)
        return image, label


In [None]:
train_dataset = HamDataset(train_df, DATA_DIR, transform=train_transform)
val_dataset   = HamDataset(val_df, DATA_DIR, transform=eval_transform)
test_dataset  = HamDataset(test_df, DATA_DIR, transform=eval_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


7 – Model (ResNet50) + Loss + Optimizer

In [None]:
device = torch.device(
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
device


In [None]:
base_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

in_features = base_model.fc.in_features
base_model.fc = nn.Linear(in_features, 1)

model = base_model.to(device)
print(model.fc)


In [None]:
# class weights for imbalance
train_counts = Counter(train_df["risk_label"])
num_neg = train_counts[0]
num_pos = train_counts[1]
pos_weight_value = num_neg / max(num_pos, 1)

criterion = nn.BCEWithLogitsLoss(
    pos_weight=torch.tensor([pos_weight_value], device=device)
)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
NUM_EPOCHS = 8

print("Class counts:", train_counts)
print("pos_weight:", pos_weight_value)


8 – function: train & eval

In [None]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        
        outputs = model(images).squeeze(1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        
        probs = sigmoid(outputs)
        preds = (probs >= 0.5).float()
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


In [None]:
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images).squeeze(1)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            
            probs = sigmoid(outputs)
            preds = (probs >= 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            
            all_labels.extend(labels.cpu().numpy().tolist())
            all_probs.extend(probs.cpu().numpy().tolist())
    
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc, np.array(all_labels), np.array(all_probs)


9 - Train loop and best model

In [None]:
best_val_loss = float("inf")
best_state_dict = None

train_history = []
val_history = []

for epoch in range(NUM_EPOCHS):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, _, _ = evaluate(model, val_loader, criterion, device)
    
    train_history.append((train_loss, train_acc))
    val_history.append((val_loss, val_acc))
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} "
          f"- Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} "
          f"- Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_state_dict = model.state_dict().copy()

model.load_state_dict(best_state_dict)
torch.save(model.state_dict(), "../best_resnet50_triage.pth")


10 – Charts: Train/Val

In [None]:
train_losses = [x[0] for x in train_history]
val_losses   = [x[0] for x in val_history]
train_accs   = [x[1] for x in train_history]
val_accs     = [x[1] for x in val_history]

plt.figure()
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Training and Validation Loss")
plt.show()

plt.figure()
plt.plot(train_accs, label="Train Acc")
plt.plot(val_accs, label="Val Acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Training and Validation Accuracy")
plt.show()


11 – Evaluate on Test Set (Confusion Matrix + ROC)

In [None]:
test_loss, test_acc, y_true, y_prob = evaluate(model, test_loader, criterion, device)
print("Test Loss:", test_loss)
print("Test Accuracy:", test_acc)


In [None]:
y_pred = (y_prob >= 0.5).astype(int)

print(classification_report(y_true, y_pred, target_names=["Low Risk", "High Risk"]))


In [None]:
cm = confusion_matrix(y_true, y_pred)
cm


In [None]:
fig, ax = plt.subplots()
im = ax.imshow(cm)
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
ax.set_xticks([0,1])
ax.set_yticks([0,1])
ax.set_xticklabels(["Low", "High"])
ax.set_yticklabels(["Low", "High"])
for i in range(2):
    for j in range(2):
        ax.text(j, i, cm[i, j], ha="center", va="center")
ax.set_title("Confusion Matrix")
plt.show()


In [None]:
roc_auc = roc_auc_score(y_true, y_prob)
fpr, tpr, _ = roc_curve(y_true, y_prob)

print("ROC-AUC:", roc_auc)

plt.figure()
plt.plot(fpr, tpr, label=f"ROC curve (AUC = {roc_auc:.3f})")
plt.plot([0,1], [0,1], linestyle="--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend()
plt.show()
