In [1]:
import argparse
import json
import logging
import time
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from accelerate import Accelerator
from accelerate.utils import (
    DistributedDataParallelKwargs,
    ProjectConfiguration,
    set_seed,
)
from accelerate.logging import get_logger
from io import BytesIO

import albumentations as A
import h5py
from albumentations.pytorch import ToTensorV2
from PIL import Image
from torch.utils.data import Dataset, DataLoader, RandomSampler
from timm import create_model
import torch.nn.functional as F
from safetensors import safe_open

from dataclasses import dataclass
from isic_helper import get_folds
from isic_helper import compute_auc, compute_pauc

In [2]:
@dataclass
class Config:
    mixed_precision: bool = "fp16"
    image_size: int = 64
    train_batch_size: int = 64
    val_batch_size: int = 512
    num_workers: int = 4
    init_lr: float = 3e-5
    num_epochs: int = 20
    n_tta: int = 8
    seed: int = 2022

    ext: str = "2020,2019"
    only_malignant: bool = False
    debug: bool = False

args = Config()
args.model_name = "efficientnet_b0"
args.version = "v3"
args.model_identifier = f"{args.model_name}_{args.version}"
args.pretrained_weights_path = f"/kaggle/input/isic-scd-{args.model_name.replace('_', '-')}-{args.version}-train"
args.fold = 1
args.model_dir = f"{args.model_identifier}_finetune"
args.logging_dir = "logs"

In [3]:
def dev_augment(image_size, mean=None, std=None):
    if mean is not None and std is not None:
        normalize = A.Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0)
    else:
        normalize = A.Normalize(max_pixel_value=255.0, p=1.0)
    transform = A.Compose(
        [
            A.Transpose(p=0.5),
            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(
                brightness_limit=0.2, contrast_limit=0.2, p=0.75
            ),
            A.OneOf(
                [
                    A.MotionBlur(blur_limit=(5, 7)),
                    A.MedianBlur(blur_limit=(5, 7)),
                    A.GaussianBlur(blur_limit=(5, 7)),
                    A.GaussNoise(var_limit=(5.0, 30.0)),
                ],
                p=0.7,
            ),
            A.OneOf(
                [
                    A.OpticalDistortion(distort_limit=1.0),
                    A.GridDistortion(num_steps=5, distort_limit=1.0),
                    A.ElasticTransform(alpha=3),
                ],
                p=0.7,
            ),
            A.CLAHE(clip_limit=4.0, p=0.7),
            A.HueSaturationValue(
                hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5
            ),
            A.ShiftScaleRotate(
                shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85
            ),
            A.Resize(image_size, image_size),
            A.CoarseDropout(
                max_height=int(image_size * 0.375),
                max_width=int(image_size * 0.375),
                max_holes=1,
                min_holes=1,
                p=0.7,
            ),
            normalize,
            ToTensorV2(),
        ],
        p=1.0,
    )
    return transform


def val_augment(image_size, mean=None, std=None):
    if mean is not None and std is not None:
        normalize = A.Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0)
    else:
        normalize = A.Normalize(max_pixel_value=255.0, p=1.0)
    transform = A.Compose(
        [A.Resize(image_size, image_size), normalize, ToTensorV2()], p=1.0
    )
    return transform


class ISICDataset(Dataset):
    def __init__(self, metadata, images, augment, infer=False):
        self.metadata = metadata
        self.images = images
        self.augment = augment
        self.length = len(self.metadata)
        self.infer = infer

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        row = self.metadata.iloc[index]
        image = np.array(Image.open(BytesIO(self.images[row["isic_id"]][()])))
        if self.augment is not None:
            image = self.augment(image=image)["image"].float()
        if self.infer:
            return image
        else:
            target = torch.tensor(row["target"])
            return image, target

class ISICNet(nn.Module):
    def __init__(
        self,
        model_name
    ):
        super(ISICNet, self).__init__()
        self.model = create_model(
            model_name=model_name,
            pretrained=False,
            in_chans=3,
            num_classes=0,
            global_pool="",
        )
        in_dim = self.model.num_features
        self.classifier = nn.Linear(in_dim, 1)
        self.dropouts = nn.ModuleList([nn.Dropout(0.5) for _ in range(5)])

    def forward(self, images):
        x = self.model(images)
        bs = len(images)
        pool = F.adaptive_avg_pool2d(x, 1).reshape(bs, -1)
        if self.training:
            logits = 0
            for i in range(len(self.dropouts)):
                logits += self.classifier(self.dropouts[i](pool))
            logits = logits / len(self.dropouts)
        else:
            logits = self.classifier(pool)
        return logits

