In [None]:
pip install gradio

Collecting gradio
  Downloading gradio-5.29.0-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<25.0,>=22.0 (from gradio)
  Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.10.0 (from gradio)
  Downloading gradio_client-1.10.0-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.11.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<0.2.0,>=0.1.6

In [None]:
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import numpy as np
from skimage.color import rgb2gray
from skimage.filters import threshold_otsu, median
from skimage.morphology import opening, dilation, square
from skimage.measure import label, regionprops
from google.colab import drive

# Mount Drive
drive.mount('/content/drive')

# Constants
NUM_CLASSES = 4
DROP_PROB = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = ['Abnormal Heartbeat', 'History of MI', 'MI', 'Normal']
CROP_TOP = 10
CROP_BOTTOM = 10
MEDIAN_KERNEL_SIZE = 3
DILATION_SIZE = 2

# Heuristics for ECG detection
MIN_WHITE_RATIO = 0.005
MAX_WHITE_RATIO = 0.5
MIN_LINE_LENGTH = 100
MIN_REGION_AREA = 300
MIN_ASPECT_RATIO = 2.0  # width / height

# Load and modify model
model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
for p in model.features[:16].parameters():
    p.requires_grad = False
for p in model.features[16:].parameters():
    p.requires_grad = True

in_features = model.classifier[0].in_features
model.classifier = nn.Sequential(
    nn.Dropout(DROP_PROB),
    nn.Linear(in_features, 256),
    nn.ReLU(inplace=True),
    nn.Dropout(DROP_PROB),
    nn.Linear(256, NUM_CLASSES)
)

model.load_state_dict(torch.load("/content/drive/MyDrive/best_vgg16_(processed)ecg.pth", map_location=DEVICE))
model.eval().to(DEVICE)

# Image preprocessing
def preprocess_image(np_image):
    np_image = np_image[CROP_TOP:np_image.shape[0] - CROP_BOTTOM, :, :]
    gray = rgb2gray(np_image)
    filtered = median(gray, behavior='ndimage')
    thresh = threshold_otsu(filtered)
    binary = filtered < thresh
    binary = opening(binary, square(2))
    binary = dilation(binary, square(DILATION_SIZE))
    return binary

def binary_to_tensor(binary_mask):
    rgb = np.zeros((*binary_mask.shape, 3), dtype=np.uint8)
    rgb[binary_mask] = [255, 255, 255]
    rgb[~binary_mask] = [0, 0, 0]
    image = Image.fromarray(rgb).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)

# Prediction with improved validation
def predict_ecg(image_np):
    try:
        binary_mask = preprocess_image(image_np)

        white_ratio = binary_mask.sum() / binary_mask.size
        if white_ratio < MIN_WHITE_RATIO or white_ratio > MAX_WHITE_RATIO:
            return "❌ Error: Image doesn't appear to contain a valid ECG signal (abnormal white pixel ratio)."

        labeled = label(binary_mask)
        regions = regionprops(labeled)

        long_regions = [
            r for r in regions
            if r.major_axis_length >= MIN_LINE_LENGTH and r.area >= MIN_REGION_AREA
        ]

        if not long_regions:
            return "❌ Error: No valid ECG waveform detected!"

        largest = max(long_regions, key=lambda r: r.major_axis_length)
        bbox_height = largest.bbox[2] - largest.bbox[0]
        bbox_width = largest.bbox[3] - largest.bbox[1]
        if bbox_width / bbox_height < MIN_ASPECT_RATIO:
            return "❌ Error: No valid ECG waveform detected!"

        input_tensor = binary_to_tensor(binary_mask).to(DEVICE)
        with torch.no_grad():
            outputs = model(input_tensor)
            predicted_class = outputs.argmax(1).item()
        return f"✅ Prediction: {class_names[predicted_class]}"
    except Exception as e:
        return f"❌ Error: Failed to process image. Please upload a valid ECG image.\nDetails: {str(e)}"

# UI
img2_path = "/content/drive/MyDrive/docteam.jpg"

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🩺 HeartGuard: ECG Image Classifier")
    gr.Markdown("Is your heart healthy? Upload your ECG image to know!")

    with gr.Row():
        with gr.Column():
            input_img = gr.Image(label="Upload ECG Image", type="numpy")
            submit_btn = gr.Button("Classify")
            result = gr.Label(label="Result")
        with gr.Column():
            gr.Markdown("""### 💡 How It Works:
- HeartGuard preprocesses ECG images and uses a deep learning model to classify them
- Classifies images into:
  - **Abnormal Heartbeat(HB)** 🚨
  - **History of Myocardial Infarction(PMI)** 🫣
  - **Myocardial Infarction(MI)** 💔
  - **Normal** 👍
- Powered by deep learning for accurate results 🧠📊""")
            gr.Image(value=img2_path, type="filepath", show_label=False)

    submit_btn.click(fn=predict_ecg, inputs=input_img, outputs=result)

demo.launch(share=True)

Mounted at /content/drive


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:03<00:00, 149MB/s]


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://96f8ecb6ea0f0f558d.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


