# Binocular

We explore a different approach of training per patient by loading both left and right image and train efficient net b0 

We also treat ODIR as multi-label problem instead of multi-class as originally it is officially a multi-label problem
from https://odir2019.grand-challenge.org/dataset/
> Note: one patient may contains one or multiple labels. 

We also want to explore binocular or siamese approach to train our model on both left and right fundus image pair. This has been researched in https://arxiv.org/html/2504.18046v3 DMS-Net:Dual-Modal Multi-Scale Siamese Network for Binocular: Fundus Image Classification Guohao Huo, Zibo Lin, Zitong Wang, Ruiting Dai, Hao Tang paper to work well for fundus disease classification 

There are 3 advantages of use both eyes images instead of one eye image :
- Symmetry: Diseases like Diabetes aren't "accidents" in one eye; they are systemic. If the AI sees it in both, it's a "confirmed" diagnosis.

- Comparison: The left eye acts as a "control" for the right eye. AI can spot a tiny change by noticing how much it differs from the other eye.

- Noise Reduction: Just like your two eyes help you see depth, two images help the AI ignore "camera blur" or "dust" on one lens that might look like a disease.

Install Dependencies

In [1]:
%%capture
!pip install -q kagglehub torch torchvision scikit-learn pandas opencv-python tqdm wandb

Import python libraries

In [2]:

import os
import cv2
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report, multilabel_confusion_matrix
import kagglehub
from tqdm import tqdm # tqdm for progress bars
import wandb


Download DataSet

In [3]:
# 1. Download Dataset (Official ODIR-5K)
path = kagglehub.dataset_download("andrewmvd/ocular-disease-recognition-odir5k")
print("Dataset path:", path)
IMG_DIR = os.path.join(path, "ODIR-5K/ODIR-5K/Training Images")
CSV_PATH = os.path.join(path, "full_df.csv")
TRAIN_CSV_PATH = "train.csv"
VAL_CSV_PATH =  "val.csv"
TEST_CSV_PATH =  "test.csv"
IMG_SIZE = 512
BATCH_SIZE = 4
ACCUMULATION_STEPS = 8
EPOCHS = 30
PATIENCE = 5 # Early stopping patience in epochs, stop if no improvement in F1 score for this many epochs
LEARNING_RATE = 1e-4
NUM_CLASSES = 8
NUM_WORKERS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
df = pd.read_csv(CSV_PATH)
train_df = pd.read_csv(TRAIN_CSV_PATH)
val_df = pd.read_csv(VAL_CSV_PATH)
test_df = pd.read_csv(TEST_CSV_PATH)
FAST_IMG_DIR = f"tmp/processed_{IMG_SIZE}_images"
os.makedirs(FAST_IMG_DIR, exist_ok=True)
RUN_NAME = f"efficient-b0_binocular_ben_graham_{IMG_SIZE}"
SAVED_MODELS_DIR = "saved_models"
os.makedirs(SAVED_MODELS_DIR, exist_ok=True)
SAVED_MODEL_PATH = os.path.join(SAVED_MODELS_DIR, f"{RUN_NAME}_best.pth")
CHECKPOINT_PATH = os.path.join(SAVED_MODELS_DIR, f"{RUN_NAME}_checkpoint.pth")
CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Other']
CLASS_CODES = ['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']


Dataset path: /home/ray/.cache/kagglehub/datasets/andrewmvd/ocular-disease-recognition-odir5k/versions/2


In [4]:
wandb.init(project="odir-2019", name=RUN_NAME, config={
    "img_size": IMG_SIZE, "lr": 1e-4, 
    "batch_size": BATCH_SIZE, "accumulation_steps": ACCUMULATION_STEPS, 
    "epochs": EPOCHS, "patience": PATIENCE})

