In [1]:
import io

import h5py
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.functional as F
import torchvision
from torch.utils.data import Sampler
from PIL import Image
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.decomposition import KernelPCA
from sklearn.feature_selection import VarianceThreshold
from imblearn.under_sampling import RandomUnderSampler
from imblearn.over_sampling import RandomOverSampler
from torchmetrics import Metric
from torch.utils.data import DataLoader, Dataset, Subset
import time
from pytorch_lightning.callbacks import (
    Callback,
    ModelCheckpoint,
    EarlyStopping,
    LearningRateMonitor,
    ProgressBar
)
from pytorch_lightning.loggers import TensorBoardLogger
from torchvision.transforms import transforms
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

pl.seed_everything(42, workers=True)

  from .autonotebook import tqdm as notebook_tqdm
Global seed set to 42


42

In [2]:
"""
2024 ISIC Challenge primary prize scoring metric

Given a list of binary labels, an associated list of prediction 
scores ranging from [0,1], this function produces, as a single value, 
the partial area under the receiver operating characteristic (pAUC) 
above a given true positive rate (TPR).
https://en.wikipedia.org/wiki/Partial_Area_Under_the_ROC_Curve.

(c) 2024 Nicholas R Kurtansky, MSKCC
"""

import numpy as np
import pandas as pd
import pandas.api.types
from sklearn.metrics import roc_curve, auc, roc_auc_score
from collections import Counter

class PartialAUROC(Metric):
    def __init__(
        self,
        min_tpr: float = 0.80,
        dist_sync_on_step: bool = False,
    ):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.min_tpr = min_tpr
        self.add_state("preds", default=[], dist_reduce_fx="cat")
        self.add_state("target", default=[], dist_reduce_fx="cat")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        self.preds.append(preds)
        self.target.append(target)

    def compute(self):
        preds = torch.cat(self.preds)
        target = torch.cat(self.target)
        return self._partial_auroc(target, preds, self.min_tpr)

    def _partial_auroc(
        self, y_true: torch.Tensor, y_score: torch.Tensor, min_tpr: float
    ) -> float:
        y_true = torch.abs(y_true - 1)
        y_score = -y_score

        fpr, tpr, _ = self._roc_curve(y_true, y_score)
        max_fpr = 1.0 - min_tpr

        # print(f"Computed FPR: {fpr}")
        # print(f"Computed TPR: {tpr}")

        if max_fpr == 1:
            return self._auc(fpr, tpr)
        if max_fpr <= 0 or max_fpr > 1:
            raise ValueError(f"Expected min_tpr in range [0, 1), got: {min_tpr}")

        stop = torch.searchsorted(fpr, torch.tensor(max_fpr), right=True)
        x_interp = fpr[stop - 1 : stop + 1]
        y_interp = tpr[stop - 1 : stop + 1]

        # print(f"x_interp: {x_interp}")
        # print(f"y_interp: {y_interp}")

        if len(x_interp) == 1:
            interp_tpr = y_interp[0]
        else:
            interp_tpr = y_interp[0] + (max_fpr - x_interp[0]) * (
                y_interp[1] - y_interp[0]
            ) / (x_interp[1] - x_interp[0])

        tpr = torch.cat([tpr[:stop], torch.tensor([interp_tpr])])
        fpr = torch.cat([fpr[:stop], torch.tensor([max_fpr])])

        partial_auc = self._auc(fpr, tpr)
        return partial_auc

    def _roc_curve(self, y_true: torch.Tensor, y_score: torch.Tensor):
        desc_score_indices = torch.argsort(y_score, descending=True)
        y_score = y_score[desc_score_indices]
        y_true = y_true[desc_score_indices]

        distinct_value_indices = torch.where(torch.diff(y_score))[0]
        threshold_idxs = torch.cat(
            [distinct_value_indices, torch.tensor([y_true.numel() - 1])]
        )

        tps = torch.cumsum(y_true, dim=0)[threshold_idxs]
        fps = 1 + threshold_idxs - tps
        
        # Handle the case where there are no positive samples
        if tps[-1] == 0:
            tpr = torch.zeros_like(tps)
        else:
            tpr = tps / tps[-1]
        
        fpr = fps / fps[-1]
        thresholds = y_score[threshold_idxs]

        # print(f"tps: {tps}")
        # print(f"fps: {fps}")
        # print(f"tpr: {tpr}")
        # print(f"fpr: {fpr}")
        # print(f"thresholds: {thresholds}")

        return fpr, tpr, thresholds

    def _auc(self, x: torch.Tensor, y: torch.Tensor) -> float:
        if torch.all(y == 0):
            print("Warning: All TPR values are zero. AUC is undefined.")
            return 0.0

        direction = 1
        dx = torch.diff(x)
        if torch.any(dx < 0):
            if torch.all(dx <= 0):
                direction = -1
            else:
                raise ValueError("x is neither increasing nor decreasing")
        auc_value = direction * torch.trapz(y, x).item()
        # print(f"Computed AUC: {auc_value}")
        return auc_value

