In [None]:
# Setup and Imports
import sys
sys.path.insert(0, '..')

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

from omegaconf import OmegaConf
import hydra

# Project imports
from src.datamodules import APTOSDataModule
from src.models import DRModel
from src.utils import quadratic_weighted_kappa, compute_confusion_matrix

print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Lightning version: {pl.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration Management

We use Hydra for configuration management. Let's load and inspect the configs.

In [None]:
# Load configuration manually (for notebook use)
from omegaconf import OmegaConf

# Load individual configs
config_path = Path('../conf')

main_cfg = OmegaConf.load(config_path / 'config.yaml')
dataset_cfg = OmegaConf.load(config_path / 'dataset' / 'aptos.yaml')
model_cfg = OmegaConf.load(config_path / 'model' / 'efficientnet_b5.yaml')
training_cfg = OmegaConf.load(config_path / 'training' / 'default.yaml')
loss_cfg = OmegaConf.load(config_path / 'loss' / 'regression.yaml')

# Merge configs
cfg = OmegaConf.merge(
    main_cfg,
    {'dataset': dataset_cfg},
    {'model': model_cfg},
    {'training': training_cfg},
    {'loss': loss_cfg}
)

print("Configuration:")
print(OmegaConf.to_yaml(cfg))

In [None]:
# Configuration overrides for notebook
# Adjust these based on your hardware

OVERRIDES = {
    'data_dir': '../data/aptos',
    'training': {
        'batch_size': 16,  # Reduce if OOM
        'num_workers': 4,
        'epochs': 30,
        'accumulate_grad_batches': 2,
    },
    'model': {
        'pretrained': True,
        'head_type': 'regression',
    }
}

cfg = OmegaConf.merge(cfg, OmegaConf.create(OVERRIDES))
print("Updated batch_size:", cfg.training.batch_size)
print("Updated epochs:", cfg.training.epochs)

## 3. Data Loading & Verification

In [None]:
# Initialize DataModule
data_dir = Path(cfg.data_dir)

datamodule = APTOSDataModule(
    data_dir=data_dir,
    batch_size=cfg.training.batch_size,
    num_workers=cfg.training.num_workers,
    image_size=cfg.model.input_size,
    use_processed=True,
    val_split=0.2,
    seed=cfg.training.seed
)

# Setup data
datamodule.setup()

print(f"Training samples: {len(datamodule.train_dataset)}")
print(f"Validation samples: {len(datamodule.val_dataset)}")

In [None]:
# Visualize a training batch
train_loader = datamodule.train_dataloader()
batch = next(iter(train_loader))
images, labels = batch

print(f"Batch shape: {images.shape}")
print(f"Labels: {labels.numpy()}")

# Denormalize for visualization
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
images_denorm = images * std + mean

# Plot
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
class_names = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative']

for idx, ax in enumerate(axes.flat):
    if idx < len(images):
        img = images_denorm[idx].permute(1, 2, 0).numpy().clip(0, 1)
        ax.imshow(img)
        ax.set_title(f"{class_names[labels[idx]]} (Grade {labels[idx]})")
        ax.axis('off')

plt.suptitle('Training Batch Sample (with augmentation)', fontsize=14)
plt.tight_layout()
plt.show()

## 4. Model Architecture Overview

In [None]:
# Create model
model = DRModel(
    backbone=cfg.model.backbone,
    num_classes=cfg.dataset.num_classes,
    head_type=cfg.model.head_type,
    pretrained=cfg.model.pretrained,
    dropout=cfg.model.dropout,
    pooling=cfg.model.pooling
)

print(model)

In [None]:
# Model summary
from torchinfo import summary

input_size = (1, 3, cfg.model.input_size, cfg.model.input_size)
summary(model, input_size=input_size, col_names=['input_size', 'output_size', 'num_params', 'trainable'])

In [None]:
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Non-trainable parameters: {total_params - trainable_params:,}")

## 5. Lightning Module & Training Setup

In [None]:
# Lightning Module Definition
class DRLightningModule(pl.LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.save_hyperparameters()
        self.cfg = cfg
        
        # Model
        self.model = DRModel(
            backbone=cfg.model.backbone,
            num_classes=cfg.dataset.num_classes,
            head_type=cfg.model.head_type,
            pretrained=cfg.model.pretrained,
            dropout=cfg.model.dropout,
            pooling=cfg.model.pooling
        )
        
        # Loss
        if cfg.model.head_type == 'regression':
            self.criterion = nn.MSELoss()
        else:
            weights = torch.tensor(cfg.dataset.class_weights)
            self.criterion = nn.CrossEntropyLoss(weight=weights)
        
        # Metrics storage
        self.training_step_outputs = []
        self.validation_step_outputs = []
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        
        if self.cfg.model.head_type == 'regression':
            outputs = outputs.squeeze()
            loss = self.criterion(outputs, labels.float())
            preds = outputs.round().clamp(0, 4).long()
        else:
            loss = self.criterion(outputs, labels)
            preds = outputs.argmax(dim=1)
        
        self.training_step_outputs.append({
            'loss': loss.detach(),
            'preds': preds.detach(),
            'labels': labels.detach()
        })
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def on_training_epoch_end(self):
        all_preds = torch.cat([x['preds'] for x in self.training_step_outputs])
        all_labels = torch.cat([x['labels'] for x in self.training_step_outputs])
        
        qwk = quadratic_weighted_kappa(all_preds.cpu().numpy(), all_labels.cpu().numpy())
        self.log('train_qwk', qwk, prog_bar=True)
        
        self.training_step_outputs.clear()
    
    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        
        if self.cfg.model.head_type == 'regression':
            outputs = outputs.squeeze()
            loss = self.criterion(outputs, labels.float())
            preds = outputs.round().clamp(0, 4).long()
        else:
            loss = self.criterion(outputs, labels)
            preds = outputs.argmax(dim=1)
        
        self.validation_step_outputs.append({
            'loss': loss.detach(),
            'preds': preds.detach(),
            'labels': labels.detach(),
            'outputs': outputs.detach()
        })
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        return loss
    
    def on_validation_epoch_end(self):
        all_preds = torch.cat([x['preds'] for x in self.validation_step_outputs])
        all_labels = torch.cat([x['labels'] for x in self.validation_step_outputs])
        all_outputs = torch.cat([x['outputs'] for x in self.validation_step_outputs])
        
        qwk = quadratic_weighted_kappa(all_preds.cpu().numpy(), all_labels.cpu().numpy())
        accuracy = (all_preds == all_labels).float().mean().item()
        
        self.log('val_qwk', qwk, prog_bar=True)
        self.log('val_accuracy', accuracy, prog_bar=True)
        
        self.validation_step_outputs.clear()
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.cfg.training.lr,
            weight_decay=self.cfg.training.weight_decay
        )
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.cfg.training.epochs,
            eta_min=self.cfg.training.lr * 0.01
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch'
            }
        }