[34m[1mwandb[0m: Currently logged in as: [33mraymond-samalo[0m ([33msamalo[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Ben Graham's Preprocessing

This function implements the Ben Graham Preprocessing 
ref : 
- https://scholar.google.com/citations?view_op=view_citation&hl=en&user=jQkkhlkAAAAJ&citation_for_view=jQkkhlkAAAAJ:sNmaIFBj_lkC
- https://scholar.google.com/citations?user=jQkkhlkAAAAJ&hl=en


From https://medium.com/@astronomer.abdurrehman/enhancing-image-quality-for-machine-learning-ben-grahams-preprocessing-e795ad982abe
the method described as followed
<blockquote>

The cv2.GauissanBlur takes an image, (0, 0) tuple automatically chooses a gaussian filter size based on sigmaX value which specifies the intensity of blur. Goal of using gaussian blur here is to reduce the noise and smooth out the fine details.

The addWeighted function blends two images together using specified weights, the -4 here is the beta value which subtracts the blurred image from the original image and 128 is the gamma value that adjusts the brightness so that the image does not become too dark after subtraction.


</blockquote>


In [5]:
def ben_graham_prep(img, sigmaX=10):
    """Enhances vessels and normalizes lighting."""
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # Circular Crop: Find non-black pixels and crop
    mask = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) > 10
    if np.any(mask):
        coords = np.argwhere(mask)
        y0, x0 = coords.min(axis=0)
        y1, x1 = coords.max(axis=0)
        img = img[y0:y1, x0:x1]
    
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    blurred = cv2.GaussianBlur(img, (0, 0), sigmaX)
    enhanced = cv2.addWeighted(img, 4, blurred, -4, 128)
    return enhanced

On the fly image prep caused the training slowdown given the image need to be preprocessed repeatedly each time it is loaded. We speed up the process by performing preprocessing offline once and cache it

In [6]:
def run_offline_prep(df, raw_dir, img_prep_func, save_dir):
    print("üöÄ Starting Offline Pre-processing (Ben Graham)...")
    all_images = pd.concat([df['Left-Fundus'], df['Right-Fundus']]).unique()
    for img_name in tqdm(all_images):
        save_path = os.path.join(save_dir, img_name)
        if not os.path.exists(save_path):
            img = cv2.imread(os.path.join(raw_dir, img_name))
            # Ben Graham Logic
            enhanced = img_prep_func(img)
            cv2.imwrite(save_path, cv2.cvtColor(enhanced, cv2.COLOR_RGB2BGR))

In [7]:
run_offline_prep(df, IMG_DIR, ben_graham_prep, FAST_IMG_DIR)

üöÄ Starting Offline Pre-processing (Ben Graham)...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6716/6716 [00:00<00:00, 447339.14it/s]


## Data Loader 

Previously we load the data and then performed preprocessing on the fly
Given we did the preprocessing offline, we can now simply load the image 

In [8]:
class FastODIRDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        # Target classes: Normal, Diabetes, Glaucoma, Cataract, AMD, Hypertension, Myopia, Other
        self.labels = df[['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        l_img_path = os.path.join(self.img_dir, row['Left-Fundus'])
        r_img_path = os.path.join(self.img_dir, row['Right-Fundus'])
        
        # Load and Preprocess
        l_img = cv2.cvtColor(cv2.imread(l_img_path), cv2.COLOR_BGR2RGB)
        r_img = cv2.cvtColor(cv2.imread(r_img_path), cv2.COLOR_BGR2RGB )
        
        if self.transform:
            l_img = self.transform(l_img)
            r_img = self.transform(r_img)
            
        return l_img, r_img, torch.tensor(self.labels[idx], dtype=torch.float32)

In [9]:
class ODIRDualNet(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        # Using B0 for efficiency, upgrade to B4 for better accuracy
        self.backbone = models.efficientnet_b0(weights='DEFAULT') # use pretrained weights for better feature extraction
        # ref https://docs.pytorch.org/vision/main/models/generated/torchvision.models.efficientnet_b0.html#torchvision.models.EfficientNet_B0_Weights
        self.feature_dim = self.backbone.classifier[1].in_features # Get feature dimension before classifier
        self.backbone.classifier = nn.Identity() # Remove top layer

        self.classifier = nn.Sequential( # replace classifier with a custom head that combines features from both eyes
            nn.Linear(self.feature_dim * 2, IMG_SIZE), # Combine features from both eyes
            nn.ReLU(), # Non-linearity for better learning relu f(x) = max(0, x)
            nn.Dropout(0.3), # Regularization to prevent overfitting
            nn.Linear(IMG_SIZE, num_classes) # Final output layer for multi-label classification
        )

    def forward(self, left, right):
        l_feat = self.backbone(left) # Extract features from left eye
        r_feat = self.backbone(right) # Extract features from right eye
        combined = torch.cat((l_feat, r_feat), dim=1) # Combine features from both eyes
        return self.classifier(combined) # Pass through classifier to get final predictions

## Thresholds

Instead of using 1 thresholds 0.5 for all labels, we find the best or optimise threshold for each label individually to maximise our F1-score

In [10]:
def find_best_thresholds(y_true, y_probs):
    thresholds = np.linspace(0.1, 0.9, 81) # Test thresholds from 0.1 to 0.9 with fine granularity
    best_ts = np.zeros(NUM_CLASSES) 
    for i in range(NUM_CLASSES):
        best_f1 = 0 # Initialize best F1 score for this class
        for t in thresholds: # Test each threshold and calculate F1 score 
            score = f1_score(y_true[:, i], (y_probs[:, i] > t).astype(int), zero_division=0) # zero_division=0 to handle cases where there are no positive predictions
            if score > best_f1: 
                best_f1 = score # Update best F1 score for this class
                best_ts[i] = t # Update best threshold for this class
    return best_ts # Return array of best thresholds for each class

Data Loader with ImageNet Transformation


"All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]."
Ref:
- https://docs.pytorch.org/vision/0.9/models.html.

model definition and training

In [11]:
from torch import autocast
from torch.amp.grad_scaler import  GradScaler 

# val and test transforms (no augmentation, just normalization)
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Data augmentation for training set 
train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(degrees=20),
        transforms.ColorJitter(brightness=0.1, contrast=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

train_loader = DataLoader(FastODIRDataset(train_df, FAST_IMG_DIR, train_transform), 
                            batch_size=BATCH_SIZE, 
                            shuffle=True,
                            num_workers=NUM_WORKERS,
                            pin_memory=True)
val_loader = DataLoader(FastODIRDataset(val_df, FAST_IMG_DIR, transform), 
                        batch_size=BATCH_SIZE,
                        num_workers=NUM_WORKERS, 
                        pin_memory=True)
test_loader = DataLoader(FastODIRDataset(test_df, FAST_IMG_DIR, transform), 
                         batch_size=BATCH_SIZE,
                         num_workers=NUM_WORKERS, 
                         pin_memory=True)
model = ODIRDualNet().to(DEVICE)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scaler = GradScaler(DEVICE)
# ‚ö° ACCELERATION: Compile the model (Requires PyTorch 2.0+)
# This can provide a 10-20% speedup in training time
if hasattr(torch, 'compile'):
    model = torch.compile(model)
    print("‚úÖ Model Compiled for speed.")

start_epoch, best_f1, counter = 0, 0, 0

if os.path.exists(CHECKPOINT_PATH):
    checkpoint = torch.load(CHECKPOINT_PATH)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])
    best_ts = checkpoint['thresholds']
    counter = checkpoint.get('counter', 0)
    start_epoch, best_f1, counter = checkpoint['epoch'] + 1, checkpoint['best_f1'], checkpoint.get('counter', 0)
    print(f"Resuming from epoch {start_epoch}")

for epoch in range(start_epoch, EPOCHS):
        model.train()
        train_loss = 0
        tr_preds, tr_true = [], []
        optimizer.zero_grad()

        for i, (l, r, y) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
            l, r, y = l.to(DEVICE), r.to(DEVICE), y.to(DEVICE)
            
            with autocast(device_type=DEVICE):
                preds = model(l, r) # logits output from the model
                loss = criterion(preds, y) / ACCUMULATION_STEPS
            
            scaler.scale(loss).backward()
            
            if (i + 1) % ACCUMULATION_STEPS == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            
            train_loss += loss.item() * ACCUMULATION_STEPS
            tr_preds.append(torch.sigmoid(preds).detach().cpu().numpy())
            tr_true.append(y.cpu().numpy())
        # --- Validation ---
        model.eval()
        val_preds, val_true = [], []
        with torch.no_grad():
            for l, r, y in val_loader:
                out = torch.sigmoid(model(l.to(DEVICE), r.to(DEVICE)))
                val_preds.append(out.cpu().numpy())
                val_true.append(y.numpy())
        
        val_probs = np.vstack(val_preds)
        val_true = np.vstack(val_true)
        best_ts = find_best_thresholds(val_true, val_probs)
        tr_probs, tr_true = np.vstack(tr_preds), np.vstack(tr_true)
        # Calculate Macro F1 with optimized thresholds
        val_f1 = f1_score(val_true, (val_probs > best_ts).astype(int), average='macro')
        train_f1 = f1_score(tr_true, (tr_probs > best_ts).astype(int), average='macro')
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss / len(train_loader),
            "val_f1": val_f1,
            "train_f1": train_f1
        })
        print(f"Epoch {epoch+1} | Loss: {train_loss/len(train_loader):.4f} | train_f1: {train_f1:.4f} | val_f1: {val_f1:.4f}")
        if val_f1 > best_f1:
            best_f1 = val_f1
            counter = 0 # reset counter on improvement
            torch.save({'model': model.state_dict(), 'thresholds': best_ts}, SAVED_MODEL_PATH)
            print(f"üöÄ New Best Model Saved! F1: {val_f1:.4f}")
        # Save checkpoint every 5 epochs or if no improvement for 3 epochs
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
                'best_f1': best_f1,
                'counter': counter,
                'thresholds': best_ts
            }, CHECKPOINT_PATH)
            print(f"üíæ Checkpoint saved at epoch {epoch+1}")
        if  PATIENCE>0 and counter >= PATIENCE: # early stopping if no improvement for 3 epochs
            print("‚èπÔ∏è Early stopping triggered.")
            break

‚úÖ Model Compiled for speed.


Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [02:02<00:00,  5.51it/s]


Epoch 1 | Loss: 0.3814 | train_f1: 0.2265 | val_f1: 0.4186
üöÄ New Best Model Saved! F1: 0.4186


Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:45<00:00,  6.40it/s]


Epoch 2 | Loss: 0.3166 | train_f1: 0.3279 | val_f1: 0.5100
üöÄ New Best Model Saved! F1: 0.5100


Epoch 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:36<00:00,  6.94it/s]


Epoch 3 | Loss: 0.2970 | train_f1: 0.3621 | val_f1: 0.5398
üöÄ New Best Model Saved! F1: 0.5398


Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:37<00:00,  6.92it/s]


Epoch 4 | Loss: 0.2747 | train_f1: 0.4245 | val_f1: 0.5600
üöÄ New Best Model Saved! F1: 0.5600


Epoch 5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:37<00:00,  6.93it/s]


Epoch 5 | Loss: 0.2628 | train_f1: 0.4584 | val_f1: 0.6194
üöÄ New Best Model Saved! F1: 0.6194
üíæ Checkpoint saved at epoch 5


Epoch 6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:36<00:00,  6.95it/s]


