In [1]:
import time
import json
from io import BytesIO
from pathlib import Path

import albumentations as A
import h5py
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import (
    DistributedDataParallelKwargs,
    ProjectConfiguration,
    set_seed,
)
from albumentations.pytorch import ToTensorV2
from PIL import Image
from sklearn.metrics import auc, roc_curve, roc_auc_score
from timm import create_model
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from tqdm import tqdm

from isic_helper import DotDict, get_folds

In [2]:
label_mapping = {
    "2024": {
        "Hidradenoma": "unknown",
        "Lichen planus like keratosis": "BKL",
        "Pigmented benign keratosis": "BKL",
        "Seborrheic keratosis": "BKL",
        "Solar lentigo": "BKL",
        "Nevus": "NV",
        "Angiofibroma": "unknown",
        "Dermatofibroma": "DF",
        "Fibroepithelial polyp": "unknown",
        "Scar": "unknown",
        "Hemangioma": "unknown",
        "Trichilemmal or isthmic-catagen or pilar cyst": "unknown",
        "Lentigo NOS": "BKL",
        "Verruca": "unknown",
        "Solar or actinic keratosis": "AKIEC",
        "Atypical intraepithelial melanocytic proliferation": "unknown",
        "Atypical melanocytic neoplasm": "unknown",
        "Basal cell carcinoma": "BCC",
        "Squamous cell carcinoma in situ": "SCC",
        "Squamous cell carcinoma, Invasive": "SCC",
        "Squamous cell carcinoma, NOS": "SCC",
        "Melanoma in situ": "MEL",
        "Melanoma Invasive": "MEL",
        "Melanoma metastasis": "MEL",
        "Melanoma, NOS": "MEL",
    },
    "2020": {
        "nevus": "NV",
        "melanoma": "MEL",
        "seborrheic keratosis": "BKL",
        "lentigo NOS": "BKL",
        "lichenoid keratosis": "BKL",
        "other": "unknown",
        "solar lentigo": "BKL",
        "scar": "unknown",
        "cafe-au-lait macule": "unknown",
        "atypical melanocytic proliferation": "unknown",
        "pigmented benign keratosis": "BKL",
    },
    "2019": {
        "nevus": "NV",
        "melanoma": "MEL",
        "seborrheic keratosis": "BKL",
        "pigmented benign keratosis": "BKL",
        "dermatofibroma": "DF",
        "squamous cell carcinoma": "SCC",
        "basal cell carcinoma": "BCC",
        "vascular lesion": "VASC",
        "actinic keratosis": "AKIEC",
        "solar lentigo": "BKL",
    },
}


def dev_augment(image_size):
    transform = A.Compose(
        [
            A.Transpose(p=0.5),
            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(
                brightness_limit=0.2, contrast_limit=0.2, p=0.75
            ),
            A.OneOf(
                [
                    A.MotionBlur(blur_limit=(5, 7)),
                    A.MedianBlur(blur_limit=(5, 7)),
                    A.GaussianBlur(blur_limit=(5, 7)),
                    A.GaussNoise(var_limit=(5.0, 30.0)),
                ],
                p=0.7,
            ),
            A.OneOf(
                [
                    A.OpticalDistortion(distort_limit=1.0),
                    A.GridDistortion(num_steps=5, distort_limit=1.0),
                    A.ElasticTransform(alpha=3),
                ],
                p=0.7,
            ),
            A.CLAHE(clip_limit=4.0, p=0.7),
            A.HueSaturationValue(
                hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5
            ),
            A.ShiftScaleRotate(
                shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85
            ),
            A.Resize(image_size, image_size),
            # A.Cutout(max_h_size=int(image_size * 0.375), max_w_size=int(image_size * 0.375), num_holes=1, p=0.7),
            ToTensorV2(),
        ],
        p=1.0,
    )
    return transform


def val_augment(image_size):
    transform = A.Compose([A.Resize(image_size, image_size), ToTensorV2()], p=1.0)
    return transform


