# 3. Zero-Shot Classification with CLIP

Now that we understand how CLIP creates meaningful embeddings, we can leverage them for **zero-shot classification**. This means classifying images into new categories without any task-specific model retraining. CLIP maps images and text to the same embedding space, allowing us to classify an image by finding the text description with the most similar embedding, something that we explored briefly [in the previous notebook](https://github.com/andandandand/fiftyone/blob/develop/docs/source/getting_started_experiences/Classification/2_clip_embeddings.ipynb). 

**Key concepts covered:**
*   Zero-shot classification principles
*   Text prompts for classification
*   [Applying a model](https://docs.voxel51.com/api/fiftyone.core.dataset.html#fiftyone.core.dataset.Dataset) to a FiftyOne [dataset](https://docs.voxel51.com/api/fiftyone.core.dataset.html#fiftyone.core.dataset.Dataset)
*   Evaluating classification results including accuracy and confusion matrix

## Setup

Let's import our libraries and load the test dataset, which should now contain the CLIP embeddings from the previous step.

In [1]:
import os
import numpy as np
import torch
import gc
import fiftyone as fo
import fiftyone.zoo as foz
from fiftyone import ViewField as F

# Ensure the test dataset exists
if "mnist-test-set" in fo.list_datasets():
    test_dataset = fo.load_dataset("mnist-test-set")
else:
    print("Test dataset not found. Please run '1_explore_mnist.ipynb' and '2_clip_embeddings.ipynb' first.")

session = fo.launch_app(test_dataset, auto=False)
print(session.url)

Session launched. Run `session.show()` to open the App in a cell output.

Welcome to

███████╗██╗███████╗████████╗██╗   ██╗ ██████╗ ███╗   ██╗███████╗
██╔════╝██║██╔════╝╚══██╔══╝╚██╗ ██╔╝██╔═══██╗████╗  ██║██╔════╝
█████╗  ██║█████╗     ██║    ╚████╔╝ ██║   ██║██╔██╗ ██║█████╗
██╔══╝  ██║██╔══╝     ██║     ╚██╔╝  ██║   ██║██║╚██╗██║██╔══╝
██║     ██║██║        ██║      ██║   ╚██████╔╝██║ ╚████║███████╗
╚═╝     ╚═╝╚═╝        ╚═╝      ╚═╝    ╚═════╝ ╚═╝  ╚═══╝╚══════╝ v1.7.0

If you're finding FiftyOne helpful, here's how you can get involved:

|
|  ⭐⭐⭐ Give the project a star on GitHub ⭐⭐⭐
|  https://github.com/voxel51/fiftyone
|
|  🚀🚀🚀 Join the FiftyOne Discord community 🚀🚀🚀
|  https://community.voxel51.com/
|

http://0.0.0.0:5151/


## Performing Zero-Shot Classification

First, we get the distinct class labels from our dataset. For MNIST, these are the digits 0-9.

In [2]:
dataset_classes = sorted(test_dataset.distinct("ground_truth.label"))
dataset_classes

['0 - zero',
 '1 - one',
 '2 - two',
 '3 - three',
 '4 - four',
 '5 - five',
 '6 - six',
 '7 - seven',
 '8 - eight',
 '9 - nine']

### The Effect of the Text Prompt

With CLIP, the `text_prompt` significantly affects accuracy. It provides context for the class labels. We'll use a simple prompt, but feel free to experiment with others like "A grayscale image of the number" or "An MNIST digit" to see how performance changes.

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

clip_model = foz.load_zoo_model(
    "clip-vit-base32-torch",
    text_prompt="A photo of ",
    classes=dataset_classes,
    device=device
)

Downloading model from 'https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt'...
 100% |██████|    2.6Gb/2.6Gb [36.2s elapsed, 0s remaining, 123.6Mb/s]      
Downloading CLIP tokenizer...
 100% |█████|   10.4Mb/10.4Mb [156.6ms elapsed, 0s remaining, 66.1Mb/s]     


Now we apply the model to our dataset. The `apply_model()` method will iterate through the samples, generate a prediction for each one, and store it in a new field. We set `store_logits=True` to save the model's raw output scores, which are useful for later analysis.

In [4]:
test_dataset.apply_model(
    model=clip_model,
    label_field="clip_zero_shot_classification",
    store_logits=True,
    batch_size=256,
    num_workers=os.cpu_count()
)

session.refresh()
print("CLIP predictions added to the dataset.")

 100% |█████████████| 10000/10000 [48.9s elapsed, 0s remaining, 199.1 samples/s]      
CLIP predictions added to the dataset.


## Evaluating CLIP's Performance

Now that we have predictions, we can evaluate them against the ground truth labels. FiftyOne's `evaluate_classifications()` method provides a comprehensive report and a confusion matrix.

You can also view these results interactively in the FiftyOne App's [Model Evaluation Panel](https://docs.voxel51.com/user_guide/app.html#model-evaluation-panel-sub-new).

In [5]:
clip_evaluation_results = test_dataset.evaluate_classifications(
    "clip_zero_shot_classification",
    gt_field="ground_truth",
    eval_key="clip_zero_shot_eval")

session.refresh()

Let's print a classification report. We can see that the overall performance is well below the 88% accuracy reported on the original CLIP paper (as we are using a smaller variant of the model).

In [6]:
clip_evaluation_results.print_report(classes=dataset_classes, digits=3)

              precision    recall  f1-score   support

    0 - zero      0.256     0.999     0.407       980
     1 - one      0.144     0.033     0.053      1135
     2 - two      1.000     0.016     0.031      1032
   3 - three      0.899     0.387     0.541      1010
    4 - four      0.515     0.178     0.265       982
    5 - five      0.600     0.161     0.254       892
     6 - six      0.076     0.254     0.117       958
   7 - seven      0.612     0.634     0.623      1028
   8 - eight      0.179     0.116     0.141       974
    9 - nine      0.000     0.000     0.000      1009

    accuracy                          0.275     10000
   macro avg      0.428     0.278     0.243     10000
weighted avg      0.427     0.275     0.241     10000



The confusion matrix shows which classes are most often confused. A perfect model would have values only on the main diagonal. Here, we see significant confusion between certain digits (like 4s and 9s, or 3s and 8s).

In [7]:
clip_evaluation_results.plot_confusion_matrix()

ImportError: Please install anywidget to use the FigureWidget class

### Cleaning Up

Before moving on, let's clear the CLIP model from memory to free up GPU resources for the next steps.

In [None]:
del clip_model
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("CLIP model cleared from memory.")

## Next Steps

We have successfully performed zero-shot classification and evaluated the results. This provides a strong baseline for comparison.

Next, we will build and train a traditional supervised learning model, LeNet-5, to see how a task-specific model compares.

Proceed to `4_lenet_training.ipynb`.