# B0 RESIZE BINOCULAR without AUGMENTATION extra weight

We explore a different approach of training per patient by loading both left and right image and train efficient net b0 

We also treat ODIR as multi-label problem instead of multi-class as originally it is officially a multi-label problem
from https://odir2019.grand-challenge.org/dataset/
> Note: one patient may contains one or multiple labels. 

We also want to explore binocular or siamese approach to train our model on both left and right fundus image pair. This has been researched in https://arxiv.org/html/2504.18046v3 DMS-Net:Dual-Modal Multi-Scale Siamese Network for Binocular: Fundus Image Classification Guohao Huo, Zibo Lin, Zitong Wang, Ruiting Dai, Hao Tang paper to work well for fundus disease classification 

There are 3 advantages of use both eyes images instead of one eye image :
- Symmetry: Diseases like Diabetes aren't "accidents" in one eye; they are systemic. If the AI sees it in both, it's a "confirmed" diagnosis.

- Comparison: The left eye acts as a "control" for the right eye. AI can spot a tiny change by noticing how much it differs from the other eye.

- Noise Reduction: Just like your two eyes help you see depth, two images help the AI ignore "camera blur" or "dust" on one lens that might look like a disease.

Install Dependencies

In [1]:
%%capture
!pip install -q  torch torchvision scikit-learn pandas opencv-python tqdm wandb torchinfo

Import python libraries

In [2]:

import os
import cv2
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from sklearn.metrics import f1_score, classification_report, multilabel_confusion_matrix, accuracy_score
from tqdm import tqdm # tqdm for progress bars
import wandb


Download DataSet

In [3]:
IMAGE_PREP_NAME = "resize" # Name for this image pre-processing method, used for directory naming and logging
IMG_DIR = f"tmp/{IMAGE_PREP_NAME}_prep" # Directory where images are stored, adjust if needed
RUN_NAME = f"efficient-b0_{IMAGE_PREP_NAME}_naew" # Unique name for this run, used for saving models and logging

TRAIN_CSV_PATH = "train.csv"
VAL_CSV_PATH =  "val.csv"
TEST_CSV_PATH =  "test.csv"
train_df = pd.read_csv(TRAIN_CSV_PATH)
val_df = pd.read_csv(VAL_CSV_PATH)
test_df = pd.read_csv(TEST_CSV_PATH)

IMG_SIZE = 512
BATCH_SIZE = 4
ACCUMULATION_STEPS = 8
EPOCHS = 30
PATIENCE = 5 # Early stopping patience in epochs, stop if no improvement in F1 score for this many epochs
LEARNING_RATE = 1e-4
NUM_CLASSES = 8
NUM_WORKERS = 2

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAVED_MODELS_DIR = "saved_models"
os.makedirs(SAVED_MODELS_DIR, exist_ok=True)
SAVED_MODEL_PATH = os.path.join(SAVED_MODELS_DIR, f"{RUN_NAME}_best.pth")
CHECKPOINT_PATH = os.path.join(SAVED_MODELS_DIR, f"{RUN_NAME}_checkpoint.pth")
CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Other']
CLASS_CODES = ['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']


In [4]:
wandb.init(project="odir-2019-b", name=RUN_NAME, config={
    "img_size": IMG_SIZE, "lr": 1e-4, 
    "batch_size": BATCH_SIZE, "accumulation_steps": ACCUMULATION_STEPS, 
    "epochs": EPOCHS, "patience": PATIENCE})