class ISICDataset(Dataset):
    def __init__(self, metadata, images, augment, infer=False):
        self.metadata = metadata
        self.images = images
        self.augment = augment
        self.length = len(self.metadata)
        self.infer = infer

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        row = self.metadata.iloc[index]

        image = np.array(Image.open(BytesIO(self.images[row["isic_id"]][()])))
        image = self.augment(image=image)["image"]

        data = image.float().div(255)

        if not self.infer:
            label = torch.tensor(row["label"]).long()
            return data, label

        return data


def get_data(data_dir, data_2020_dir, data_2019_dir, out_dim, debug, seed):
    all_labels = np.unique(
        list(label_mapping["2024"].values())
        + list(label_mapping["2020"].values())
        + list(label_mapping["2019"].values())
    )
    label2idx = {label: idx for idx, label in enumerate(all_labels)}
    malignant_labels = ["BCC", "MEL", "SCC"]
    malignant_idx = [label2idx[label] for label in malignant_labels]

    train_metadata = pd.read_csv(f"{data_dir}/train-metadata.csv", low_memory=False)
    train_images = h5py.File(f"{data_dir}/train-image.hdf5", mode="r")

    folds_df = get_folds()
    train_metadata = train_metadata.merge(
        folds_df, on=["isic_id", "patient_id"], how="inner"
    )
    if out_dim == 2:
        train_metadata["label"] = train_metadata["target"]
        train_metadata.loc[(train_metadata["label"] == 1), "sample_weight"] = 10
        train_metadata.loc[
            train_metadata["lesion_id"].notnull() & (train_metadata["label"] == 0),
            "sample_weight",
        ] = 0.6
        train_metadata.loc[
            train_metadata["lesion_id"].isnull() & (train_metadata["label"] == 0),
            "sample_weight",
        ] = 0.4

    elif out_dim == 9:
        train_metadata["label"] = train_metadata["iddx_3"].fillna("unknown")
        train_metadata["label"] = train_metadata["label"].replace(label_mapping["2024"])
        train_metadata["label"] = train_metadata["label"].map(label2idx)
        train_metadata["strength"] = np.where(
            train_metadata["lesion_id"].notnull(), "strong", "weak"
        )
        train_metadata.loc[
            train_metadata["label"].isin(malignant_idx), "sample_weight"
        ] = 10
        train_metadata.loc[
            ~train_metadata["label"].isin(malignant_idx)
            & (train_metadata["strength"] == "strong"),
            "sample_weight",
        ] = 0.6
        train_metadata.loc[
            ~train_metadata["label"].isin(malignant_idx)
            & (train_metadata["strength"] == "weak"),
            "sample_weight",
        ] = 0.4
    else:
        raise ValueError(f"Invalid out_dim: {out_dim}")

    if debug:
        train_metadata = train_metadata.sample(
            frac=0.05, random_state=seed
        ).reset_index(drop=True)

    if data_2020_dir is not None:
        train_metadata_2020 = pd.read_csv(
            f"{data_2020_dir}/train-metadata.csv", low_memory=False
        )
        train_images_2020 = h5py.File(f"{data_2020_dir}/train-image.hdf5", mode="r")
        train_metadata_2020["label"] = train_metadata_2020["diagnosis"].fillna(
            "unknown"
        )
        train_metadata_2020["label"] = train_metadata_2020["label"].replace(
            label_mapping["2020"]
        )
        train_metadata_2020["label"] = train_metadata_2020["label"].map(label2idx)
        train_metadata_2020["strength"] = np.where(
            train_metadata_2020["diagnosis_confirm_type"] == "histopathology",
            "strong",
            "weak",
        )
        if out_dim == 2:
            train_metadata_2020["label"] = np.where(
                train_metadata_2020["label"].isin(malignant_labels), 1, 0
            )
            train_metadata_2020.loc[
                (train_metadata_2020["label"] == 1), "sample_weight"
            ] = 10
            train_metadata_2020.loc[
                (train_metadata_2020["label"] == 0)
                & (train_metadata_2020["strength"] == "strong"),
                "sample_weight",
            ] = 0.6
            train_metadata_2020.loc[
                (train_metadata_2020["label"] == 0)
                & (train_metadata_2020["strength"] == "weak"),
                "sample_weight",
            ] = 0.4
        elif out_dim == 9:
            train_metadata_2020.loc[
                train_metadata_2020["label"].isin(malignant_idx), "sample_weight"
            ] = 10
            train_metadata_2020.loc[
                ~train_metadata_2020["label"].isin(malignant_idx)
                & (train_metadata_2020["strength"] == "strong"),
                "sample_weight",
            ] = 0.6
            train_metadata_2020.loc[
                ~train_metadata_2020["label"].isin(malignant_idx)
                & (train_metadata_2020["strength"] == "weak"),
                "sample_weight",
            ] = 0.4
        else:
            raise ValueError(f"Invalid out_dim: {out_dim}")

        if debug:
            train_metadata_2020 = train_metadata_2020.sample(
                frac=0.05, random_state=seed
            ).reset_index(drop=True)
    else:
        train_metadata_2020 = pd.DataFrame()
        train_images_2020 = None

    if data_2019_dir is not None:
        train_metadata_2019 = pd.read_csv(
            f"{data_2019_dir}/train-metadata.csv", low_memory=False
        )
        train_images_2019 = h5py.File(f"{data_2019_dir}/train-image.hdf5", mode="r")
        train_metadata_2019["label"] = train_metadata_2019["diagnosis"].replace(
            label_mapping["2019"]
        )
        train_metadata_2019["label"] = train_metadata_2019["label"].map(label2idx)
        train_metadata_2019["strength"] = np.where(
            train_metadata_2019["diagnosis_confirm_type"] == "histopathology",
            "strong",
            "weak",
        )
        if out_dim == 2:
            train_metadata_2019["label"] = np.where(
                train_metadata_2019["label"].isin(malignant_labels), 1, 0
            )
            train_metadata_2019.loc[
                (train_metadata_2019["label"] == 1), "sample_weight"
            ] = 10
            train_metadata_2019.loc[
                (train_metadata_2019["label"] == 0)
                & (train_metadata_2019["strength"] == "strong"),
                "sample_weight",
            ] = 0.6
            train_metadata_2019.loc[
                (train_metadata_2019["label"] == 0)
                & (train_metadata_2019["strength"] == "weak"),
                "sample_weight",
            ] = 0.4
        elif out_dim == 9:
            train_metadata_2019.loc[
                train_metadata_2019["label"].isin(malignant_idx), "sample_weight"
            ] = 10
            train_metadata_2019.loc[
                ~train_metadata_2019["label"].isin(malignant_idx)
                & (train_metadata_2019["strength"] == "strong"),
                "sample_weight",
            ] = 0.6
            train_metadata_2019.loc[
                ~train_metadata_2019["label"].isin(malignant_idx)
                & (train_metadata_2019["strength"] == "weak"),
                "sample_weight",
            ] = 0.4
        else:
            raise ValueError(f"Invalid out_dim: {out_dim}")
        if debug:
            train_metadata_2019 = train_metadata_2019.sample(
                frac=0.05, random_state=seed
            ).reset_index(drop=True)
    else:
        train_metadata_2019 = pd.DataFrame()
        train_images_2019 = None

    return (
        train_metadata,
        train_images,
        train_metadata_2020,
        train_images_2020,
        train_metadata_2019,
        train_images_2019,
        malignant_idx,
    )


