In [1]:
pip install gradio torch torchvision pillow kagglehub

Collecting gradio
  Downloading gradio-5.7.0-py3-none-any.whl.metadata (16 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.5-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.4.0-py3-none-any.whl.metadata (2.9 kB)
Collecting gradio-client==1.5.0 (from gradio)
  Downloading gradio_client-1.5.0-py3-none-any.whl.metadata (7.1 kB)
Collecting python-multipart==0.0.12 (from gradio)
  Downloading python_multipart-0.0.12-py3-none-any.whl.metadata (1.9 kB)
Collecting ruff>=0.2.2 (from gradio)
  Downloading ruff-0.8.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<1.0,>=0.1.1 (from gradio)
  Downloading safehttpx-0.1.1-py3-none-any.whl.metadata (4.1 kB)
Collecting semantic-version~=2.0 (from gradio)
  Downloading semantic_version-2.10.0-py2.py3-none-any.whl.metadata (9.7 kB)
Collecting starlette<1.0,>=0.40.0 (from gradio)
  Downloading starlette-0.41.3-py3-none-any.whl.metadata (6.0

In [2]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models
import gradio as gr

## Loading Models

In [3]:
def load_model_from_path(model_name, file_path=None):

    model = None

    if model_name == "VGG16":
        model = vgg16(weights=None)  
        model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2)  
        
    elif model_name == "ResNet18":
        model = resnet18(weights=None)  
        model.fc = torch.nn.Linear(model.fc.in_features, 2) 
        
    elif model_name == "ResNet50":
        model = resnet50(weights=None) 
        model.fc = torch.nn.Linear(model.fc.in_features, 2) 
    else:
        raise ValueError(f"Unsupported model name: {model_name}")

    if file_path:
        try:
            model.load_state_dict(torch.load(file_path, map_location=device), strict=False)
        except Exception as e:
            raise RuntimeError(f"Error loading model weights from {file_path}: {e}")
  
    model = model.to(device)
    model.eval()
    
    return model


## Tumor classification

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.CenterCrop(224),
    transforms.ColorJitter(brightness=2),
    transforms.RandomResizedCrop(size=(224, 224), antialias=True),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

def predict(image, model_name):
    if image is None:
        return "No image provided", "", "", ""

    kaggle_model_paths = {
        "VGG16": "/kaggle/input/tumortrace-models/VGG16.pth",
        "ResNet18": "/kaggle/input/tumortrace-models/Resnet18.pth",
        "ResNet50": "/kaggle/input/tumortrace-models/Resnet50.pth"
    }
    
    model_file = kaggle_model_paths.get(model_name)
    if not model_file:
        return f"Model {model_name} not found", "", "", ""

    model = load_model_from_path(model_name, model_file)

    image = Image.open(image).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(image)
        probabilities = torch.nn.functional.softmax(outputs[0], dim=0)

    benign_score = float(probabilities[0]) * 100
    malignant_score = float(probabilities[1]) * 100
    pred_class = "Malignant" if malignant_score > benign_score else "Benign"
    confidence = abs(benign_score - malignant_score)

    return pred_class, f"{benign_score:.2f}", f"{malignant_score:.2f}", f"{confidence:.2f}"


## Gardio 

In [10]:
title = "TumorTrace : MRI-Based AI for Breast Cancer Detection"
description = (
    "Upload an image of a tumor and select a model to classify it as Benign or Malignant. "
)

interface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="filepath", label="Upload Image"),
        gr.Radio(choices=["VGG16", "ResNet18", "ResNet50"], label="Choose Model")
    ],
    outputs=[
        gr.Textbox(label="Prediction"),
        gr.Textbox(label="Benign Probability (%)"),
        gr.Textbox(label="Malignant Probability (%)"),
        gr.Textbox(label="Confidence (%)")
    ],
    title=title,
    description=description
)

interface.launch(share=True)


* Running on local URL:  http://127.0.0.1:7862

Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.


2024/11/28 11:40:46 [W] [service.go:132] login to server failed: tls: failed to verify certificate: x509: certificate has expired or is not yet valid: current time 2024-11-28T11:40:46Z is after 2024-11-28T06:24:31Z


