In [1]:
import json
import logging
from io import BytesIO
from pathlib import Path

import albumentations as A
import h5py
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import (
    DistributedDataParallelKwargs,
    ProjectConfiguration,
    set_seed,
)
from albumentations.pytorch import ToTensorV2
from PIL import Image
from sklearn.metrics import auc, roc_curve
from sklearn.metrics import roc_auc_score as compute_auc
from timm import create_model
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from tqdm import tqdm

from isic_helper import DotDict, get_folds

In [2]:
def dev_augment(image_size):
    transform = A.Compose(
        [
            A.Transpose(p=0.5),
            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(
                brightness_limit=0.2, contrast_limit=0.2, p=0.75
            ),
            A.OneOf(
                [
                    A.MotionBlur(blur_limit=(5, 7)),
                    A.MedianBlur(blur_limit=(5, 7)),
                    A.GaussianBlur(blur_limit=(5, 7)),
                    A.GaussNoise(var_limit=(5.0, 30.0)),
                ],
                p=0.7,
            ),
            A.OneOf(
                [
                    A.OpticalDistortion(distort_limit=1.0),
                    A.GridDistortion(num_steps=5, distort_limit=1.0),
                    A.ElasticTransform(alpha=3),
                ],
                p=0.7,
            ),
            A.CLAHE(clip_limit=4.0, p=0.7),
            A.HueSaturationValue(
                hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5
            ),
            A.ShiftScaleRotate(
                shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85
            ),
            A.Resize(image_size, image_size),
            # A.Cutout(max_h_size=int(image_size * 0.375), max_w_size=int(image_size * 0.375), num_holes=1, p=0.7),
            ToTensorV2(),
        ],
        p=1.0,
    )
    return transform


def val_augment(image_size):
    transform = A.Compose([A.Resize(image_size, image_size), ToTensorV2()], p=1.0)
    return transform


class ISICDataset(Dataset):
    def __init__(self, metadata, images, augment, infer=False):
        self.metadata = metadata
        self.images = images
        self.augment = augment
        self.length = len(self.metadata)
        self.infer = infer

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        data = self.metadata.iloc[index]

        image = np.array(Image.open(BytesIO(self.images[data["isic_id"]][()])))
        image = self.augment(image=image)["image"]

        record = {"image": image}

        if not self.infer:
            target = data["target"]
            record["target"] = torch.tensor(target).float()

        return record


class ISICNet(nn.Module):
    def __init__(self, model_name, pretrained=True, infer=False):
        super(ISICNet, self).__init__()
        self.infer = infer
        self.model = create_model(
            model_name=model_name,
            pretrained=pretrained,
            in_chans=3,
            num_classes=0,
            global_pool="",
        )
        self.classifier = nn.Linear(self.model.num_features, 1)

        self.dropouts = nn.ModuleList([nn.Dropout(0.5) for i in range(5)])

    def forward(self, batch):
        image = batch["image"]
        image = image.float() / 255

        x = self.model(image)
        bs = len(image)
        pool = F.adaptive_avg_pool2d(x, 1).reshape(bs, -1)

        if self.training:
            logit = 0
            for i in range(len(self.dropouts)):
                logit += self.classifier(self.dropouts[i](pool))
            logit = logit / len(self.dropouts)
        else:
            logit = self.classifier(pool)
        return logit