In [None]:
# Create Lightning Module
lightning_model = DRLightningModule(cfg)
print("Lightning module created successfully!")

In [None]:
# Setup callbacks
callbacks = [
    ModelCheckpoint(
        dirpath='../checkpoints',
        filename='dr-{epoch:02d}-{val_qwk:.4f}',
        monitor='val_qwk',
        mode='max',
        save_top_k=3,
        save_last=True,
        verbose=True
    ),
    EarlyStopping(
        monitor='val_qwk',
        mode='max',
        patience=5,
        verbose=True
    ),
    LearningRateMonitor(logging_interval='epoch')
]

# Logger
logger = TensorBoardLogger(
    save_dir='../logs',
    name='dr_training'
)

In [None]:
# Create trainer
trainer = pl.Trainer(
    max_epochs=cfg.training.epochs,
    accelerator='auto',
    devices=1,
    precision='16-mixed',  # Mixed precision for faster training
    callbacks=callbacks,
    logger=logger,
    accumulate_grad_batches=cfg.training.accumulate_grad_batches,
    gradient_clip_val=cfg.training.gradient_clip_val,
    log_every_n_steps=10,
    deterministic=False,
    enable_progress_bar=True
)

print("Trainer configured successfully!")

## 6. Training

In [None]:
# Start training!
# Uncomment the line below to train

# trainer.fit(lightning_model, datamodule)

In [None]:
# View training logs with TensorBoard
# Run this in a terminal:
# tensorboard --logdir=logs/dr_training

# Or use the magic command in Jupyter:
# %load_ext tensorboard
# %tensorboard --logdir ../logs/dr_training

## 7. Threshold Optimization

For regression output, we need to find optimal thresholds to convert continuous predictions to classes.

In [None]:
from src.utils import ThresholdOptimizer

# Load best checkpoint
checkpoint_path = '../checkpoints/last.ckpt'  # or best checkpoint

if os.path.exists(checkpoint_path):
    trained_model = DRLightningModule.load_from_checkpoint(checkpoint_path, cfg=cfg)
    trained_model.eval()
    trained_model.freeze()
    print("Model loaded from checkpoint!")
else:
    print("No checkpoint found. Using untrained model for demonstration.")
    trained_model = lightning_model

In [None]:
# Collect validation predictions
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trained_model = trained_model.to(device)

all_preds = []
all_labels = []

val_loader = datamodule.val_dataloader()

