In [None]:
!git clone https://github.com/proxy-pylon/medical-deep-learning-project.git
%cd medical-deep-learning-project

Cloning into 'medical-deep-learning-project'...
remote: Enumerating objects: 92, done.[K
remote: Counting objects: 100% (92/92), done.[K
remote: Compressing objects: 100% (89/89), done.[K
remote: Total 92 (delta 34), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (92/92), 1.44 MiB | 18.23 MiB/s, done.
Resolving deltas: 100% (34/34), done.
/content/medical-deep-learning-project/medical-deep-learning-project


In [None]:
!pip install -r requirements.txt



In [None]:
import torch, torchvision, platform
print("torch:", torch.__version__, "cuda:", torch.version.cuda)
print("torchvision:", torchvision.__version__)
print("is_cuda_available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("device:", torch.cuda.get_device_name(0))
    print("compute capability:", torch.cuda.get_device_capability(0))

print("To disable GPU and force CPU usage, write this in your terminal before running the main script:")
print('export CUDA_VISIBLE_DEVICES=\"\"')  # for Linux/Mac
print('set CUDA_VISIBLE_DEVICES=')        # for Windows

torch: 2.8.0+cu126 cuda: 12.6
torchvision: 0.23.0+cu126
is_cuda_available: True
device: Tesla T4
compute capability: (7, 5)
To disable GPU and force CPU usage, write this in your terminal before running the main script:
export CUDA_VISIBLE_DEVICES=""
set CUDA_VISIBLE_DEVICES=


In [None]:
pip install kaggle



In [18]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("kmader/skin-cancer-mnist-ham10000")

print("Path to dataset files:", path)

Using Colab cache for faster access to the 'skin-cancer-mnist-ham10000' dataset.
Path to dataset files: /kaggle/input/skin-cancer-mnist-ham10000


In [19]:
!ls "/root/.cache/kagglehub/datasets/kmader/skin-cancer-mnist-ham10000/versions/2"



ls: cannot access '/root/.cache/kagglehub/datasets/kmader/skin-cancer-mnist-ham10000/versions/2': No such file or directory


In [21]:
# Standard library
import os
import sys
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

# Third-party libraries
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm import tqdm
from numpy.typing import ArrayLike
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
    roc_curve,
    auc,
    brier_score_loss,
    log_loss,
    precision_recall_curve
)
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torchvision.models import ResNet50_Weights, EfficientNet_B0_Weights

# Warnings
warnings.filterwarnings('ignore')

import math
from collections import OrderedDict

# Add this to your imports section
class SEModule(nn.Module):
    """Squeeze-and-Excitation module"""
    def __init__(self, channels, reduction=16):
        super(SEModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class SEBasicBlock(nn.Module):
    """Basic ResNet block with SE module"""
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
        super(SEBasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.se = SEModule(planes, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class SEBottleneck(nn.Module):
    """Bottleneck ResNet block with SE module"""
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
        super(SEBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.se = SEModule(planes * 4, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

def se_resnet50(pretrained=False, **kwargs):
    """Constructs a SE-ResNet-50 model."""
    model = SEResNet(SEBottleneck, [3, 4, 6, 3], **kwargs)
    return model

class SEResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000, reduction=16):
        self.inplanes = 64
        super(SEResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], reduction=reduction)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, reduction=reduction)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, reduction=reduction)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, reduction=reduction)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, reduction=16):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                         kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, reduction))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, reduction=reduction))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

class Config:
    """Container for training, model, and path configuration.

    Attributes
    ----------
    HAM10000_BASE : str
        Base directory for the HAM10000 dataset.
    ISIC_BASE : str
        Base directory for the ISIC dataset.
    OUTPUT_DIR : str
        Base output directory.
    CHECKPOINT_DIR : str
        Directory for checkpoints.
    RESULTS_DIR : str
        Directory for results and plots.
    MODEL_NAME : str
        Backbone model identifier. One of {'resnet50', 'efficientnet', 'vgg16'}.
    IMG_SIZE : int
        Input image size (square).
    NUM_CLASSES : int
        Number of output classes.
    PRETRAINED : bool
        Whether to use ImageNet pretrained weights.
    BATCH_SIZE : int
        Training batch size.
    NUM_EPOCHS : int
        Maximum number of epochs.
    LEARNING_RATE : float
        Base learning rate (not used directly when discriminative LRs are applied).
    WEIGHT_DECAY : float
        Weight decay for optimizers.
    EARLY_STOPPING_PATIENCE : int
        Patience for early stopping.
    FREEZE_EPOCHS : int
        Number of warmup epochs training the head only.
    HEAD_LR_WARMUP : float
        Learning rate for head during warmup.
    HEAD_LR_FINETUNE : float
        Learning rate for head during fine-tuning.
    BACKBONE_LR_LOW : float
        LR for earliest backbone layers.
    BACKBONE_LR_MID : float
        LR for mid backbone layers.
    BACKBONE_LR_HIGH : float
        LR for deepest backbone layers.
    TEST_SIZE : float
        Proportion for test split.
    VAL_SIZE : float
        Proportion for validation split (from the non-train part).
    RANDOM_STATE : int
        Random seed for splitting.
    USE_MIXUP : bool
        Placeholder flag for mixup usage.
    USE_CUTMIX : bool
        Placeholder flag for cutmix usage.
    DEVICE : str
        'cuda' if available else 'cpu'.
    NUM_WORKERS : int
        DataLoader workers.
    """

    # Data Paths
    HAM10000_BASE = '/kaggle/input/skin-cancer-mnist-ham10000'  # Change this path
    ISIC_BASE = 'not defined lol'
    # Output Paths
    OUTPUT_DIR = './output/'
    CHECKPOINT_DIR = OUTPUT_DIR + 'checkpoints'
    RESULTS_DIR = OUTPUT_DIR + 'results'

    # Model configurations
    MODEL_NAME = 'senet50'  # Options: 'resnet50', 'efficientnet', 'vgg16', 'senet50'
    IMG_SIZE = 224
    NUM_CLASSES = 2  # Binary: melanoma vs benign
    PRETRAINED = True

    # Training configurations
    BATCH_SIZE = 32
    NUM_EPOCHS = 250
    LEARNING_RATE = 0.001
    WEIGHT_DECAY = 1e-4
    EARLY_STOPPING_PATIENCE = 30

    # Warmup / fine-tune
    FREEZE_EPOCHS = 3
    HEAD_LR_WARMUP = 1e-3
    HEAD_LR_FINETUNE = 1e-4

    # Discriminative LRs for backbone
    BACKBONE_LR_LOW = 1e-5
    BACKBONE_LR_MID = 2e-5
    BACKBONE_LR_HIGH = 3e-5

    # Dataset split ratio and seeding
    TEST_SIZE = 0.30
    VAL_SIZE = 0.20
    RANDOM_STATE = 42

    # Augmentation flags
    USE_MIXUP = False
    USE_CUTMIX = False

    # Device
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    NUM_WORKERS = 2


def load_ham10000_data(base_path: Union[str, os.PathLike]) -> pd.DataFrame:
    """Load HAM10000 metadata and resolve image paths.

    Parameters
    ----------
    base_path : str or os.PathLike
        Directory containing `HAM10000_metadata.csv` and image folders.

    Returns
    -------
    pd.DataFrame
        DataFrame with columns including:
        - 'image_id'
        - 'dx' (diagnosis)
        - 'image_path' (resolved path to JPEG)
        - 'binary_label' (1 for melanoma, 0 otherwise)
    """
    print('Loading HAM10000 dataset....')
    base_path = str(base_path)

    metadata_path = os.path.join(base_path, 'HAM10000_metadata.csv')
    df = pd.read_csv(metadata_path)

    def get_image_path(image_id: str) -> Optional[str]:
        """Return path to image file if present in either image folder."""
        part1 = os.path.join(base_path, 'HAM10000_images_part_1', f'{image_id}.jpg')
        part2 = os.path.join(base_path, 'HAM10000_images_part_2', f'{image_id}.jpg')
        if os.path.exists(part1):
            return part1
        if os.path.exists(part2):
            return part2
        return None

    df['image_path'] = df['image_id'].apply(get_image_path)
    df = df[df['image_path'].notna()].reset_index(drop=True)
    df['binary_label'] = (df['dx'] == 'mel').astype(int)

    print(f"Loaded {len(df)} images")
    print(f"Melanoma: {df['binary_label'].sum()}")
    print(f"Benign: {len(df) - df['binary_label'].sum()}")
    print(f"\nClass Distribution:")
    print(df['dx'].value_counts())

    return df


