In [1]:
%%capture
!pip install gradio torch efficientnet-pytorch plotly pillow matplotlib ipywidgets

In [2]:
import sys
sys.path.append('.')

from efficient_net import EfficientNetTrainer
from preprocessor import ODIRImageProcessor

In [3]:
config = {
    'model_name': 'efficientnet-b4',
    'num_classes': 8,
    'image_size': 512,
}
MODEL_PATH = "ODIR-2019/results/efficientnet-b4-odir-2019-pmg.pth"  # Change this to your model path
trainer = EfficientNetTrainer(config)
trainer.load_model(path=MODEL_PATH)
processor = ODIRImageProcessor()

Using device: cuda
Loading efficientnet-b4 with 8 classes...
Loaded pretrained weights for efficientnet-b4
Model loaded from ODIR-2019/results/efficientnet-b4-odir-2019-pmg.pth


## SANITY CHECK
Here we check whether our model is able to predict correctly

In [4]:
trainer.predict_single_image_path(image_path="ODIR-2019/YOLO/processed_512g_merged/test/diabetes/26_left.jpg")


üîç Predicting: 26_left.jpg
(Load: 6.74 ms)
Prediction completed in 297.48 ms Transform: 8.12 ms, Inference: 262.21 ms, Process: 27.14 ms)


{'predicted_class': 'diabetes',
 'confidence': 0.937221348285675,
 'class_index': 1,
 'predictions': [('diabetes', 0.937221348285675),
  ('ageing', 0.0272287055850029),
  ('hypertension', 0.009323053993284702),
  ('glaucoma', 0.008425207808613777),
  ('myopia', 0.006647056899964809),
  ('cataract', 0.004811877850443125),
  ('normal', 0.0032602217979729176),
  ('other', 0.003082534298300743)],
 'total_time_ms': 297.47700691223145,
 'inference_time_ms': 262.2108459472656}

Check for image that has not been processed, using our preprocessor

In [5]:
trainer.predict_single_image_path(image_path="ODIR-2019/YOLO/preprocessed/test/diabetes/26_left.jpg", 
                             preprocessor=processor.load_image)



üîç Predicting: 26_left.jpg
(Load: 27.76 ms)
Prediction completed in 54.78 ms Transform: 3.84 ms, Inference: 15.13 ms, Process: 35.81 ms)


{'predicted_class': 'diabetes',
 'confidence': 0.9442383646965027,
 'class_index': 1,
 'predictions': [('diabetes', 0.9442383646965027),
  ('ageing', 0.02547771856188774),
  ('glaucoma', 0.007623407524079084),
  ('hypertension', 0.007071923930197954),
  ('myopia', 0.006035467144101858),
  ('cataract', 0.004399948753416538),
  ('other', 0.0033553375396877527),
  ('normal', 0.0017977905226871371)],
 'total_time_ms': 54.780006408691406,
 'inference_time_ms': 15.128135681152344}

In [None]:
import gradio as gr
import cv2
import numpy as np
from PIL import Image
import time
# Let's assume your updated function now takes a NumPy array:
ODIR_CLASS_DESCRIPTIONS = {
    'normal': 'Healthy fundus with no apparent pathologies',
    'diabetes': 'Diabetic Retinopathy - presence of microaneurysms, hemorrhages, hard exudates, cotton wool spots, or neovascularization',
    'glaucoma': 'Glaucoma - optic nerve damage with increased cup-to-disc ratio, retinal nerve fiber layer defects',
    'cataract': 'Cataract - lens opacity visible through fundus image',
    'ageing': 'Age-related Macular Degeneration - drusen, pigmentary changes, geographic atrophy, or neovascularization in macular region',
    'hypertension': 'Hypertensive Retinopathy - arteriovenous nicking, copper/silver wiring, flame-shaped hemorrhages',
    'myopia': 'Pathological Myopia - tessellated fundus, peripapillary atrophy, posterior staphyloma',
    'other': 'Other retinal conditions including vein occlusion, retinal detachment, tumors, etc.'
}

def diagnosis_wrapper(input_array):
    if input_array is None:
        # Return empty/default values to the outputs to prevent a crash
        return None, {}, "No image provided. Please upload a scan."
    # 1. Handle Color Space
    # Gradio provides RGB. If your model/OpenCV logic expects BGR:
    # input_bgr = cv2.cvtColor(input_array, cv2.COLOR_RGB2BGR)
    start_time = time.perf_counter()
    img_bgr=cv2.cvtColor(input_array, cv2.COLOR_RGB2BGR)
    img_bgr=processor.preprocess_image(img_bgr)
    img_rgb=cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    img = Image.fromarray(img_rgb)
    # 2. Run your prediction on the array
    # (Update your predict function to accept the array instead of a path)
    results = trainer.predict_single_image(img)
    
    # 3. Format results for Gradio Label
    confidences = {label: float(conf) for label, conf in results['predictions']}
    predicted_class = results["predicted_class"]
    predicted_conf = round(results["confidence"]*100)
    comment = ODIR_CLASS_DESCRIPTIONS.get(predicted_class,"")
    # 4. Return the image array itself and the labels
    end_time = time.perf_counter()
    elapsed_time = end_time - start_time

    summary=f"The model predict in {elapsed_time:.4f} s with {predicted_conf}% confidence that the eye is likely {predicted_class}\n{comment}"
    
    return img_rgb, confidences, summary

# Define the Interface
with gr.Blocks() as demo:
    gr.Markdown("## üëÅÔ∏è Rapid Eye Analysis")
    
    with gr.Row():
        # INPUT: Set type="numpy" to pass the actual pixel array
        with gr.Column():
            img_input = gr.Image(type="numpy", label="Input Image", width=512, height=512)
        with gr.Column():
            summary = gr.Textbox(label="Summary")    
            label_output = gr.Label(num_top_classes=5)
        with gr.Column():
            img_output = gr.Image(label="Processed Image", width=512, height=512)

    # Link the input change directly to the function for real-time feel
    img_input.change(
        fn=diagnosis_wrapper, 
        inputs=img_input, 
        outputs=[img_output, label_output, summary]
    )

demo.launch(share=True)

* Running on local URL:  http://127.0.0.1:7861
* Running on public URL: https://5b4120fc5f26eea598.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




Prediction completed in 80.08 ms Transform: 21.65 ms, Inference: 55.38 ms, Process: 3.04 ms)
Prediction completed in 70.35 ms Transform: 19.62 ms, Inference: 10.54 ms, Process: 40.19 ms)
Prediction completed in 63.94 ms Transform: 13.04 ms, Inference: 10.28 ms, Process: 40.61 ms)
Prediction completed in 71.40 ms Transform: 15.36 ms, Inference: 54.18 ms, Process: 1.86 ms)
Prediction completed in 67.49 ms Transform: 16.26 ms, Inference: 10.88 ms, Process: 40.35 ms)
Prediction completed in 55.10 ms Transform: 4.42 ms, Inference: 12.42 ms, Process: 38.25 ms)
Prediction completed in 54.92 ms Transform: 4.10 ms, Inference: 11.60 ms, Process: 39.22 ms)
Prediction completed in 65.65 ms Transform: 15.13 ms, Inference: 10.96 ms, Process: 39.56 ms)
Prediction completed in 55.20 ms Transform: 4.73 ms, Inference: 13.28 ms, Process: 37.19 ms)
Prediction completed in 65.62 ms Transform: 15.29 ms, Inference: 12.02 ms, Process: 38.31 ms)
Prediction completed in 55.10 ms Transform: 4.54 ms, Inference: 1