with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device)
        outputs = trained_model(images).squeeze().cpu().numpy()
        all_preds.extend(outputs)
        all_labels.extend(labels.numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

print(f"Collected {len(all_preds)} predictions")
print(f"Prediction range: [{all_preds.min():.3f}, {all_preds.max():.3f}]")

In [None]:
# Optimize thresholds
optimizer = ThresholdOptimizer(num_classes=5)
optimal_thresholds = optimizer.optimize(all_preds, all_labels)

print(f"Optimal thresholds: {optimal_thresholds}")

In [None]:
# Compare default vs optimized thresholds
default_thresholds = [0.5, 1.5, 2.5, 3.5]

# Apply thresholds
def apply_thresholds(preds, thresholds):
    classes = np.zeros_like(preds, dtype=int)
    for i, thresh in enumerate(thresholds):
        classes[preds > thresh] = i + 1
    return classes

preds_default = apply_thresholds(all_preds, default_thresholds)
preds_optimized = apply_thresholds(all_preds, optimal_thresholds)

qwk_default = quadratic_weighted_kappa(preds_default, all_labels)
qwk_optimized = quadratic_weighted_kappa(preds_optimized, all_labels)

print(f"QWK with default thresholds: {qwk_default:.4f}")
print(f"QWK with optimized thresholds: {qwk_optimized:.4f}")
print(f"Improvement: {(qwk_optimized - qwk_default) * 100:.2f}%")

## 8. Model Evaluation

In [None]:
# Confusion Matrix
from sklearn.metrics import confusion_matrix, classification_report

cm = confusion_matrix(all_labels, preds_optimized)

plt.figure(figsize=(10, 8))
sns.heatmap(
    cm, 
    annot=True, 
    fmt='d', 
    cmap='Blues',
    xticklabels=['No DR', 'Mild', 'Moderate', 'Severe', 'PDR'],
    yticklabels=['No DR', 'Mild', 'Moderate', 'Severe', 'PDR']
)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title(f'Confusion Matrix (QWK: {qwk_optimized:.4f})')
plt.tight_layout()
plt.show()

In [None]:
# Classification Report
class_names = ['No DR', 'Mild', 'Moderate', 'Severe', 'PDR']
print(classification_report(all_labels, preds_optimized, target_names=class_names))

In [None]:
# Per-class metrics visualization
report = classification_report(all_labels, preds_optimized, target_names=class_names, output_dict=True)

metrics_df = pd.DataFrame(report).T.iloc[:-3]  # Exclude avg rows

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

metrics = ['precision', 'recall', 'f1-score']
colors = ['steelblue', 'darkorange', 'forestgreen']

for ax, metric, color in zip(axes, metrics, colors):
    metrics_df[metric].plot(kind='bar', ax=ax, color=color, edgecolor='black')
    ax.set_title(f'{metric.capitalize()} by Class')
    ax.set_ylabel(metric.capitalize())
    ax.set_xlabel('DR Severity')
    ax.set_ylim(0, 1)
    ax.tick_params(axis='x', rotation=45)
    
    # Add value labels
    for i, v in enumerate(metrics_df[metric]):
        ax.text(i, v + 0.02, f'{v:.2f}', ha='center', fontsize=10)

plt.tight_layout()
plt.show()

In [None]:
# Save optimized thresholds
import json

thresholds_path = '../checkpoints/thresholds.json'
with open(thresholds_path, 'w') as f:
    json.dump({
        'thresholds': optimal_thresholds.tolist() if hasattr(optimal_thresholds, 'tolist') else list(optimal_thresholds),
        'qwk': float(qwk_optimized),
        'num_samples': len(all_labels)
    }, f, indent=2)

print(f"Thresholds saved to {thresholds_path}")

## 9. Export Model for Deployment

In [None]:
# Export to ONNX
trained_model.eval()
trained_model = trained_model.cpu()

dummy_input = torch.randn(1, 3, cfg.model.input_size, cfg.model.input_size)
onnx_path = '../checkpoints/dr_model.onnx'

torch.onnx.export(
    trained_model.model,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

print(f"Model exported to {onnx_path}")

## Summary

In this notebook, we:
1. ✅ Loaded and configured the training pipeline
2. ✅ Verified data loading and augmentation
3. ✅ Built the model architecture (EfficientNet-B5 + Regression head)
4. ✅ Set up PyTorch Lightning training with callbacks
5. ✅ Optimized regression-to-class thresholds
6. ✅ Evaluated model performance (QWK, confusion matrix)
7. ✅ Exported model for deployment

### Next Steps
- Run the training with `trainer.fit()`
- Fine-tune hyperparameters based on validation QWK
- Try cross-validation for more robust evaluation
- Explore Grad-CAM for model interpretability