In [3]:
print(torch.__version__)
print(torchvision.__version__)

2.0.1+cu117
0.15.2+cu117


In [4]:
class InvertedResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, expand_ratio, stride):
        super(InvertedResidualBlock, self).__init__()
        hidden_dim = in_channels * expand_ratio
        self.use_res_connect = stride == 1 and in_channels == out_channels

        layers = []
        if expand_ratio != 1:
            layers.append(ConvBNActivation(in_channels, hidden_dim, kernel_size=1))
        layers.extend(
            [
                ConvBNActivation(
                    hidden_dim, hidden_dim, stride=stride, groups=hidden_dim
                ),
                nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
                nn.BatchNorm2d(out_channels),
            ]
        )
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class ConvBNActivation(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        super(ConvBNActivation, self).__init__(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                groups=groups,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.Mish(),
        )


class DenseBlock(nn.Module):
    def __init__(self, in_channels, num_layers, growth_rate, dropout_rate=0.2):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList(
            [
                DenseLayer(in_channels + i * growth_rate, growth_rate, dropout_rate)
                for i in range(num_layers)
            ]
        )

    def forward(self, x):
        for layer in self.layers:
            x = torch.cat([x, layer(x)], 1)
        return x


class DenseLayer(nn.Sequential):
    def __init__(self, in_channels, growth_rate, dropout_rate):
        super(DenseLayer, self).__init__(
            nn.BatchNorm2d(in_channels),
            nn.Mish(),
            nn.Conv2d(in_channels, 4 * growth_rate, 1, bias=False),
            nn.BatchNorm2d(4 * growth_rate),
            nn.Mish(),
            nn.Conv2d(4 * growth_rate, growth_rate, 3, padding=1, bias=False),
            nn.Dropout2d(dropout_rate),
        )


class TransitionLayer(nn.Sequential):
    def __init__(self, in_channels, compression_factor=0.5):
        out_channels = int(in_channels * compression_factor)
        super(TransitionLayer, self).__init__(
            nn.BatchNorm2d(in_channels),
            nn.Mish(),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.AvgPool2d(2, stride=2),
        )


class AttentionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AttentionBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(in_channels, out_channels, 1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, 1, 1)
        self.bn3 = nn.BatchNorm2d(1)

    def forward(self, x):
        g = self.bn1(self.conv1(x))
        x = self.bn2(self.conv2(x))
        att = nn.Hardswish()(g + x)
        att = nn.Sigmoid()(self.bn3(self.conv3(att)))
        return x * att


class InceptionBlock(nn.Module):
    def __init__(self, in_channels, filters):
        super(InceptionBlock, self).__init__()
        f1, f2, f3 = filters
        self.branch1 = ConvBNActivation(in_channels, f1, kernel_size=1)
        self.branch2 = nn.Sequential(
            ConvBNActivation(in_channels, f2[0], kernel_size=1),
            ConvBNActivation(f2[0], f2[1], kernel_size=3),
        )
        self.branch3 = nn.Sequential(
            ConvBNActivation(in_channels, f3[0], kernel_size=1),
            ConvBNActivation(f3[0], f3[1], kernel_size=5),
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            ConvBNActivation(in_channels, f1, kernel_size=1),
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)
        return torch.cat([branch1, branch2, branch3, branch4], 1)


class GatedResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, strides):
        super(GatedResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=strides,
            padding=kernel_size // 2,
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels, 1)
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.activation = nn.Mish()

        # Add a shortcut connection if input and output dimensions don't match
        self.shortcut = nn.Sequential()
        if strides != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=strides, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        residual = self.shortcut(x)

        x = self.activation(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        gate = nn.Sigmoid()(self.bn3(self.conv3(x)))
        x = x * gate
        x += residual
        return self.activation(x)


class GuruNet(pl.LightningModule):
    def __init__(
        self,
        input_shape=(139, 139, 3),
        metadata_shape=None,
        classes=2,
    ):
        super(GuruNet, self).__init__()
        self.input_shape = input_shape
        self.metadata_shape = metadata_shape

        # Initial convolutional layer
        self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.activation = nn.Hardswish()

        # Inverted Residual Blocks
        self.inv_res_blocks = nn.ModuleList()
        block_params = [
            # expand_ratio, filters, strides, repeats
            (6, 16, 1, 1),
            (6, 24, 2, 2),
            (6, 40, 2, 2),
            (6, 80, 2, 3),
            (6, 112, 1, 3),
            (6, 128, 2, 4),
            (6, 196, 1, 1),
        ]

        in_channels = 128
        for i, (expand_ratio, filters, strides, repeats) in enumerate(block_params):
            for j in range(repeats):
                if j > 0:
                    strides = 1
                self.inv_res_blocks.append(
                    InvertedResidualBlock(in_channels, filters, expand_ratio, strides)
                )
                in_channels = filters

        # Dense Block
        self.dense_block = DenseBlock(in_channels, num_layers=16, growth_rate=32)
        in_channels += 16 * 32  # Update in_channels after dense block

        # Transition Layer
        self.transition = TransitionLayer(in_channels, compression_factor=0.5)
        in_channels = int(in_channels * 0.5)

        # Attention Block
        self.attention = AttentionBlock(in_channels, 128)
        in_channels = 128

        # Average Pooling
        self.avg_pool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)

        # Inception Block
        self.inception = InceptionBlock(in_channels, [128, (128, 192), (32, 96)])
        in_channels = 128 + 192 + 96 + 128

        # Attention Block
        self.attention2 = AttentionBlock(in_channels, 128)
        in_channels = 128

        # Gated Residual Block
        self.gated_res = GatedResidualBlock(in_channels, 256, kernel_size=3, strides=2)
        in_channels = 256

        # Attention Block
        self.attention3 = AttentionBlock(in_channels, 128)
        in_channels = 128

        # Global Average Pooling
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten()

        # Fully connected layers
        self.fc1 = nn.Linear(in_channels, 4096)
        self.bn_fc1 = nn.BatchNorm1d(4096)
        self.fc2 = nn.Linear(4096, 1024)
        self.bn_fc2 = nn.BatchNorm1d(1024)
        self.fc3 = nn.Linear(1024, 256)
        self.bn_fc3 = nn.BatchNorm1d(256)
        self.fc4 = nn.Linear(256, 128)
        self.bn_fc4 = nn.BatchNorm1d(128)
        self.dropout = nn.Dropout(0.5)

        self.metadata_fc1 = nn.Linear(40, 4096)
        self.metadata_bn1 = nn.BatchNorm1d(4096)
        self.metadata_fc2 = nn.Linear(4096, 1024)
        self.metadata_bn2 = nn.BatchNorm1d(1024)
        self.metadata_fc3 = nn.Linear(1024, 512)
        self.metadata_bn3 = nn.BatchNorm1d(512)
        self.metadata_fc4 = nn.Linear(512, 128)
        self.metadata_bn4 = nn.BatchNorm1d(128)
        self.final_fc = nn.Linear(128 + 128, classes)
        self.final_activation = nn.Softmax()
        self.scaler = GradScaler()
        self.loss = self.loss = nn.CrossEntropyLoss()
        self.auroc = PartialAUROC(min_tpr=0.8)

    def forward(self, x, metadata):
        x = self.activation(self.bn1(self.conv1(x)))

        # Inverted Residual Blocks
        for block in self.inv_res_blocks:
            x = block(x)

        # Dense Block
        x = self.dense_block(x)

        # Transition Layer
        x = self.transition(x)

        # Attention Block
        x = self.attention(x)

        # Average Pooling
        x = self.avg_pool(x)

        # Inception Block
        x = self.inception(x)

        # Attention Block
        x = self.attention2(x)

        # Gated Residual Block
        x = self.gated_res(x)

        # Attention Block
        x = self.attention3(x)

        x = self.global_avg_pool(x)
        x = self.flatten(x)
        x = self.activation(self.bn_fc1(self.fc1(x)))
        x = self.dropout(x)
        x = self.activation(self.bn_fc2(self.fc2(x)))
        x = self.dropout(x)
        x = self.activation(self.bn_fc3(self.fc3(x)))
        x = self.dropout(x)
        x = self.activation(self.bn_fc4(self.fc4(x)))

        metadata = self.activation(self.metadata_bn1(self.metadata_fc1(metadata)))
        metadata = self.dropout(metadata)
        metadata = self.activation(self.metadata_bn2(self.metadata_fc2(metadata)))
        metadata = self.dropout(metadata)
        metadata = self.activation(self.metadata_bn3(self.metadata_fc3(metadata)))
        metadata = self.dropout(metadata)
        metadata = self.activation(self.metadata_bn4(self.metadata_fc4(metadata)))

        x = torch.cat([x, metadata], dim=1)

        x = self.final_fc(x)
        # Apply sigmoid to ensure output is between 0 and 1
        x = self.final_activation(x)

        return x

    def training_step(self, batch, batch_idx):
        (images, metadata), targets = batch
        outputs = self(images, metadata)
        loss = self.loss(outputs, targets)  # targets is already one-hot encoded
        # Get the probability of the positive class
        pos_probs = outputs[:, 1].float().cpu()

        # Convert one-hot encoded targets to binary labels
        targets_binary = targets[:, 1].int().cpu()
        rocauc = self.auroc(
            pos_probs, targets_binary
        )  # Use class 1 probability

        self.log(
            "train_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )
        self.log(
            "train_pAUC",
            rocauc,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        (images, metadata), targets = batch
        outputs = self(images, metadata)
        loss = self.loss(outputs, targets)  # targets is already one-hot encoded
        # Get the probability of the positive class
        pos_probs = outputs[:, 1].float().cpu()

        # Convert one-hot encoded targets to binary labels
        targets_binary = targets[:, 1].int().cpu()
        rocauc = self.auroc(
            pos_probs, targets_binary
        )

        # Use class 1 probability

        self.log(
            "val_loss",
            loss,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )
        self.log(
            "val_pAUC",
            rocauc,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )

        return loss

    def test_step(self, batch, batch_idx):
        (images, metadata), targets = batch
        outputs = self(images, metadata)
        loss = self.loss(outputs, targets)  # targets is already one-hot encoded

        # Get the probability of the positive class
        pos_probs = outputs[:, 1].float().cpu()

        # Convert one-hot encoded targets to binary labels
        targets_binary = targets[:, 1].int().cpu()
        rocauc = self.auroc(
            pos_probs, targets_binary
        )

        self.log(
            "test_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )
        self.log(
            "test_pAUC",
            rocauc,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )

        return loss

    def configure_optimizers(self):
        optimizer = optim.NAdam(
            self.parameters(), lr=0.001, momentum_decay=0.5, weight_decay=1e-5
        )
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.2, patience=1, verbose=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "train_loss",
            },
        }

In [5]:
from sklearn.model_selection import train_test_split
from joblib import Parallel, delayed
import multiprocessing

import os
from sklearn.decomposition import PCA
from sklearn.feature_selection import VarianceThreshold
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
import numpy as np
import pandas as pd
import torch
import joblib
import hashlib


def prepare_df(
    df: pd.DataFrame, is_training=True, n_components=0.7, variance_threshold=0.95
):
    print("Preparing DataFrame...")
    df_hash = hashlib.md5(pd.util.hash_pandas_object(df).values).hexdigest()
    cache_dir = "./cache"
    param_string = f"{is_training}_{n_components}_{variance_threshold}"
    cache_file = os.path.join(cache_dir, f"prepared_df_{df_hash}_{param_string}.joblib")

    # Check if cached version exists
    if os.path.exists(cache_file):
        print("Loading cached prepared DataFrame...")
        return joblib.load(cache_file)
    start_time = time.time()

    drop_columns_train = [
        "lesion_id",
        "iddx_full",
        "iddx_1",
        "iddx_2",
        "iddx_3",
        "iddx_4",
        "iddx_5",
        "mel_mitotic_index",
        "mel_thick_mm",
        "tbp_lv_dnn_lesion_confidence",
    ]
    drop_columns_test = ["attribution", "copyright_license"]

    # train_metadata_df.drop(drop_columns_train, axis=1, inplace=True)
    # train_metadata_df.drop(drop_columns_test, axis=1, inplace=True)
    # test_metadata_df.drop(drop_columns_test, axis=1, inplace=True)
    if is_training:
        df.drop(drop_columns_train, axis=1, inplace=True)
    df.drop(drop_columns_test, axis=1, inplace=True)
    target_columns = ["target"] if is_training else []
    X = df.drop(target_columns + ["isic_id"], axis=1)
    y = torch.tensor(df["target"].values, dtype=torch.int8) if is_training else None

    # Separate features by type
    integer_features = X.select_dtypes(include=["int64", "int32", "int16"]).columns
    float_features = X.select_dtypes(include=["float64", "float32", "float16"]).columns
    categorical_features = X.select_dtypes(include=["object"]).columns

    # Handle NaN values and type conversions
    for feature in float_features:
        X[feature] = X[feature].fillna(X[feature].mean()).astype("float32")

    for feature in integer_features:
        X[feature] = X[feature].fillna(X[feature].median()).astype("int32")

    for feature in categorical_features:
        X[feature] = X[feature].astype(str).fillna("Unknown")
        X[feature] = pd.Categorical(X[feature]).codes

    # Feature Engineering
    print("Performing feature engineering...")

    # 1. Polynomial features for numeric columns
    numeric_features = list(float_features) + list(integer_features)
    poly = PolynomialFeatures(degree=2, include_bias=True)
    poly_features = poly.fit_transform(X[numeric_features])
    poly_feature_names = poly.get_feature_names_out(numeric_features)
    X_poly = pd.DataFrame(poly_features, columns=poly_feature_names, index=X.index)

    # 2. Interaction terms between categorical and numeric features
    for cat_feature in categorical_features:
        for num_feature in numeric_features:
            X[f"{cat_feature}_{num_feature}_interaction"] = (
                X[cat_feature] * X[num_feature]
            )

    # 3. Binning for numeric features
    for feature in numeric_features:
        X[f"{feature}_binned"] = pd.qcut(
            X[feature], q=5, labels=False, duplicates="drop"
        )

    # Combine all features
    X = pd.concat([X, X_poly], axis=1)

    # Standardize all numeric features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    X_scaled = pd.DataFrame(X_scaled, columns=X.columns, index=X.index)

    # Dimensionality Reduction
    print("Performing dimensionality reduction...")

    # 1. Variance Threshold
    selector = VarianceThreshold(threshold=variance_threshold)
    X_var_threshold = selector.fit_transform(X_scaled)

    # 2. PCA
    pca = PCA(n_components=40)
    X_pca = pca.fit_transform(X_var_threshold)

    print(f"Original number of features: {X.shape[1]}")
    print(f"Number of features after Variance Threshold: {X_var_threshold.shape[1]}")
    print(f"Number of features after PCA: {X_pca.shape[1]}")

    # Final dataset
    X_final = pd.DataFrame(X_pca, index=X.index)
    print(X_final.columns)

    # Final check for any remaining NaN values
    assert (
        not X_final.isnull().any().any()
    ), "There are still NaN values in the processed data"

    print("Data shape after preprocessing:", X_final.shape)
    print("Number of NaN values after preprocessing:", X_final.isnull().sum().sum())

    if is_training:
        print("Class distribution:")
        print(df["target"].value_counts(normalize=True))

    print(f"DataFrame prepared in {time.time() - start_time:.2f} seconds")
    print(f"Metadata Shape: {X_final.shape}")

    # Cache the results
    os.makedirs(cache_dir, exist_ok=True)
    joblib.dump((X_final, y, df["isic_id"]), cache_file)
    return X_final, y, df["isic_id"]


class ISICDataset(Dataset):
    def __init__(self, hdf5_path, metadata_df, is_training=True, transform=None):
        self.hdf5_path = hdf5_path
        self.metadata_df = metadata_df
        self.is_training = is_training
        self.transform = transform
        self.X, self.y, self.image_names = prepare_df(metadata_df, is_training)
        self.metadata_shape = self.X.shape
        self.train_transform = get_transforms(is_training=True)
        self.test_transform = get_transforms(is_training=False)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        if isinstance(idx, tuple):
            idx, augment = idx
        else:
            augment = False

        isic_id = self.image_names[idx]
        metadata = torch.tensor(self.X.iloc[idx].values, dtype=torch.float16)

        with h5py.File(self.hdf5_path, "r") as hdf:
            image_data = hdf[str(isic_id)][()]
            image = Image.open(io.BytesIO(image_data))

        if self.is_training and augment:
            image = self.train_transform(image)
        elif self.transform:
            image = self.transform(image)

        if self.is_training:
            target = self.y[idx]
            target_long = target.long()
            del target
            target_one_hot = nn.functional.one_hot(target_long, num_classes=2).float()
            return (image, metadata), target_one_hot
        else:
            return (image, metadata)


# Create separate transforms for training and validation
def get_transforms(is_training=True):
    # Define augmentation parameters
    ROTATION_RANGE = 90
    BRIGHTNESS_RANGE = (0.9, 1.1)
    CONTRAST_RANGE = (0.9, 1.1)
    SATURATION_RANGE = (0.9, 1.1)
    HUE_RANGE = (-0.001, 0.001)
    base_transforms = [
        transforms.ToTensor(),
        transforms.Resize((139, 139), antialias=True),
    ]

    if is_training:
        train_transforms = [
            transforms.RandomResizedCrop(
                size=(139, 139), scale=(0.9, 1.1), antialias=True
            ),
            transforms.RandomRotation(degrees=ROTATION_RANGE, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ColorJitter(
                brightness=BRIGHTNESS_RANGE,
                contrast=CONTRAST_RANGE,
                saturation=SATURATION_RANGE,
                hue=HUE_RANGE,
            ),
        ]
        return transforms.Compose(train_transforms + base_transforms)
    else:
        return transforms.Compose(base_transforms)


class ISICDataModule(pl.LightningDataModule):

    def __init__(
        self,
        train_hdf5_path: str,
        test_hdf5_path: str,
        train_metadata_df: pd.DataFrame,
        test_metadata_df: pd.DataFrame,
        batch_size: int = 32,
    ):
        super().__init__()
        self.train_hdf5_path = train_hdf5_path
        self.test_hdf5_path = test_hdf5_path
        self.batch_size = batch_size

        self.train_metadata_df = train_metadata_df
        self.test_metadata_df = test_metadata_df

    def setup(self, stage=None):
        full_dataset = ISICDataset(
            self.train_hdf5_path,
            self.train_metadata_df,
            True,
            transform=get_transforms(is_training=True),
        )
        self.metadata_shape = full_dataset.metadata_shape
        # Get targets for stratification
        targets = self.train_metadata_df["target"].values
        balanced_indices = self.balance_dataset(np.arange(len(full_dataset)), targets)
        balanced_targets = targets[balanced_indices]
        print(f"Unique indices: {np.unique(balanced_indices)}")
        print(f"Unique targets: {np.unique(balanced_targets)}")
        print(len(balanced_indices))
        print(len(balanced_targets))
        unique, counts = np.unique(balanced_targets, return_counts=True)
        print(dict(zip(unique, counts)))
        # Perform stratified split
        train_indices, temp_indices, train_targets, temp_targets = train_test_split(
            balanced_indices,
            balanced_targets,
            test_size=0.2,
            stratify=balanced_targets,
            random_state=42,
        )

        val_indices, test_indices, val_targets, test_targets = train_test_split(
            temp_indices,
            temp_targets,
            test_size=0.5,
            stratify=temp_targets,
            random_state=42,
        )

        # Create subset datasets
        if stage in ["fit", "validate", "test"]:
            self.train_dataset = Subset(full_dataset, train_indices)
            self.val_dataset = Subset(full_dataset, val_indices)
            self.test_dataset = Subset(full_dataset, test_indices)

        # Check for class balance
        self._check_class_balance(train_targets.flatten(), "Train")
        self._check_class_balance(val_targets.flatten(), "Validation")
        self._check_class_balance(test_targets.flatten(), "Test")

        print(f"Length of full_dataset: {len(full_dataset)}")
        print(
            f"Length of train_indices: {len(train_indices)}, max index: {max(train_indices)}"
        )
        print(
            f"Length of val_indices: {len(val_indices)}, max index: {max(val_indices)}"
        )
        print(
            f"Length of test_indices: {len(test_indices)}, max index: {max(test_indices)}"
        )

    def balance_dataset(self, indices, targets):
        np.random.seed(42)
        labels = targets
        positive_indices = indices[np.where(labels[indices] == 1)[0]]
        negative_indices = indices[np.where(labels[indices] == 0)[0]]

        num_positive_samples = len(positive_indices)
        num_negative_samples = len(negative_indices)

        num_avg_samples =  num_positive_samples

        # Upsample positive indices
        upsampled_positive_indices = np.random.choice(positive_indices, size=num_avg_samples, replace=True)
        
        # Add augmentation flag to upsampled positive indices
        upsampled_positive_indices = [(idx, True) for idx in upsampled_positive_indices]

        # Downsample negative indices (no augmentation needed)
        downsampled_negative_indices = np.random.choice(negative_indices, size=num_avg_samples, replace=False)
        downsampled_negative_indices = [(idx, False) for idx in downsampled_negative_indices]

        balanced_indices = upsampled_positive_indices + downsampled_negative_indices
        np.random.shuffle(balanced_indices)

        return balanced_indices

    def balance_dataset_parallel(self, dataset, indices, targets):
        num_cores = multiprocessing.cpu_count()
        print(f"Number of CPU cores available: {num_cores}")

        # Use joblib for parallel processing
        balanced_indices = Parallel(n_jobs=num_cores)(
            delayed(self.balance_indices)(dataset, indices[i], targets)
            for i in range(len(indices))
        )

        return [idx for sublist in balanced_indices for idx in sublist]

    def _check_class_balance(self, targets, split_name):
        class_counts = np.bincount(targets)
        print(
            f"{split_name} class distribution: {class_counts / len(targets)}, {len(targets)}"
        )
        if len(class_counts) < 2 or min(class_counts) == 0:
            raise ValueError(f"Imbalanced classes in {split_name} split")

    def train_dataloader(self):
        data_loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=16,
            pin_memory=True,
        )
        print(f"Number of batches in train_loader: {len(data_loader)}")
        return data_loader

    def val_dataloader(self):
        data_loader = DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
        )
        print(f"Number of batches in val_loader: {len(data_loader)}")
        return data_loader

    def test_dataloader(self):

        data_loader = DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
        )
        print(f"Number of batches in test_loader: {len(data_loader)}")
        return data_loader

