In [1]:
%%capture
!pip install -q  torch torchvision scikit-learn pandas opencv-python torchinfo gradio

In [2]:
import torch 
import torch.nn as nn
from torchvision import models
from torch.utils.checkpoint import checkpoint
NUM_CLASSES=8
IMG_SIZE = 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class ODIRDualNet(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        # Using B0 for efficiency, upgrade to B4 for better accuracy
        self.backbone = models.efficientnet_b0(weights='DEFAULT') # use pretrained weights for better feature extraction
        # ref https://docs.pytorch.org/vision/main/models/generated/torchvision.models.efficientnet_b0.html#torchvision.models.EfficientNet_B0_Weights
        self.feature_dim = self.backbone.classifier[1].in_features # Get feature dimension before classifier
        self.backbone.classifier = nn.Identity() # Remove top layer
        self.features = self.backbone.features # Extract feature extractor part for checkpointing
        self.classifier = nn.Sequential( # replace classifier with a custom head that combines features from both eyes
            nn.Linear(self.feature_dim * 2, IMG_SIZE), # Combine features from both eyes
            nn.ReLU(), # Non-linearity for better learning relu f(x) = max(0, x)
            nn.Dropout(0.3), # Regularization to prevent overfitting
            nn.Linear(IMG_SIZE, num_classes) # Final output layer for multi-label classification
        )

    def forward(self, left, right):
        # manually checkpoint the feature extraction part to save memory, since EfficientNet can be quite large, especially B4
        l_feat = checkpoint(self.features, left, use_reentrant=False)
        r_feat = checkpoint(self.features, right, use_reentrant=False)
        
        # Global Average Pooling to get (Batch, Feat_Dim)
        l_feat = torch.flatten(nn.functional.adaptive_avg_pool2d(l_feat, 1), 1)
        r_feat = torch.flatten(nn.functional.adaptive_avg_pool2d(r_feat, 1), 1)
        combined = torch.cat((l_feat, r_feat), dim=1) # Combine features from both eyes
        return self.classifier(combined) # Pass through classifier to get final predictions

In [3]:
import os
import torch
SAVED_MODELS_DIR = "saved_models"
os.makedirs(SAVED_MODELS_DIR, exist_ok=True)
IMAGE_PREP_NAME = "gamma" # Name for this image pre-processing method, used for directory naming and logging
RUN_NAME = f"efficient-b0_{IMAGE_PREP_NAME}" # Unique name for this run, used for saving models and logging
SAVED_MODEL_PATH = os.path.join(SAVED_MODELS_DIR, f"{RUN_NAME}_best.pth")
best_model = torch.load(SAVED_MODEL_PATH, weights_only=False)
thresholds = best_model['thresholds']
CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Other']

model = ODIRDualNet().to(DEVICE)
if hasattr(torch, 'compile'):
    model = torch.compile(model)
    print("‚úÖ Model Compiled for speed.")
model.load_state_dict(best_model['model'])
thresholds = best_model['thresholds']


‚úÖ Model Compiled for speed.


<All keys matched successfully>

In [7]:
import sys
sys.path.append('.')
from preprocessing import custom_gamma
import cv2
import numpy as np
from PIL import Image

def preprocessing(img_pil_rgb):
    img_np = np.array(img_pil_rgb)
    img = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
    img = custom_gamma(img)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return Image.fromarray(img)


In [11]:
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
from torchvision import models, transforms



# 2. PREDICTION LOGIC
def predict(left_img, right_img):
    if left_img is None and right_img is None:
        return None, None, None, "‚ùå Please upload at least one image."
    
    info_msg = "‚úÖ Binocular analysis complete."
    print("images ")
    
    # Handle single upload
    if left_img is None:
        left_img = right_img.transpose(Image.FLIP_LEFT_RIGHT)
        info_msg = "‚ö†Ô∏è Using flipped Right eye for missing Left eye."
    elif right_img is None:
        right_img = left_img.transpose(Image.FLIP_LEFT_RIGHT)
        info_msg = "‚ö†Ô∏è Using flipped Left eye for missing Right eye."

    # Process images for the model
    # proc_l_pil = medical_prep(left_img)
    # proc_r_pil = medical_prep(right_img)
    proc_l_pil = preprocessing(left_img)
    proc_r_pil = preprocessing(right_img)
    
    # Convert to Tensor for model
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    l_tensor = transform(proc_l_pil).unsqueeze(0).to(DEVICE)
    r_tensor = transform(proc_r_pil).unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        logits = model(l_tensor, r_tensor)
        probs = torch.sigmoid(logits).cpu().numpy()[0]

    results={}
    diagnoses=[]
    for i, name in enumerate(CLASS_NAMES):
        p = float(probs[i])
        t = thresholds[i]
        results[name] = p 
        if p>=t:
            diagnoses.append(f"**{name}** (Prob: {p:.2f} > Thr: {t:.2f})")
    if not diagnoses:
        info_msg = "### Summary: No diseases detected."
    else:
        info_msg = "### üö© Detected Conditions:\n" + "\n".join(diagnoses)
    # Return processed PIL images for display, results, and msg
    return proc_l_pil, proc_r_pil, results, info_msg

# 3. INTERFACE DESIGN
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
    gr.Markdown("# üëÅÔ∏è ODIR Diagnostic Dashboard")
    gr.Markdown(''' 
                ```
                Some images that we can try : 
                    - 112_left.jpg,112_right.jpg,normal fundus,cataract
                    - 43_left.jpg,43_right.jpg,wet age-related macular degeneration,dry age-related macular degenerationÔºåglaucoma
                    - 32_left.jpg,32_right.jpg,hypertensive retinopathy
                    - 71_left.jpg,71_right.jpg,diabetic retinopathy,wet age-related macular degenerationÔºådiabetic retinopathy
                ```                
                ''')    
    with gr.Row():
        with gr.Column():
            gr.Markdown("### 1. Upload Raw Images")
            with gr.Column():
                in_l = gr.Image(label="Raw Left Eye", type="pil", width=256, height=256)
                in_r = gr.Image(label="Raw Right Eye", type="pil", width=256, height=256)
            btn = gr.Button("Process & Diagnose", variant="primary")
        
        with gr.Column():
            gr.Markdown("### 2. Model's View (Processed)")
            out_l = gr.Image(label="Processed Left", interactive=False, width=256, height=256)
            out_r = gr.Image(label="Processed Right", interactive=False, width=256, height=256)
            
        with gr.Column():
            gr.Markdown("### 3. Diagnostic Results")
            out_label = gr.Label(num_top_classes=8)
            status = gr.Markdown("Status: Waiting for input...")

    btn.click(
        fn=predict, 
        inputs=[in_l, in_r], 
        outputs=[out_l, out_r, out_label, status]
    )

demo.launch()

  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:


* Running on local URL:  http://127.0.0.1:7865
* To create a public link, set `share=True` in `launch()`.




images 
images 
images 