[34m[1mwandb[0m: Currently logged in as: [33mraymond-samalo[0m ([33msamalo[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Data Loader 

Previously we load the data and then performed preprocessing on the fly
Given we did the preprocessing offline, we can now simply load the image 

In [5]:
class FastODIRDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        # Target classes: Normal, Diabetes, Glaucoma, Cataract, AMD, Hypertension, Myopia, Other
        self.labels = df[['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        l_img_path = os.path.join(self.img_dir, row['Left-Fundus'])
        r_img_path = os.path.join(self.img_dir, row['Right-Fundus'])
        
        # Load and Preprocess
        l_img = cv2.cvtColor(cv2.imread(l_img_path), cv2.COLOR_BGR2RGB)
        r_img = cv2.cvtColor(cv2.imread(r_img_path), cv2.COLOR_BGR2RGB )
        
        if self.transform:
            l_img = self.transform(l_img)
            r_img = self.transform(r_img)
            
        return l_img, r_img, torch.tensor(self.labels[idx], dtype=torch.float32)

In [6]:
from torch.utils.checkpoint import checkpoint
class ODIRDualNet(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        # Using B0 for efficiency, upgrade to B4 for better accuracy
        self.backbone = models.efficientnet_b0(weights='DEFAULT') # use pretrained weights for better feature extraction
        # ref https://docs.pytorch.org/vision/main/models/generated/torchvision.models.efficientnet_b0.html#torchvision.models.EfficientNet_B0_Weights
        self.feature_dim = self.backbone.classifier[1].in_features # Get feature dimension before classifier
        self.backbone.classifier = nn.Identity() # Remove top layer
        self.features = self.backbone.features # Extract feature extractor part for checkpointing
        self.classifier = nn.Sequential( # replace classifier with a custom head that combines features from both eyes
            nn.Linear(self.feature_dim * 2, IMG_SIZE), # Combine features from both eyes
            nn.ReLU(), # Non-linearity for better learning relu f(x) = max(0, x)
            nn.Dropout(0.3), # Regularization to prevent overfitting
            nn.Linear(IMG_SIZE, num_classes) # Final output layer for multi-label classification
        )

    def forward(self, left, right):
        # manually checkpoint the feature extraction part to save memory, since EfficientNet can be quite large, especially B4
        l_feat = checkpoint(self.features, left, use_reentrant=False)
        r_feat = checkpoint(self.features, right, use_reentrant=False)
        
        # Global Average Pooling to get (Batch, Feat_Dim)
        l_feat = torch.flatten(nn.functional.adaptive_avg_pool2d(l_feat, 1), 1)
        r_feat = torch.flatten(nn.functional.adaptive_avg_pool2d(r_feat, 1), 1)
        combined = torch.cat((l_feat, r_feat), dim=1) # Combine features from both eyes
        return self.classifier(combined) # Pass through classifier to get final predictions

## Thresholds

Instead of using 1 thresholds 0.5 for all labels, we find the best or optimise threshold for each label individually to maximise our F1-score

In [7]:
def find_best_thresholds(y_true, y_probs):
    thresholds = np.linspace(0.1, 0.9, 81) # Test thresholds from 0.1 to 0.9 with fine granularity
    best_ts = np.zeros(NUM_CLASSES) 
    for i in range(NUM_CLASSES):
        best_f1 = 0 # Initialize best F1 score for this class
        for t in thresholds: # Test each threshold and calculate F1 score 
            score = f1_score(y_true[:, i], (y_probs[:, i] > t).astype(int), zero_division=0) # zero_division=0 to handle cases where there are no positive predictions
            if score > best_f1: 
                best_f1 = score # Update best F1 score for this class
                best_ts[i] = t # Update best threshold for this class
    return best_ts # Return array of best thresholds for each class

Data Loader with ImageNet Transformation


"All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]."
Ref:
- https://docs.pytorch.org/vision/0.9/models.html.

model definition and training

*We increase POS_WEIGHTS for hypertension and Others*

In [8]:
# --- CALCULATE POSITIVE WEIGHTS ---
def get_pos_weights(df, class_names):
    weights = []
    for col in class_names:
        num_pos = df[col].sum()
        num_neg = len(df) - num_pos
        # Weight = Count of Negatives / Count of Positives
        # We add a small epsilon to avoid division by zero
        weight = num_neg / (num_pos + 1e-6)
        if col == 'H': 
            weight = weight * 2.0  # Give extra weight to Hypertension class due to its lower prevalence and importance in this dataset 
        if col == 'O': 
            weight = weight * 1.5  # Give extra weight to Other class due to its lower prevalence and importance in this dataset
        weights.append(weight)
    return torch.tensor(weights, dtype=torch.float32).to(DEVICE)

pos_weights = get_pos_weights(train_df, CLASS_CODES)

Smoothing for loss

In [9]:
from torch import autocast
from torch.amp.grad_scaler import  GradScaler 

# val and test transforms (no augmentation, just normalization)
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Data augmentation for training set 
train_transform = transforms.Compose([
        transforms.ToPILImage(),
        # Rotation and Color Jitter can help the model generalize better by simulating real-world variations in the images
        # transforms.RandomRotation(degrees=20),
        # transforms.ColorJitter(brightness=0.1, contrast=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

train_loader = DataLoader(FastODIRDataset(train_df, IMG_DIR, train_transform), 
                            batch_size=BATCH_SIZE, 
                            shuffle=True,
                            num_workers=NUM_WORKERS,
                            pin_memory=True)
val_loader = DataLoader(FastODIRDataset(val_df, IMG_DIR, transform), 
                        batch_size=BATCH_SIZE,
                        num_workers=NUM_WORKERS, 
                        pin_memory=True)
test_loader = DataLoader(FastODIRDataset(test_df, IMG_DIR, transform), 
                         batch_size=BATCH_SIZE,
                         num_workers=NUM_WORKERS, 
                         pin_memory=True)
model = ODIRDualNet().to(DEVICE)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights) # Use pos_weights to handle class imbalance

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2) # reduce learning rate when a metric stop improving
scaler = GradScaler(DEVICE)
# ‚ö° ACCELERATION: Compile the model (Requires PyTorch 2.0+)
# This can provide a 10-20% speedup in training time
if hasattr(torch, 'compile'):
    model = torch.compile(model)
    print("‚úÖ Model Compiled for speed.")

start_epoch, best_f1, counter = 0, 0, 0

if os.path.exists(CHECKPOINT_PATH):
    checkpoint = torch.load(CHECKPOINT_PATH, weights_only=False)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])
    best_ts = checkpoint['thresholds']
    counter = checkpoint.get('counter', 0)
    start_epoch, best_f1, counter = checkpoint['epoch'] + 1, checkpoint['best_f1'], checkpoint.get('counter', 0)
    print(f"Resuming from epoch {start_epoch}")

for epoch in range(start_epoch, EPOCHS):
        model.train()
        train_loss = 0
        tr_preds, tr_true = [], []
        optimizer.zero_grad()

        for i, (l, r, y) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
            l, r, y = l.to(DEVICE), r.to(DEVICE), y.to(DEVICE)
            
            with autocast(device_type=DEVICE):
                preds = model(l, r) # logits output from the model
                loss = criterion(preds, y) / ACCUMULATION_STEPS
            
            scaler.scale(loss).backward()
            
            if (i + 1) % ACCUMULATION_STEPS == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            
            train_loss += loss.item() * ACCUMULATION_STEPS
            tr_preds.append(torch.sigmoid(preds).detach().cpu().numpy())
            tr_true.append(y.cpu().numpy())
        # --- Validation ---
        model.eval()
        val_preds, val_true = [], []
        with torch.no_grad():
            for l, r, y in val_loader:
                out = torch.sigmoid(model(l.to(DEVICE), r.to(DEVICE)))
                val_preds.append(out.cpu().numpy())
                val_true.append(y.numpy())
        torch.cuda.empty_cache() # Explicitly free memory after validation to prevent fragmentation        
        val_probs = np.vstack(val_preds)
        val_true = np.vstack(val_true)
        best_ts = find_best_thresholds(val_true, val_probs)
        tr_probs, tr_true = np.vstack(tr_preds), np.vstack(tr_true)
        # Calculate Macro F1 with optimized thresholds
        val_preds_binary = (val_probs > best_ts).astype(int)
        val_f1 = f1_score(val_true, val_preds_binary, average='macro', zero_division=0)
        val_acc = accuracy_score(val_true, val_preds_binary)*100. # Convert to percentage for better interpretability
        # For training metrics, we can use a fixed threshold of 0.5 
        # since we're mainly interested in validation performance for threshold optimization

        tr_preds_binary = (tr_probs > 0.5).astype(int)
        train_f1 = f1_score(tr_true, tr_preds_binary, average='macro', zero_division=0)
        train_acc = accuracy_score(tr_true, tr_preds_binary)*100. # Convert to percentage for better interpretability
        # NEW: Step the scheduler based on Validation F1
        scheduler.step(val_f1) # pass the metric to scheduler to monitor
        current_lr = optimizer.param_groups[0]['lr']

        # Log per-class F1 for visibility
        per_class_f1 = f1_score(val_true, (val_probs > best_ts).astype(int), average=None)
        metrics_dict = {f"val_f1_{name}": f for name, f in zip(CLASS_NAMES, per_class_f1)}
        metrics_dict.update({"epoch": epoch+1, "val_f1": val_f1, "train_f1": train_f1, "lr": current_lr, "train_loss": train_loss / len(train_loader),
                             "train_acc": train_acc, "val_acc": val_acc})
        wandb.log(metrics_dict)
        print(f"Epoch {epoch+1} | Loss: {train_loss/len(train_loader):.4f} | train_f1: {train_f1:.4f} | val_f1: {val_f1:.4f}")
        if val_f1 > best_f1:
            best_f1 = val_f1
            counter = 0 # reset counter on improvement
            torch.save({'model': model.state_dict(), 'thresholds': best_ts}, SAVED_MODEL_PATH)
            print(f"üöÄ New Best Model Saved! F1: {val_f1:.4f}")
        # Save checkpoint every 5 epochs or if no improvement for 3 epochs
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
                'best_f1': best_f1,
                'counter': counter,
                'thresholds': best_ts
            }, CHECKPOINT_PATH)
            print(f"üíæ Checkpoint saved at epoch {epoch+1}")
        if  PATIENCE>0 and counter >= PATIENCE: # early stopping if no improvement for 3 epochs
            print("‚èπÔ∏è Early stopping triggered.")
            break

‚úÖ Model Compiled for speed.


Epoch 1:   0%|          | 0/600 [00:00<?, ?it/s]W0224 09:27:54.569000 8334 site-packages/torch/_inductor/utils.py:1613] [0/0] Not enough SMs to use max_autotune_gemm mode
Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [04:54<00:00,  2.03it/s] 


Epoch 1 | Loss: 1.1507 | train_f1: 0.2868 | val_f1: 0.4395
üöÄ New Best Model Saved! F1: 0.4395


Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [02:04<00:00,  4.83it/s]


Epoch 2 | Loss: 0.9756 | train_f1: 0.3644 | val_f1: 0.5385
üöÄ New Best Model Saved! F1: 0.5385


Epoch 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:55<00:00,  5.21it/s]


Epoch 3 | Loss: 0.8696 | train_f1: 0.4259 | val_f1: 0.5855
üöÄ New Best Model Saved! F1: 0.5855


Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:53<00:00,  5.30it/s]


Epoch 4 | Loss: 0.7787 | train_f1: 0.4671 | val_f1: 0.5500


Epoch 5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:51<00:00,  5.39it/s]


Epoch 5 | Loss: 0.7108 | train_f1: 0.4982 | val_f1: 0.5443
üíæ Checkpoint saved at epoch 5


Epoch 6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:53<00:00,  5.28it/s]


Epoch 6 | Loss: 0.6476 | train_f1: 0.5298 | val_f1: 0.5777


Epoch 7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:51<00:00,  5.36it/s]


Epoch 7 | Loss: 0.5808 | train_f1: 0.5770 | val_f1: 0.5895
üöÄ New Best Model Saved! F1: 0.5895


Epoch 8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:50<00:00,  5.45it/s]