class ISICNet(nn.Module):
    def __init__(self, model_name, out_dim, pretrained=True, infer=False):
        super(ISICNet, self).__init__()
        self.infer = infer
        self.model = create_model(
            model_name=model_name,
            pretrained=pretrained,
            in_chans=3,
            num_classes=0,
            global_pool="",
        )
        self.classifier = nn.Linear(self.model.num_features, out_dim)

        self.dropouts = nn.ModuleList([nn.Dropout(0.5) for _ in range(5)])

    def forward(self, data):
        image = data
        x = self.model(image)
        bs = len(image)
        pool = F.adaptive_avg_pool2d(x, 1).reshape(bs, -1)

        if self.training:
            logit = 0
            for i in range(len(self.dropouts)):
                logit += self.classifier(self.dropouts[i](pool))
            logit = logit / len(self.dropouts)
        else:
            logit = self.classifier(pool)
        return logit

    
def train_epoch(
    epoch,
    model,
    optimizer,
    criterion,
    dev_dataloader,
    lr_scheduler,
    accelerator,
    log_interval=10,
):
    model.train()
    train_loss = []
    total_steps = len(dev_dataloader)
    for step, (data, target) in enumerate(dev_dataloader):
        optimizer.zero_grad()
        logits = model(data)
        loss = criterion(logits, target)
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()

        loss_value = accelerator.gather(loss).item()
        train_loss.append(loss_value)
        smooth_loss = sum(train_loss[-100:]) / min(len(train_loss), 100)
        if step % log_interval == 0:
            print(
                f"Epoch: {epoch} | Step: {step + 1}/{total_steps} |"
                f" Loss: {loss_value:.5f} | Smooth loss: {smooth_loss:.5f}"
            )
    train_loss = np.mean(train_loss)
    return train_loss


