### 🧠 Interactive Prediction Explorer with Gradio

In this notebook, you’ll create a simple **interactive tool** that helps you visually explore how well your trained image classifier performs.

The idea is straightforward but powerful:

- **Pick a class label** (e.g. apple, kiwi, tomato) from a dropdown menu,
- The tool will randomly select **5 validation images** from that class,
- It will run them through your **trained MobileNet model**,
- And display the images side by side, along with the model’s **predicted class** for each.

This interface is built using [Gradio](https://gradio.app/) — a Python library that lets you build UI components quickly and easily, directly from your code. It’s great for **debugging**, **demonstrations**, or just getting a better feel for how your model behaves.

✅ Make sure Gradio is installed: `pip install gradio`


In [None]:
import gradio as gr
from data_utils import get_classes, GroceryDataset, transform, VAL_CSV
from model_utils import load_model, DEVICE
import torch
import random
import matplotlib.pyplot as plt

# Load class labels and model
classes = get_classes()
model = load_model(num_classes=len(classes))
model.eval()

# Load validation dataset
dataset = GroceryDataset(csv_file=VAL_CSV, transform=transform)

# ✅ Group images by class name
images_by_class = {}  # e.g. "fruits" => [img1, img2, ...]

for image, label in dataset:
    class_name = classes[label]
    if class_name not in images_by_class:
        images_by_class[class_name] = []
    images_by_class[class_name].append(image)

# ✅ This function is complete — no changes needed
def get_random_images(class_name, n=5):
    """Return N random images from the selected class."""
    matching = images_by_class[class_name]
    return random.sample(matching, min(n, len(matching)))


def predict(class_name):
    images = get_random_images(class_name, n=5)
    outputs = model(torch.stack(images).to(DEVICE))
    _, preds = torch.max(outputs, 1)
    preds = preds.cpu().numpy()

    fig, axs = plt.subplots(1, len(images), figsize=(15, 3))
    for i, ax in enumerate(axs):
        img = images[i].permute(1, 2, 0).numpy()
        img = img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]  # unnormalize
        img = img.clip(0, 1)
        ax.imshow(img)
        ax.axis("off")
        ax.set_title(f"Pred: {classes[preds[i]]}")
    return fig


# ✅ Launch the Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("### 🛒 Grocery Classifier Explorer")
    class_dropdown = gr.Dropdown(choices=list(classes.values()), label="Select a class")
    output_plot = gr.Plot()
    run_btn = gr.Button("Show predictions")
    run_btn.click(fn=predict, inputs=class_dropdown, outputs=output_plot)

demo.launch(server_name="10.26.26.x")  # TODO set your IP address