Epoch 8 | Loss: 0.5381 | train_f1: 0.5919 | val_f1: 0.6116
üöÄ New Best Model Saved! F1: 0.6116


Epoch 9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:51<00:00,  5.37it/s]


Epoch 9 | Loss: 0.5160 | train_f1: 0.6174 | val_f1: 0.6181
üöÄ New Best Model Saved! F1: 0.6181


Epoch 10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:52<00:00,  5.31it/s]


Epoch 10 | Loss: 0.4794 | train_f1: 0.6552 | val_f1: 0.6160
üíæ Checkpoint saved at epoch 10


Epoch 11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:52<00:00,  5.31it/s]


Epoch 11 | Loss: 0.4639 | train_f1: 0.6606 | val_f1: 0.6113


Epoch 12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:54<00:00,  5.25it/s]


Epoch 12 | Loss: 0.4556 | train_f1: 0.6624 | val_f1: 0.6160


Epoch 13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:52<00:00,  5.31it/s]


Epoch 13 | Loss: 0.4537 | train_f1: 0.6703 | val_f1: 0.6111


Epoch 14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:53<00:00,  5.31it/s]


Epoch 14 | Loss: 0.4248 | train_f1: 0.6901 | val_f1: 0.6219
üöÄ New Best Model Saved! F1: 0.6219