Epoch 6 | Loss: 0.2528 | train_f1: 0.5233 | val_f1: 0.6056


Epoch 7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:37<00:00,  6.92it/s]


Epoch 7 | Loss: 0.2389 | train_f1: 0.5807 | val_f1: 0.6269
üöÄ New Best Model Saved! F1: 0.6269


Epoch 8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:36<00:00,  6.94it/s]


Epoch 8 | Loss: 0.2310 | train_f1: 0.5394 | val_f1: 0.6113


Epoch 9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:37<00:00,  6.92it/s]


Epoch 9 | Loss: 0.2194 | train_f1: 0.6259 | val_f1: 0.6392
üöÄ New Best Model Saved! F1: 0.6392


Epoch 10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:37<00:00,  6.92it/s]


Epoch 10 | Loss: 0.2156 | train_f1: 0.6569 | val_f1: 0.6403
üöÄ New Best Model Saved! F1: 0.6403
üíæ Checkpoint saved at epoch 10


Epoch 11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:36<00:00,  6.94it/s]


Epoch 11 | Loss: 0.2042 | train_f1: 0.6363 | val_f1: 0.6431
üöÄ New Best Model Saved! F1: 0.6431


Epoch 12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:37<00:00,  6.92it/s]


Epoch 12 | Loss: 0.1947 | train_f1: 0.6507 | val_f1: 0.6578
üöÄ New Best Model Saved! F1: 0.6578


