# MILK10k Skin Lesion Classification (Standalone)

This notebook trains the **Tone-Aware Multi-Scale Vision Transformer (TAM-ViT)** on the MILK10k dataset.
It is self-contained and includes all necessary model and data classes.

## ðŸš€ Setup

1.  **Enable GPU**: Go to `Runtime` -> `Change runtime type` -> `T4 GPU` (or better).
2.  **Upload Data**: You need to upload your `milk10k` dataset folder to your Drive or directly here.
    - Expected structure:
        ```
        /content/data/milk10k/
        â”œâ”€â”€ train/
        â”‚   â”œâ”€â”€ lesion1_clin.jpg
        â”‚   â”œâ”€â”€ lesion1_derm.jpg
        â”‚   â””â”€â”€ ...
        â”œâ”€â”€ train.csv
        â””â”€â”€ val.csv
        ```

In [None]:
# 1. Install Dependencies
!pip install torch torchvision timm albumentations pandas numpy omegaconf pytorch-lightning wandb

In [None]:
# 2. Imports
import warnings
# Suppress non-critical warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

import os
import math
import sys
from pathlib import Path
from typing import Optional, Dict, Any, List, Tuple, Callable
from datetime import datetime

import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR

import albumentations as A
from albumentations.pytorch import ToTensorV2
from einops import rearrange, repeat
# Updated import for newer timm versions to suppress warnings
try:
    from timm.layers import DropPath, trunc_normal_
except ImportError:
    from timm.models.layers import DropPath, trunc_normal_

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor, RichProgressBar
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score
from omegaconf import OmegaConf

print("Libraries imported successfully!")

## 3. Data Classes