def train_epoch(
    epoch,
    model,
    optimizer,
    criterion,
    dev_dataloader,
    lr_scheduler,
    accelerator,
    log_interval=100,
):
    model.train()
    train_loss = []
    total_steps = len(dev_dataloader)
    for step, (images, targets) in enumerate(dev_dataloader):
        optimizer.zero_grad()
        logits = model(images)
        probs = torch.sigmoid(logits)
        targets = targets.float().unsqueeze(1)
        loss = criterion(probs, targets)
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()

        loss_value = accelerator.gather(loss).item()
        train_loss.append(loss_value)
        smooth_loss = sum(train_loss[-100:]) / min(len(train_loss), 100)
        if (step == 0) or ((step + 1) % log_interval == 0):
            print(
                f"Epoch: {epoch} | Step: {step + 1}/{total_steps} |"
                f" Loss: {loss_value:.5f} | Smooth loss: {smooth_loss:.5f}"
            )
    train_loss = np.mean(train_loss)
    return train_loss


def get_trans(img, iteration):
    if iteration >= 6:
        img = img.transpose(2, 3)
    if iteration % 6 == 0:
        return img
    elif iteration % 6 == 1:
        return torch.flip(img, dims=[2])
    elif iteration % 6 == 2:
        return torch.flip(img, dims=[3])
    elif iteration % 6 == 3:
        return torch.rot90(img, 1, dims=[2, 3])
    elif iteration % 6 == 4:
        return torch.rot90(img, 2, dims=[2, 3])
    elif iteration % 6 == 5:
        return torch.rot90(img, 3, dims=[2, 3])


def val_epoch(
    epoch,
    model,
    criterion,
    val_dataloader,
    accelerator,
    n_tta,
    log_interval=10,
):
    model.eval()
    val_probs = []
    val_targets = []
    val_loss = []
    total_steps = len(val_dataloader)
    with torch.no_grad():
        for step, (images, targets) in enumerate(val_dataloader):
            logits = 0
            probs = 0
            for i in range(n_tta):
                logits_iter = model(get_trans(images, i))
                logits += logits_iter
                probs += torch.sigmoid(logits_iter)
            logits /= n_tta
            probs /= n_tta

            targets = targets.float().unsqueeze(1)
            loss = criterion(probs, targets)
            val_loss.append(loss.detach().cpu().numpy())

            probs, targets = accelerator.gather((probs, targets))
            val_probs.append(probs)
            val_targets.append(targets)

            if (step == 0) or ((step + 1) % log_interval == 0):
                print(f"Epoch: {epoch} | Step: {step + 1}/{total_steps}")

    val_loss = np.mean(val_loss)
    val_probs = torch.cat(val_probs).cpu().numpy()
    val_targets = torch.cat(val_targets).cpu().numpy()
    val_auc = compute_auc(val_targets, val_probs)
    val_pauc = compute_pauc(val_targets, val_probs, min_tpr=0.8)
    return (
        val_loss,
        val_auc,
        val_pauc,
        val_probs,
        val_targets,
    )


In [4]:
logger = get_logger(__name__)
logging_dir = Path(args.model_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
    project_dir=args.model_dir, logging_dir=str(logging_dir)
)
kwargs = DistributedDataParallelKwargs()
accelerator = Accelerator(
    mixed_precision=args.mixed_precision,
    project_config=accelerator_project_config,
    kwargs_handlers=[kwargs],
)

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
print(accelerator.state)

if args.seed is not None:
    set_seed(args.seed)

Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16



In [5]:
id_column = "isic_id"
target_column = "target"
group_column = "patient_id"

INPUT_PATH = Path("/kaggle/input/isic-2024-challenge/")

train_metadata = pd.read_csv(INPUT_PATH / "train-metadata.csv", low_memory=False, na_values=["NA"])
test_metadata = pd.read_csv(INPUT_PATH / "test-metadata.csv", low_memory=False, na_values=["NA"])

folds_df = get_folds()
train_metadata = train_metadata.merge(folds_df, on=["isic_id", "patient_id"], how="inner")
print(f"Train data size: {train_metadata.shape}")
print(f"Test data size: {test_metadata.shape}")

train_images = h5py.File(INPUT_PATH / "train-image.hdf5", mode="r")

Train data size: (401059, 57)
Test data size: (3, 44)


In [6]:
if args.debug:
    args.num_epochs = 2
    dev_index = (
        train_metadata[train_metadata["fold"] != args.fold]
        .sample(args.train_batch_size * 3, random_state=args.seed)
        .index
    )
    val_index = (
        train_metadata[train_metadata["fold"] == args.fold]
        .sample(args.val_batch_size * 10, random_state=args.seed)
        .index
    )
else:
    dev_index = train_metadata[train_metadata["fold"] != args.fold].index
    val_index = train_metadata[train_metadata["fold"] == args.fold].index

dev_metadata = train_metadata.loc[dev_index, :].reset_index(drop=True)
val_metadata = train_metadata.loc[val_index, :].reset_index(drop=True)

mean = None
std = None

dev_dataset = ISICDataset(
    dev_metadata,
    train_images,
    augment=dev_augment(args.image_size, mean=mean, std=std),
    infer=False,
)
val_dataset = ISICDataset(
    val_metadata,
    train_images,
    augment=val_augment(args.image_size, mean=mean, std=std),
    infer=False,
)

sampler = RandomSampler(dev_dataset)