def get_trans(img, iteration):
    if iteration >= 6:
        img = img.transpose(2, 3)
    if iteration % 6 == 0:
        return img
    elif iteration % 6 == 1:
        return torch.flip(img, dims=[2])
    elif iteration % 6 == 2:
        return torch.flip(img, dims=[3])
    elif iteration % 6 == 3:
        return torch.rot90(img, 1, dims=[2, 3])
    elif iteration % 6 == 4:
        return torch.rot90(img, 2, dims=[2, 3])
    elif iteration % 6 == 5:
        return torch.rot90(img, 3, dims=[2, 3])


def val_epoch(
    epoch,
    model,
    criterion,
    val_dataloader,
    accelerator,
    out_dim,
    n_tta,
    malignant_idx,
    log_interval=50,
):
    model.eval()
    val_probs = []
    val_targets = []
    val_loss = []
    total_steps = len(val_dataloader)
    with torch.no_grad():
        for step, (data, target) in enumerate(val_dataloader):
            logits = torch.zeros((data.shape[0], out_dim)).to(accelerator.device)
            probs = torch.zeros((data.shape[0], out_dim)).to(accelerator.device)
            for idx in range(n_tta):
                logits_iter = model(get_trans(data, idx))
                logits += logits_iter
                probs += logits_iter.softmax(1)
            logits /= n_tta
            probs /= n_tta

            loss = criterion(logits, target)
            val_loss.append(loss.detach().cpu().numpy())

            logits, probs, targets = accelerator.gather((logits, probs, target))
            val_probs.append(probs)
            val_targets.append(targets)

            if step % log_interval == 0:
                print(f"Epoch: {epoch} | Step: {step + 1}/{total_steps}")

    val_loss = np.mean(val_loss)
    val_probs = torch.cat(val_probs).cpu().numpy()
    val_targets = torch.cat(val_targets).cpu().numpy()
    if out_dim == 9:
        binary_probs = val_probs[:, malignant_idx].sum(1)
        binary_targets = (
            (val_targets == malignant_idx[0])
            | (val_targets == malignant_idx[1])
            | (val_targets == malignant_idx[2])
        )

        val_auc = compute_auc(binary_targets, binary_probs)
        val_pauc = compute_pauc(binary_targets, binary_probs, min_tpr=0.8)
    else:
        binary_probs = val_probs[:, 1]
        binary_targets = val_targets

        val_auc = compute_auc(binary_targets, binary_probs)
        val_pauc = compute_pauc(binary_targets, binary_probs, min_tpr=0.8)
    return (
        val_loss,
        val_auc,
        val_pauc,
        val_probs,
        val_targets,
        binary_probs,
        binary_targets,
    )

