In [None]:
import torch
from torchvision import transforms, models
from PIL import Image
import os
import cv2
import time

# Hyperparameters and device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the image transformation (same as used during training)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the saved model
model = models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 1)
model.load_state_dict(torch.load(r""))# Load the .pth model file here
model = model.to(DEVICE)
model.eval()

# Function for preprocessing the fundus image
def preprocess_fundus_image(image_path, output_size=(512, 512)):
    """
    Preprocesses a fundus retina image by cropping out black regions, ensuring a square aspect ratio, 
    and resizing it to a standardized size (default: 512x512). The preprocessed image is returned.
    """
    # Step 1: Load Image
    original_image = cv2.imread(image_path)
    if original_image is None:
        print(f"Error: Image not found at specified path: {image_path}")
        return None
    
    # Step 2: Convert Image to Grayscale
    gray_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)

    # Step 3: Threshold to Detect Retina Region
    _, binary_mask = cv2.threshold(gray_image, 10, 255, cv2.THRESH_BINARY)

    # Step 4: Find Contours to Detect Retina
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        print(f"Warning: No retina region detected in {image_path}")
        return None

    largest_contour = max(contours, key=cv2.contourArea)
    x, y, w, h = cv2.boundingRect(largest_contour)

    # Crop the image around the bounding box
    cropped_image = original_image[y:y+h, x:x+w]

    # Step 5: Adjust to Square Aspect Ratio by Adding Padding
    height, width = cropped_image.shape[:2]
    if height > width:
        # Add padding to width
        padding = (height - width) // 2
        padded_image = cv2.copyMakeBorder(cropped_image, 0, 0, padding, height - width - padding, cv2.BORDER_CONSTANT, value=[0, 0, 0])
    elif width > height:
        # Add padding to height
        padding = (width - height) // 2
        padded_image = cv2.copyMakeBorder(cropped_image, padding, width - height - padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0])
    else:
        padded_image = cropped_image

    # Step 6: Resize to Standardized Size
    resized_image = cv2.resize(padded_image, output_size, interpolation=cv2.INTER_AREA)

    # Return the preprocessed image
    return resized_image

# Function to make predictions on a single image
def predict(image_path):
    # Step 1: Preprocess the image
    preprocessed_image = preprocess_fundus_image(image_path)
    if preprocessed_image is None:
        return None

    # Step 2: Convert the processed image to PIL format and apply transformations
    preprocessed_image = Image.fromarray(cv2.cvtColor(preprocessed_image, cv2.COLOR_BGR2RGB))
    image = transform(preprocessed_image).unsqueeze(0)  # Add batch dimension

    image = image.to(DEVICE)

    # Step 3: Make the prediction
    with torch.no_grad():
        output = model(image).squeeze()
        predicted_score = output.item()

    return predicted_score

In [None]:
# Predict for an individual image
img_path = r"" # REPLACE WITH OWN IMAGE PATH

# Timing the function call
start_time1 = time.time()
predicted_score = predict(img_path)

if predicted_score is not None:
    print(f"Predicted score: {predicted_score}")
else:
    print(f"Prediction failed for {img_path}")

end_time1 = time.time()

elapsed_time1 = end_time1 - start_time1
print("Time used: ", elapsed_time1)

print("Device used:", DEVICE)

Predicted score: 0.20264814794063568
Time used:  0.2736399173736572
Device used: cpu