class MelanomaDataset(Dataset):
    """PyTorch Dataset for melanoma classification with Albumentations transforms.

    Parameters
    ----------
    dataframe : pd.DataFrame
        DataFrame with at least columns 'image_path' and 'binary_label'.
    transform : Optional[A.Compose]
        Albumentations transform to apply to the loaded image.

    Returns
    -------
    dict
        A sample dict containing:
        - 'image': torch.Tensor of shape [3, H, W]
        - 'label': torch.LongTensor scalar {0, 1}
    """

    def __init__(self, dataframe: pd.DataFrame, transform: Optional[A.Compose] = None) -> None:
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform

    def __len__(self) -> int:
        """Number of samples."""
        return len(self.df)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Load one sample by index.

        Parameters
        ----------
        idx : int
            Row index.

        Returns
        -------
        dict
            Dict with 'image' and 'label' tensors.
        """
        img_path: str = self.df.loc[idx, 'image_path']
        label: int = int(self.df.loc[idx, 'binary_label'])

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        return {
            'image': image,
            'label': torch.tensor(label, dtype=torch.long)
        }


def get_train_transform(img_size: int = 224) -> A.Compose:
    """Build the training augmentation pipeline.

    Parameters
    ----------
    img_size : int, default=224
        Target square image size.

    Returns
    -------
    A.Compose
        Albumentations composition.
    """
    return A.Compose([
        A.RandomResizedCrop(
            size=(img_size, img_size),
            scale=(0.90, 1.00),
            ratio=(0.9, 1.1),
            p=1.0
        ),
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=20, border_mode=cv2.BORDER_REFLECT_101, p=0.7),
        A.CenterCrop(height=img_size, width=img_size, p=1.0),
        A.Affine(scale=(0.95, 1.05), translate_percent=(-0.02, 0.02),
                 shear=(-5, 5), mode=cv2.BORDER_REFLECT_101, p=0.5),
        A.RandomBrightnessContrast(0.10, 0.10, p=0.3),
        A.ColorJitter(0.05, 0.05, 0.05, 0.02, p=0.2),
        A.CLAHE(clip_limit=(1, 2), tile_grid_size=(8, 8), p=0.2),
        A.OneOf([
            A.GaussNoise(var_limit=(10.0, 50.0)),
            A.GaussianBlur(blur_limit=(3, 7)),
            A.MedianBlur(blur_limit=5),
        ], p=0.3),
        A.CoarseDropout(max_holes=1, min_holes=1,
                        max_height=int(0.08 * img_size), max_width=int(0.08 * img_size),
                        min_height=8, min_width=8, fill_value=0, p=0.15),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])


def get_val_transform(img_size: int = 224) -> A.Compose:
    """Build the validation/test preprocessing pipeline.

    Parameters
    ----------
    img_size : int, default=224
        Target square image size.

    Returns
    -------
    A.Compose
        Albumentations composition for eval.
    """
    return A.Compose([
        A.Resize(img_size, img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])


class MelanomaClassifier(nn.Module):
    """CNN classifier using a torchvision backbone with a custom head.

    Parameters
    ----------
    model_name : str, default='resnet50'
        One of {'resnet50', 'efficientnet'} currently supported.
    num_classes : int, default=2
        Number of output classes.
    pretrained : bool, default=True
        If True, initialize backbone with pretrained weights.

    Notes
    -----
    Forward returns both logits and the pooled backbone features.
    """
    def __init__(self, model_name: str = 'resnet50', num_classes: int = 2, pretrained: bool = True) -> None:
        super().__init__()

        if model_name == 'resnet50':
            self.backbone = models.resnet50(
                weights=ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
            )
            num_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()

        elif model_name == 'efficientnet':
            self.backbone = models.efficientnet_b0(
                weights=EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
            )
            num_features = self.backbone.classifier[1].in_features
            self.backbone.classifier = nn.Identity()

        elif model_name == 'senet50':
            self.backbone = se_resnet50(pretrained=False)  #no pretrained weights available by default
            num_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()

        else:
            raise ValueError(f"Unsupported model_name: {model_name}")

        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        features = self.backbone(x)
        output = self.classifier(features)
        return output, features

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Input batch of shape [N, 3, H, W].

        Returns
        -------
        (torch.Tensor, torch.Tensor)
            - logits: [N, C]
            - features: pooled features from the backbone, shape [N, F]
        """
        features = self.backbone(x)
        output = self.classifier(features)
        return output, features


def set_backbone_trainable(model: nn.Module, trainable: bool) -> None:
    """Enable or disable gradient updates for backbone parameters.

    Parameters
    ----------
    model : nn.Module
        Model with attribute `backbone`.
    trainable : bool
        If True, unfreezes the backbone; if False, freezes it.
    """
    for p in model.backbone.parameters():
        p.requires_grad = trainable


