# 3. Zero-Shot Classification with CLIP

Now that we understand how CLIP creates meaningful embeddings, we can leverage them for **zero-shot classification**—classifying images into categories without any task-specific training. CLIP maps images and text to the same embedding space, allowing us to classify an image by finding the text description with the most similar embedding.

**Key concepts covered:**
*   Zero-shot classification principles
*   Text prompts for classification
*   Applying a model to a FiftyOne dataset
*   Evaluating classification results including accuracy and confusion matrix

## Setup

Let's import our libraries and load the test dataset, which should now contain the CLIP embeddings from the previous step.

In [None]:
import os
import numpy as np
import torch
import gc
import fiftyone as fo
import fiftyone.zoo as foz
from fiftyone import ViewField as F

# Ensure the test dataset exists
if "mnist-test-set" in fo.list_datasets():
    test_dataset = fo.load_dataset("mnist-test-set")
else:
    print("Test dataset not found. Please run '1_explore_mnist.ipynb' and '2_clip_embeddings.ipynb' first.")

session = fo.launch_app(test_dataset, auto=False)
print(session.url)

## Performing Zero-Shot Classification

First, we get the distinct class labels from our dataset. For MNIST, these are the digits 0-9.

In [None]:
dataset_classes = sorted(test_dataset.distinct("ground_truth.label"))
dataset_classes

### The Effect of the Text Prompt

With CLIP, the `text_prompt` significantly affects accuracy. It provides context for the class labels. We'll use a simple prompt, but feel free to experiment with others like "A grayscale image of the number" or "An MNIST digit" to see how performance changes.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

clip_model = foz.load_zoo_model(
    "clip-vit-base32-torch",
    text_prompt="A photo of ",
    classes=dataset_classes,
    device=device
)

Now we apply the model to our dataset. The `apply_model()` method will iterate through the samples, generate a prediction for each one, and store it in a new field. We set `store_logits=True` to save the model's raw output scores, which are useful for later analysis.

In [None]:
test_dataset.apply_model(
    model=clip_model,
    label_field="clip_zero_shot_classification",
    store_logits=True,
    batch_size=256,
    num_workers=os.cpu_count()
)

session.refresh()
print("CLIP predictions added to the dataset.")

## Evaluating CLIP's Performance

Now that we have predictions, we can evaluate them against the ground truth labels. FiftyOne's `evaluate_classifications()` method provides a comprehensive report and a confusion matrix.

You can also view these results interactively in the FiftyOne App's [Model Evaluation Panel](https://docs.voxel51.com/user_guide/app.html#model-evaluation-panel-sub-new).

In [None]:
clip_evaluation_results = test_dataset.evaluate_classifications(
    "clip_zero_shot_classification",
    gt_field="ground_truth",
    eval_key="clip_zero_shot_eval")

session.refresh()

Let's print a classification report. We can see the overall accuracy is around 88%, as expected.

In [None]:
clip_evaluation_results.print_report(classes=dataset_classes, digits=3)

The confusion matrix shows which classes are most often confused. A perfect model would have values only on the main diagonal. Here, we see significant confusion between certain digits (like 4s and 9s, or 3s and 8s).

In [None]:
clip_evaluation_results.plot_confusion_matrix()

### Cleaning Up

Before moving on, let's clear the CLIP model from memory to free up GPU resources for the next steps.

In [None]:
del clip_model
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("CLIP model cleared from memory.")

## Next Steps

We have successfully performed zero-shot classification and evaluated the results. This provides a strong baseline for comparison.

Next, we will build and train a traditional supervised learning model, LeNet-5, to see how a task-specific model compares.

Proceed to `4_lenet_training.ipynb`.