Epoch 15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:53<00:00,  5.30it/s]


Epoch 15 | Loss: 0.4267 | train_f1: 0.7004 | val_f1: 0.6353
üöÄ New Best Model Saved! F1: 0.6353
üíæ Checkpoint saved at epoch 15


Epoch 16: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:53<00:00,  5.30it/s]


Epoch 16 | Loss: 0.4075 | train_f1: 0.7107 | val_f1: 0.6415
üöÄ New Best Model Saved! F1: 0.6415


Epoch 17: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:53<00:00,  5.31it/s]


Epoch 17 | Loss: 0.4108 | train_f1: 0.7257 | val_f1: 0.6183


Epoch 18: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:53<00:00,  5.30it/s]


Epoch 18 | Loss: 0.3831 | train_f1: 0.7368 | val_f1: 0.6379


Epoch 19: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:52<00:00,  5.32it/s]


Epoch 19 | Loss: 0.3838 | train_f1: 0.7329 | val_f1: 0.6155


Epoch 20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:53<00:00,  5.27it/s]


Epoch 20 | Loss: 0.3780 | train_f1: 0.7430 | val_f1: 0.6211
üíæ Checkpoint saved at epoch 20


Epoch 21: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:53<00:00,  5.31it/s]


Epoch 21 | Loss: 0.3779 | train_f1: 0.7481 | val_f1: 0.6319


