# 02 — Model Training

Interactive training notebook:
- Build dataloaders
- Configure and build the model
- Run training with live loss/accuracy tracking
- Plot training curves
- Save the best checkpoint

In [None]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import numpy as np
import torch
import matplotlib.pyplot as plt

from src.config import CONFIG, get_device
from src.utils.seed import set_seed
from src.data.dataset import get_dataloaders, SkinLesionDataset
from src.models.model import build_model
from src.models.loss import get_criterion
from src.train.train import train_one_epoch, get_optimizer, get_scheduler
from src.train.evaluate import evaluate
from src.utils.metrics import compute_binary_auc
from src.utils.visualization import plot_training_curves

%matplotlib inline

DEVICE = get_device()
print(f'Device: {DEVICE}')

## 1. Configuration overrides

Adjust hyperparameters here before training.

In [None]:
# Override any CONFIG values for this run
CONFIG['epochs'] = 30
CONFIG['learning_rate'] = 1e-4
CONFIG['batch_size'] = 32
CONFIG['model_name'] = 'efficientnet_b0'
CONFIG['loss'] = 'bce'             # binary classification → BCEWithLogitsLoss
CONFIG['scheduler'] = 'cosine'

set_seed(CONFIG['seed'])
print('Config:', CONFIG)

## 2. Dataloaders

In [None]:
loaders = get_dataloaders()

for name, loader in loaders.items():
    if loader is not None:
        print(f'{name:>5}: {len(loader.dataset)} samples, {len(loader)} batches')
    else:
        print(f'{name:>5}: None')

In [None]:
# Peek at a batch
images, labels = next(iter(loaders['train']))
print(f'Batch shape: {images.shape}, Labels: {labels[:8]}')

## 3. Build model, loss, optimiser

In [None]:
model = build_model().to(DEVICE)

# Compute pos_weight for imbalanced binary data
train_ds: SkinLesionDataset = loaders['train'].dataset
pos_weight = train_ds.compute_pos_weight()
print('pos_weight (neg/pos):', pos_weight.item())

criterion = get_criterion(pos_weight=pos_weight, device=DEVICE)
optimizer = get_optimizer(model)
scheduler = get_scheduler(optimizer)

total_params = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Parameters: {total_params:,} total, {trainable:,} trainable')

## 4. Training loop

In [None]:
from src.config import MODELS_DIR
from torch.optim.lr_scheduler import ReduceLROnPlateau

MODELS_DIR.mkdir(parents=True, exist_ok=True)

history = {
    'train_loss': [], 'train_acc': [], 'train_auc': [],
    'val_loss': [], 'val_acc': [], 'val_auc': [],
}
best_val_auc = 0.0
patience_counter = 0

for epoch in range(CONFIG['epochs']):
    lr = optimizer.param_groups[0]['lr']
    print(f"\nEpoch {epoch+1}/{CONFIG['epochs']}  lr={lr:.2e}")

    # Train
    train_metrics = train_one_epoch(
        model, loaders['train'], criterion, optimizer, DEVICE,
        grad_clip=CONFIG.get('grad_clip_max_norm'),
    )

    # Validate
    val_loss, val_acc, val_preds, val_labels, val_probs = evaluate(
        model, loaders['val'], criterion, DEVICE,
    )
    val_auc = compute_binary_auc(val_labels, val_probs)

    # Scheduler
    if scheduler is not None:
        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(val_loss)
        else:
            scheduler.step()

    print(f"  train  loss={train_metrics['loss']:.4f}  acc={train_metrics['acc']:.4f}  auc={train_metrics['auc']:.4f}")
    print(f"  val    loss={val_loss:.4f}  acc={val_acc:.4f}  auc={val_auc:.4f}")

    history['train_loss'].append(train_metrics['loss'])
    history['train_acc'].append(train_metrics['acc'])
    history['train_auc'].append(train_metrics['auc'])
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_auc'].append(val_auc)

    # Save best model by validation AUC (consistent with train.py)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_auc': best_val_auc,
            'val_acc': val_acc,
            'val_loss': val_loss,
            'config': CONFIG,
        }, MODELS_DIR / 'best_model.pth')
        print(f'  -> saved best model (val_auc={best_val_auc:.4f})')
    else:
        patience_counter += 1

    if patience_counter >= CONFIG['early_stopping_patience']:
        print(f'\nEarly stopping at epoch {epoch+1}')
        break

print(f'\nDone. Best val AUC = {best_val_auc:.4f}')

## 5. Training curves

In [None]:
plot_training_curves(history)