### Dataset 
The measles images used for model training are from two sources: 
1. IEEE DataPort (DOI: 10.21227/9r41-4x79) at https://ieee-dataport.org/documents/image-dataset-various-skin-conditions-and-rashes
2. Mpox Skin Lesion Dataset Version 2.0 (MSLD v2.0) on Kaggle: https://www.kaggle.com/datasets/joydippaul/mpox-skin-lesion-dataset-version-20-msld-v20/data

After combining the two datasets, our final collection comprised 2,070 skin images, including 212 depicting measles rashes. The dataset was imbalanced, with ~73.6% White, ~11.2% Black or African American, ~10.6% Hispanic or Latina, ~1% Asian, and the remainder representing other minority groups, including Native Hawaiian. The current implementation did not address this imbalance, which will be remedied in a future version.

In [2]:
import os
import glob
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import timm
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm import tqdm
import numpy as np
import re

In [3]:
# Set random seed for reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
np.random.seed(42)

In [4]:
# Config parameters
img_size = 224
batch_size = 32
epochs = 30
lr = 3e-4
model_name = 'deit_base_patch16_224'
data_root = 'data'
output_dir = 'models_cv'
os.makedirs(output_dir, exist_ok=True)

device = torch.device("mps" if torch.backends.mps.is_available() else 
                      "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [5]:
# Transforms
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [6]:
# Dataset
class ImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [7]:
# ========== LOAD IMAGE PATHS AND LABELS ==========
class_map = {'non_measles': 0, 'measles': 1}
image_label_pairs = []

for cls in class_map:
    paths = glob.glob(os.path.join(data_root, cls, '*'))
    image_label_pairs.extend([(p, class_map[cls]) for p in paths])

# Sort together to avoid misalignment
image_label_pairs.sort(key=lambda x: x[0])
image_paths, labels = zip(*image_label_pairs)
labels = np.array(labels)

print(f"Total images: {len(image_paths)}, Measles: {labels.sum().item()}, Non-measles: {(labels == 0).sum().item()}")

Total images: 2070, Measles: 212, Non-measles: 1858


In [8]:
# With some images coming from the same individuals, we cannot randomly split at the image level. 
# Otherwise, the same person’s images could be in both training and validation sets, causing 
# data leakage and overoptimistic results.
import re
pattern = re.compile(r"^(?:MSL|MKP|CHP|CWP|HFMD|HEALTHY)_\d+_\d+$")
grp_pat = re.compile(r"^(?:MSL|MKP|CHP|CWP|HFMD|HEALTHY)_\d+")
groups = []
for path in image_paths:             # Assign identical group ids to the images of the same individuals
    filename_no_ext = os.path.splitext(os.path.basename(path))[0]  # Remove extension
    if pattern.match(filename_no_ext):
        grp_id = grp_pat.findall(filename_no_ext)[0]
    else:
        grp_id = filename_no_ext
    groups.append(grp_id)
# print(image_paths)
# print(groups)

In [None]:
# ========== CROSS VALIDATION SPLIT ==========
skf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)

fold = 1
#for train_idx, val_idx in skf.split(image_paths, labels, groups):
for train_idx, temp_idx in skf.split(image_paths, labels, groups):
    print(f"\n===== Fold {fold} =====")

    # Further split temp_idx into val_idx and test_idx
    temp_image_paths = [image_paths[i] for i in temp_idx]
    temp_labels = labels[temp_idx]
    temp_groups = [groups[i] for i in temp_idx]

    skf_temp = StratifiedGroupKFold(n_splits=2, shuffle=True, random_state=42)
    val_idx_sub, test_idx_sub = next(skf_temp.split(temp_image_paths, temp_labels, temp_groups))
    val_idx = [temp_idx[i] for i in val_idx_sub]
    test_idx = [temp_idx[i] for i in test_idx_sub]

    train_ds = ImageDataset([image_paths[i] for i in train_idx],
                            [labels[i].item() for i in train_idx],
                            transform=train_transforms)
    
    test_ds = ImageDataset([image_paths[i] for i in test_idx],
                           [labels[i].item() for i in test_idx],
                           transform=val_transforms)

    val_ds = ImageDataset([image_paths[i] for i in val_idx],
                          [labels[i].item() for i in val_idx],
                          transform=val_transforms)
    
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)#, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)#, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)#, num_workers=4, pin_memory=True)

    # Load model
    model = timm.create_model(model_name, pretrained=True, num_classes=2)
    
    # Fine-tune model head
    # model.head = nn.Linear(model.head.in_features, 2)
    
    # Fine-tune all layers
    for param in model.parameters():
        param.requires_grad = True

    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_f1 = 0
    for epoch in range(1, epochs + 1):
        print(f"\nEpoch {epoch}/{epochs}")

        # ===== Train =====
        model.train()
        total_loss = 0
        train_preds, train_labels = [], []
        for imgs, lbls in tqdm(train_loader, desc="Training"):
            imgs, lbls = imgs.to(device), lbls.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, lbls)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            train_preds += outputs.argmax(1).cpu().tolist()
            train_labels += lbls.cpu().tolist()

        train_acc = accuracy_score(train_labels, train_preds)

        # ===== Eval =====
        model.eval()
        val_preds, val_labels = [], []
        with torch.no_grad():
            for imgs, lbls in tqdm(val_loader, desc="Evaluating"):
                imgs, lbls = imgs.to(device), lbls.to(device)
                outputs = model(imgs)
                val_preds += outputs.argmax(1).cpu().tolist()
                val_labels += lbls.cpu().tolist()

        val_acc = accuracy_score(val_labels, val_preds)
        val_prec, val_rec, val_f1, _ = precision_recall_fscore_support(
            val_labels, val_preds, average='weighted')

        print(f"\nTrain Acc: {train_acc:.4f}")
        print(f"Val Acc: {val_acc:.4f}, Precision: {val_prec:.4f}, Recall: {val_rec:.4f}, F1: {val_f1:.4f}")

        scheduler.step()
        
        # Save best model
        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save(model.state_dict(), os.path.join(output_dir, f"fold{fold}_best.pth"))
            best_model_state = model.state_dict()
            print(f"✅ Saved best model for fold {fold} (F1: {best_f1:.4f})")

    # Load best model for test evaluation
    model.load_state_dict(best_model_state)
    model.eval()
    test_preds, test_labels = [], []
    with torch.no_grad():
        for imgs, lbls in tqdm(test_loader, desc="Testing"):
            imgs, lbls = imgs.to(device), lbls.to(device)
            outputs = model(imgs)
            test_preds += outputs.argmax(1).cpu().tolist()
            test_labels += lbls.cpu().tolist()

    test_acc = accuracy_score(test_labels, test_preds)
    test_prec, test_rec, test_f1, _ = precision_recall_fscore_support(
        test_labels, test_preds, average='weighted')

    print(f"\nTest Acc: {test_acc:.4f}, Precision: {test_prec:.4f}, Recall: {test_rec:.4f}, F1: {test_f1:.4f}")
    
    fold += 1