In [1]:
!pip install gradio

Collecting gradio
  Downloading gradio-6.0.1-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<25.0,>=22.0 (from gradio)
  Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting brotli>=1.1.0 (from gradio)
  Downloading brotli-1.2.0-cp311-cp311-macosx_10_9_universal2.whl.metadata (6.1 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.122.0-py3-none-any.whl.metadata (30 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-1.0.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==2.0.0 (from gradio)
  Downloading gradio_client-2.0.0-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting huggingface-hub<2.0,>=0.33.5 (from gradio)
  Downloading huggingface_hub-1.1.5-py3-none-any.whl.metadata (13 kB)
Collecting orjson~=3.0 (from gradio)
  Downloading orjson-3.11.4-cp311-cp311-macosx_15_0_arm64.whl.metadata (41 kB)
Collecting pydantic<=2.12.4,>=2

In [2]:
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
import numpy as np

# --- CONFIGURATION ---
MODEL_PATH = 'skin_cancer_resnet50_binary.pth'
IMAGE_SIZE = 224
NUM_CLASSES = 2
CLASS_NAMES = ['Benign', 'Malignant']

# --- M4 GPU (MPS) Check ---
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"✅ Prediction device set to: {DEVICE}")

# --- 1. MODEL SETUP FUNCTION (Must match training architecture) ---
def setup_model(num_classes):
    """Initializes the ResNet50 model structure."""
    model = models.resnet50(weights=None) # Load structure without weights
    
    # Replace the final classification layer for 2 classes
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    
    # Load the trained weights
    # map_location ensures weights load correctly regardless of where they were trained
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model = model.to(DEVICE)
    model.eval() # Set to evaluation mode
    return model

# Initialize the model once
try:
    model = setup_model(NUM_CLASSES)
    print("Trained model loaded successfully.")
except FileNotFoundError:
    print(f"ERROR: Model file not found at {MODEL_PATH}. Please check the path.")
    exit()

# --- 2. TRANSFORMATION PIPELINE (Must match validation transforms) ---
# Normalization stats must be identical to training
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# --- 3. PREDICTION FUNCTION (Core logic for Gradio) ---
def predict_image(input_img: Image.Image):
    """
    Takes a PIL image, processes it, and returns a dictionary of class probabilities.
    """
    # Apply transforms
    img_tensor = transform(input_img).unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        output = model(img_tensor)
        probabilities = torch.softmax(output, dim=1).cpu().squeeze().numpy()
        
    # Format for Gradio (dictionary of label: probability)
    results = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(NUM_CLASSES)}
    
    return results

# --- 4. GRADIO INTERFACE SETUP ---
# Define the interface components
input_image = gr.Image(label="Upload Skin Lesion Image (.jpg)", type="pil")
output_label = gr.Label(num_top_classes=2)

# Create and launch the interface
iface = gr.Interface(
    fn=predict_image,
    inputs=input_image,
    outputs=output_label,
    title="Skin Cancer Binary Classifier (ResNet50)",
    description="Upload an image to classify it as Benign or Malignant. Trained using PyTorch on the HAM10000 dataset (M4 GPU accelerated).",
    examples=["data/images/ISIC_0024310.jpg", "data/images/ISIC_0024468.jpg"] # You can use sample image paths here
)

# Launch the app (set share=True to generate a public link if needed, but not required locally)
iface.launch(inbrowser=True)

✅ Prediction device set to: mps
Trained model loaded successfully.
* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.