dev_dataloader = DataLoader(
    dev_dataset,
    batch_size=args.train_batch_size,
    sampler=sampler,
    num_workers=args.num_workers,
    pin_memory=True,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=args.val_batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    drop_last=False,
    pin_memory=True,
)

In [7]:
model = ISICNet(
    model_name=args.model_name
)
model = model.to(accelerator.device)
tensors = {}
with safe_open(f"{args.pretrained_weights_path}/models/fold_{args.fold}/model.safetensors", framework="pt") as f:
    for key in f.keys():
        if "classifier" not in key:
            tensors[key] = f.get_tensor(key)
msg = model.load_state_dict(tensors, strict=False)
print(msg)
criterion = nn.BCELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.init_lr, weight_decay=0.001, amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    pct_start=1 / args.num_epochs,
    max_lr=args.init_lr * 10,
    div_factor=10,
    epochs=args.num_epochs,
    steps_per_epoch=len(dev_dataloader),
)
(
    model,
    optimizer,
    dev_dataloader,
    val_dataloader,
    lr_scheduler,
) = accelerator.prepare(
    model, optimizer, dev_dataloader, val_dataloader, lr_scheduler
)

best_val_auc = 0
best_val_pauc = 0
best_val_loss = 0
best_epoch = 0
best_val_probs = None
train_losses = []
val_losses = []
val_paucs = []
val_aucs = []
for epoch in range(1, args.num_epochs + 1):
    print(f"Fold {args.fold} | Epoch {epoch}")
    start_time = time.time()
    lr = optimizer.param_groups[0]["lr"]
    train_loss = train_epoch(
        epoch,
        model,
        optimizer,
        criterion,
        dev_dataloader,
        lr_scheduler,
        accelerator,
    )
    (
        val_loss,
        val_auc,
        val_pauc,
        val_probs,
        val_targets,
    ) = val_epoch(
        epoch,
        model,
        criterion,
        val_dataloader,
        accelerator,
        args.n_tta,
    )
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_paucs.append(val_pauc)
    val_aucs.append(val_auc)
    print(
        f"Fold: {args.fold} | Epoch: {epoch} | LR: {lr:.7f} |"
        f" Train loss: {train_loss:.5f} | Val loss: {val_loss:.5f} |"
        f" Val AUC: {val_auc:.5f} | Val pAUC: {val_pauc:.5f}"
    )
    if val_pauc > best_val_pauc:
        print(
            f"pAUC: {best_val_pauc:.5f} --> {val_pauc:.5f}, saving model..."
        )
        best_val_pauc = val_pauc
        best_val_auc = val_auc
        best_val_loss = val_loss
        best_epoch = epoch
        best_val_probs = val_probs
        output_dir = f"{args.model_dir}/models/fold_{args.fold}"
        accelerator.save_state(output_dir)
    else:
        print(
            f"pAUC: {best_val_pauc:.5f} --> {val_pauc:.5f}, skipping model save..."
        )
    elapsed_time = time.time() - start_time
    elapsed_mins = int(elapsed_time // 60)
    elapsed_secs = int(elapsed_time % 60)
    print(f"Epoch {epoch} took {elapsed_mins}m {elapsed_secs}s")
    if epoch == 3:
        break

output_dir = f"{args.model_dir}/models/fold_{args.fold}/final"
accelerator.save_state(output_dir)


_IncompatibleKeys(missing_keys=['classifier.weight', 'classifier.bias'], unexpected_keys=[])
Fold 1 | Epoch 1
Epoch: 1 | Step: 1/5014 | Loss: 0.58040 | Smooth loss: 0.58040
Epoch: 1 | Step: 100/5014 | Loss: 0.20837 | Smooth loss: 0.35927
Epoch: 1 | Step: 200/5014 | Loss: 0.07850 | Smooth loss: 0.12965
Epoch: 1 | Step: 300/5014 | Loss: 0.03784 | Smooth loss: 0.05594
Epoch: 1 | Step: 400/5014 | Loss: 0.02317 | Smooth loss: 0.03373
Epoch: 1 | Step: 500/5014 | Loss: 0.01504 | Smooth loss: 0.01932
Epoch: 1 | Step: 600/5014 | Loss: 0.01120 | Smooth loss: 0.02051
Epoch: 1 | Step: 700/5014 | Loss: 0.00844 | Smooth loss: 0.01341
Epoch: 1 | Step: 800/5014 | Loss: 0.00664 | Smooth loss: 0.01156
Epoch: 1 | Step: 900/5014 | Loss: 0.00514 | Smooth loss: 0.01121
Epoch: 1 | Step: 1000/5014 | Loss: 0.00434 | Smooth loss: 0.01023
Epoch: 1 | Step: 1100/5014 | Loss: 0.00377 | Smooth loss: 0.01024
Epoch: 1 | Step: 1200/5014 | Loss: 0.00358 | Smooth loss: 0.00869
Epoch: 1 | Step: 1300/5014 | Loss: 0.00286 |

PosixPath('efficientnet_b0_v3_finetune/models/fold_1/final')