In [11]:
from pathlib import Path
import torch
import gradio as gr
from torch import nn

In [13]:
LABELS = Path("class_names.txt").read_text().splitlines()

In [14]:
model = nn.Sequential(
    nn.Conv2d(1, 32, 3, padding=1), # 28x28
    nn.ReLU(), # 28x28
    nn.MaxPool2d(2), # 14x14
    nn.Conv2d(32, 64, 3, padding=1), # 14x14
    nn.ReLU(), # 14x14
    nn.MaxPool2d(2), # 7x7
    nn.Conv2d(64, 128, 3, padding=1), # 7x7
    nn.ReLU(), # 7x7
    nn.MaxPool2d(2), # 3x3
    nn.Flatten(), # 128 * 3 * 3
    nn.Linear(1152, 256), # 1152
    nn.ReLU(), # 256
    nn.Linear(256, len(LABELS)) # 10
)

In [15]:
# Load the model
state_dict = torch.load("pytorch_model.bin", map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=False)
model.eval()

Sequential(
  (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU()
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): ReLU()
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): ReLU()
  (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (9): Flatten(start_dim=1, end_dim=-1)
  (10): Linear(in_features=1152, out_features=256, bias=True)
  (11): ReLU()
  (12): Linear(in_features=256, out_features=100, bias=True)
)

In [16]:
# Define the predict function
def predict(image):
    x = torch.tensor(image, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0 # 1x1x28x28
    with torch.no_grad():
        logits = model(x)
    probs = torch.nn.functional.softmax(logits[0], dim=0)
    values, indices = torch.topk(probs, 5)
    return {LABELS[i]: v.item() for i, v in zip(indices, values)}

In [19]:
# Define and the gradio interface

interface = gr.Interface(
    fn=predict,
    inputs="sketchpad",
    outputs="label",
    theme="default",
    title="Sketch Recognition",
    description="Play pictionary? Draw a common object like a cat, dog, or car and see if the model can guess what it is.",
    article="<p style='text-align: center'>Sketch Recognition | Demo Model</p>",
    live=True,
)

In [20]:
# Launch the interface
interface.launch(share=True)

Running on local URL:  http://127.0.0.1:7864
Running on public URL: https://938c7df53dd4cb9ead.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




Traceback (most recent call last):
  File "c:\Users\Raj\repos\hf-nlp\.venv\Lib\site-packages\gradio\routes.py", line 437, in run_predict
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Raj\repos\hf-nlp\.venv\Lib\site-packages\gradio\blocks.py", line 1352, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Raj\repos\hf-nlp\.venv\Lib\site-packages\gradio\blocks.py", line 1077, in call_function
    prediction = await anyio.to_thread.run_sync(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Raj\repos\hf-nlp\.venv\Lib\site-packages\anyio\to_thread.py", line 33, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Raj\repos\hf-nlp\.venv\Lib\site-packages\anyio\_backends\_asyncio.py", line 877, in run_sync_in_worker_thread
    return await future
           ^