In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
from PIL import Image
import os
import matplotlib.pyplot as plt
import cv2
import tempfile
import torch.nn.functional as F

# Class map
class_map = {0: 'NORMALROP', 1: 'AROP1', 2: 'AROP2'}
retinopathy_classes = ['AROP1', 'AROP2']

# Transformations (same as training)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 3)
model.load_state_dict(torch.load("rop_model.pth", map_location=device))
model.to(device)
model.eval()

# Load image from path or npy
def load_image(file_path):
    ext = os.path.splitext(file_path)[1].lower()

    if ext == '.npy':
        img = np.load(file_path)
        if img.ndim == 2:
            img = np.repeat(img[..., np.newaxis], 3, axis=2)
        img = (img * 255).astype(np.uint8) if img.max() <= 1 else img.astype(np.uint8)
        original = Image.fromarray(img)
    else:
        original = Image.open(file_path).convert("RGB")

    img_tensor = transform(original).unsqueeze(0)
    return img_tensor, original

# Plot brightness histogram
def plot_brightness_histogram(image):
    gray = image.convert("L")
    pixel_values = np.array(gray).flatten()

    plt.figure(figsize=(6, 4))
    plt.hist(pixel_values, bins=256, range=(0, 255), color='gray')
    plt.title("🧪 Brightness Histogram")
    plt.xlabel("Pixel Intensity (0=Dark, 255=Bright)")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Capture image from webcam
def capture_image_from_camera():
    cap = cv2.VideoCapture(0)
    print("📷 Press 'c' to capture the image, or 'q' to quit.")
    captured_path = None

    while True:
        ret, frame = cap.read()
        if not ret:
            print("❌ Failed to open camera.")
            break

        cv2.imshow("Live Camera - Press 'c' to capture", frame)
        key = cv2.waitKey(1)

        if key & 0xFF == ord('c'):
            with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
                captured_path = tmp.name
                cv2.imwrite(captured_path, frame)
                print(f"✅ Image captured and saved at: {captured_path}")
            break
        elif key & 0xFF == ord('q'):
            print("🚪 Exiting without capturing.")
            break

    cap.release()
    cv2.destroyAllWindows()
    return captured_path

# Confidence Trust Report
def get_trust_score(probability):
    if probability >= 0.9:
        return "✅ High Trust"
    elif probability >= 0.7:
        return "⚠️ Moderate Trust"
    else:
        return "❌ Low Trust"

# Inference
def predict_image():
    print("📌 Select input type:\n1. Provide image from folder\n2. Capture image from camera")
    option = input("Enter 1 or 2: ").strip()

    if option == '1':
        file_path = input("📂 Enter the full path of the image or .npy file: ").strip()
    elif option == '2':
        file_path = capture_image_from_camera()
        if file_path is None:
            return
    else:
        print("❌ Invalid option.")
        return

    if not os.path.exists(file_path):
        print("❌ File not found.")
        return

    img_tensor, original_img = load_image(file_path)
    img_tensor = img_tensor.to(device)

    with torch.no_grad():
        output = model(img_tensor)
        probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
        pred_class = np.argmax(probabilities)
        pred_label = class_map[pred_class]
        confidence = probabilities[pred_class]

        print(f"\n🔍 Prediction: {pred_label}")
        print(f"🔐 Confidence: {confidence:.4f}")
        print(f"📊 Trust Score: {get_trust_score(confidence)}")

        if pred_label in retinopathy_classes:
            print("🧠 Diagnosis: ✅ Retinopathy Detected")
        else:
            print("🧠 Diagnosis: ❌ No Retinopathy (Normal)")

        # Histogram
        print("\n📊 Displaying brightness histogram...")
        plot_brightness_histogram(original_img)

# 🔁 Run prediction
predict_image()
