<a href="https://colab.research.google.com/github/sayanarajasekhar/GenerativeAiApplications/blob/main/Image_Classification_PyTorch_Gradio.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Image Classification

 Image classification is a central task in computer vision. Building better classifiers to classify what object is present in a picture is an active area of research, as it has applications stretching from autonomous vehicles to medical imaging

### Step 1: Setting up the image classification model

In [2]:
import torch

model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained = True).eval()

Downloading: "https://github.com/pytorch/vision/zipball/v0.6.0" to /root/.cache/torch/hub/v0.6.0.zip
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 154MB/s]


### Step 2: Defining a predict function

We will need to define a function that takes in the user input, which in this case is an image, and returns the prediction. The prediction should be returned as a dictionary whose keys are class name and values are confidence probabilities. We will load the class names from this text file.



In [None]:
import requests
from PIL import Image
from torchvision import transforms

# Download human-readable labels for ImageNet
response = requests.get("https://git.io/JJkYN")
labels = response.text.split('\n')

def predict(image):
  '''
    input image as PIL image
  '''
  input = transforms.ToTensor()(image).unsqueeze(0)
  with torch.inference_mode():
    predictions = torch.nn.functional.softmax(model(input)[0], dim = 0)
    confidences = { labels[i]: float(predictions[i]) for i in range(1000)}
  return confidences

### Step 3: Creating a Gradio interface


Now that we have our predictive function set up, we can create a Gradio Interface around it.

In this case, the input component is a drag-and-drop image component. To create this input, we use `Image(type=“pil”)` which creates the component and handles the preprocessing to convert that to a PIL image.

The output component will be a Label, which displays the top labels in a nice form. Since we don't want to show all 1,000 class labels, we will customize it to show only the top 3 classes by constructing it as `Label(num_top_classes=3)`.

In [None]:
import gradio as gr

gr.Interface(
    fn = predict,
    inputs = gr.Image(type = 'pil'),
    outputs = gr.Label(num_top_classes = 3),
    examples=["/content/lion.jpg", "/content/cheetah.jpg"]
  ).launch()
