# RETFound: Fine-Tuning and Evaluation for Diabetic Retinopathy

This notebook provides a complete, end-to-end workflow to reproduce the fine-tuning of the RETFound model on the IDRiD dataset for diabetic retinopathy classification.

# ==============================================================================
# ## 1. Environment Setup
#
# This section clones the original RETFound_MAE repository, installs all
# necessary dependencies, and applies essential patches to the scripts to ensure
# they run correctly in the Colab environment.
# ==============================================================================

In [None]:
import os

# --- Clone Repository & Set Working Directory ---
repo_dir = 'RETFound_MAE'

if not os.path.isdir(repo_dir):
  print("Cloning repository...")
  !git clone https://github.com/rmaphoh/RETFound_MAE/
else:
  print("Repository already exists. Skipping clone.")

%cd {repo_dir}

# --- Install Dependencies ---
print("\n⏳ Installing dependencies...")
!pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121 -q
!pip install timm==0.9.16 pandas==2.2.2 scikit-learn -q
print("✅ Dependencies installed.")

# --- Apply Initial Patches ---
# Note: Further patches are applied programmatically in the next section.
print("\n⚙️ Patching script to bypass errors...")
!sed -i "s/if True:  # args.distributed:/if args.distributed:/g" main_finetune.py
print("✅ Script patched.")

print("\n🎉 Setup complete. Ready for the next step.")

# ==============================================================================
# ## 2. Generate Core Scripts
#
# We programmatically overwrite `main_finetune.py` and `engine_finetune.py`
# to incorporate all the fixes and improvements we developed. This includes
# robust metric calculation (AUROC, Specificity), error handling, and disabling
# conflicting libraries.
# ==============================================================================

In [None]:
# --- Create main_finetune.py ---
main_finetune_content = r"""
import argparse, datetime, json, numpy as np, os, time
from pathlib import Path
import torch, torch.backends.cudnn as cudnn
from timm.data.mixup import Mixup
import models_vit as models, util.lr_decay as lrd, util.misc as misc
from util.datasets import build_dataset
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from huggingface_hub import hf_hub_download
from engine_finetune import train_one_epoch, evaluate
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)

def get_args_parser():
    parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--epochs', default=50, type=int)
    parser.add_argument('--accum_iter', default=1, type=int)
    parser.add_argument('--model', default='vit_large_patch16', type=str)
    parser.add_argument('--input_size', default=256, type=int)
    parser.add_argument('--drop_path', type=float, default=0.2)
    parser.add_argument('--clip_grad', type=float, default=None)
    parser.add_argument('--weight_decay', type=float, default=0.05)
    parser.add_argument('--lr', type=float, default=None)
    parser.add_argument('--blr', type=float, default=5e-3)
    parser.add_argument('--layer_decay', type=float, default=0.65)
    parser.add_argument('--min_lr', type=float, default=1e-6)
    parser.add_argument('--warmup_epochs', type=int, default=10)
    parser.add_argument('--finetune', default='', type=str)
    parser.add_argument('--task', default='', type=str)
    parser.add_argument('--global_pool', action='store_true', default=True)
    parser.add_argument('--data_path', default='./data/', type=str)
    parser.add_argument('--nb_classes', default=5, type=int)
    parser.add_argument('--output_dir', default='./output_dir')
    parser.add_argument('--log_dir', default='./output_logs')
    parser.add_argument('--device', default='cuda')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='')
    parser.add_argument('--start_epoch', default=0, type=int)
    parser.add_argument('--eval', action='store_true')
    parser.add_argument('--num_workers', default=2, type=int)
    parser.add_argument('--pin_mem', action='store_true', default=True)
    parser.add_argument('--world_size', default=1, type=int)
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://')
    return parser

def main(args):
    misc.init_distributed_mode(args)
    device = torch.device(args.device)
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    dataset_test = build_dataset(is_train='test', args=args)
    if not args.eval:
        dataset_train = build_dataset(is_train='train', args=args)
        dataset_val = build_dataset(is_train='val', args=args)
    else:
        dataset_train, dataset_val = None, None

    if args.distributed:
        # ... (distributed setup omitted for Colab clarity)
        pass
    else:
        sampler_test = torch.utils.data.SequentialSampler(dataset_test)
        if not args.eval:
            sampler_train = torch.utils.data.RandomSampler(dataset_train)
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    log_writer = None # Disabled for Colab

    data_loader_test = torch.utils.data.DataLoader(dataset_test, sampler=sampler_test, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
    if not args.eval:
        data_loader_train = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True)
        data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)

    model = models.__dict__[args.model](img_size=args.input_size, num_classes=args.nb_classes, drop_path_rate=args.drop_path, global_pool=args.global_pool)

    if args.finetune and not args.eval:
        print(f"Downloading pre-trained weights from Hugging Face: {args.finetune}")
        checkpoint_path = hf_hub_download(repo_id=f'YukunZhou/{args.finetune}', filename=f'{args.finetune}.pth')
        checkpoint = torch.load(checkpoint_path, map_location='cpu')['model']
        msg = model.load_state_dict(checkpoint, strict=False)
        print(f"Loaded pre-trained checkpoint from {args.finetune} with message: {msg}")

    if args.resume:
        print(f"Resuming from checkpoint: {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')['model']
        model.load_state_dict(checkpoint, strict=False)

    model.to(device)
    print(f'Number of model params (M): {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1.e6:.2f}')

    if args.eval:
        evaluate(data_loader_test, model, device, args, 0, 'test', args.nb_classes, log_writer)
        return

    eff_batch_size = args.batch_size * misc.get_world_size()
    if args.lr is None: args.lr = args.blr * eff_batch_size / 256
    print(f"Actual lr: {args.lr:.2e}")
    param_groups = lrd.param_groups_lrd(model, args.weight_decay, no_weight_decay_list=model.no_weight_decay(), layer_decay=args.layer_decay)
    optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
    loss_scaler = NativeScaler()
    criterion = torch.nn.CrossEntropyLoss()

    print(f"--- Starting Training for {args.epochs} epochs ---")
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args=args)
        val_stats, _ = evaluate(data_loader_val, model, device, args, epoch, 'val', args.nb_classes, log_writer=log_writer)
        print(f"EPOCH:{epoch} | Val Acc: {val_stats['acc1']:.1f}%")
        if max_accuracy < val_stats["acc1"]:
            max_accuracy = val_stats["acc1"]
            misc.save_model(args=args, model=model, model_without_ddp=model, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, mode='best')
        print(f'Max accuracy: {max_accuracy:.2f}%')

if __name__ == '__main__':
    args = get_args_parser().parse_args()
    if args.output_dir: Path(os.path.join(args.output_dir, args.task)).mkdir(parents=True, exist_ok=True)
    main(args)
"""
with open("main_finetune.py", "w") as f:
    f.write(main_finetune_content)
