In [None]:
import torch
from torchvision import models
import gradio as gr
import torchvision.transforms as transforms
from PIL import Image
import torch.nn as nn

# Load the model
resnet18 = models.resnet18(pretrained=False)
resnet18.fc = nn.Linear(resnet18.fc.in_features, 2)  # Modify the last layer
resnet18.load_state_dict(torch.load(r"resnet18_1.pth", map_location=torch.device('cpu')))
resnet18.eval()

# Preprocessing: Transform for the input image
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to model input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalization
])

def predict(image):
    # Preprocess the image
    image = Image.fromarray(image)
    image = transform(image).unsqueeze(0)  # Add batch dimension

    with torch.no_grad():
        outputs = resnet18(image)
        _, predicted = torch.max(outputs, 1)  # Get the class with max probability
        label = "Malignant" if predicted.item() == 1 else "Benign"

    return label

# CSS Styles
css = """
body {
    background: linear-gradient(-45deg, #e3f2fd, #bbdefb, #90caf9, #64b5f6);
    background-size: 400% 400%;
    animation: gradientBG 15s ease infinite;
    font-family: 'Arial', sans-serif;
    margin: 0;
    padding: 0;
}

@keyframes gradientBG {
    0% {background-position: 0% 50%;}
    50% {background-position: 100% 50%;}
    100% {background-position: 0% 50%;}
}

h1 {
    color: #0d47a1;
    text-align: center;
    font-size: 3em;
    margin-top: 20px;
}

h2 {
    color: #1565c0;
    text-align: center;
    font-size: 1.5em;
    margin-bottom: 50px;
}

.gradio-container {
    max-width: 800px;
    margin: auto;
    padding: 20px;
    background-color: rgba(255, 255, 255, 0.8);
    border-radius: 10px;
}

.gr-button {
    background-color: #1976d2;
    color: white;
    border: none;
    font-size: 1em;
    padding: 10px 20px;
    margin-top: 20px;
}

.gr-button:hover {
    background-color: #1565c0;
}

.gr-input, .gr-output {
    margin-top: 20px;
}

"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("<h1>TumorTrace: MRI-Based AI for Breast Cancer Detection</h1>")
    gr.Markdown("<h2>Upload an MRI scan to classify it as Benign or Malignant</h2>")
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="numpy", label="Input Image")
            classify_button = gr.Button("Classify")
        with gr.Column():
            output_label = gr.Textbox(label="Classification Result")
    classify_button.click(fn=predict, inputs=image_input, outputs=output_label)

demo.launch()

  resnet18.load_state_dict(torch.load(r"c:\Users\shrey\Desktop\resnet18_1.pth", map_location=torch.device('cpu')))


* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