Epoch 13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:36<00:00,  6.95it/s]


Epoch 13 | Loss: 0.1893 | train_f1: 0.6960 | val_f1: 0.6598
üöÄ New Best Model Saved! F1: 0.6598


Epoch 14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:37<00:00,  6.93it/s]


Epoch 14 | Loss: 0.1784 | train_f1: 0.7180 | val_f1: 0.6631
üöÄ New Best Model Saved! F1: 0.6631


Epoch 15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:37<00:00,  6.92it/s]


Epoch 15 | Loss: 0.1713 | train_f1: 0.6856 | val_f1: 0.6452
üíæ Checkpoint saved at epoch 15


Epoch 16: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:36<00:00,  6.95it/s]


Epoch 16 | Loss: 0.1634 | train_f1: 0.7415 | val_f1: 0.6604


Epoch 17: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:37<00:00,  6.92it/s]


Epoch 17 | Loss: 0.1608 | train_f1: 0.7204 | val_f1: 0.6803
üöÄ New Best Model Saved! F1: 0.6803


Epoch 18: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:36<00:00,  6.94it/s]


Epoch 18 | Loss: 0.1486 | train_f1: 0.7257 | val_f1: 0.6719


Epoch 19: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:37<00:00,  6.92it/s]


Epoch 19 | Loss: 0.1448 | train_f1: 0.7681 | val_f1: 0.6836
üöÄ New Best Model Saved! F1: 0.6836


