### 🧠 Interactive Prediction Explorer with Gradio

In this notebook, you can **explore how your trained image classifier behaves** on real validation images — interactively.

Use the dropdown to select a class (e.g. apple, kiwi, tomato). The model will then:

* Randomly sample 5 images from that class,
* Run predictions using your trained MobileNetV3 model,
* Display the images along with predicted labels.

This is a lightweight visual tool to **understand model behavior** and **spot patterns or mistakes** — especially useful when debugging, presenting results, or doing ad-hoc testing.

In [None]:
import gradio as gr
from data_utils import get_classes, GroceryDataset, transform, VAL_CSV
from model_utils import load_model, DEVICE
import torch
import random
import matplotlib.pyplot as plt

# Load class labels and model
classes = get_classes()
model = load_model(num_classes=len(classes))
model.eval()

# Load validation dataset
dataset = GroceryDataset(csv_file=VAL_CSV, transform=transform)
images_by_class = {} # dict(class_name => list(images))
for image, label in dataset:
    class_name = classes[label]
    if class_name not in images_by_class:
        images_by_class[class_name] = []
    images_by_class[class_name].append(image)

# Helper to get N images of a certain class
def get_random_images(class_name, n=5):
    matching = images_by_class[class_name]
    return random.sample(matching, min(n, len(matching)))

# Inference function
def predict(class_name):
    images = get_random_images(class_name, n=5)
    outputs = model(torch.stack(images).to(DEVICE))
    _, preds = torch.max(outputs, 1)
    preds = preds.cpu().numpy()

    fig, axs = plt.subplots(1, len(images), figsize=(15, 3))
    for i, ax in enumerate(axs):
        img = images[i].permute(1, 2, 0).numpy()
        img = img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]  # unnormalize
        img = img.clip(0, 1)
        ax.imshow(img)
        ax.axis("off")
        ax.set_title(f"Pred: {classes[preds[i]]}")
    return fig

# Launch Gradio inside notebook
with gr.Blocks() as demo:
    gr.Markdown("### 🛒 Grocery Classifier Explorer")
    class_dropdown = gr.Dropdown(choices=list(classes.values()), label="Select a class")
    output_plot = gr.Plot()
    run_btn = gr.Button("Show predictions")
    run_btn.click(fn=predict, inputs=class_dropdown, outputs=output_plot)

demo.launch(inline=True)
