# Dual Eyes Training


Install Dependencies

In [1]:
%%capture
!pip install -q kagglehub torch torchvision scikit-learn pandas opencv-python tqdm

Import python libraries

In [2]:

import os
import cv2
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import kagglehub
from tqdm import tqdm # tqdm for progress bars


Download DataSet

In [3]:
# 1. Download Dataset (Official ODIR-5K)
path = kagglehub.dataset_download("andrewmvd/ocular-disease-recognition-odir5k")
print("Dataset path:", path)
IMG_DIR = os.path.join(path, "ODIR-5K/ODIR-5K/Training Images")
CSV_PATH = os.path.join(path, "full_df.csv")
IMG_SIZE = 224
BATCH_SIZE = 4
ACCUMULATION_STEPS = 8
EPOCHS = 10
LEARNING_RATE = 1e-4
NUM_CLASSES = 8
NUM_WORKERS = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
df = pd.read_csv(CSV_PATH)
train_df, val_df = train_test_split(df, test_size=0.15, random_state=42)
FAST_IMG_DIR = "tmp/processed_512_images"
os.makedirs(FAST_IMG_DIR, exist_ok=True)


Dataset path: /home/ray/.cache/kagglehub/datasets/andrewmvd/ocular-disease-recognition-odir5k/versions/2


DataSet With Ben Graham's Preprocessing
This class implements the Ben Graham Preprocessing on-the-fly and loads both eyes for one patient.
ref : 
- https://scholar.google.com/citations?view_op=view_citation&hl=en&user=jQkkhlkAAAAJ&citation_for_view=jQkkhlkAAAAJ:sNmaIFBj_lkC
- https://scholar.google.com/citations?user=jQkkhlkAAAAJ&hl=en


In [4]:
def ben_graham_prep(img, sigmaX=10):
    """Enhances vessels and normalizes lighting."""
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # Circular Crop: Find non-black pixels and crop
    mask = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) > 10
    if np.any(mask):
        coords = np.argwhere(mask)
        y0, x0 = coords.min(axis=0)
        y1, x1 = coords.max(axis=0)
        img = img[y0:y1, x0:x1]
    
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    blurred = cv2.GaussianBlur(img, (0, 0), sigmaX)
    enhanced = cv2.addWeighted(img, 4, blurred, -4, 128)
    return enhanced

In [5]:
def run_offline_prep(df, raw_dir, img_prep_func, save_dir):
    print("ðŸš€ Starting Offline Pre-processing (Ben Graham)...")
    all_images = pd.concat([df['Left-Fundus'], df['Right-Fundus']]).unique()
    for img_name in tqdm(all_images):
        save_path = os.path.join(save_dir, img_name)
        if not os.path.exists(save_path):
            img = cv2.imread(os.path.join(raw_dir, img_name))
            # Ben Graham Logic
            enhanced = img_prep_func(img)
            cv2.imwrite(save_path, cv2.cvtColor(enhanced, cv2.COLOR_RGB2BGR))

In [None]:
run_offline_prep(df, IMG_DIR, ben_graham_prep, FAST_IMG_DIR)

ðŸš€ Starting Offline Pre-processing (Ben Graham)...


 11%|â–ˆ         | 754/6716 [00:32<04:25, 22.48it/s]

In [None]:
class FastODIRDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        # Target classes: Normal, Diabetes, Glaucoma, Cataract, AMD, Hypertension, Myopia, Other
        self.labels = df[['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        l_img_path = os.path.join(self.img_dir, row['Left-Fundus'])
        r_img_path = os.path.join(self.img_dir, row['Right-Fundus'])
        
        # Load and Preprocess
        l_img = cv2.cvtColor(cv2.imread(l_img_path), cv2.COLOR_BGR2RGB)
        r_img = cv2.cvtColor(cv2.imread(r_img_path), cv2.COLOR_BGR2RGB )
        
        if self.transform:
            l_img = self.transform(l_img)
            r_img = self.transform(r_img)
            
        return l_img, r_img, torch.tensor(self.labels[idx], dtype=torch.float32)

In [None]:
class ODIRDualNet(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        # Using B0 for efficiency, upgrade to B4 for better accuracy
        self.backbone = models.efficientnet_b0(weights='DEFAULT')
        self.feature_dim = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity() # Remove top layer

        self.classifier = nn.Sequential(
            nn.Linear(self.feature_dim * 2, IMG_SIZE),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(IMG_SIZE, num_classes)
        )

    def forward(self, left, right):
        l_feat = self.backbone(left)
        r_feat = self.backbone(right)
        combined = torch.cat((l_feat, r_feat), dim=1)
        return self.classifier(combined)

In [None]:
def find_best_thresholds(y_true, y_probs):
    thresholds = np.linspace(0.1, 0.9, 81)
    best_ts = np.zeros(NUM_CLASSES)
    for i in range(NUM_CLASSES):
        best_f1 = 0
        for t in thresholds:
            score = f1_score(y_true[:, i], (y_probs[:, i] > t).astype(int), zero_division=0)
            if score > best_f1:
                best_f1 = score
                best_ts[i] = t
    return best_ts

Data Loader with ImageNet Transformation


"All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]."
Ref:
- https://docs.pytorch.org/vision/0.9/models.html.

model definition and training

In [None]:
from torch import autocast
from torch.amp.grad_scaler import  GradScaler 

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_loader = DataLoader(FastODIRDataset(train_df, FAST_IMG_DIR, transform), 
                            batch_size=BATCH_SIZE, 
                            shuffle=True,
                            num_workers=NUM_WORKERS,
                            pin_memory=True)
val_loader = DataLoader(FastODIRDataset(val_df, FAST_IMG_DIR, transform), batch_size=BATCH_SIZE,num_workers=NUM_WORKERS, pin_memory=True)

model = ODIRDualNet().to(DEVICE)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scaler = GradScaler(DEVICE)
best_overall_f1 = 0
# âš¡ ACCELERATION: Compile the model (Requires PyTorch 2.0+)
# This can provide a 10-20% speedup in training time
if hasattr(torch, 'compile'):
    model = torch.compile(model)
    print("âœ… Model Compiled for speed.")
for epoch in range(EPOCHS):
        model.train()
        train_loss = 0
        optimizer.zero_grad()

        for i, (l, r, y) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
            l, r, y = l.to(DEVICE), r.to(DEVICE), y.to(DEVICE)
            
            with autocast(device_type=DEVICE):
                preds = model(l, r)
                loss = criterion(preds, y) / ACCUMULATION_STEPS
            
            scaler.scale(loss).backward()
            
            if (i + 1) % ACCUMULATION_STEPS == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            
            train_loss += loss.item() * ACCUMULATION_STEPS

        # --- Validation ---
        model.eval()
        val_preds, val_true = [], []
        with torch.no_grad():
            for l, r, y in val_loader:
                out = torch.sigmoid(model(l.to(DEVICE), r.to(DEVICE)))
                val_preds.append(out.cpu().numpy())
                val_true.append(y.numpy())
        
        val_probs = np.vstack(val_preds)
        val_true = np.vstack(val_true)
        best_ts = find_best_thresholds(val_true, val_probs)
        
        # Calculate Macro F1 with optimized thresholds
        f1 = f1_score(val_true, (val_probs > best_ts).astype(int), average='macro')
        print(f"Loss: {train_loss/len(train_loader):.4f} | Val Macro F1: {f1:.4f}")

        if f1 > best_overall_f1:
            best_overall_f1 = f1
            torch.save({'model': model.state_dict(), 'thresholds': best_ts}, "best_odir_f1.pth")
            print("ðŸš€ New Best Model Saved!")