def compute_auc(y_true, y_pred) -> float:
    """
    Compute the Area Under the Receiver Operating Characteristic Curve (ROC AUC).

    Args:
        y_true: ground truth of 1s and 0s
        y_pred: predictions of scores ranging [0, 1]

    Returns:
        Float value range [0, 1]
    """
    return roc_auc_score(y_true, y_pred)


def compute_pauc(y_true, y_pred, min_tpr: float = 0.80) -> float:
    """
    2024 ISIC Challenge metric: pAUC

    Given a solution file and submission file, this function returns the
    partial area under the receiver operating characteristic (pAUC)
    above a given true positive rate (TPR) = 0.80.
    https://en.wikipedia.org/wiki/Partial_Area_Under_the_ROC_Curve.

    (c) 2024 Nicholas R Kurtansky, MSKCC

    Args:
        min_tpr:
        y_true: ground truth of 1s and 0s
        y_pred: predictions of scores ranging [0, 1]

    Returns:
        Float value range [0, max_fpr]
    """

    # rescale the target. set 0s to 1s and 1s to 0s (since sklearn only has max_fpr)
    v_gt = abs(y_true - 1)

    # flip the submissions to their compliments
    v_pred = -1.0 * y_pred

    max_fpr = abs(1 - min_tpr)

    # using sklearn.metric functions: (1) roc_curve and (2) auc
    fpr, tpr, _ = roc_curve(v_gt, v_pred, sample_weight=None)
    if max_fpr is None or max_fpr == 1:
        return auc(fpr, tpr)
    if max_fpr <= 0 or max_fpr > 1:
        raise ValueError("Expected min_tpr in range [0, 1), got: %r" % min_tpr)

    # Add a single point at max_fpr by linear interpolation
    stop = np.searchsorted(fpr, max_fpr, "right")
    x_interp = [fpr[stop - 1], fpr[stop]]
    y_interp = [tpr[stop - 1], tpr[stop]]
    tpr = np.append(tpr[:stop], np.interp(max_fpr, x_interp, y_interp))
    fpr = np.append(fpr[:stop], max_fpr)
    partial_auc = auc(fpr, tpr)

    #     # Equivalent code that uses sklearn's roc_auc_score
    #     v_gt = abs(np.asarray(solution.values)-1)
    #     v_pred = np.array([1.0 - x for x in submission.values])
    #     max_fpr = abs(1-min_tpr)
    #     partial_auc_scaled = roc_auc_score(v_gt, v_pred, max_fpr=max_fpr)
    #     # change scale from [0.5, 1.0] to [0.5 * max_fpr**2, max_fpr]
    #     # https://math.stackexchange.com/questions/914823/shift-numbers-into-a-different-range
    #     partial_auc = 0.5 * max_fpr**2 + (max_fpr - 0.5 * max_fpr**2) / (1.0 - 0.5) * (partial_auc_scaled - 0.5)

    return partial_auc


In [3]:
INPUT_DIR = Path("/kaggle/input")
ARTIFACTS_DIR = Path(".")

args = DotDict()

args.data_dir = INPUT_DIR / "isic-2024-challenge"
args.data_2020_dir, args.data_2019_dir = INPUT_DIR / "isic-2020-challenge", INPUT_DIR / "isic-2019-challenge"
args.model_name = "efficientnet_b3"
args.version = "v1"
args.model_identifier = f"{args.model_name}_{args.version}"
args.model_dir = Path(ARTIFACTS_DIR) / args.model_identifier
args.model_dir.mkdir(parents=True, exist_ok=True)
args.logging_dir = "logs"
args.fold = 1

args.mixed_precision = "fp16"
args.image_size = 128
args.train_batch_size = 128
args.val_batch_size = 256
args.num_workers = 2
args.learning_rate = 1e-3
args.num_epochs = 1
args.n_tta: int = 6
args.seed = 2022
args.out_dim = 9
args.debug = False

In [4]:
logging_dir = Path(args.model_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
    project_dir=args.model_dir, logging_dir=str(logging_dir)
)
kwargs = DistributedDataParallelKwargs()
accelerator = Accelerator(
    mixed_precision=args.mixed_precision,
    project_config=accelerator_project_config,
    kwargs_handlers=[kwargs],
)
print(accelerator.state)