In [6]:
# Define parameters
img_height, img_width = 139, 139

# Load metadata
train_metadata_df = pd.read_csv("train-metadata.csv")
test_metadata_df = pd.read_csv("test-metadata.csv")

  train_metadata_df = pd.read_csv("train-metadata.csv")


In [7]:
import os

os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
torch.cuda.memory.empty_cache()

batch_size = 128
epochs = 20

logger = TensorBoardLogger("tb_logs", name="gurunet_model")

checkpoint_callback = ModelCheckpoint(
    dirpath=f"checkpoints/version_{logger.version}",
    filename="gurunet-{epoch:02d}-{val_loss:.2f}",
    save_top_k=3,
    monitor="val_loss",
    mode="min",
    verbose=True,
)

early_stop_callback = EarlyStopping(monitor="val_loss", patience=3, mode="min")

lr_monitor = LearningRateMonitor(logging_interval="step")

# Initialize your data module
data_module = ISICDataModule(
    "train-image.hdf5",
    "test-image.hdf5",
    train_metadata_df,
    test_metadata_df,
    batch_size=batch_size,
)


# Initialize your model
model = GuruNet(
    input_shape=(139, 139, 3),
    metadata_shape=(None, 37),
    classes=2,
)
# Initialize a trainer
trainer = pl.Trainer(
    max_epochs=epochs,
    accelerator="gpu",
    gpus=1 if torch.cuda.is_available() else 0,
    callbacks=[
        checkpoint_callback,
        early_stop_callback,
        lr_monitor,
    ],
    logger=logger,
    precision=16,
    deterministic=True,
    # accumulate_grad_batches=2,
)

