# 🛒 Grocery Store – Training an Image Classifier

In this notebook, we’ll build and train a simple image classification model to recognize grocery items from photos. The dataset contains labeled images of various food products, each assigned to a coarse category like "vegetables", "dairy", or "snacks".

We’ll walk through a full training workflow, including:

- Loading and preprocessing image data
- Defining a PyTorch dataset and data loaders
- Fine-tuning a pretrained MobileNet model
- Evaluating the model's performance on a validation set
- Visualizing sample predictions with images

By the end, you'll have a basic image classifier running locally, ready to be served or deployed in later steps of the workshop.

In [None]:
from data_utils import *
from model_utils import *
from train_utils import *

# Get class names and dataloaders
classes = get_classes()
train_loader, val_loader = get_loaders()

# Create a model and train it
model = create_model(len(classes))
train_model(model, train_loader)

# Evaluate model accuracy
accuracy = evaluate_model(model, val_loader)
print(f"\nValidation Accuracy: {accuracy:.2%}")

**🖼️ Visualizing Predictions**

To better understand the model's behavior, we display **10 sample predictions** from the validation set. For each image, we show the predicted and true class names. The pixel values are unnormalized for display so the images look natural. This kind of qualitative inspection is useful for spotting obvious misclassifications or biases in the model.

In [None]:
import matplotlib.pyplot as plt

# Show predictions with images
print("\nSample predictions with images:")
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

# Get 10 images, predictions, and labels again for visualization
shown = 0
for images, labels in val_loader:
    images_cpu = images.cpu()
    outputs = model(images.to(DEVICE))
    _, predicted = torch.max(outputs, 1)
    predicted = predicted.cpu()

    for i in range(images.size(0)):
        if shown >= 10:
            break
        img = images_cpu[i].permute(1, 2, 0).numpy()
        img = img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]  # unnormalize
        img = img.clip(0, 1)

        axes[shown].imshow(img)
        axes[shown].axis("off")
        axes[shown].set_title(f"Predict: {classes[predicted[i].item()]}\nTrue label: {classes[labels[i].item()]}")
        shown += 1

    if shown >= 10:
        break

plt.tight_layout()
plt.show()