if args.seed is not None:
    set_seed(args.seed)

Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16



In [5]:
(
    train_metadata,
    train_images,
    train_metadata_2020,
    train_images_2020,
    train_metadata_2019,
    train_images_2019,
    malignant_idx,
) = get_data(
    args.data_dir,
    args.data_2020_dir,
    args.data_2019_dir,
    args.out_dim,
    args.debug,
    args.seed,
)

In [6]:
dev_index = train_metadata[train_metadata["fold"] != args.fold].index
val_index = train_metadata[train_metadata["fold"] == args.fold].index

dev_metadata = train_metadata.loc[dev_index, :].reset_index(drop=True)
val_metadata = train_metadata.loc[val_index, :].reset_index(drop=True)

if "sample_weight" not in dev_metadata.columns:
    dev_metadata["sample_weight"] = 1
sample_weight = dev_metadata["sample_weight"].values.tolist()

dev_dataset = ISICDataset(
    dev_metadata, train_images, augment=dev_augment(args.image_size)
)
val_dataset = ISICDataset(
    val_metadata, train_images, augment=val_augment(args.image_size)
)

if not train_metadata_2020.empty:
    print("Using 2020 data")
    if "sample_weight" not in train_metadata_2020.columns:
        train_metadata_2020["sample_weight"] = 1
    sample_weight += train_metadata_2020["sample_weight"].values.tolist()
    train_dataset_2020 = ISICDataset(
        train_metadata_2020, train_images_2020, augment=dev_augment(args.image_size)
    )
    dev_dataset = torch.utils.data.ConcatDataset([dev_dataset, train_dataset_2020])
if not train_metadata_2019.empty:
    print("Using 2019 data")
    if "sample_weight" not in train_metadata_2019.columns:
        train_metadata_2019["sample_weight"] = 1
    sample_weight += train_metadata_2019["sample_weight"].values.tolist()
    train_dataset_2019 = ISICDataset(
        train_metadata_2019, train_images_2019, augment=dev_augment(args.image_size)
    )
    dev_dataset = torch.utils.data.ConcatDataset([dev_dataset, train_dataset_2019])

if np.unique(sample_weight).size > 1:
    print("Using Weighted sampler")
sampler = WeightedRandomSampler(sample_weight, len(sample_weight), replacement=True)
print(f"Building a model with {args.out_dim} classes")

dev_dataloader = DataLoader(
    dev_dataset,
    batch_size=args.train_batch_size,
    sampler=sampler,
    num_workers=args.num_workers,
    pin_memory=True,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=args.val_batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    drop_last=False,
    pin_memory=True,
)

model = ISICNet(
    model_name=args.model_name, out_dim=args.out_dim, pretrained=True, infer=False
)
model = model.to(accelerator.device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate / 20)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=args.learning_rate,
    epochs=args.num_epochs,
    steps_per_epoch=len(dev_dataloader),
)

(
    model,
    optimizer,
    dev_dataloader,
    val_dataloader,
    lr_scheduler,
) = accelerator.prepare(
    model, optimizer, dev_dataloader, val_dataloader, lr_scheduler
)

best_val_auc = 0
best_val_pauc = 0
best_val_loss = 0
best_epoch = 0
best_val_probs = None

