# Fundus Transfer‑Learning Pipeline

This notebook‑style script:
1. Installs & imports dependencies  
2. Sets a global random seed  
3. Clones RETFound_MAE & sets PYTHONPATH  
4. Downloads pretrained weights  
5. Detects your fundus dataset structure  
6. Instantiates & loads the pretrained ViT  
7. Inserts adapters for method=`'adapters'`  
8. Defines four transfer‑learning modes  
9. Runs a single experiment (`run_single`) and logs to CSV  
10. Visualizes learning curves & confusion matrix  

In [1]:
# 1) Install & import dependencies
import os, sys, random
SEED = 42
random.seed(SEED)
import numpy as np
np.random.seed(SEED)
import torch
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

try:
    import timm, gdown, pandas as pd, matplotlib.pyplot as plt, seaborn as sns
    from torchvision import transforms
    from torchvision.datasets import ImageFolder
    from torch.utils.data import DataLoader
    from tqdm import tqdm
    from sklearn.metrics import confusion_matrix, classification_report
except ImportError:
    os.system(f"{sys.executable} -m pip install --quiet torch torchvision timm gdown pandas numpy matplotlib seaborn scikit-learn tqdm")
    import timm, gdown, pandas as pd, matplotlib.pyplot as plt, seaborn as sns
    from torchvision import transforms
    from torchvision.datasets import ImageFolder
    from torch.utils.data import DataLoader
    from tqdm import tqdm
    from sklearn.metrics import confusion_matrix, classification_report

torch.backends.cudnn.benchmark = True

In [2]:
# 2) Clone RETFound_MAE & fix PYTHONPATH
def ensure_repo(name, url):
    if not os.path.isdir(name):
        print(f"Cloning {name} from {url}…")
        os.system(f"git clone {url} {name}")

ensure_repo("RETFound_MAE", "https://huggingface.co/open-eye/RETFound_MAE")
repo_path = os.path.join(os.getcwd(), "RETFound_MAE")
if not os.path.isdir(repo_path):
    raise FileNotFoundError(f"RETFound_MAE not found at {repo_path}")
sys.path.insert(0, repo_path)
print("✔ Added RETFound_MAE to PYTHONPATH")

✔ Added RETFound_MAE to PYTHONPATH


In [3]:
# 3) Download pretrained fundus weights
WEIGHTS = "RETFound_cfp_weights.pth"
if not os.path.isfile(WEIGHTS):
    print("Downloading pretrained fundus weights…")
    os.system(f"gdown --quiet --id 1l62zbWUFTlp214SvK6eMwPQZAzcwoeBE -O {WEIGHTS}")
print("✔ Model weights ready:", WEIGHTS)

✔ Model weights ready: RETFound_cfp_weights.pth


In [4]:
# 4) Detect dataset folder (handles nested color_fundus_eye/color_fundus_eye)
import os

# top‐level folder name (change if yours differs)
base = "color_fundus_eye"

# Look for either color_fundus_eye/train or color_fundus_eye/color_fundus_eye/train
if os.path.isdir(os.path.join(base, "train")):
    DATA_DIR = base
elif os.path.isdir(os.path.join(base, base, "train")):
    DATA_DIR = os.path.join(base, base)
else:
    raise FileNotFoundError(f"Cannot find 'train' under {base} or {base}/{base}")

train_dir = os.path.join(DATA_DIR, "train")
test_dir  = os.path.join(DATA_DIR, "test")

# Auto‐discover your classes from the folders
classes     = sorted(os.listdir(train_dir))
num_classes = len(classes)

# Quick sanity‐print
n_train = sum(len(files) for _,_,files in os.walk(train_dir))
n_test  = sum(len(files) for _,_,files in os.walk(test_dir))
print(f"DATA_DIR    = {DATA_DIR}")
print(f"Detected {num_classes} classes: {classes}")
print(f"Train dir   = {train_dir} ({n_train} images)")
print(f"Test dir    = {test_dir} ({n_test} images)")

DATA_DIR    = color_fundus_eye/color_fundus_eye
Detected 10 classes: ['Central Serous Chorioretinopathy [Color Fundus]', 'Diabetic Retinopathy', 'Disc Edema', 'Glaucoma', 'Healthy', 'Macular Scar', 'Myopia', 'Pterygium', 'Retinal Detachment', 'Retinitis Pigmentosa']
Train dir   = color_fundus_eye/color_fundus_eye/train (12989 images)
Test dir    = color_fundus_eye/color_fundus_eye/test (3253 images)