Epoch 22: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:52<00:00,  5.31it/s]


Epoch 22 | Loss: 0.3662 | train_f1: 0.7521 | val_f1: 0.6238


Epoch 23: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:52<00:00,  5.32it/s]


Epoch 23 | Loss: 0.3559 | train_f1: 0.7590 | val_f1: 0.6275


Epoch 24: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:52<00:00,  5.32it/s]


Epoch 24 | Loss: 0.3608 | train_f1: 0.7577 | val_f1: 0.6248


Epoch 25: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:52<00:00,  5.33it/s]


Epoch 25 | Loss: 0.3511 | train_f1: 0.7726 | val_f1: 0.6219
üíæ Checkpoint saved at epoch 25


Epoch 26: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:53<00:00,  5.30it/s]


Epoch 26 | Loss: 0.3574 | train_f1: 0.7645 | val_f1: 0.6261


Epoch 27: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:52<00:00,  5.32it/s]


Epoch 27 | Loss: 0.3487 | train_f1: 0.7685 | val_f1: 0.6344


Epoch 28: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:53<00:00,  5.31it/s]


Epoch 28 | Loss: 0.3575 | train_f1: 0.7597 | val_f1: 0.6365


Epoch 29: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:54<00:00,  5.26it/s]


Epoch 29 | Loss: 0.3518 | train_f1: 0.7755 | val_f1: 0.6184


Epoch 30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 600/600 [01:53<00:00,  5.30it/s]


Epoch 30 | Loss: 0.3609 | train_f1: 0.7756 | val_f1: 0.6286
üíæ Checkpoint saved at epoch 30


## Test Evaluation

In [11]:
best_model = torch.load(SAVED_MODEL_PATH, weights_only=False)
thresholds = best_model['thresholds']
model.load_state_dict(best_model['model'])
print("Label Thresholds and Macro F1 Score of the Best Model:")
print("Labels: Normal, Diabetes, Glaucoma, Cataract, AMD, Hypertension, Myopia, Other")
print("Best Thresholds:", thresholds)
print("Best Macro F1 Score:", best_f1)
print("‚úÖ Training Complete. Best model and thresholds saved.")

Label Thresholds and Macro F1 Score of the Best Model:
Labels: Normal, Diabetes, Glaucoma, Cataract, AMD, Hypertension, Myopia, Other
Best Thresholds: [0.45 0.36 0.88 0.76 0.79 0.81 0.56 0.45]
Best Macro F1 Score: 0.6415411400679418
‚úÖ Training Complete. Best model and thresholds saved.


In [12]:
model.eval()
t_probs, t_true = [], []
with torch.no_grad():
    for l, r, y in test_loader:
        out = torch.sigmoid(model(l.to(DEVICE), r.to(DEVICE)))
        t_probs.append(out.cpu().numpy()); t_true.append(y.numpy())

t_p, t_t = np.vstack(t_probs), np.vstack(t_true)
t_preds = (t_p > thresholds).astype(int)