print("✅ `main_finetune.py` generated.")


# --- Create engine_finetune.py ---
engine_finetune_content = r"""
import math, sys
from typing import Iterable, Optional
import torch, torch.nn.functional as F
import numpy as np
from sklearn.metrics import roc_auc_score, classification_report, confusion_matrix
from timm.utils import accuracy
import util.misc as misc, util.lr_sched as lr_sched

def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, args=None):
    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = f'Epoch: [{epoch}]'
    for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, 20, header)):
        lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
        samples, targets = samples.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        with torch.cuda.amp.autocast():
            loss = criterion(model(samples), targets)
        loss_value = loss.item()
        if not math.isfinite(loss_value): sys.exit(f"Loss is {loss_value}, stopping training")
        loss_scaler(loss, optimizer, parameters=model.parameters())
        optimizer.zero_grad()
        torch.cuda.synchronize()
        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

@torch.no_grad()
def evaluate(data_loader, model, device, args, epoch, mode, num_class, log_writer=None):
    criterion = torch.nn.CrossEntropyLoss()
    metric_logger = misc.MetricLogger(delimiter="  ")
    header, model.eval(), all_preds, all_labels, all_probs = 'Test:', [], [], []
    for batch in metric_logger.log_every(data_loader, 10, header):
        images, target = batch[0].to(device, non_blocking=True), batch[-1].to(device, non_blocking=True)
        with torch.cuda.amp.autocast():
            output = model(images)
            loss = criterion(output, target)
        preds, probs = torch.argmax(output, dim=1), F.softmax(output, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(target.cpu().numpy())
        all_probs.extend(probs.cpu().detach().numpy())
        metric_logger.update(loss=loss.item())
        metric_logger.meters['acc1'].update(accuracy(output, target, topk=(1,))[0].item(), n=images.shape[0])
    print(f'* Acc@1 {metric_logger.acc1.global_avg:.3f} loss {metric_logger.loss.global_avg:.3f}')
    all_labels, all_preds, all_probs = np.array(all_labels), np.array(all_preds), np.array(all_probs)
    print("\n--- Performance Metrics ---")
    if len(np.unique(all_labels)) > 1 and len(np.unique(all_preds)) > 1:
        try:
            print(f"AUROC (Macro): {roc_auc_score(all_labels, all_probs, multi_class='ovr', average='macro'):.4f}")
        except Exception as e: print(f"Could not calculate AUROC: {e}")
    else: print("Skipping AUROC: not enough classes in labels or predictions.")
    cm, fp = confusion_matrix(all_labels, all_preds), cm.sum(axis=0) - np.diag(cm)
    tn = cm.sum() - (fp + (cm.sum(axis=1) - np.diag(cm)) + np.diag(cm))
    print(f"Specificity (Macro): {np.mean(tn / (tn + fp)):.4f}")
    print("\n--- Classification Report ---")
    print(classification_report(all_labels, all_preds, target_names=[f'Class {i}' for i in range(num_class)], digits=4, zero_division=0))
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, 'results'
"""
with open("engine_finetune.py", "w") as f:
    f.write(engine_finetune_content)