Epoch 20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:37<00:00,  6.93it/s]


Epoch 20 | Loss: 0.1363 | train_f1: 0.7379 | val_f1: 0.6831
üíæ Checkpoint saved at epoch 20


Epoch 21: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:36<00:00,  6.94it/s]


Epoch 21 | Loss: 0.1337 | train_f1: 0.7531 | val_f1: 0.6723


Epoch 22: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:37<00:00,  6.92it/s]


Epoch 22 | Loss: 0.1263 | train_f1: 0.7548 | val_f1: 0.6739


Epoch 23: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:37<00:00,  6.92it/s]


Epoch 23 | Loss: 0.1196 | train_f1: 0.7695 | val_f1: 0.6649


Epoch 24: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:36<00:00,  6.94it/s]


Epoch 24 | Loss: 0.1130 | train_f1: 0.8124 | val_f1: 0.6730


Epoch 25: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:35<00:00,  7.03it/s]


Epoch 25 | Loss: 0.1134 | train_f1: 0.8064 | val_f1: 0.6730
üíæ Checkpoint saved at epoch 25


Epoch 26: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:33<00:00,  7.20it/s]


Epoch 26 | Loss: 0.1071 | train_f1: 0.8189 | val_f1: 0.6643


Epoch 27: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:33<00:00,  7.19it/s]


Epoch 27 | Loss: 0.1042 | train_f1: 0.8223 | val_f1: 0.6628


Epoch 28: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:33<00:00,  7.18it/s]


Epoch 28 | Loss: 0.0997 | train_f1: 0.8197 | val_f1: 0.6573


Epoch 29: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:35<00:00,  7.08it/s]


Epoch 29 | Loss: 0.0915 | train_f1: 0.8325 | val_f1: 0.6510