# Train the model
trainer.fit(model, data_module)

# Test the model

Using native 16bit precision.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Preparing DataFrame...
Performing feature engineering...


  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_feature}_{num_feature}_interaction"] = (
  X[f"{cat_f

Performing dimensionality reduction...
Original number of features: 943
Number of features after Variance Threshold: 907
Number of features after PCA: 40
RangeIndex(start=0, stop=40, step=1)
Data shape after preprocessing: (401059, 40)
Number of NaN values after preprocessing: 0
Class distribution:
target
0    0.99902
1    0.00098
Name: proportion, dtype: float64
DataFrame prepared in 55.66 seconds
Metadata Shape: (401059, 40)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Unique indices: [     0      1    579    935   1280   1846   2200   2538   3133   3463
   3478   4188   4812   6295   6741   7134   7885   8532   8920   9217
  10007  10041  10162  10305  10400  10732  11114  12448  12742  13202
  13261  15163  15469  15947  16541  17219  17646  18669  18957  19551
  20007  20265  20740  21253  21781  22439  23053  23569  25000  25764
  25765  25840  25864  26697  27945  28079  28208  28354  29299  29686
  30007  30013  30396  31509  32489  33259  34343  36157  36210  36436
  36688  38474  39054  39934  40145  40459  41526  41888  41890  42127
  43028  43094  43326  43483  43578  44556  45317  46166  46405  47061
  47617  47682  48245  49295  49405  50337  51162  52008  53533  53747
  54515  56632  57014  57353  58719  58989  59441  59841  60739  61886
  62239  62793  62925  64457  64578  65053  65631  65736  67392  67519
  67748  67921  68909  69215  69644  74415  74523  76039  76292  77832
  78124  78384  78480  78569  78846  81522  81744  82544  829


   | Name             | Type               | Params
---------------------------------------------------------
0  | conv1            | Conv2d             | 3.6 K 
1  | bn1              | BatchNorm2d        | 256   
2  | activation       | Hardswish          | 0     
3  | inv_res_blocks   | ModuleList         | 1.8 M 
4  | dense_block      | DenseBlock         | 1.5 M 
5  | transition       | TransitionLayer    | 252 K 
6  | attention        | AttentionBlock     | 91.5 K
7  | avg_pool         | AvgPool2d          | 0     
8  | inception        | InceptionBlock     | 352 K 
9  | attention2       | AttentionBlock     | 140 K 
10 | gated_res        | GatedResidualBlock | 461 K 
11 | attention3       | AttentionBlock     | 66.4 K
12 | global_avg_pool  | AdaptiveAvgPool2d  | 0     
13 | flatten          | Flatten            | 0     
14 | fc1              | Linear             | 528 K 
15 | bn_fc1           | BatchNorm1d        | 8.2 K 
16 | fc2              | Linear             | 4.2 M 
17 | 

Validation sanity check: 0it [00:00, ?it/s]Number of batches in val_loader: 1
                                                                      

  x = self.final_activation(x)
Global seed set to 42
  rank_zero_warn(


Number of batches in train_loader: 5
Epoch 0: 100%|██████████| 6/6 [00:04<00:00,  1.70it/s, loss=0.587, v_num=74, train_loss_step=0.514, train_pAUC_step=0.141, val_loss=0.550, val_pAUC=0.158]

Epoch 0, global step 4: val_loss reached 0.55047 (best 0.55047), saving model to "checkpoints/version_74/gurunet-epoch=00-val_loss=0.55.ckpt" as top 3


Epoch 1: 100%|██████████| 6/6 [00:03<00:00,  1.96it/s, loss=0.541, v_num=74, train_loss_step=0.492, train_pAUC_step=0.141, val_loss=0.481, val_pAUC=0.160, train_loss_epoch=0.589, train_pAUC_epoch=0.0935]

Epoch 1, global step 9: val_loss reached 0.48089 (best 0.48089), saving model to "checkpoints/version_74/gurunet-epoch=01-val_loss=0.48.ckpt" as top 3


Epoch 2: 100%|██████████| 6/6 [00:03<00:00,  1.98it/s, loss=0.517, v_num=74, train_loss_step=0.494, train_pAUC_step=0.137, val_loss=0.468, val_pAUC=0.153, train_loss_epoch=0.495, train_pAUC_epoch=0.147]    

Epoch 2, global step 14: val_loss reached 0.46825 (best 0.46825), saving model to "checkpoints/version_74/gurunet-epoch=02-val_loss=0.47.ckpt" as top 3


Epoch 3: 100%|██████████| 6/6 [00:03<00:00,  1.99it/s, loss=0.498, v_num=74, train_loss_step=0.433, train_pAUC_step=0.165, val_loss=0.445, val_pAUC=0.161, train_loss_epoch=0.469, train_pAUC_epoch=0.153]   

Epoch 3, global step 19: val_loss reached 0.44506 (best 0.44506), saving model to "checkpoints/version_74/gurunet-epoch=03-val_loss=0.45.ckpt" as top 3


Epoch 4: 100%|██████████| 6/6 [00:03<00:00,  2.00it/s, loss=0.458, v_num=74, train_loss_step=0.433, train_pAUC_step=0.164, val_loss=0.440, val_pAUC=0.163, train_loss_epoch=0.442, train_pAUC_epoch=0.167]   

Epoch 4, global step 24: val_loss reached 0.44029 (best 0.44029), saving model to "checkpoints/version_74/gurunet-epoch=04-val_loss=0.44.ckpt" as top 3


Epoch 5: 100%|██████████| 6/6 [00:04<00:00,  1.61it/s, loss=0.438, v_num=74, train_loss_step=0.435, train_pAUC_step=0.155, val_loss=0.442, val_pAUC=0.165, train_loss_epoch=0.424, train_pAUC_epoch=0.174]   

Epoch 5, global step 29: val_loss reached 0.44193 (best 0.44029), saving model to "checkpoints/version_74/gurunet-epoch=05-val_loss=0.44.ckpt" as top 3


Epoch 6: 100%|██████████| 6/6 [00:03<00:00,  1.95it/s, loss=0.423, v_num=74, train_loss_step=0.414, train_pAUC_step=0.178, val_loss=0.439, val_pAUC=0.162, train_loss_epoch=0.416, train_pAUC_epoch=0.175]   

Epoch 6, global step 34: val_loss reached 0.43900 (best 0.43900), saving model to "checkpoints/version_74/gurunet-epoch=06-val_loss=0.44.ckpt" as top 3


Epoch 7: 100%|██████████| 6/6 [00:03<00:00,  1.93it/s, loss=0.413, v_num=74, train_loss_step=0.426, train_pAUC_step=0.168, val_loss=0.439, val_pAUC=0.163, train_loss_epoch=0.412, train_pAUC_epoch=0.179]   

Epoch 7, global step 39: val_loss reached 0.43854 (best 0.43854), saving model to "checkpoints/version_74/gurunet-epoch=07-val_loss=0.44.ckpt" as top 3


Epoch 8: 100%|██████████| 6/6 [00:03<00:00,  1.91it/s, loss=0.406, v_num=74, train_loss_step=0.378, train_pAUC_step=0.188, val_loss=0.434, val_pAUC=0.160, train_loss_epoch=0.399, train_pAUC_epoch=0.181]   

Epoch 8, global step 44: val_loss reached 0.43369 (best 0.43369), saving model to "checkpoints/version_74/gurunet-epoch=08-val_loss=0.43.ckpt" as top 3


Epoch 9: 100%|██████████| 6/6 [00:03<00:00,  2.00it/s, loss=0.398, v_num=74, train_loss_step=0.387, train_pAUC_step=0.186, val_loss=0.444, val_pAUC=0.156, train_loss_epoch=0.398, train_pAUC_epoch=0.180]   

Epoch 9, global step 49: val_loss was not in top 3


Epoch 10: 100%|██████████| 6/6 [00:03<00:00,  1.90it/s, loss=0.394, v_num=74, train_loss_step=0.401, train_pAUC_step=0.160, val_loss=0.446, val_pAUC=0.163, train_loss_epoch=0.384, train_pAUC_epoch=0.186]  

Epoch 10, global step 54: val_loss was not in top 3


Epoch 11: 100%|██████████| 6/6 [00:03<00:00,  1.91it/s, loss=0.39, v_num=74, train_loss_step=0.367, train_pAUC_step=0.176, val_loss=0.444, val_pAUC=0.161, train_loss_epoch=0.394, train_pAUC_epoch=0.180]    

Epoch 11, global step 59: val_loss was not in top 3


Epoch 11: 100%|██████████| 6/6 [00:03<00:00,  1.91it/s, loss=0.39, v_num=74, train_loss_step=0.367, train_pAUC_step=0.176, val_loss=0.444, val_pAUC=0.161, train_loss_epoch=0.394, train_pAUC_epoch=0.180]


In [8]:
trainer.test(model=model, datamodule=data_module)

  rank_zero_deprecation(


Preparing DataFrame...


KeyError: "['lesion_id', 'iddx_full', 'iddx_1', 'iddx_2', 'iddx_3', 'iddx_4', 'iddx_5', 'mel_mitotic_index', 'mel_thick_mm', 'tbp_lv_dnn_lesion_confidence'] not found in axis"