In [None]:
!pip install gradio

In [None]:
import gradio as gr
import torch
import cv2
from PIL import Image
from torchvision import transforms, models
from huggingface_hub import hf_hub_download

# Download model from Hugging Face
repo_id = "potguy/efficientnet_clahe_fracture_classification"
filename = "efficientnet_clahe_hf.pth"
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"✅ Model downloaded to: {model_path}")

# Load model
model = models.efficientnet_b0(weights=None)
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, 2)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()

# Define preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def apply_clahe(image):
    """Apply CLAHE to an input image"""
    img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)  # Convert to grayscale
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    img_clahe = clahe.apply(img)
    img_rgb = cv2.cvtColor(img_clahe, cv2.COLOR_GRAY2RGB)  # Convert back to 3 channels
    return img_rgb

def predict(image):
    """Make a fracture prediction on an uploaded image"""
    image = apply_clahe(image)
    image = Image.fromarray(image)  # Convert to PIL format
    image = transform(image).unsqueeze(0)  # Apply transformations

    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)

    return "Fractured" if predicted.item() == 1 else "Not Fractured"

interface = gr.Interface(
    fn=predict,  # Your model function
    inputs=gr.Image(type="numpy"),  # Upload image
    outputs="text",  # Display prediction
    title="🦴 AI Fracture Detection",  # Custom Title
    description="Upload an X-ray image to check for fractures. This AI model uses EfficientNet with CLAHE preprocessing.",  # Custom Description
    theme="default",  # Other themes: "huggingface", "soft", "dark"
    allow_flagging="never"  # Removes flagging button
)

with gr.Blocks(css="""
    #title { text-align: center; font-size: 24px; }
    #desc { text-align: center; font-style: italic; }
    #image-container { display: flex; justify-content: center; } /* Centers image */
""") as interface:
    gr.Markdown("## 🦴 AI Fracture Detection", elem_id="title")
    gr.Markdown("*Upload an X-ray image to check for fractures.*", elem_id="desc")

    with gr.Column(elem_id="image-container"):  # Centers the whole section
        image = gr.Image(type="numpy", label="Upload X-ray Image")
        output = gr.Textbox(label="Prediction")

    btn = gr.Button("Analyze")  # Button below output
    btn.click(fn=predict, inputs=image, outputs=output)

interface.launch(share=True)

✅ Model downloaded to: /root/.cache/huggingface/hub/models--potguy--efficientnet_clahe_fracture_classification/snapshots/6b74521c73d092f53b83dda9f0bea659bd01d543/efficientnet_clahe_hf.pth


  model.load_state_dict(torch.load(model_path, map_location="cpu"))


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://0a240a8e2417be618a.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