def compute_pauc(y_true, y_pred, min_tpr: float = 0.80) -> float:
    """
    2024 ISIC Challenge metric: pAUC

    Given a solution file and submission file, this function returns the
    partial area under the receiver operating characteristic (pAUC)
    above a given true positive rate (TPR) = 0.80.
    https://en.wikipedia.org/wiki/Partial_Area_Under_the_ROC_Curve.

    (c) 2024 Nicholas R Kurtansky, MSKCC

    Args:
        min_tpr:
        y_true: ground truth of 1s and 0s
        y_pred: predictions of scores ranging [0, 1]

    Returns:
        Float value range [0, max_fpr]
    """

    # rescale the target. set 0s to 1s and 1s to 0s (since sklearn only has max_fpr)
    v_gt = abs(y_true - 1)

    # flip the submissions to their compliments
    v_pred = -1.0 * y_pred

    max_fpr = abs(1 - min_tpr)

    # using sklearn.metric functions: (1) roc_curve and (2) auc
    fpr, tpr, _ = roc_curve(v_gt, v_pred, sample_weight=None)
    if max_fpr is None or max_fpr == 1:
        return auc(fpr, tpr)
    if max_fpr <= 0 or max_fpr > 1:
        raise ValueError("Expected min_tpr in range [0, 1), got: %r" % min_tpr)

    # Add a single point at max_fpr by linear interpolation
    stop = np.searchsorted(fpr, max_fpr, "right")
    x_interp = [fpr[stop - 1], fpr[stop]]
    y_interp = [tpr[stop - 1], tpr[stop]]
    tpr = np.append(tpr[:stop], np.interp(max_fpr, x_interp, y_interp))
    fpr = np.append(fpr[:stop], max_fpr)
    partial_auc = auc(fpr, tpr)

    #     # Equivalent code that uses sklearn's roc_auc_score
    #     v_gt = abs(np.asarray(solution.values)-1)
    #     v_pred = np.array([1.0 - x for x in submission.values])
    #     max_fpr = abs(1-min_tpr)
    #     partial_auc_scaled = roc_auc_score(v_gt, v_pred, max_fpr=max_fpr)
    #     # change scale from [0.5, 1.0] to [0.5 * max_fpr**2, max_fpr]
    #     # https://math.stackexchange.com/questions/914823/shift-numbers-into-a-different-range
    #     partial_auc = 0.5 * max_fpr**2 + (max_fpr - 0.5 * max_fpr**2) / (1.0 - 0.5) * (partial_auc_scaled - 0.5)

    return partial_auc

In [3]:
logger = get_logger(__name__)

INPUT_DIR = Path("/kaggle/input")
ARTIFACTS_DIR = Path(".")

args = DotDict()

args.data_dir = INPUT_DIR / "isic-2024-challenge"
args.model_name = "resnet18"
args.version = "v3"
args.model_identifier = f"{args.model_name}_{args.version}"
args.model_dir = Path(ARTIFACTS_DIR) / args.model_identifier
args.model_dir.mkdir(parents=True, exist_ok=True)
args.logging_dir = "logs"
args.fold = 1

args.mixed_precision = "fp16"
args.pos_weight = 10
args.image_size = 64
args.train_batch_size = 256
args.val_batch_size = 512
args.num_workers = 2
args.learning_rate = 1e-3
args.num_epochs = 2
args.tta = True
args.seed = 2022

args.debug = False

In [4]:
logging_dir = Path(args.model_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
    project_dir=args.model_dir, logging_dir=str(logging_dir)
)
kwargs = DistributedDataParallelKwargs()
accelerator = Accelerator(
    mixed_precision=args.mixed_precision,
    project_config=accelerator_project_config,
    kwargs_handlers=[kwargs],
)

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)

if args.seed is not None:
    set_seed(args.seed)

In [5]:
train_metadata = pd.read_csv(
    f"{args.data_dir}/train-metadata.csv", low_memory=False
)
train_images = h5py.File(f"{args.data_dir}/train-image.hdf5", mode="r")

folds_df = get_folds()
train_metadata = train_metadata.merge(
    folds_df, on=["isic_id", "patient_id"], how="inner"
)
if args.debug:
    train_metadata = train_metadata.sample(
        frac=0.05, random_state=args.seed
    ).reset_index(drop=True)

y_train = train_metadata["target"]
train_metadata["sample_weight"] = np.where(train_metadata["target"] == 1, args.pos_weight, 1)

In [6]:
dev_index = train_metadata[train_metadata["fold"] != args.fold].index
val_index = train_metadata[train_metadata["fold"] == args.fold].index

dev_metadata = train_metadata.loc[dev_index, :].reset_index(drop=True)
val_metadata = train_metadata.loc[val_index, :].reset_index(drop=True)

dev_dataset = ISICDataset(
    dev_metadata, train_images, augment=dev_augment(args.image_size)
)
val_dataset = ISICDataset(
    val_metadata, train_images, augment=val_augment(args.image_size)
)

weighted_sampler = WeightedRandomSampler(weights=dev_metadata["sample_weight"], num_samples=len(dev_dataset), replacement=True)
dev_dataloader = DataLoader(
    dev_dataset,
    batch_size=args.train_batch_size,
    sampler=weighted_sampler,
    num_workers=args.num_workers,
    pin_memory=True,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=args.val_batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    drop_last=False,
    pin_memory=True,
)

In [7]:
model = ISICNet(model_name=args.model_name, pretrained=True, infer=False)
model = model.to(accelerator.device)

optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate / 5)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=args.learning_rate,
    epochs=args.num_epochs,
    steps_per_epoch=len(dev_dataloader),
)

(
    model,
    optimizer,
    dev_dataloader,
    val_dataloader,
    lr_scheduler,
) = accelerator.prepare(
    model, optimizer, dev_dataloader, val_dataloader, lr_scheduler
)