def get_param_groups_discriminative(model: nn.Module, config: Config) -> List[Dict[str, Any]]:
    """Create optimizer parameter groups with discriminative learning rates.

    Parameters
    ----------
    model : nn.Module
        Model with `backbone` and `classifier`.
    config : Config
        Hyperparameters containing LR tiers and weight decay.

    Returns
    -------
    list of dict
        Parameter groups consumable by torch optimizer.
    """
    param_groups: List[Dict[str, Any]] = []

    # Classifier / head
    param_groups.append({
        "params": list(model.classifier.parameters()),
        "lr": config.HEAD_LR_FINETUNE,
        "weight_decay": config.WEIGHT_DECAY
    })

    bb = model.backbone
    low = config.BACKBONE_LR_LOW
    mid = config.BACKBONE_LR_MID
    high = config.BACKBONE_LR_HIGH

    if isinstance(bb, models.ResNet) or hasattr(bb, 'layer1'):  # Includes both ResNet and SENet
        tiers = [
            (["conv1", "bn1"], low),
            (["layer1"], low),
            (["layer2"], mid),
            (["layer3"], high),
            (["layer4"], high),
        ]
        for names, lr in tiers:
            params: List[nn.Parameter] = []
            for n in names:
                m = getattr(bb, n)
                params += list(m.parameters())
            param_groups.append({"params": params, "lr": lr, "weight_decay": config.WEIGHT_DECAY})

    elif hasattr(bb, "features"):  # EfficientNet-style
        feat = bb.features
        n = len(feat)
        cut1 = max(1, n // 3)
        cut2 = max(cut1 + 1, (2 * n) // 3)

        early = list(feat[:cut1].parameters())      # low
        middle = list(feat[cut1:cut2].parameters()) # mid
        late = list(feat[cut2:].parameters())       # high

        param_groups += [
            {"params": early,  "lr": low,  "weight_decay": config.WEIGHT_DECAY},
            {"params": middle, "lr": mid,  "weight_decay": config.WEIGHT_DECAY},
            {"params": late,   "lr": high, "weight_decay": config.WEIGHT_DECAY},
        ]
    else:
        param_groups.append({
            "params": [p for p in model.backbone.parameters() if p.requires_grad],
            "lr": mid,
            "weight_decay": config.WEIGHT_DECAY
        })

    return param_groups

def build_optimizer_warmup(model: nn.Module, config: Config) -> optim.Optimizer:
    """Create AdamW optimizer for the head-only warmup phase.

    Parameters
    ----------
    model : nn.Module
        Model with a `classifier` module.
    config : Config
        Hyperparameters.

    Returns
    -------
    torch.optim.Optimizer
        AdamW optimizer over head parameters.
    """
    head_params = [p for p in model.classifier.parameters() if p.requires_grad]
    return optim.AdamW(head_params, lr=config.HEAD_LR_WARMUP, weight_decay=config.WEIGHT_DECAY)


def build_optimizer_finetune(model: nn.Module, config: Config) -> optim.Optimizer:
    """Create AdamW optimizer with discriminative LRs for fine-tuning.

    Parameters
    ----------
    model : nn.Module
        Model with `backbone` and `classifier`.
    config : Config
        Hyperparameters.

    Returns
    -------
    torch.optim.Optimizer
        AdamW optimizer over grouped parameters.
    """
    groups = get_param_groups_discriminative(model, config)
    return optim.AdamW(groups)


class FocalLoss(nn.Module):
    """Focal Loss for addressing class imbalance.

    Parameters
    ----------
    alpha : Optional[torch.Tensor or float], default=None
        Class weighting factor. Either scalar or tensor of shape [num_classes].
    gamma : float, default=2.0
        Focusing parameter to down-weight easy examples.
    reduction : {'none', 'mean', 'sum'}, default='mean'
        Reduction mode.

    Returns
    -------
    torch.Tensor
        Reduced loss according to `reduction`.
    """

    def __init__(self, alpha: Optional[Union[torch.Tensor, float]] = None,
                 gamma: float = 2.0, reduction: str = 'mean') -> None:
        super().__init__()
        self.alpha = alpha  # tensor of shape [num_classes] or scalar
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """Compute focal loss.

        Parameters
        ----------
        inputs : torch.Tensor
            Logits of shape [N, C].
        targets : torch.Tensor
            Integer class labels of shape [N].

        Returns
        -------
        torch.Tensor
            Loss tensor reduced by `reduction`.
        """
        ce_loss = F.cross_entropy(inputs, targets, weight=self.alpha, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        if self.reduction == 'sum':
            return focal_loss.sum()
        return focal_loss


class TemperatureScaler(nn.Module):
    """Post-hoc calibration via temperature scaling."""

    def __init__(self) -> None:
        super().__init__()
        self.log_temperature = nn.Parameter(torch.zeros(1))

    @property
    def T(self) -> torch.Tensor:
        """Return positive temperature parameter."""
        return self.log_temperature.exp()

    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        """Scale logits by temperature.

        Parameters
        ----------
        logits : torch.Tensor
            Unnormalized model outputs of shape [N, C].

        Returns
        -------
        torch.Tensor
            Scaled logits of shape [N, C].
        """
        return logits / self.T


def _nll_criterion(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """Negative log-likelihood (cross-entropy) for calibration."""
    return F.cross_entropy(logits, targets)


@torch.no_grad()
def collect_logits_and_labels(
    model: nn.Module,
    loader: DataLoader,
    device: Union[str, torch.device]
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Run model over a DataLoader and collect logits and labels.

    Parameters
    ----------
    model : nn.Module
        Trained classifier.
    loader : DataLoader
        DataLoader yielding dicts with 'image' and 'label'.
    device : str or torch.device
        Computation device.

    Returns
    -------
    (torch.Tensor, torch.Tensor)
        - logits: [N, C]
        - labels: [N]
    """
    model.eval()
    all_logits: List[torch.Tensor] = []
    all_labels: List[torch.Tensor] = []
    for batch in tqdm(loader, desc='Collecting logits for calibration'):
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        outputs, _ = model(images)
        all_logits.append(outputs.detach().cpu())
        all_labels.append(labels.detach().cpu())
    logits = torch.cat(all_logits, dim=0)
    labels = torch.cat(all_labels, dim=0)
    return logits, labels


def fit_temperature(
    model: nn.Module,
    val_loader: DataLoader,
    device: Union[str, torch.device],
    max_iter: int = 200,
    lr: float = 0.01,
    verbose: bool = True
) -> TemperatureScaler:
    """Fit a TemperatureScaler on validation data by minimizing NLL.

    Parameters
    ----------
    model : nn.Module
        Classifier to calibrate.
    val_loader : DataLoader
        Validation loader.
    device : str or torch.device
        Device for computation.
    max_iter : int, default=200
        Max iterations for the fallback Adam optimizer.
    lr : float, default=0.01
        LR for fallback Adam.
    verbose : bool, default=True
        Print calibration summary.

    Returns
    -------
    TemperatureScaler
        Fitted temperature scaler.
    """
    logits, labels = collect_logits_and_labels(model, val_loader, device)
    scaler = TemperatureScaler().to(device)
    optimizer = torch.optim.LBFGS(scaler.parameters(), lr=0.25, max_iter=50, line_search_fn='strong_wolfe')

    logits = logits.to(device)
    labels = labels.to(device)

    def closure() -> torch.Tensor:
        optimizer.zero_grad()
        loss = _nll_criterion(scaler(logits), labels)
        loss.backward()
        return loss

    try:
        optimizer.step(closure)
    except Exception:
        opt2 = torch.optim.Adam([scaler.log_temperature], lr=lr)
        for _ in range(max_iter):
            opt2.zero_grad()
            loss = _nll_criterion(scaler(logits), labels)
            loss.backward()
            opt2.step()

    if verbose:
        with torch.no_grad():
            before = _nll_criterion(logits, labels).item()
            after = _nll_criterion(scaler(logits), labels).item()
            print(f"Temperature learned: T={scaler.T.item():.4f} | NLL: {before:.4f} -> {after:.4f}")
    return scaler


def apply_temperature(
    logits: torch.Tensor,
    scaler: Optional[TemperatureScaler],
    device: Union[str, torch.device]
) -> torch.Tensor:
    """Apply a fitted temperature scaler to logits if provided.

    Parameters
    ----------
    logits : torch.Tensor
        Raw logits [N, C].
    scaler : Optional[TemperatureScaler]
        Fitted scaler or None.
    device : str or torch.device
        Device.

    Returns
    -------
    torch.Tensor
        Possibly scaled logits [N, C].
    """
    if scaler is None:
        return logits
    return scaler(logits.to(device))


def compute_ece(probs: ArrayLike, labels: ArrayLike, n_bins: int = 15) -> float:
    """Compute Expected Calibration Error (ECE) for binary classification.

    Parameters
    ----------
    probs : ArrayLike
        Predicted probabilities for the positive class, shape [N].
    labels : ArrayLike
        Binary labels {0, 1}, shape [N].
    n_bins : int, default=15
        Number of confidence bins.

    Returns
    -------
    float
        ECE value in [0, 1].
    """
    probs = np.asarray(probs)
    labels = np.asarray(labels)
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    binids = np.digitize(probs, bins[1:-1], right=True)

    ece = 0.0
    for b in range(n_bins):
        mask = binids == b
        if not np.any(mask):
            continue
        conf = probs[mask].mean()
        acc = labels[mask].mean()
        w = mask.mean()
        ece += w * abs(acc - conf)
    return float(ece)


def plot_reliability_diagram(
    labels: ArrayLike,
    probs: ArrayLike,
    save_path: Union[str, os.PathLike],
    n_bins: int = 15
) -> None:
    """Plot and save a reliability diagram.

    Parameters
    ----------
    labels : ArrayLike
        Binary ground-truth labels [N].
    probs : ArrayLike
        Predicted positive-class probabilities [N].
    save_path : str or os.PathLike
        Output image path.
    n_bins : int, default=15
        Number of bins.
    """
    labels = np.asarray(labels)
    probs = np.asarray(probs)
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    binids = np.digitize(probs, bins[1:-1], right=True)

    bin_acc: List[float] = []
    bin_conf: List[float] = []
    for b in range(n_bins):
        mask = binids == b
        if not np.any(mask):
            bin_acc.append(0.0)
            bin_conf.append((bins[b] + bins[b + 1]) / 2.0)
        else:
            bin_acc.append(float(labels[mask].mean()))
            bin_conf.append(float(probs[mask].mean()))

    plt.figure(figsize=(6, 6))
    plt.plot([0, 1], [0, 1], linestyle='--', linewidth=2)
    plt.bar(bin_conf, np.array(bin_acc) - np.array(bin_conf),
            width=1.0 / n_bins, bottom=bin_conf, align='center', alpha=0.7)
    plt.xlabel('Confidence')
    plt.ylabel('Accuracy')
    plt.title('Reliability Diagram')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()


def train_epoch(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    device: Union[str, torch.device]
) -> Tuple[float, float]:
    """Train the model for one epoch.

    Parameters
    ----------
    model : nn.Module
        Model to train.
    loader : DataLoader
        Training DataLoader.
    criterion : nn.Module
        Loss function.
    optimizer : torch.optim.Optimizer
        Optimizer instance.
    device : str or torch.device
        Device.

    Returns
    -------
    (float, float)
        Tuple of (average_loss, accuracy) where accuracy is in [0, 1].
    """
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(loader, desc='Training')

    for batch in pbar:
        images = batch['image'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        outputs, _ = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        pbar.set_postfix({
            'loss': running_loss / max(1, total),
            'acc': 100 * correct / max(1, total)
        })

    return running_loss / max(1, total), correct / max(1, total)


def validate(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    device: Union[str, torch.device]
) -> Tuple[float, float, float]:
    """Validate the model and compute loss, accuracy, and F1.

    Parameters
    ----------
    model : nn.Module
        Model to evaluate.
    loader : DataLoader
        Validation DataLoader.
    criterion : nn.Module
        Loss function.
    device : str or torch.device
        Device.

    Returns
    -------
    (float, float, float)
        Tuple (avg_loss, accuracy, f1), accuracy and f1 in [0, 1].
    """
    model.eval()
    running_loss = 0.0
    all_labels: List[int] = []
    all_preds: List[int] = []

    with torch.no_grad():
        for batch in tqdm(loader, desc='Validation'):
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            outputs, _ = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)

            _, predicted = outputs.max(1)
            all_labels.extend(labels.cpu().numpy().tolist())
            all_preds.extend(predicted.cpu().numpy().tolist())

    avg_loss = running_loss / len(loader.dataset)
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    accuracy = (np.array(all_preds) == np.array(all_labels)).mean().item()
    return avg_loss, accuracy, f1


def evaluate_model(
    model: nn.Module,
    loader: DataLoader,
    device: Union[str, torch.device],
    scaler: Optional[TemperatureScaler] = None,
    threshold: Optional[float] = None
) -> Dict[str, Any]:
    """Evaluate model with optional calibration and decision threshold.

    Parameters
    ----------
    model : nn.Module
        Model to evaluate.
    loader : DataLoader
        DataLoader.
    device : str or torch.device
        Device.
    scaler : Optional[TemperatureScaler], default=None
        Temperature scaler for logits.
    threshold : Optional[float], default=None
        If provided, use this threshold on positive-class probability.
        If None, uses argmax over classes.

    Returns
    -------
    dict
        Metrics and raw outputs:
        - 'accuracy', 'precision', 'recall', 'f1', 'roc_auc', 'ece'
        - 'brier', 'nll'
        - 'confusion_matrix' (np.ndarray shape [2,2])
        - 'predictions', 'labels', 'probabilities' (lists)
    """
    model.eval()
    all_preds: List[int] = []
    all_labels: List[int] = []
    all_probs: List[float] = []

    with torch.no_grad():
        for batch in tqdm(loader, desc='Evaluating'):
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            logits, _ = model(images)
            if scaler is not None:
                logits = scaler(logits)

            probs = torch.softmax(logits, dim=1)
            if threshold is None:
                _, preds = probs.max(1)
            else:
                preds = (probs[:, 1] >= float(threshold)).long()

            all_preds.extend(preds.cpu().numpy().tolist())
            all_labels.extend(labels.cpu().numpy().tolist())
            all_probs.extend(probs[:, 1].cpu().numpy().tolist())

    all_preds_np = np.array(all_preds)
    all_labels_np = np.array(all_labels)
    all_probs_np = np.array(all_probs)

    accuracy = accuracy_score(all_labels_np, all_preds_np)
    precision = precision_score(all_labels_np, all_preds_np, zero_division=0)
    recall = recall_score(all_labels_np, all_preds_np, zero_division=0)
    f1 = f1_score(all_labels_np, all_preds_np, zero_division=0)
    roc_auc = roc_auc_score(all_labels_np, all_probs_np)
    cm = confusion_matrix(all_labels_np, all_preds_np)
    ece = compute_ece(all_probs_np, all_labels_np)

    brier = brier_score_loss(all_labels_np, all_probs_np)
    nll = log_loss(all_labels_np, np.vstack([1 - all_probs_np, all_probs_np]).T, labels=[0, 1])

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'roc_auc': roc_auc,
        'ece': ece,
        'brier': brier,
        'nll': nll,
        'confusion_matrix': cm,
        'predictions': all_preds,
        'labels': all_labels,
        'probabilities': all_probs
    }


def best_threshold_for_f1(labels: ArrayLike, probs: ArrayLike) -> Tuple[float, float]:
    """Compute the probability threshold that maximizes F1 on validation data.

    Parameters
    ----------
    labels : ArrayLike
        Binary labels {0,1}.
    probs : ArrayLike
        Predicted positive-class probabilities.

    Returns
    -------
    (float, float)
        Tuple (threshold, best_f1).
    """
    labels = np.asarray(labels)
    probs = np.asarray(probs)
    prec, rec, thr = precision_recall_curve(labels, probs)
    f1 = 2 * prec[1:] * rec[1:] / (prec[1:] + rec[1:] + 1e-12)
    idx = int(np.argmax(f1))
    return float(thr[idx]), float(f1[idx])


def youden_j_threshold(labels: ArrayLike, probs: ArrayLike) -> float:
    """Compute Youden's J statistic threshold from the ROC curve.

    Parameters
    ----------
    labels : ArrayLike
        Binary labels {0,1}.
    probs : ArrayLike
        Predicted positive-class probabilities.

    Returns
    -------
    float
        Threshold that maximizes TPR - FPR.
    """
    fpr, tpr, thr = roc_curve(labels, probs)
    j = tpr - fpr
    idx = int(np.argmax(j))
    return float(thr[idx])


def plot_training_history(history: Dict[str, List[float]], save_path: Union[str, os.PathLike]) -> None:
    """Plot training/validation loss and accuracy curves.

    Parameters
    ----------
    history : dict
        History dict with keys: 'train_loss', 'val_loss', 'train_acc', 'val_acc'.
    save_path : str or os.PathLike
        Output path for the image file.
    """
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))

    axes[0].plot(history['train_loss'], label='Train Loss')
    axes[0].plot(history['val_loss'], label='Val Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True)

    axes[1].plot(history['train_acc'], label='Train Acc')
    axes[1].plot(history['val_acc'], label='Val Acc')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Training and Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True)

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()


def plot_confusion_matrix(cm: np.ndarray, save_path: Union[str, os.PathLike]) -> None:
    """Plot and save a confusion matrix heatmap.

    Parameters
    ----------
    cm : np.ndarray
        Confusion matrix of shape [2, 2].
    save_path : str or os.PathLike
        Output image path.
    """
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Benign', 'Melanoma'],
                yticklabels=['Benign', 'Melanoma'])
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()


def plot_roc_curve(
    labels: ArrayLike,
    probs: ArrayLike,
    save_path: Union[str, os.PathLike]
) -> None:
    """Plot and save the ROC curve.

    Parameters
    ----------
    labels : ArrayLike
        Ground truth binary labels {0, 1}.
    probs : ArrayLike
        Predicted probabilities for the positive class.
    save_path : str or os.PathLike
        Path to save the ROC curve image.
    """
    fpr, tpr, _ = roc_curve(labels, probs)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
    plt.plot([0, 1], [0, 1], lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()


def main(use_ham10000: bool = True, use_isic: bool = False) -> Tuple[nn.Module, Dict[str, Any], Dict[str, List[float]]]:
    """Main training, calibration, and evaluation pipeline.

    Parameters
    ----------
    use_ham10000 : bool, default=True
        If True, load HAM10000 dataset.
    use_isic : bool, default=False
        If True, load ISIC dataset (requires `load_isic_data`, not provided here).

    Returns
    -------
    (nn.Module, dict, dict)
        - Trained model
        - Metrics dict for calibrated evaluation on test set
        - Training history dict
    """
    config = Config()
    os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
    os.makedirs(config.RESULTS_DIR, exist_ok=True)

    print('=' * 70)
    print("MELANOMA CLASSIFICATION - TRAINING PIPELINE")
    print('=' * 70)
    print(f"Device: {config.DEVICE}")
    print(f"Model: {config.MODEL_NAME}")
    print(f"Image size: {config.IMG_SIZE}")
    print(f"Batch size: {config.BATCH_SIZE}")
    print(f"Learning rate: {config.LEARNING_RATE}")
    print("=" * 70)

    if use_ham10000:
        df = load_ham10000_data(config.HAM10000_BASE)
    elif use_isic:
        # Placeholder: function not defined in this file.
        df = load_isic_data(config.ISIC_BASE)  # type: ignore[name-defined]
        df = df[df['split'] == 'train'].reset_index(drop=True)
    else:
        raise ValueError("Must specify either HAM10000 or ISIC dataset")

    print("\nSplitting dataset...")
    train_df, temp_df = train_test_split(
        df, test_size=config.TEST_SIZE + config.VAL_SIZE,
        random_state=config.RANDOM_STATE, stratify=df['binary_label']
    )
    val_df, test_df = train_test_split(
        temp_df, test_size=config.TEST_SIZE / (config.TEST_SIZE + config.VAL_SIZE),
        random_state=config.RANDOM_STATE, stratify=temp_df['binary_label']
    )
    print(f"Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)}")

    train_dataset = MelanomaDataset(train_df, get_train_transform(config.IMG_SIZE))
    val_dataset = MelanomaDataset(val_df, get_val_transform(config.IMG_SIZE))
    test_dataset = MelanomaDataset(test_df, get_val_transform(config.IMG_SIZE))

    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE,
                              shuffle=True, num_workers=config.NUM_WORKERS)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE,
                            shuffle=False, num_workers=config.NUM_WORKERS)
    test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE,
                             shuffle=False, num_workers=config.NUM_WORKERS)

    class_counts = train_df['binary_label'].value_counts()
    total = len(train_df)
    class_weights = {
        0: total / (2 * class_counts[0]),
        1: total / (2 * class_counts[1]) * 1.0
    }
    weights = torch.FloatTensor([class_weights[0], class_weights[1]]).to(config.DEVICE)

    print(f"\nClass weights: {class_weights}")

    print(f"\nCreating {config.MODEL_NAME} model...")
    model = MelanomaClassifier(config.MODEL_NAME, config.NUM_CLASSES, config.PRETRAINED)
    model = model.to(config.DEVICE)

    criterion: nn.Module = FocalLoss(alpha=weights, gamma=2.0)

    set_backbone_trainable(model, trainable=False)
    optimizer: optim.Optimizer = build_optimizer_warmup(model, config)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

    current_phase = "warmup"

    print("\nStarting training...")
    print("=" * 70)

    history: Dict[str, List[float]] = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }

    best_val_f1 = 0.0
    patience_counter = 0

    for epoch in range(config.NUM_EPOCHS):
        print(f"\nEpoch {epoch + 1}/{config.NUM_EPOCHS}")
        print("-" * 70)

        if current_phase == "warmup" and epoch == config.FREEZE_EPOCHS:
            print("\nUnfreezing backbone and switching to discriminative LRs...")
            set_backbone_trainable(model, trainable=True)
            optimizer = build_optimizer_finetune(model, config)
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
            current_phase = "finetune"

        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, config.DEVICE)
        val_loss, val_acc, val_f1 = validate(model, val_loader, criterion, config.DEVICE)

        scheduler.step(val_loss)

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
        print(f"LR: {optimizer.param_groups[0]['lr']:.6f}")

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': val_f1,
                'val_loss': val_loss,
            }, os.path.join(config.CHECKPOINT_DIR, 'best_model.pth'))
            print(f"✓ Best model saved! Val F1: {val_f1:.4f}")
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= config.EARLY_STOPPING_PATIENCE:
            print(f"\nEarly stopping at epoch {epoch + 1}")
            break

    plot_training_history(history, os.path.join(config.RESULTS_DIR, 'training_history.png'))

    print("\n" + "=" * 70)
    print("CALIBRATING AND EVALUATING")
    print("=" * 70)

    checkpoint = torch.load(os.path.join(config.CHECKPOINT_DIR, 'best_model.pth'), map_location=config.DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(config.DEVICE)

    scaler = fit_temperature(model, val_loader, config.DEVICE, verbose=True)

    print("\nSelecting threshold on VALIDATION (calibrated)...")
    val_metrics_cal = evaluate_model(model, val_loader, config.DEVICE, scaler=scaler, threshold=None)
    t_f1, val_f1_at_t = best_threshold_for_f1(val_metrics_cal['labels'], val_metrics_cal['probabilities'])
    print(f"Chosen threshold for F1: t={t_f1:.4f} (val F1 @ t = {val_f1_at_t:.4f})")

    print("\nEvaluating UNCALIBRATED on test...")
    metrics_raw = evaluate_model(model, test_loader, config.DEVICE, scaler=None, threshold=None)

    print("\nEvaluating CALIBRATED (default argmax) on test...")
    metrics_cal = evaluate_model(model, test_loader, config.DEVICE, scaler=scaler, threshold=None)

    print("\nEvaluating CALIBRATED + THRESHOLD-TUNED on test...")
    metrics_cal_t = evaluate_model(model, test_loader, config.DEVICE, scaler=scaler, threshold=t_f1)

    def _pretty(m: Dict[str, Any]) -> str:
        return (f"Acc {m['accuracy']:.4f} | Prec {m['precision']:.4f} | "
                f"Rec {m['recall']:.4f} | F1 {m['f1']:.4f} | "
                f"AUC {m['roc_auc']:.4f} | ECE {m['ece']:.4f} | "
                f"Brier {m['brier']:.4f} | NLL {m['nll']:.4f}")

    print("\nTest (uncalibrated, argmax):")
    print(_pretty(metrics_raw))
    print("\nTest (calibrated, argmax):")
    print(_pretty(metrics_cal))
    print("\nTest (calibrated, tuned t):")
    print(_pretty(metrics_cal_t))

    plot_reliability_diagram(metrics_raw['labels'], metrics_raw['probabilities'],
                             os.path.join(config.RESULTS_DIR, 'reliability_uncalibrated.png'))
    plot_reliability_diagram(metrics_cal['labels'], metrics_cal['probabilities'],
                             os.path.join(config.RESULTS_DIR, 'reliability_calibrated.png'))

    plot_confusion_matrix(metrics_cal['confusion_matrix'],
                          os.path.join(config.RESULTS_DIR, 'confusion_matrix_calibrated.png'))
    plot_roc_curve(metrics_cal['labels'], metrics_cal['probabilities'],
                   os.path.join(config.RESULTS_DIR, 'roc_curve_calibrated.png'))

    print(f"\nResults saved to: {config.RESULTS_DIR}")
    print(f"Temperature T: {scaler.T.item():.4f}")
    print("Training + calibration complete!")

    return model, metrics_cal, history


if __name__ == "__main__":
    # Run training
    model, metrics, history = main(use_ham10000=True, use_isic=True)



MELANOMA CLASSIFICATION - TRAINING PIPELINE
Device: cuda
Model: senet50
Image size: 224
Batch size: 32
Learning rate: 0.001
Loading HAM10000 dataset....
Loaded 10015 images
Melanoma: 1113
Benign: 8902

Class Distribution:
dx
nv       6705
mel      1113
bkl      1099
bcc       514
akiec     327
vasc      142
df        115
Name: count, dtype: int64

Splitting dataset...
Train: 5007 | Val: 2003 | Test: 3005

Class weights: {0: np.float64(0.5624578746349135), 1: np.float64(4.502697841726619)}

Creating senet50 model...

Starting training...

Epoch 1/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [00:59<00:00,  2.63it/s, loss=0.311, acc=25.2]
Validation: 100%|██████████| 63/63 [00:20<00:00,  3.10it/s]



Train Loss: 0.3108 | Train Acc: 0.2520
Val Loss: 0.2307 | Val Acc: 0.2317
LR: 0.001000
✓ Best model saved! Val F1: 0.2247

Epoch 2/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [00:48<00:00,  3.22it/s, loss=0.284, acc=20.8]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.25it/s]



Train Loss: 0.2841 | Train Acc: 0.2077
Val Loss: 0.2197 | Val Acc: 0.2062
LR: 0.001000

Epoch 3/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [00:48<00:00,  3.24it/s, loss=0.289, acc=21.9]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.26it/s]



Train Loss: 0.2889 | Train Acc: 0.2191
Val Loss: 0.2218 | Val Acc: 0.1628
LR: 0.001000

Epoch 4/250
----------------------------------------------------------------------

Unfreezing backbone and switching to discriminative LRs...


Training: 100%|██████████| 157/157 [01:03<00:00,  2.49it/s, loss=0.237, acc=27.9]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]



Train Loss: 0.2369 | Train Acc: 0.2792
Val Loss: 0.1987 | Val Acc: 0.4019
LR: 0.000100
✓ Best model saved! Val F1: 0.2686

Epoch 5/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.49it/s, loss=0.228, acc=36.9]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.19it/s]



Train Loss: 0.2276 | Train Acc: 0.3691
Val Loss: 0.1908 | Val Acc: 0.4169
LR: 0.000100
✓ Best model saved! Val F1: 0.2727

Epoch 6/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.48it/s, loss=0.222, acc=39.9]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.17it/s]



Train Loss: 0.2218 | Train Acc: 0.3992
Val Loss: 0.1853 | Val Acc: 0.4373
LR: 0.000100
✓ Best model saved! Val F1: 0.2790

Epoch 7/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:02<00:00,  2.49it/s, loss=0.213, acc=41.5]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.20it/s]



Train Loss: 0.2134 | Train Acc: 0.4150
Val Loss: 0.1878 | Val Acc: 0.4403
LR: 0.000100
✓ Best model saved! Val F1: 0.2809

Epoch 8/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:02<00:00,  2.49it/s, loss=0.208, acc=42.2]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.00it/s]



Train Loss: 0.2084 | Train Acc: 0.4216
Val Loss: 0.1835 | Val Acc: 0.4853
LR: 0.000100
✓ Best model saved! Val F1: 0.2982

Epoch 9/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.48it/s, loss=0.2, acc=43.5]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.00it/s]



Train Loss: 0.2001 | Train Acc: 0.4354
Val Loss: 0.1758 | Val Acc: 0.4433
LR: 0.000100

Epoch 10/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.48it/s, loss=0.199, acc=43.7]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]



Train Loss: 0.1989 | Train Acc: 0.4366
Val Loss: 0.1777 | Val Acc: 0.4753
LR: 0.000100

Epoch 11/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:02<00:00,  2.49it/s, loss=0.193, acc=45]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.17it/s]



Train Loss: 0.1934 | Train Acc: 0.4504
Val Loss: 0.1745 | Val Acc: 0.4578
LR: 0.000100

Epoch 12/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.48it/s, loss=0.195, acc=46.1]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.17it/s]



Train Loss: 0.1951 | Train Acc: 0.4608
Val Loss: 0.1861 | Val Acc: 0.4893
LR: 0.000100

Epoch 13/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.49it/s, loss=0.194, acc=46.9]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.22it/s]



Train Loss: 0.1944 | Train Acc: 0.4691
Val Loss: 0.1722 | Val Acc: 0.5127
LR: 0.000100
✓ Best model saved! Val F1: 0.3088

Epoch 14/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.49it/s, loss=0.184, acc=48.3]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.09it/s]



Train Loss: 0.1836 | Train Acc: 0.4827
Val Loss: 0.1636 | Val Acc: 0.5607
LR: 0.000100
✓ Best model saved! Val F1: 0.3313

Epoch 15/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:04<00:00,  2.44it/s, loss=0.189, acc=49.5]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]



Train Loss: 0.1888 | Train Acc: 0.4953
Val Loss: 0.1667 | Val Acc: 0.5322
LR: 0.000100

Epoch 16/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.47it/s, loss=0.182, acc=49.5]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.15it/s]



Train Loss: 0.1824 | Train Acc: 0.4951
Val Loss: 0.1677 | Val Acc: 0.5182
LR: 0.000100

Epoch 17/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:04<00:00,  2.43it/s, loss=0.181, acc=50.6]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]



Train Loss: 0.1807 | Train Acc: 0.5065
Val Loss: 0.1665 | Val Acc: 0.5282
LR: 0.000100

Epoch 18/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.45it/s, loss=0.183, acc=51.2]
Validation: 100%|██████████| 63/63 [00:15<00:00,  3.98it/s]



Train Loss: 0.1830 | Train Acc: 0.5123
Val Loss: 0.1661 | Val Acc: 0.5242
LR: 0.000100

Epoch 19/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.47it/s, loss=0.182, acc=50.4]
Validation: 100%|██████████| 63/63 [00:16<00:00,  3.88it/s]



Train Loss: 0.1818 | Train Acc: 0.5045
Val Loss: 0.1649 | Val Acc: 0.5776
LR: 0.000100
✓ Best model saved! Val F1: 0.3370

Epoch 20/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.48it/s, loss=0.179, acc=53.4]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.12it/s]



Train Loss: 0.1791 | Train Acc: 0.5345
Val Loss: 0.1669 | Val Acc: 0.5237
LR: 0.000050

Epoch 21/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.48it/s, loss=0.174, acc=52.5]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.13it/s]



Train Loss: 0.1739 | Train Acc: 0.5249
Val Loss: 0.1648 | Val Acc: 0.5402
LR: 0.000050

Epoch 22/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.46it/s, loss=0.174, acc=53.5]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.14it/s]



Train Loss: 0.1742 | Train Acc: 0.5349
Val Loss: 0.1597 | Val Acc: 0.5826
LR: 0.000050
✓ Best model saved! Val F1: 0.3397

Epoch 23/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.47it/s, loss=0.166, acc=54.2]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.15it/s]



Train Loss: 0.1658 | Train Acc: 0.5418
Val Loss: 0.1639 | Val Acc: 0.5931
LR: 0.000050
✓ Best model saved! Val F1: 0.3454

Epoch 24/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.49it/s, loss=0.167, acc=54.3]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]



Train Loss: 0.1674 | Train Acc: 0.5430
Val Loss: 0.1653 | Val Acc: 0.6111
LR: 0.000050
✓ Best model saved! Val F1: 0.3525

Epoch 25/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.48it/s, loss=0.172, acc=55.7]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]



Train Loss: 0.1716 | Train Acc: 0.5570
Val Loss: 0.1684 | Val Acc: 0.5956
LR: 0.000050

Epoch 26/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.48it/s, loss=0.167, acc=53.4]
Validation: 100%|██████████| 63/63 [00:16<00:00,  3.91it/s]



Train Loss: 0.1673 | Train Acc: 0.5337
Val Loss: 0.1654 | Val Acc: 0.5297
LR: 0.000050

Epoch 27/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.49it/s, loss=0.165, acc=55.4]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]



Train Loss: 0.1649 | Train Acc: 0.5544
Val Loss: 0.1700 | Val Acc: 0.5577
LR: 0.000050

Epoch 28/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.49it/s, loss=0.169, acc=54.4]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.17it/s]



Train Loss: 0.1692 | Train Acc: 0.5440
Val Loss: 0.1702 | Val Acc: 0.6056
LR: 0.000025
✓ Best model saved! Val F1: 0.3535

Epoch 29/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:02<00:00,  2.50it/s, loss=0.169, acc=54.3]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.14it/s]



Train Loss: 0.1688 | Train Acc: 0.5426
Val Loss: 0.1585 | Val Acc: 0.6016
LR: 0.000025

Epoch 30/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.47it/s, loss=0.166, acc=55.6]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]



Train Loss: 0.1658 | Train Acc: 0.5564
Val Loss: 0.1612 | Val Acc: 0.5921
LR: 0.000025

Epoch 31/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.47it/s, loss=0.165, acc=54.9]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]



Train Loss: 0.1645 | Train Acc: 0.5492
Val Loss: 0.1606 | Val Acc: 0.5691
LR: 0.000025

Epoch 32/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.49it/s, loss=0.164, acc=55]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.12it/s]



Train Loss: 0.1636 | Train Acc: 0.5502
Val Loss: 0.1650 | Val Acc: 0.5781
LR: 0.000025

Epoch 33/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.48it/s, loss=0.162, acc=56]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.14it/s]



Train Loss: 0.1621 | Train Acc: 0.5596
Val Loss: 0.1615 | Val Acc: 0.5916
LR: 0.000025

Epoch 34/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.47it/s, loss=0.164, acc=56.3]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.15it/s]



Train Loss: 0.1638 | Train Acc: 0.5634
Val Loss: 0.1664 | Val Acc: 0.5647
LR: 0.000025

Epoch 35/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:03<00:00,  2.46it/s, loss=0.166, acc=54.9]
Validation: 100%|██████████| 63/63 [00:15<00:00,  3.96it/s]



Train Loss: 0.1659 | Train Acc: 0.5486
Val Loss: 0.1642 | Val Acc: 0.6116
LR: 0.000013

Epoch 36/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:02<00:00,  2.53it/s, loss=0.164, acc=57.2]
Validation: 100%|██████████| 63/63 [00:16<00:00,  3.93it/s]



Train Loss: 0.1644 | Train Acc: 0.5720
Val Loss: 0.1637 | Val Acc: 0.5991
LR: 0.000013

Epoch 37/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:02<00:00,  2.52it/s, loss=0.16, acc=56.8]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.06it/s]



Train Loss: 0.1604 | Train Acc: 0.5684
Val Loss: 0.1636 | Val Acc: 0.5801
LR: 0.000013

Epoch 38/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.54it/s, loss=0.154, acc=57.4]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.03it/s]



Train Loss: 0.1544 | Train Acc: 0.5744
Val Loss: 0.1688 | Val Acc: 0.6316
LR: 0.000013
✓ Best model saved! Val F1: 0.3616

Epoch 39/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:02<00:00,  2.51it/s, loss=0.158, acc=57.2]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.21it/s]



Train Loss: 0.1581 | Train Acc: 0.5718
Val Loss: 0.1679 | Val Acc: 0.6111
LR: 0.000013

Epoch 40/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:02<00:00,  2.49it/s, loss=0.158, acc=58.1]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.20it/s]



Train Loss: 0.1582 | Train Acc: 0.5806
Val Loss: 0.1633 | Val Acc: 0.6051
LR: 0.000013

Epoch 41/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.55it/s, loss=0.161, acc=57.7]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.18it/s]



Train Loss: 0.1615 | Train Acc: 0.5772
Val Loss: 0.1637 | Val Acc: 0.6026
LR: 0.000006

Epoch 42/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.54it/s, loss=0.156, acc=57.8]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.24it/s]



Train Loss: 0.1565 | Train Acc: 0.5782
Val Loss: 0.1652 | Val Acc: 0.6016
LR: 0.000006

Epoch 43/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:02<00:00,  2.52it/s, loss=0.154, acc=57.6]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.17it/s]



Train Loss: 0.1539 | Train Acc: 0.5762
Val Loss: 0.1658 | Val Acc: 0.6216
LR: 0.000006

Epoch 44/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:02<00:00,  2.51it/s, loss=0.154, acc=57.1]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.22it/s]



Train Loss: 0.1536 | Train Acc: 0.5706
Val Loss: 0.1670 | Val Acc: 0.6355
LR: 0.000006
✓ Best model saved! Val F1: 0.3652

Epoch 45/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:02<00:00,  2.53it/s, loss=0.155, acc=57.4]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.08it/s]



Train Loss: 0.1547 | Train Acc: 0.5738
Val Loss: 0.1665 | Val Acc: 0.6321
LR: 0.000006

Epoch 46/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:02<00:00,  2.53it/s, loss=0.16, acc=57.3]
Validation: 100%|██████████| 63/63 [00:15<00:00,  3.96it/s]



Train Loss: 0.1601 | Train Acc: 0.5732
Val Loss: 0.1659 | Val Acc: 0.6256
LR: 0.000006

Epoch 47/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:00<00:00,  2.60it/s, loss=0.158, acc=57.5]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.17it/s]



Train Loss: 0.1575 | Train Acc: 0.5750
Val Loss: 0.1633 | Val Acc: 0.6121
LR: 0.000003

Epoch 48/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.57it/s, loss=0.159, acc=58.2]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.21it/s]



Train Loss: 0.1586 | Train Acc: 0.5822
Val Loss: 0.1649 | Val Acc: 0.6146
LR: 0.000003

Epoch 49/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.55it/s, loss=0.157, acc=57.4]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.13it/s]



Train Loss: 0.1572 | Train Acc: 0.5736
Val Loss: 0.1637 | Val Acc: 0.6066
LR: 0.000003

Epoch 50/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.56it/s, loss=0.156, acc=57.2]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.16it/s]



Train Loss: 0.1558 | Train Acc: 0.5724
Val Loss: 0.1643 | Val Acc: 0.6071
LR: 0.000003

Epoch 51/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:00<00:00,  2.59it/s, loss=0.159, acc=58.1]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.19it/s]



Train Loss: 0.1587 | Train Acc: 0.5814
Val Loss: 0.1697 | Val Acc: 0.6460
LR: 0.000003
✓ Best model saved! Val F1: 0.3698

Epoch 52/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:00<00:00,  2.58it/s, loss=0.154, acc=57.8]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.19it/s]



Train Loss: 0.1539 | Train Acc: 0.5780
Val Loss: 0.1634 | Val Acc: 0.6141
LR: 0.000003

Epoch 53/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.55it/s, loss=0.154, acc=57.8]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.22it/s]



Train Loss: 0.1537 | Train Acc: 0.5782
Val Loss: 0.1635 | Val Acc: 0.6066
LR: 0.000002

Epoch 54/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.54it/s, loss=0.157, acc=58]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.24it/s]



Train Loss: 0.1570 | Train Acc: 0.5798
Val Loss: 0.1637 | Val Acc: 0.6146
LR: 0.000002

Epoch 55/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.56it/s, loss=0.16, acc=58.1]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.25it/s]



Train Loss: 0.1603 | Train Acc: 0.5808
Val Loss: 0.1639 | Val Acc: 0.6141
LR: 0.000002

Epoch 56/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.56it/s, loss=0.152, acc=58.4]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.28it/s]



Train Loss: 0.1519 | Train Acc: 0.5840
Val Loss: 0.1637 | Val Acc: 0.6056
LR: 0.000002

Epoch 57/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:00<00:00,  2.58it/s, loss=0.152, acc=58.3]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.20it/s]



Train Loss: 0.1516 | Train Acc: 0.5828
Val Loss: 0.1620 | Val Acc: 0.6031
LR: 0.000002

Epoch 58/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.54it/s, loss=0.151, acc=58.2]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.22it/s]



Train Loss: 0.1506 | Train Acc: 0.5824
Val Loss: 0.1656 | Val Acc: 0.6181
LR: 0.000002

Epoch 59/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:00<00:00,  2.57it/s, loss=0.153, acc=57.8]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.20it/s]



Train Loss: 0.1535 | Train Acc: 0.5784
Val Loss: 0.1659 | Val Acc: 0.6201
LR: 0.000001

Epoch 60/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.57it/s, loss=0.152, acc=58.2]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.24it/s]



Train Loss: 0.1516 | Train Acc: 0.5816
Val Loss: 0.1650 | Val Acc: 0.6186
LR: 0.000001

Epoch 61/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.57it/s, loss=0.156, acc=58.4]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.22it/s]



Train Loss: 0.1558 | Train Acc: 0.5840
Val Loss: 0.1651 | Val Acc: 0.6206
LR: 0.000001

Epoch 62/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.56it/s, loss=0.155, acc=58.8]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.23it/s]



Train Loss: 0.1546 | Train Acc: 0.5884
Val Loss: 0.1656 | Val Acc: 0.6266
LR: 0.000001

Epoch 63/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:00<00:00,  2.59it/s, loss=0.151, acc=58.4]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.08it/s]



Train Loss: 0.1508 | Train Acc: 0.5836
Val Loss: 0.1634 | Val Acc: 0.6171
LR: 0.000001

Epoch 64/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.55it/s, loss=0.156, acc=58.3]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.14it/s]



Train Loss: 0.1556 | Train Acc: 0.5826
Val Loss: 0.1628 | Val Acc: 0.6241
LR: 0.000001

Epoch 65/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:00<00:00,  2.59it/s, loss=0.15, acc=58.9]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.12it/s]



Train Loss: 0.1505 | Train Acc: 0.5886
Val Loss: 0.1671 | Val Acc: 0.6316
LR: 0.000000

Epoch 66/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:00<00:00,  2.60it/s, loss=0.155, acc=58.4]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.00it/s]



Train Loss: 0.1549 | Train Acc: 0.5842
Val Loss: 0.1658 | Val Acc: 0.6231
LR: 0.000000

Epoch 67/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:00<00:00,  2.60it/s, loss=0.159, acc=58.3]
Validation: 100%|██████████| 63/63 [00:15<00:00,  3.98it/s]



Train Loss: 0.1594 | Train Acc: 0.5828
Val Loss: 0.1665 | Val Acc: 0.6226
LR: 0.000000

Epoch 68/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:00<00:00,  2.59it/s, loss=0.157, acc=58.4]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.05it/s]



Train Loss: 0.1570 | Train Acc: 0.5842
Val Loss: 0.1649 | Val Acc: 0.6241
LR: 0.000000

Epoch 69/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:00<00:00,  2.59it/s, loss=0.157, acc=58.9]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.05it/s]



Train Loss: 0.1567 | Train Acc: 0.5886
Val Loss: 0.1649 | Val Acc: 0.6181
LR: 0.000000

Epoch 70/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:00<00:00,  2.59it/s, loss=0.148, acc=58.4]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.09it/s]



Train Loss: 0.1483 | Train Acc: 0.5842
Val Loss: 0.1628 | Val Acc: 0.6026
LR: 0.000000

Epoch 71/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:00<00:00,  2.61it/s, loss=0.157, acc=58.7]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.26it/s]



Train Loss: 0.1567 | Train Acc: 0.5866
Val Loss: 0.1661 | Val Acc: 0.6256
LR: 0.000000

Epoch 72/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.56it/s, loss=0.153, acc=58.6]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.29it/s]



Train Loss: 0.1533 | Train Acc: 0.5862
Val Loss: 0.1668 | Val Acc: 0.6400
LR: 0.000000

Epoch 73/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.56it/s, loss=0.152, acc=58.2]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.26it/s]



Train Loss: 0.1524 | Train Acc: 0.5820
Val Loss: 0.1656 | Val Acc: 0.6306
LR: 0.000000

Epoch 74/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.56it/s, loss=0.152, acc=58.5]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.20it/s]



Train Loss: 0.1522 | Train Acc: 0.5852
Val Loss: 0.1642 | Val Acc: 0.6196
LR: 0.000000

Epoch 75/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:02<00:00,  2.52it/s, loss=0.147, acc=59]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.25it/s]



Train Loss: 0.1466 | Train Acc: 0.5904
Val Loss: 0.1640 | Val Acc: 0.6176
LR: 0.000000

Epoch 76/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.54it/s, loss=0.149, acc=59.2]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.20it/s]



Train Loss: 0.1487 | Train Acc: 0.5922
Val Loss: 0.1638 | Val Acc: 0.6116
LR: 0.000000

Epoch 77/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:02<00:00,  2.53it/s, loss=0.157, acc=58.4]
Validation: 100%|██████████| 63/63 [00:15<00:00,  4.20it/s]



Train Loss: 0.1570 | Train Acc: 0.5838
Val Loss: 0.1652 | Val Acc: 0.6156
LR: 0.000000

Epoch 78/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.54it/s, loss=0.156, acc=58.2]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.21it/s]



Train Loss: 0.1559 | Train Acc: 0.5824
Val Loss: 0.1631 | Val Acc: 0.6071
LR: 0.000000

Epoch 79/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.56it/s, loss=0.165, acc=58.3]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.24it/s]



Train Loss: 0.1648 | Train Acc: 0.5830
Val Loss: 0.1637 | Val Acc: 0.6116
LR: 0.000000

Epoch 80/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.57it/s, loss=0.149, acc=58.8]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.22it/s]



Train Loss: 0.1489 | Train Acc: 0.5876
Val Loss: 0.1673 | Val Acc: 0.6266
LR: 0.000000

Epoch 81/250
----------------------------------------------------------------------


Training: 100%|██████████| 157/157 [01:01<00:00,  2.55it/s, loss=0.155, acc=58.8]
Validation: 100%|██████████| 63/63 [00:14<00:00,  4.26it/s]



Train Loss: 0.1554 | Train Acc: 0.5882
Val Loss: 0.1660 | Val Acc: 0.6326
LR: 0.000000

Early stopping at epoch 81

CALIBRATING AND EVALUATING


Collecting logits for calibration: 100%|██████████| 63/63 [00:15<00:00,  4.20it/s]


Temperature learned: T=1.2216 | NLL: 0.6041 -> 0.6005

Selecting threshold on VALIDATION (calibrated)...


Evaluating: 100%|██████████| 63/63 [00:14<00:00,  4.28it/s]


Chosen threshold for F1: t=0.6979 (val F1 @ t = 0.5027)

Evaluating UNCALIBRATED on test...


Evaluating: 100%|██████████| 94/94 [00:32<00:00,  2.85it/s]



Evaluating CALIBRATED (default argmax) on test...


Evaluating: 100%|██████████| 94/94 [00:22<00:00,  4.09it/s]



Evaluating CALIBRATED + THRESHOLD-TUNED on test...


Evaluating: 100%|██████████| 94/94 [00:22<00:00,  4.12it/s]



Test (uncalibrated, argmax):
Acc 0.6556 | Prec 0.2386 | Rec 0.9581 | F1 0.3821 | AUC 0.8898 | ECE 0.3358 | Brier 0.2038 | NLL 0.5899

Test (calibrated, argmax):
Acc 0.6556 | Prec 0.2386 | Rec 0.9581 | F1 0.3821 | AUC 0.8898 | ECE 0.3426 | Brier 0.2022 | NLL 0.5885

Test (calibrated, tuned t):
Acc 0.8682 | Prec 0.4370 | Rec 0.6437 | F1 0.5206 | AUC 0.8898 | ECE 0.3426 | Brier 0.2022 | NLL 0.5885

Results saved to: ./output/results
Temperature T: 1.2216
Training + calibration complete!