for epoch in range(1, args.num_epochs + 1):
    print(f"Fold {args.fold} | Epoch {epoch}")
    start_time = time.time()

    train_loss = train_epoch(
        epoch,
        model,
        optimizer,
        criterion,
        dev_dataloader,
        lr_scheduler,
        accelerator,
    )
    (
        val_loss,
        val_auc,
        val_pauc,
        val_probs,
        val_targets,
        binary_probs,
        binary_targets,
    ) = val_epoch(
        epoch,
        model,
        criterion,
        val_dataloader,
        accelerator,
        args.out_dim,
        args.n_tta,
        malignant_idx,
    )

    if val_pauc > best_val_pauc:
        print(
            f"pAUC: {best_val_pauc:.5f} --> {val_pauc:.5f}, saving model..."
        )
        best_val_pauc = val_pauc
        best_val_auc = val_auc
        best_val_loss = val_loss
        best_epoch = epoch
        best_val_probs = binary_probs
        output_dir = f"{args.model_dir}/models/fold_{args.fold}"
        accelerator.save_state(output_dir)
    else:
        print(
            f"pAUC: {best_val_pauc:.5f} --> {val_pauc:.5f}, skipping model save..."
        )
    print(
        f"Fold: {args.fold} | Epoch: {epoch} |"
        f" Train loss: {train_loss:.5f} | Val loss: {val_loss:.5f}"
        f" Val AUC: {val_auc:.5f} | Val pAUC: {val_pauc:.5f}"
    )
    elapsed_time = time.time() - start_time
    elapsed_mins = int(elapsed_time // 60)
    elapsed_secs = int(elapsed_time % 60)
    print(f"Epoch {epoch} took {elapsed_mins}m {elapsed_secs}s")

print(
    f"Fold: {args.fold} | "
    f"Best Val pAUC: {best_val_pauc} | Best AUC: {best_val_auc} |"
    f" Best loss: {best_val_loss} |"
    f" Best epoch: {best_epoch}"
)
oof_df = pd.DataFrame(
    {
        "isic_id": val_metadata["isic_id"],
        "patient_id": val_metadata["patient_id"],
        "fold": args.fold,
        "label": val_metadata["label"],
        "target": val_metadata["target"],
        f"oof_{args.model_identifier}": best_val_probs,
    }
)
oof_df.to_csv(
    f"{args.model_dir}/oof_preds_{args.model_identifier}_fold_{args.fold}.csv",
    index=False,
)

fold_metadata = {
    "fold": args.fold,
    "best_epoch": best_epoch,
    "best_val_auc": best_val_auc,
    "best_val_pauc": best_val_pauc,
    "best_val_loss": float(best_val_loss),
}
with open(f"{args.model_dir}/models/fold_{args.fold}/metadata.json", "w") as f:
    json.dump(fold_metadata, f)
print(f"Finished training fold {args.fold}")


Using 2020 data
Using 2019 data
Using Weighted sampler
Building a model with 9 classes


model.safetensors:   0%|          | 0.00/49.3M [00:00<?, ?B/s]

Fold 1 | Epoch 1
Epoch: 1 | Step: 1/2964 | Loss: 2.19641 | Smooth loss: 2.19641
Epoch: 1 | Step: 11/2964 | Loss: 2.04298 | Smooth loss: 2.11749
Epoch: 1 | Step: 21/2964 | Loss: 1.89369 | Smooth loss: 2.04205
Epoch: 1 | Step: 31/2964 | Loss: 1.74879 | Smooth loss: 1.96799
Epoch: 1 | Step: 41/2964 | Loss: 1.59513 | Smooth loss: 1.90007
Epoch: 1 | Step: 51/2964 | Loss: 1.44191 | Smooth loss: 1.82415
Epoch: 1 | Step: 61/2964 | Loss: 1.26834 | Smooth loss: 1.74888
Epoch: 1 | Step: 71/2964 | Loss: 1.21832 | Smooth loss: 1.67203
Epoch: 1 | Step: 81/2964 | Loss: 1.08460 | Smooth loss: 1.60216
Epoch: 1 | Step: 91/2964 | Loss: 0.85261 | Smooth loss: 1.53265
Epoch: 1 | Step: 101/2964 | Loss: 0.86563 | Smooth loss: 1.45846
Epoch: 1 | Step: 111/2964 | Loss: 0.75419 | Smooth loss: 1.32021
Epoch: 1 | Step: 121/2964 | Loss: 0.75705 | Smooth loss: 1.19914
Epoch: 1 | Step: 131/2964 | Loss: 0.66752 | Smooth loss: 1.08823
Epoch: 1 | Step: 141/2964 | Loss: 0.49256 | Smooth loss: 0.98411
Epoch: 1 | Step: 15