best_pauc = 0
best_auc = 0
best_epoch = 0
best_val_preds = None

for epoch in range(args.num_epochs):
    model.train()
    for batch in tqdm(dev_dataloader, total=len(dev_dataloader)):
        optimizer.zero_grad()
        output = model(batch)
        loss = F.binary_cross_entropy_with_logits(
            output, batch["target"].unsqueeze(1)
        )
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()

    model.eval()
    val_preds = []
    val_targets = []
    with torch.no_grad():
        for batch in tqdm(val_dataloader, total=len(val_dataloader)):
            image0 = batch["image"].clone().detach()
            val_preds_batch = 0
            counter = 0
            with torch.no_grad():
                outputs = model(batch)
            preds = torch.sigmoid(outputs)
            val_targets_batch = batch["target"]
            preds, val_targets_batch = accelerator.gather_for_metrics(
                (preds, val_targets_batch)
            )
            val_preds_batch += preds.data.cpu().numpy().reshape(-1)
            counter += 1
            if args.tta:
                batch["image"] = torch.flip(image0, dims=[2])
                with torch.no_grad():
                    outputs = model(batch)
                preds = torch.sigmoid(outputs)
                preds = accelerator.gather_for_metrics(preds)
                val_preds_batch += preds.data.cpu().numpy().reshape(-1)
                counter += 1

                batch["image"] = torch.flip(image0, dims=[3])
                with torch.no_grad():
                    outputs = model(batch)
                preds = torch.sigmoid(outputs)
                preds = accelerator.gather_for_metrics(preds)
                val_preds_batch += preds.data.cpu().numpy().reshape(-1)
                counter += 1

                for k in [1, 2, 3]:
                    batch["image"] = torch.rot90(image0, k, dims=[2, 3])
                    with torch.no_grad():
                        outputs = model(batch)
                    preds = torch.sigmoid(outputs)
                    preds = accelerator.gather_for_metrics(preds)
                    val_preds_batch += preds.data.cpu().numpy().reshape(-1)
                    counter += 1
            val_preds_batch = val_preds_batch / counter
            val_preds.append(val_preds_batch)
            val_targets.append(val_targets_batch.data.cpu().numpy().reshape(-1))

    val_preds = np.concatenate(val_preds)
    val_targets = np.concatenate(val_targets)

    epoch_auc = compute_auc(val_targets, val_preds)
    epoch_pauc = compute_pauc(val_targets, val_preds, min_tpr=0.8)

    if epoch_pauc > best_pauc:
        best_pauc = epoch_pauc
        best_auc = epoch_auc
        best_epoch = epoch
        best_val_preds = val_preds
    logger.info(
        f"Epoch {epoch} - Epoch pauc: {epoch_pauc} | Best auc: {best_auc} | Best pauc: {best_pauc} | Best "
        f"epoch: {best_epoch}"
    )

    output_dir = f"{args.model_dir}/models/fold_{args.fold}/epoch_{epoch}"
    accelerator.save_state(output_dir)

logger.info(
    f"Fold: {args.fold} | Best pauc: {best_pauc} | Best auc: {best_auc} | Best epoch: {best_epoch}"
)
oof_df = pd.DataFrame(
    {
        "isic_id": val_metadata["isic_id"],
        "patient_id": val_metadata["patient_id"],
        "fold": args.fold,
        "target": val_metadata["target"],
        f"oof_{args.model_identifier}": best_val_preds,
    }
)
oof_df.to_csv(
    f"{args.model_dir}/oof_preds_{args.model_identifier}_fold_{args.fold}.csv",
    index=False,
)

fold_metadata = {
    "fold": args.fold,
    "best_epoch": best_epoch,
    "best_auc": best_auc,
    "best_pauc": best_pauc,
}
with open(f"{args.model_dir}/models/fold_{args.fold}/metadata.json", "w") as f:
    json.dump(fold_metadata, f)
logger.info(f"Finished training fold {args.fold}")

model.safetensors:   0%|          | 0.00/46.8M [00:00<?, ?B/s]

100%|██████████| 1254/1254 [30:10<00:00,  1.44s/it]
100%|██████████| 157/157 [00:43<00:00,  3.61it/s]
100%|██████████| 1254/1254 [28:43<00:00,  1.37s/it]
100%|██████████| 157/157 [00:43<00:00,  3.65it/s]


In [8]:
fold_metadata

{'fold': 1,
 'best_epoch': 1,
 'best_auc': 0.9334444957814936,
 'best_pauc': 0.15536507367496796}