In [5]:
# 5) Instantiate & load pretrained ViT
import models_vit
from util.pos_embed import interpolate_pos_embed
from timm.layers import trunc_normal_

ckpt = torch.load(WEIGHTS, map_location="cpu", weights_only=False)
ckpt_model = ckpt['model']

model = models_vit.__dict__['vit_large_patch16'](
    num_classes=num_classes,
    drop_path_rate=0.2
)
# prune mismatched head & load
state = model.state_dict()
for k in ['head.weight','head.bias']:
    if k in ckpt_model and ckpt_model[k].shape != state[k].shape:
        del ckpt_model[k]
interpolate_pos_embed(model, ckpt_model)
msg = model.load_state_dict(ckpt_model, strict=False)
print("Loaded pretrained ViT, missing keys:", msg.missing_keys)
if 'head.weight' in msg.missing_keys:
    trunc_normal_(model.head.weight, std=2e-5)

Loaded pretrained ViT, missing keys: ['head.weight', 'head.bias']


In [6]:
# 6) Insert adapters (for method='adapters')
from torch import nn
D = model.embed_dim
adapter_dim = D // 4  # 1024→256 by default
model.adapter = nn.Sequential(
    nn.Linear(D, adapter_dim),
    nn.ReLU(inplace=True),
    nn.Dropout(0.1),
    nn.Linear(adapter_dim, D),
)
_orig_forward = model.forward_features
def forward_with_adapter(x):
    feat = _orig_forward(x)
    return feat + model.adapter(feat)
model.forward_features = forward_with_adapter
print(f"✅ Adapter inserted: {list(model.adapter)}")
model.eval()

✅ Adapter inserted: [Linear(in_features=1024, out_features=256, bias=True), ReLU(inplace=True), Dropout(p=0.1, inplace=False), Linear(in_features=256, out_features=1024, bias=True)]


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Id

In [7]:
# 7) DataLoaders & Transforms
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split

# ── a) Transforms ───────────────────────────────────────────────────────────
train_tf = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    transforms.ColorJitter(0.2,0.2,0.1,0.05),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])
val_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

# ── b) Load full TRAIN folder ────────────────────────────────────────────────
full_train = ImageFolder(train_dir, transform=train_tf)
print(f"Full TRAIN set size: {len(full_train)} images")

# ── c) Optionally sub‑sample exactly SUBSET_SIZE images ──────────────────────
SUBSET_SIZE = None   # ← set to None or len(full_train) to use all images
if SUBSET_SIZE and SUBSET_SIZE < len(full_train):
    idxs    = list(range(len(full_train)))
    labels  = [full_train.samples[i][1] for i in idxs]
    sub_idxs, _ = train_test_split(
        idxs,
        train_size=SUBSET_SIZE,
        random_state=SEED,
        stratify=labels
    )
    train_ds = Subset(full_train, sub_idxs)
    print(f"Subsampled TRAIN → {len(train_ds)} images (balanced).")
else:
    train_ds = full_train

# ── d) Load TEST folder as your validation set ──────────────────────────────
val_ds = ImageFolder(test_dir, transform=val_tf)
print(f"Full TEST set size:  {len(val_ds)} images")

# ── e) Build DataLoaders ───────────────────────────────────────────────────
batch_size = 16
train_loader = DataLoader(
    train_ds, batch_size=batch_size, shuffle=True,
    num_workers=4, pin_memory=True
)
val_loader   = DataLoader(
    val_ds,   batch_size=batch_size, shuffle=False,
    num_workers=4, pin_memory=True
)
print(f"DataLoaders → Train: {len(train_ds)},  Val: {len(val_ds)}")

Full TRAIN set size: 12989 images
Full TEST set size:  3253 images
DataLoaders → Train: 12989,  Val: 3253