Epoch 30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [01:34<00:00,  7.15it/s]


Epoch 30 | Loss: 0.0877 | train_f1: 0.8629 | val_f1: 0.6509
üíæ Checkpoint saved at epoch 30


## Test Evaluation

In [12]:
best_model = torch.load(SAVED_MODEL_PATH, weights_only=False)
thresholds = best_model['thresholds']
model.load_state_dict(best_model['model'])
print("Label Thresholds and Macro F1 Score of the Best Model:")
print("Labels: Normal, Diabetes, Glaucoma, Cataract, AMD, Hypertension, Myopia, Other")
print("Best Thresholds:", thresholds)
print("Best Macro F1 Score:", best_f1)
print("‚úÖ Training Complete. Best model and thresholds saved.")

Label Thresholds and Macro F1 Score of the Best Model:
Labels: Normal, Diabetes, Glaucoma, Cataract, AMD, Hypertension, Myopia, Other
Best Thresholds: [0.37 0.45 0.13 0.29 0.3  0.13 0.49 0.59]
Best Macro F1 Score: 0.683612948855622
‚úÖ Training Complete. Best model and thresholds saved.


In [None]:
model.eval()
t_probs, t_true = [], []
with torch.no_grad():
    for l, r, y in test_loader:
        out = torch.sigmoid(model(l.to(DEVICE), r.to(DEVICE)))
        t_probs.append(out.cpu().numpy()); t_true.append(y.numpy())

t_p, t_t = np.vstack(t_probs), np.vstack(t_true)
t_preds = (t_p > thresholds).astype(int)

mcm = multilabel_confusion_matrix(t_t, t_preds)
print("Multilabel Confusion Matrix:")
for i, class_name in enumerate(CLASS_NAMES):
    tn, fp, fn, tp = mcm[i].ravel()
    print(f"{class_name}: TP={tp}, FP={fp}, TN={tn}, FN={fn}")
    wandb.log({f"CM_{class_name}": wandb.plot.confusion_matrix(
            probs=None,
            y_true=t_t[:, i],
            preds=t_preds[:, i],
            title=f"Confusion Matrix for {class_name}",
            class_names=["Absent", "Present"]
        )})
print("\nClassification Report:")
print(classification_report(t_t, t_preds, target_names=CLASS_NAMES))
wandb.finish()

Multilabel Confusion Matrix:
Normal: TP=91, FP=56, TN=168, FN=17
Diabetes: TP=77, FP=11, TN=210, FN=34
Glaucoma: TP=11, FP=7, TN=305, FN=9
Cataract: TP=14, FP=1, TN=310, FN=7
AMD: TP=9, FP=11, TN=305, FN=7
Hypertension: TP=5, FP=16, TN=306, FN=5
Myopia: TP=13, FP=3, TN=312, FN=4
Other: TP=49, FP=23, TN=219, FN=41

Classification Report:
              precision    recall  f1-score   support

      Normal       0.62      0.84      0.71       108
    Diabetes       0.88      0.69      0.77       111
    Glaucoma       0.61      0.55      0.58        20
    Cataract       0.93      0.67      0.78        21
         AMD       0.45      0.56      0.50        16
Hypertension       0.24      0.50      0.32        10
      Myopia       0.81      0.76      0.79        17
       Other       0.68      0.54      0.60        90

   micro avg       0.68      0.68      0.68       393
   macro avg       0.65      0.64      0.63       393
weighted avg       0.71      0.68      0.69       393
 samples av

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


0,1
epoch,‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
train_f1,‚ñÅ‚ñÇ‚ñÇ‚ñÉ‚ñÑ‚ñÑ‚ñÖ‚ñÑ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà
train_loss,‚ñà‚ñÜ‚ñÜ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
val_f1,‚ñÅ‚ñÉ‚ñÑ‚ñÖ‚ñÜ‚ñÜ‚ñá‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñá‚ñá‚ñá‚ñá‚ñá

0,1
epoch,30.0
train_f1,0.86286
train_loss,0.0877
val_f1,0.6509


: 