In [1]:
import torch
import torchvision.models as models

# Load the pre-trained ResNet-18 model
model = models.resnet18(pretrained=True)

# Save the model locally
torch.save(model, 'resnet18_model.pth')

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 139MB/s]


In [2]:
!pip install gradio

Collecting gradio
  Downloading gradio-4.39.0-py3-none-any.whl (12.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.4/12.4 MB[0m [31m52.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl (15 kB)
Collecting fastapi (from gradio)
  Downloading fastapi-0.111.1-py3-none-any.whl (92 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.2/92.2 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ffmpy (from gradio)
  Downloading ffmpy-0.3.2.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting gradio-client==1.1.1 (from gradio)
  Downloading gradio_client-1.1.1-py3-none-any.whl (318 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m318.2/318.2 kB[0m [31m38.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting httpx>=0.24.1 (from gradio)
  Downloading httpx-0.27.0-py3-none-any.whl (75 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━

In [3]:
import torch
import torchvision
from torchvision.models import resnet18
import gradio as gr
import requests

# Download ImageNet class labels
url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
response = requests.get(url)
imagenet_labels = response.json()

def predict_image(image):
    try:
        # Load the pre-trained ResNet18 model
        model = resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
        model.eval()

        # Preprocess the input image
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(256),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image_tensor = transform(image)
        image_tensor = image_tensor.unsqueeze(0)

        # Use the model to make a prediction
        with torch.no_grad():
            output = model(image_tensor)
            _, predicted = torch.max(output.data, 1)

        # Get the class label
        predicted_label = imagenet_labels[predicted[0].item()]

        return f"The image is a {predicted_label}."
    except Exception as e:
        return f"Error: {str(e)}"

# Define CSS styling
css = """
    body, .gradio-container {
        font-family: 'Arial', sans-serif;
        background-color: #333;
        color: white;
    }
    .container {
        max-width: 800px;
        margin: 0 auto;
        padding: 20px;
        background: #444;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        border-radius: 8px;
    }
    h1 {
        color: #eee;
    }
    .gradio-container {
        display: flex;
        flex-direction: column;
        align-items: center;
    }
    .gr-button {
        background-color: #007bff;
        color: white;
        border: none;
        padding: 10px 20px;
        border-radius: 5px;
        cursor: pointer;
        font-size: 16px;
    }
    .gr-button:hover {
        background-color: #0056b3;
    }
    .gr-text-output {
        font-size: 18px;
        color: #eee;
        margin-top: 10px;
    }
    input[type=file]::file-selector-button {
        background-color: #007bff;
        color: white;
        border: none;
        padding: 10px 20px;
        border-radius: 5px;
        cursor: pointer;
    }
    input[type=file]::file-selector-button:hover {
        background-color: #0056b3;
    }
"""

# Create the Gradio interface
gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=gr.Textbox(label="Prediction"),
    title="Image Classification",
    description="Upload an image and I'll classify it using a pre-trained ResNet18 model.",
    css=css,
    theme="default",
).launch()


Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://f1a1aa1f7d1f75ecb7.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