In [8]:
# 8) Transfer‐learning config helper
def apply_transfer_config(model, method, partial_blocks=4):
    # freeze all
    for p in model.parameters(): p.requires_grad = False
    # unfreeze per method
    if method == 'linear_probe':
        for p in model.head.parameters(): p.requires_grad = True
    elif method == 'partial_ft':
        for blk in model.blocks[-partial_blocks:]:
            for p in blk.parameters(): p.requires_grad = True
        for p in model.head.parameters(): p.requires_grad = True
    elif method == 'full_ft':
        for p in model.parameters(): p.requires_grad = True
    elif method == 'adapters':
        for n,p in model.named_parameters():
            if n.startswith('adapter') or n.startswith('head'):
                p.requires_grad = True
    else:
        raise ValueError(f"Unknown method '{method}'")

In [9]:
# 9) Single‐run experiment & CSV logging
from datetime import datetime
import csv, time

# path to your CSV
CSV = "fundus_transfer_experiments.csv"
FIELDNAMES = ['exp_no','method','val_acc','train_loss','time_s','timestamp']

def run_single(exp_no, method):
    # 1) Reload the bare ViT and pretrained weights
    tmp = models_vit.__dict__['vit_large_patch16'](
        num_classes=num_classes, drop_path_rate=0.2
    )
    interpolate_pos_embed(tmp, ckpt_model)
    tmp.load_state_dict(
        {k: v for k, v in ckpt_model.items() if k in tmp.state_dict()},
        strict=False
    )

    # 2) If using adapters, copy + rebind forward to use tmp.adapter
    if method == 'adapters':
        tmp.adapter = model.adapter
        orig_forward = tmp.forward_features
        def forward_with_tmp_adapter(x):
            feat = orig_forward(x)
            return feat + tmp.adapter(feat)
        tmp.forward_features = forward_with_tmp_adapter

    # 3) Freeze / unfreeze per transfer‐learning method
    apply_transfer_config(tmp, method)

    # 4) Move everything (backbone + adapter) onto device
    tmp = tmp.to(device)

    print(f"\n▶ Starting experiment #{exp_no} — method={method}")
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, tmp.parameters()),
        lr=3e-5,
        weight_decay=1e-4
    )
    crit = nn.CrossEntropyLoss()

    # 5) Time‑stamp start
    start_time = time.time()

    history = {'loss': [], 'val_acc': []}
    for epoch in range(1, 6):
        # — Training epoch —
        tmp.train()
        running_loss = 0.0
        for imgs, labels in tqdm(train_loader,
                                 desc=f"[{method}] Epoch {epoch}/5 (train)",
                                 leave=False):
            imgs   = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            optimizer.zero_grad()
            loss = crit(tmp(imgs), labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * imgs.size(0)
        avg_loss = running_loss / len(train_loader.dataset)
        history['loss'].append(avg_loss)
        print(f"  Epoch {epoch}/5 ▶ Train Loss: {avg_loss:.4f}")

        # — Validation epoch —
        tmp.eval()
        correct = total = 0
        for imgs, labels in tqdm(val_loader,
                                 desc=f"[{method}] Epoch {epoch}/5 (val)",
                                 leave=False):
            imgs   = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            preds = tmp(imgs).argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        val_acc = correct / total
        history['val_acc'].append(val_acc)
        print(f"  Epoch {epoch}/5 ▶ Val Acc: {val_acc:.4%}")

    # 6) Measure elapsed time
    elapsed = time.time() - start_time

    # 7) Build and append the CSV row
    row = {
        'exp_no':     exp_no,
        'method':     method,
        'val_acc':    history['val_acc'][-1],
        'train_loss': history['loss'][-1],
        'time_s':     round(elapsed, 2),
        'timestamp':  datetime.now().strftime('%Y%m%d_%H%M%S'),
    }
    with open(CSV, 'a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=FIELDNAMES)
        writer.writerow(row)
    print(f">>> Logged exp#{exp_no} [{method}] → acc={row['val_acc']:.4%}")

    # 8) Return the fine‑tuned model and its history
    return tmp, history



In [None]:
# 10) Run your experiment by changing exp_no & method here:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import torch.nn as nn
#best_model, history = run_single(exp_no=2, method='linear_probe')

# instead of overwriting best_model & history, do:
all_models   = {}
all_histories= {}

methods = ['linear_probe','partial_ft','full_ft','adapters']
for exp_no, method in enumerate(methods, start=1):
    print(f"\n=== Experiment #{exp_no}: {method} ===")
    model_ft, history = run_single(exp_no=exp_no, method=method)
    all_models[method]    = model_ft
    all_histories[method] = history

    final_acc  = history['val_acc'][-1]
    final_loss = history['loss'][-1]
    print(f">>> {method}: val_acc={final_acc:.4%}, loss={final_loss:.4f}")

print("\n✅ All 4 experiments complete!  Check fundus_transfer_experiments.csv for results.")


=== Experiment #1: linear_probe ===

▶ Starting experiment #1 — method=linear_probe


                                                                                   

  Epoch 1/5 ▶ Train Loss: 2.0649


                                                                                 

  Epoch 1/5 ▶ Val Acc: 27.7897%


                                                                                   

  Epoch 2/5 ▶ Train Loss: 2.0079


                                                                                 

  Epoch 2/5 ▶ Val Acc: 28.4968%


                                                                                   

  Epoch 3/5 ▶ Train Loss: 1.9855


                                                                                 

  Epoch 3/5 ▶ Val Acc: 30.8331%


                                                                                   

  Epoch 4/5 ▶ Train Loss: 1.9683


                                                                                 

  Epoch 4/5 ▶ Val Acc: 30.7716%


                                                                                   

  Epoch 5/5 ▶ Train Loss: 1.9570


                                                                                 

  Epoch 5/5 ▶ Val Acc: 30.6794%
>>> Logged exp#1 [linear_probe] → acc=30.6794%
>>> linear_probe: val_acc=30.6794%, loss=1.9570

=== Experiment #2: partial_ft ===

▶ Starting experiment #2 — method=partial_ft


                                                                                 

  Epoch 1/5 ▶ Train Loss: 1.6113


                                                                               

  Epoch 1/5 ▶ Val Acc: 66.2465%


                                                                                 

  Epoch 2/5 ▶ Train Loss: 0.9940


                                                                               

  Epoch 2/5 ▶ Val Acc: 76.0221%


                                                                                 

  Epoch 3/5 ▶ Train Loss: 0.8063


                                                                               

  Epoch 3/5 ▶ Val Acc: 78.6658%


                                                                                 

  Epoch 4/5 ▶ Train Loss: 0.7245


                                                                               

  Epoch 4/5 ▶ Val Acc: 79.8955%


                                                                                 

  Epoch 5/5 ▶ Train Loss: 0.6581


                                                                               

  Epoch 5/5 ▶ Val Acc: 81.4940%
>>> Logged exp#2 [partial_ft] → acc=81.4940%
>>> partial_ft: val_acc=81.4940%, loss=0.6581

=== Experiment #3: full_ft ===

▶ Starting experiment #3 — method=full_ft


                                                                              

  Epoch 1/5 ▶ Train Loss: 1.0971


                                                                            

  Epoch 1/5 ▶ Val Acc: 76.9751%


                                                                              

  Epoch 2/5 ▶ Train Loss: 0.6025


                                                                            

  Epoch 2/5 ▶ Val Acc: 80.1107%


                                                                              

  Epoch 3/5 ▶ Train Loss: 0.4898


                                                                            

  Epoch 3/5 ▶ Val Acc: 82.0166%


[full_ft] Epoch 4/5 (train):  73%|███████▎  | 593/812 [08:56<02:11,  1.66it/s]

In [None]:
# 11) Visualize the last run’s curves
import pandas as pd

df = pd.read_csv(CSV)

plt.figure(figsize=(8,4))
for method in methods:
    sub = df[df.method==method]
    plt.plot(sub.exp_no, sub.val_acc, marker='o', label=method)
plt.xlabel('Experiment #')
plt.ylabel('Validation Accuracy')
plt.title('Val Acc by Method')
plt.legend()
plt.show()

In [None]:
# 12) Confusion matrix for the last run
from sklearn.metrics import confusion_matrix, classification_report

chosen = 'linear_probe'   # or pick linear_probe, partial_ft, full_ft
print(f"←←← Confusion matrix for {chosen} →→→")

model_to_plot = all_models[chosen]
model_to_plot.eval()

all_preds, all_labels = [], []
with torch.no_grad():
    for imgs, labels in val_loader:
        imgs = imgs.to(device)
        preds = model_to_plot(imgs).argmax(1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())

print(classification_report(all_labels, all_preds, target_names=classes))

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt='d',
            xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted'); plt.ylabel('Actual')
plt.title(f'Confusion Matrix — {chosen}')
plt.show()