In [None]:
# =============================================================================
# Transforms
# =============================================================================
def get_train_transforms(img_size: int = 224) -> A.Compose:
    return A.Compose([
        A.RandomResizedCrop(img_size, img_size, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.OneOf([
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=20),
        ], p=0.5),
        A.OneOf([
            A.GaussNoise(var_limit=(10, 50)),
            A.GaussianBlur(blur_limit=(3, 5)),
        ], p=0.3),
        A.CoarseDropout(max_holes=8, max_height=16, max_width=16, p=0.2),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

def get_val_transforms(img_size: int = 224) -> A.Compose:
    return A.Compose([
        A.Resize(int(img_size * 1.14), int(img_size * 1.14)),
        A.CenterCrop(img_size, img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

# =============================================================================
# MILK10k Dataset
# =============================================================================
class MILK10kDataset(Dataset):
    CLASS_NAMES = [
        'AKIEC', 'BCC', 'BEN_OTH', 'BKL', 'DF', 'INF', 
        'MAL_OTH', 'MEL', 'NV', 'SCCKA', 'VASC'
    ]
    
    def __init__(
        self,
        root_dir: str,
        csv_file: str,
        transform: Optional[A.Compose] = None,
        phase: str = 'train',
    ):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.phase = phase
        self.df = pd.read_csv(csv_file)
        self.label_map = {name: idx for idx, name in enumerate(self.CLASS_NAMES)}
        
    def __len__(self) -> int:
        return len(self.df)
    
    def __getitem__(self, idx: int) -> Dict[str, Any]:
        row = self.df.iloc[idx]
        lesion_id = row['lesion_id']
        
        clinical_path = self.root_dir / row.get('clinical_image_name', f"{lesion_id}_clin.jpg") 
        dermoscopic_path = self.root_dir / row.get('dermoscopic_image_name', f"{lesion_id}_derm.jpg")
        
        if not clinical_path.exists(): clinical_path = clinical_path.with_suffix('.png')
        if not dermoscopic_path.exists(): dermoscopic_path = dermoscopic_path.with_suffix('.png')
            
        try:
            img_clin = np.array(Image.open(clinical_path).convert('RGB'))
            img_derm = np.array(Image.open(dermoscopic_path).convert('RGB'))
        except FileNotFoundError:
            img_clin = np.zeros((224, 224, 3), dtype=np.uint8)
            img_derm = np.zeros((224, 224, 3), dtype=np.uint8)
            print(f"Warning: Missing images for {lesion_id}")

        if self.transform:
            img_clin = self.transform(image=img_clin)['image']
            img_derm = self.transform(image=img_derm)['image']

        # Stack images: (6, H, W)
        image_stacked = torch.cat([img_clin, img_derm], dim=0)
        
        label = -1
        if self.phase != 'test':
            for col in ['diagnosis', 'pathology', 'dx']:
                if col in row:
                    diagnosis = row[col]
                    if diagnosis in self.label_map:
                        label = self.label_map[diagnosis]
                        break
        
        meta = {
            'age': float(row.get('age', -1)),
            'sex': 1 if row.get('sex') == 'male' else 0,
            'fitzpatrick': int(row.get('skin_tone', -1)),
            'anatom_site': row.get('anatom_site', 'unknown')
        }

        return {
            'image': image_stacked,
            'label': torch.tensor(label, dtype=torch.long),
            'fitzpatrick': torch.tensor(meta['fitzpatrick'], dtype=torch.long),
            'lesion_id': lesion_id,
        }

def create_weighted_sampler(dataset: Dataset) -> WeightedRandomSampler:
    labels = []
    for i in range(len(dataset)):
        item = dataset[i]
        labels.append(item['label'].item())
    labels = np.array(labels)
    class_counts = np.bincount(labels)
    class_weights = 1.0 / (class_counts + 1e-8)
    sample_weights = class_weights[labels]
    return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

def create_dataloaders(train_dataset, val_dataset, batch_size=32, num_workers=4):
    train_sampler = create_weighted_sampler(train_dataset)
    return {
        'train': DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers),
        'val': DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    }

## 4. Model Classes (TAM-ViT)

In [None]:
# =============================================================================
# TAM-ViT Modules
# =============================================================================
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)
    
    def forward(self, x):
        x = self.proj(x)
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = self.norm(x)
        return x

class SkinToneEstimator(nn.Module):
    def __init__(self, num_tones=6, hidden_dim=256):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
        )
        self.classifier = nn.Sequential(
            nn.Linear(128, hidden_dim), nn.ReLU(inplace=True), nn.Dropout(0.2), nn.Linear(hidden_dim, num_tones),
        )
    
    def forward(self, x):
        features = self.features(x).flatten(1)
        tone_logits = self.classifier(features)
        tone_probs = F.softmax(tone_logits, dim=-1)
        return tone_logits, tone_probs

class ToneAdaptiveLayerNorm(nn.Module):
    def __init__(self, dim, tone_dim=768):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.gamma_proj = nn.Sequential(nn.Linear(tone_dim, dim // 2), nn.ReLU(inplace=True), nn.Linear(dim // 2, dim))
        self.beta_proj = nn.Sequential(nn.Linear(tone_dim, dim // 2), nn.ReLU(inplace=True), nn.Linear(dim // 2, dim))
        nn.init.zeros_(self.gamma_proj[-1].weight)
        nn.init.zeros_(self.gamma_proj[-1].bias)
        nn.init.zeros_(self.beta_proj[-1].weight)
        nn.init.zeros_(self.beta_proj[-1].bias)
    
    def forward(self, x, tone_embed):
        normalized = self.norm(x)
        gamma = 1 + self.gamma_proj(tone_embed).unsqueeze(1)
        beta = self.beta_proj(tone_embed).unsqueeze(1)
        return gamma * normalized + beta

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn

class ToneModulatedMLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, tone_dim=768, drop=0.0):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features * 4
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        self.gate = nn.Sequential(
            nn.Linear(tone_dim, hidden_features // 2), nn.ReLU(inplace=True),
            nn.Linear(hidden_features // 2, out_features), nn.Sigmoid(),
        )
    
    def forward(self, x, tone_embed):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        gate = self.gate(tone_embed).unsqueeze(1)
        x = x * gate
        x = self.drop(x)
        return x

class ToneConditionedBlock(nn.Module):
    def __init__(self, dim, num_heads, tone_dim=768, mlp_ratio=4.0, qkv_bias=True, drop=0.0, attn_drop=0.0, drop_path=0.0):
        super().__init__()
        self.norm1 = ToneAdaptiveLayerNorm(dim, tone_dim)
        self.attn = MultiHeadSelfAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = ToneAdaptiveLayerNorm(dim, tone_dim)
        self.mlp = ToneModulatedMLP(in_features=dim, hidden_features=int(dim * mlp_ratio), tone_dim=tone_dim, drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
    
    def forward(self, x, tone_embed):
        attn_out, attn_weights = self.attn(self.norm1(x, tone_embed))
        x = x + self.drop_path(attn_out)
        x = x + self.drop_path(self.mlp(self.norm2(x, tone_embed), tone_embed))
        return x, attn_weights

class MultiScalePatchMerger(nn.Module):
    def __init__(self, embed_dim, num_heads=8):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.proj = nn.Linear(embed_dim * 2, embed_dim)
    
    def forward(self, coarse, fine):
        attended, _ = self.cross_attn(self.norm1(coarse), self.norm2(fine), fine)
        merged = torch.cat([coarse, attended], dim=-1)
        merged = self.proj(merged)
        return merged

class UncertaintyHead(nn.Module):
    def __init__(self, embed_dim, num_classes, hidden_dim=256):
        super().__init__()
        self.variance_head = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim), nn.ReLU(inplace=True), nn.Dropout(0.1),
            nn.Linear(hidden_dim, num_classes), nn.Softplus()
        )
    
    def forward(self, x):
        return self.variance_head(x) + 1e-6

class TAMViT(nn.Module):
    def __init__(self, img_size=224, patch_sizes=[16, 8], in_chans=3, num_classes=9, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, num_tones=6):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.tone_estimator = SkinToneEstimator(num_tones=num_tones)
        self.tone_embed = nn.Sequential(nn.Linear(num_tones, embed_dim // 2), nn.ReLU(inplace=True), nn.Linear(embed_dim // 2, embed_dim))
        self.patch_embeds = nn.ModuleList([PatchEmbed(img_size, ps, in_chans, embed_dim) for ps in patch_sizes])
        self.num_patches = self.patch_embeds[0].num_patches
        self.patch_merger = MultiScalePatchMerger(embed_dim, num_heads)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            ToneConditionedBlock(dim=embed_dim, num_heads=num_heads, tone_dim=embed_dim, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i])
            for i in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.cls_head = nn.Sequential(nn.Linear(embed_dim, 256), nn.GELU(), nn.Dropout(drop_rate), nn.Linear(256, num_classes))
        self.uncertainty_head = UncertaintyHead(embed_dim, num_classes)
        self._init_weights()
    
    def _init_weights(self):
        trunc_normal_(self.pos_embed, std=0.02)
        trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights_module)
    
    def _init_weights_module(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias)
            nn.init.ones_(m.weight)
    
    def forward(self, x, return_uncertainty=True, return_attention=False):
        B = x.shape[0]
        # Handle 6-channel input for tone estimation
        img_for_tone = x[:, 3:, :, :] if x.shape[1] == 6 else x
        tone_logits, tone_probs = self.tone_estimator(img_for_tone)
        tone_embedding = self.tone_embed(tone_probs)
        
        coarse_patches = self.patch_embeds[0](x)
        if len(self.patch_embeds) > 1:
            fine_patches = self.patch_embeds[1](x)
            patches = self.patch_merger(coarse_patches, fine_patches)
        else:
            patches = coarse_patches
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, patches], dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        attentions = []
        for block in self.blocks:
            x, attn = block(x, tone_embedding)
            if return_attention: attentions.append(attn)
        
        x = self.norm(x)
        features = x[:, 0]
        logits = self.cls_head(features)
        probs = F.softmax(logits, dim=-1)
        
        result = {'logits': logits, 'probs': probs, 'tone_probs': tone_probs}
        if return_uncertainty:
            variance = self.uncertainty_head(features)
            result['variance'] = variance
            result['uncertainty'] = variance.sum(dim=-1)
        if return_attention:
            result['attention'] = torch.stack(attentions, dim=1)
        return result

    @torch.no_grad()
    def predict_with_mc_dropout(self, x, n_samples=30):
        was_training = self.training
        self.train()
        predictions = []
        for _ in range(n_samples):
            out = self.forward(x, return_uncertainty=True)
            predictions.append(out['probs'])
        preds = torch.stack(predictions)
        mean_pred = preds.mean(dim=0)
        epistemic = preds.var(dim=0).sum(dim=-1)
        aleatoric = -torch.sum(mean_pred * torch.log(mean_pred + 1e-8), dim=-1)
        if not was_training: self.eval()
        return {'mean_probs': mean_pred, 'epistemic_uncertainty': epistemic, 'aleatoric_uncertainty': aleatoric}

## 5. Loss Functions

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        if self.alpha is not None:
            alpha_t = self.alpha.to(inputs.device)[targets]
            focal_loss = alpha_t * focal_loss
        return focal_loss.mean() if self.reduction == 'mean' else focal_loss.sum()

class UncertaintyAwareLoss(nn.Module):
    def __init__(self, reduction='mean'):
        super().__init__()
        self.reduction = reduction
    
    def forward(self, logits, targets, variance):
        ce_loss = F.cross_entropy(logits, targets, reduction='none')
        target_variance = variance[torch.arange(len(targets)), targets]
        nll = 0.5 * (torch.log(target_variance + 1e-8) + ce_loss / (target_variance + 1e-8))
        return nll.mean() if self.reduction == 'mean' else nll.sum()

class DermEquityLoss(nn.Module):
    def __init__(self, num_classes=9, gamma=2.0, lambda_unc=0.1, lambda_fair=0.5, class_weights=None):
        super().__init__()
        self.focal_loss = FocalLoss(gamma=gamma, alpha=class_weights)
        self.uncertainty_loss = UncertaintyAwareLoss()
        self.lambda_unc = lambda_unc
        self.lambda_fair = lambda_fair
    
    def forward(self, outputs, targets):
        logits = outputs['logits']
        focal = self.focal_loss(logits, targets)
        total = focal
        loss_dict = {'focal': focal.item()}
        if 'variance' in outputs:
            unc = self.uncertainty_loss(logits, targets, outputs['variance'])
            total += self.lambda_unc * unc
            loss_dict['uncertainty'] = unc.item()
        return total, loss_dict

def compute_class_weights(labels, num_classes):
    counts = torch.bincount(labels, minlength=num_classes).float()
    weights = 1.0 / (counts + 1e-8)
    weights = weights / weights.sum() * num_classes
    return weights

## 6. Training Module

In [None]:
class DermEquityModule(pl.LightningModule):
    def __init__(self, model_config, train_config, class_weights=None):
        super().__init__()
        self.save_hyperparameters()
        self.model = TAMViT(
            img_size=model_config.get('img_size', 224),
            patch_sizes=model_config.get('patch_sizes', [16, 8]),
            num_classes=model_config.get('num_classes', 9),
            embed_dim=model_config.get('embed_dim', 768),
            depth=model_config.get('depth', 12),
            num_heads=model_config.get('num_heads', 12),
            mlp_ratio=model_config.get('mlp_ratio', 4.0),
            drop_rate=model_config.get('dropout', 0.1),
            drop_path_rate=model_config.get('drop_path', 0.1),
            in_chans=model_config.get('in_chans', 3)  # Use 6 for MILK10k
        )
        self.criterion = DermEquityLoss(
            num_classes=model_config.get('num_classes', 9),
            gamma=train_config.get('focal_gamma', 2.0),
            lambda_unc=train_config.get('lambda_unc', 0.1),
            lambda_fair=train_config.get('lambda_fair', 0.5),
            class_weights=class_weights,
        )
        self.train_config = train_config
        self.model_config = model_config
        self.val_preds = []; self.val_labels = []

    def forward(self, x):
        return self.model(x, return_uncertainty=True)

    def training_step(self, batch, batch_idx):
        images, labels = batch['image'], batch['label']
        outputs = self.model(images, return_uncertainty=True)
        loss, loss_dict = self.criterion(outputs, labels)
        self.log('train/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch['image'], batch['label']
        outputs = self.model(images, return_uncertainty=True)
        loss, loss_dict = self.criterion(outputs, labels)
        self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.val_preds.append(outputs['probs'].detach().cpu())
        self.val_labels.append(labels.cpu())

    def on_validation_epoch_end(self):
        all_probs = torch.cat(self.val_preds, dim=0).numpy()
        all_labels = torch.cat(self.val_labels, dim=0).numpy()
        all_preds = np.argmax(all_probs, axis=1)
        try:
            auc = roc_auc_score(all_labels, all_probs, multi_class='ovr', average='macro')
        except ValueError: auc = 0.0
        f1 = f1_score(all_labels, all_preds, average='macro')
        self.log('val/auc_roc', auc, prog_bar=True)
        self.log('val/f1_macro', f1, prog_bar=True)
        self.val_preds.clear(); self.val_labels.clear()

    def configure_optimizers(self):
        optimizer = AdamW(
            self.parameters(), 
            lr=self.train_config.get('lr', 1e-4), 
            weight_decay=self.train_config.get('weight_decay', 0.05)
        )
        return optimizer

## 7. Execution

In [None]:
# Configuration
config = OmegaConf.create({
    "data": {
        "train_data_dir": "/content/data/milk10k/train",  # UPDATE THIS
        "val_data_dir": "/content/data/milk10k/val",      # UPDATE THIS
        "img_size": 224,
        "batch_size": 32,
        "num_workers": 2,
    },
    "model": {
        "num_classes": 11,
        "in_chans": 6,  # 6 Channels for stacked image
        "img_size": 224,
        "patch_sizes": [16, 8],
        "embed_dim": 768,
        "depth": 12,
        "num_heads": 12
    },
    "training": {
        "epochs": 30,
        "lr": 1e-4,
        "weight_decay": 1e-4,
        "precision": "16-mixed",
        "focal_gamma": 2.0
    },
    "logging": {"log_every_n_steps": 10},
    "paths": { "checkpoint_dir": "outputs/checkpoints" }
})

# Set up data
train_transform = get_train_transforms(config.model.img_size)
val_transform = get_val_transforms(config.model.img_size)

try:
    # Check if data paths exist
    train_dir = Path(config.data.train_data_dir)
    if not train_dir.exists():
        print(f"Warning: {train_dir} not found. Please upload data.")
    else:
        train_csv = Path(config.data.train_data_dir).parent / 'train.csv'
        val_csv = Path(config.data.val_data_dir).parent / 'val.csv'
        
        train_dataset = MILK10kDataset(
            root_dir=config.data.train_data_dir,
            csv_file=str(train_csv),
            transform=train_transform,
            phase='train'
        )
        
        val_dataset = MILK10kDataset(
            root_dir=config.data.val_data_dir,
            csv_file=str(val_csv),
            transform=val_transform,
            phase='val'
        )
        
        dataloaders = create_dataloaders(train_dataset, val_dataset, batch_size=config.data.batch_size)
        class_weights = torch.ones(config.model.num_classes)
        
        # Training
        pl.seed_everything(42)
        model = DermEquityModule(
            model_config=OmegaConf.to_container(config.model),
            train_config=OmegaConf.to_container(config.training),
            class_weights=class_weights
        )
        
        trainer = pl.Trainer(
            max_epochs=config.training.epochs,
            accelerator='auto',
            devices=1,
            precision=config.training.precision,
            callbacks=[ModelCheckpoint(dirpath=config.paths.checkpoint_dir, monitor='val/auc_roc', mode='max')],
            log_every_n_steps=config.logging.log_every_n_steps
        )
        
        trainer.fit(model, dataloaders['train'], dataloaders['val'])
    
except Exception as e:
    print(f"Setup Error: {e}")
    print("Please verify your data paths in the config section!")