# MILK10k Skin Lesion Classification

This notebook trains the **Tone-Aware Multi-Scale Vision Transformer (TAM-ViT)** on the MILK10k dataset.

## ðŸš€ Setup

1.  **Enable GPU**: Go to `Runtime` -> `Change runtime type` -> `T4 GPU` (or better).
2.  **Upload Data**: You need to upload your `milk10k` dataset folder to your Drive or directly here.
    - Expected structure:
        ```
        /content/data/milk10k/
        â”œâ”€â”€ train/
        â”‚   â”œâ”€â”€ lesion1_clin.jpg
        â”‚   â”œâ”€â”€ lesion1_derm.jpg
        â”‚   â””â”€â”€ ...
        â”œâ”€â”€ train.csv
        â””â”€â”€ val.csv
        ```

In [None]:
# 1. Install Dependencies
!pip install torch torchvision timm albumentations pandas numpy omegaconf pytorch-lightning wandb

In [None]:
# 2. Clone Repository (if running from Colab/Kaggle without local files)
# If you uploaded the code manually, skip this.
# !git clone https://github.com/your-username/derm-equity.git
# %cd derm-equity

In [None]:
# 3. Import Libraries & Set Path
import sys
import os
from pathlib import Path

# Add src to python path if needed
sys.path.append(os.getcwd())

import torch
from src.models.tam_vit import create_tam_vit_base
from src.data.milk10k_dataset import MILK10kDataset
from src.data.datasets import get_train_transforms, get_val_transforms, create_dataloaders
from src.training.trainer import DermEquityModule, create_callbacks, create_loggers
import pytorch_lightning as pl
from omegaconf import OmegaConf

print("Libraries imported successfully!")

In [None]:
# 4. Configuration
# We define the config here for easy editing in the notebook

config = OmegaConf.create({
    "data": {
        "dataset": "milk10k",
        "train_data_dir": "/content/data/milk10k/train",  # UPDATE THIS PATH
        "val_data_dir": "/content/data/milk10k/val",      # UPDATE THIS PATH
        "img_size": 224,
        "batch_size": 32,
        "num_workers": 2,
        "classes": ["AKIEC", "BCC", "BEN_OTH", "BKL", "DF", "INF", "MAL_OTH", "MEL", "NV", "SCCKA", "VASC"]
    },
    "model": {
        "architecture": "tam_vit_base",
        "num_classes": 11,
        "in_chans": 6,
        "pretrained": True,
        "img_size": 224
    },
    "training": {
        "epochs": 30,
        "learning_rate": 1e-4,
        "weight_decay": 1e-4,
        "accumulate_grad_batches": 1,
        "gradient_clip_val": 1.0,
        "precision": "16-mixed",
        "use_wandb": False  # Set to True if you have a W&B account
    },
    "optimizer": {"name": "adamw", "betas": [0.9, 0.999]},
    "scheduler": {"name": "cosine_warmup", "warmup_epochs": 3},
    "loss": {"name": "focal", "gamma": 2.0},
    "logging": {"log_every_n_steps": 10, "wandb": {"project": "milk10k", "enabled": False}},
    "paths": {
        "output_dir": "outputs",
        "checkpoint_dir": "outputs/checkpoints",
        "log_dir": "outputs/logs"
    }
})

# Create directories
Path(config.paths.output_dir).mkdir(parents=True, exist_ok=True)
Path(config.paths.checkpoint_dir).mkdir(parents=True, exist_ok=True)
print("Configuration ready.")

In [None]:
# 5. Load Data
# Assumes train.csv and val.csv are in the parent directory of train_data_dir/val_data_dir
# Adjust paths as needed for your upload structure

train_transform = get_train_transforms(config.model.img_size)
val_transform = get_val_transforms(config.model.img_size)

try:
    train_csv = Path(config.data.train_data_dir).parent / 'train.csv'
    val_csv = Path(config.data.val_data_dir).parent / 'val.csv'
    
    train_dataset = MILK10kDataset(
        root_dir=config.data.train_data_dir,
        csv_file=str(train_csv),
        transform=train_transform,
        phase='train'
    )
    
    val_dataset = MILK10kDataset(
        root_dir=config.data.val_data_dir,
        csv_file=str(val_csv),
        transform=val_transform,
        phase='val'
    )
    
    print(f"Train size: {len(train_dataset)}")
    print(f"Val size: {len(val_dataset)}")
    
    dataloaders = create_dataloaders(
        train_dataset, val_dataset, 
        batch_size=config.data.batch_size, 
        num_workers=config.data.num_workers
    )
    
    # Dummy class weights for now
    class_weights = torch.ones(config.model.num_classes)
    
except Exception as e:
    print(f"Error loading data: {e}")
    print("Make sure your paths in Config are correct!")

In [None]:
# 6. Training

pl.seed_everything(42)

# Initialize model
model = DermEquityModule(
    model_config=OmegaConf.to_container(config.model),
    train_config=OmegaConf.to_container(config.training),
    class_weights=class_weights
)

# Callbacks
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=config.paths.checkpoint_dir,
    monitor='val_auc',
    mode='max',
    filename='milk10k-{epoch:02d}-{val_auc:.4f}'
)

trainer = pl.Trainer(
    max_epochs=config.training.epochs,
    accelerator='auto',
    devices=1,
    precision=config.training.precision,
    callbacks=[checkpoint_callback],
    log_every_n_steps=config.logging.log_every_n_steps
)

# Start training
trainer.fit(model, dataloaders['train'], dataloaders['val'])