mcm = multilabel_confusion_matrix(t_t, t_preds)
print("Multilabel Confusion Matrix:")
for i, class_name in enumerate(CLASS_NAMES):
    tn, fp, fn, tp = mcm[i].ravel()
    print(f"{class_name}: TP={tp}, FP={fp}, TN={tn}, FN={fn}")
    wandb.log({f"CM_{class_name}": wandb.plot.confusion_matrix(
            probs=None,
            y_true=t_t[:, i],
            preds=t_preds[:, i],
            title=f"Confusion Matrix for {class_name}",
            class_names=["Absent", "Present"]
        )})
print("\nClassification Report:")
print(classification_report(t_t, t_preds, target_names=CLASS_NAMES))
wandb.finish()

Multilabel Confusion Matrix:
Normal: TP=85, FP=90, TN=121, FN=8
Diabetes: TP=82, FP=65, TN=134, FN=23
Glaucoma: TP=8, FP=6, TN=282, FN=8
Cataract: TP=11, FP=3, TN=285, FN=5
AMD: TP=7, FP=15, TN=273, FN=9
Hypertension: TP=8, FP=10, TN=284, FN=2
Myopia: TP=14, FP=2, TN=286, FN=2
Other: TP=73, FP=165, TN=58, FN=8

Classification Report:
              precision    recall  f1-score   support

      Normal       0.49      0.91      0.63        93
    Diabetes       0.56      0.78      0.65       105
    Glaucoma       0.57      0.50      0.53        16
    Cataract       0.79      0.69      0.73        16
         AMD       0.32      0.44      0.37        16
Hypertension       0.44      0.80      0.57        10
      Myopia       0.88      0.88      0.88        16
       Other       0.31      0.90      0.46        81

   micro avg       0.45      0.82      0.58       353
   macro avg       0.54      0.74      0.60       353
weighted avg       0.49      0.82      0.60       353
 samples avg  

0,1
epoch,‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
lr,‚ñà‚ñà‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
train_acc,‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
train_f1,‚ñÅ‚ñÇ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
train_loss,‚ñà‚ñÜ‚ñÜ‚ñÖ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
val_acc,‚ñÅ‚ñÉ‚ñÑ‚ñÉ‚ñÇ‚ñÉ‚ñÑ‚ñÉ‚ñÜ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÑ‚ñÑ‚ñÑ‚ñÜ‚ñÉ‚ñÑ‚ñà‚ñÑ‚ñà‚ñÖ‚ñá‚ñÖ‚ñà‚ñá‚ñÜ‚ñá
val_f1,‚ñÅ‚ñÑ‚ñÜ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñá‚ñà‚ñá‚ñá‚ñà‚ñá‚ñà‚ñá‚ñá‚ñá‚ñà‚ñà‚ñá‚ñà
val_f1_AMD,‚ñÅ‚ñÉ‚ñÉ‚ñÉ‚ñÖ‚ñÑ‚ñÑ‚ñá‚ñÜ‚ñá‚ñÜ‚ñá‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñà‚ñá‚ñá‚ñÜ‚ñá‚ñá‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñÜ‚ñá
val_f1_Cataract,‚ñÅ‚ñÑ‚ñà‚ñÜ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÑ‚ñÉ‚ñÑ‚ñÖ‚ñÑ‚ñÜ‚ñÖ‚ñÖ‚ñÑ‚ñÉ‚ñÖ‚ñÖ‚ñÑ‚ñÖ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñà‚ñÑ‚ñÑ
val_f1_Diabetes,‚ñÅ‚ñÇ‚ñÉ‚ñÑ‚ñÇ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÖ‚ñÜ‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñÜ‚ñÜ‚ñá‚ñá‚ñà‚ñá‚ñá‚ñá‚ñà‚ñá‚ñá‚ñá

0,1
epoch,30
lr,0.0
train_acc,33.01376
train_f1,0.77558
train_loss,0.36088
val_acc,27.98635
val_f1,0.62856
val_f1_AMD,0.59259
val_f1_Cataract,0.66667
val_f1_Diabetes,0.67337


## Document Model

In [13]:
from torchinfo import summary
summary(model, input_data=[
    torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(DEVICE), 
    torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(DEVICE)], 
    col_names=["input_size", "output_size", "num_params", "trainable"])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Trainable
OptimizedModule                          [1, 3, 512, 512]          [1, 8]                    --                        True
‚îú‚îÄODIRDualNet: 1-1                       [1, 3, 512, 512]          [1, 8]                    5,322,884                 True
Total params: 5,322,884
Trainable params: 5,322,884
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0
Input size (MB): 6.29
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 6.29