print("✅ `engine_finetune.py` generated.")

# ==============================================================================
# ## 3. Data Preparation
#
# Download and unzip the IDRiD dataset for 5-class DR grading.
# ==============================================================================

In [None]:
data_folder_path = '../IDRiD_data'

if not os.path.isdir(data_folder_path):
    print("⏳ Dataset folder not found. Downloading and unzipping...")
    !gdown --id 1c6zexA705z-ANEBNXJOBsk6uCvRnzmr3 -O ../IDRiD_data.zip
    !unzip -q -o ../IDRiD_data.zip -d ../
    print("✅ Dataset downloaded and ready.")
else:
    print(f"✅ Dataset already exists at '{data_folder_path}'. Skipping download.")

# ==============================================================================
# ## 4. Workflow A: Fine-Tune and Evaluate
#
# This is the primary workflow. It involves authenticating with Hugging Face,
# running the fine-tuning process on the IDRiD dataset, and then evaluating
# the best model produced during that run.
# ==============================================================================

In [None]:
# --- Step 4.1: Authenticate with Hugging Face ---
from huggingface_hub import notebook_login
print("🚀 Please log in to Hugging Face to download the base model.")
notebook_login()

In [None]:
# --- Step 4.2: Run Fine-Tuning ---
import torch
from pathlib import Path
from main_finetune import get_args_parser, main as finetune_main

print("\n🚀 Starting fine-tuning...")
ft_parser = get_args_parser()
args_ft = ft_parser.parse_args([
    '--model', 'RETFound_mae',
    '--epochs', '20',
    '--blr', '5e-3',
    '--data_path', '../IDRiD_data',
    '--task', 'RETFound_finetune_IDRiD',
    '--output_dir', './output_dir',
    '--log_dir', './log_dir',
    '--finetune', 'RETFound_mae_meh'
])
Path(os.path.join(args_ft.output_dir, args_ft.task)).mkdir(parents=True, exist_ok=True)
finetune_main(args_ft)
print("\n✅ Fine-tuning finished.")

In [None]:
# --- Step 4.3: Evaluate the Fine-Tuned Model ---
print("\n" + "="*50)
print("\n🚀 Starting inference on the BEST model from fine-tuning...")

best_model_path = f'./output_dir/{args_ft.task}/checkpoint-best.pth'

if not os.path.exists(best_model_path):
    print(f"❌ ERROR: The expected best model was not found at {best_model_path}")
else:
    print(f"Found best model at: {best_model_path}")
    eval_parser = get_args_parser()
    args_eval_ft = eval_parser.parse_args([
        '--model', 'RETFound_mae',
        '--eval',
        '--data_path', '../IDRiD_data',
        '--resume', best_model_path
    ])
    finetune_main(args_eval_ft)

print("\n✅ Fine-tuned model evaluation finished.")

# ==============================================================================
# ## 5. Workflow B: Evaluate a Pre-Trained Classifier
#
# This workflow allows you to skip fine-tuning and directly evaluate a
# model that has already been trained for this specific task.
# ==============================================================================

In [None]:
# --- Step 5.1: Download Pre-Trained Classifier ---
checkpoint_path = './RETFound_IDRiD_Classifier.pth'
if not os.path.isfile(checkpoint_path):
    print("⏳ Model checkpoint not found. Downloading from Google Drive...")
    # Note: Using the gdown ID for the classifier model
    !gdown --id 1b0grTwARX1cXnYnMB3ZJZES26aMkgkvZ -O {checkpoint_path}
    print("✅ Model download complete.")
else:
    print(f"✅ Pre-trained classifier already exists at '{checkpoint_path}'.")

In [None]:
# --- Step 5.2: Run Inference ---
print("\n" + "="*50)
print("\n🚀 Starting inference on the PRE-TRAINED classifier...")
eval_parser_pretrained = get_args_parser()
args_eval_pretrained = eval_parser_pretrained.parse_args([
    '--model', 'RETFound_mae',
    '--eval',
    '--data_path', '../IDRiD_data',
    '--resume', checkpoint_path
])
finetune_main(args_eval_pretrained)
print("\n✅ Pre-trained classifier evaluation finished.")