[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1T6kUzlBPIi4V8RAl7m95LZMLUkJsnWhU)

Author:
- **Safouane El Ghazouali**,
- Ph.D. in AI,
- Senior data scientist and researcher at TOELT LLC,
- Lecturer at HSLU

# -----  -----  -----  -----  -----  -----  -----  -----

# Hands-on: CLIP for Zero-Shot Image Classification

Welcome to this comprehensive hands-on notebook on using CLIP (Contrastive Language-Image Pre-training) for zero-shot image classification! CLIP, developed by OpenAI, enables models to classify images without task-specific training by leveraging natural language prompts.

![CLIP Example](https://cdn.prod.website-files.com/651c34ac817aad4a2e62ec1b/653694d75b5a654627f55a91_image%20(7).png)

### Why Use CLIP?
- **Zero-Shot Capability**: Classify images using arbitrary text labels without retraining.
- **Flexibility**: Handles diverse tasks like classification, retrieval, and more via prompts.
- **Robustness**: Trained on 400M image-text pairs, generalizes well to new domains.
- **Open-Source Alternatives**: Projects like OpenCLIP extend CLIP with more models and datasets.

### What You'll Learn
- Loading and using OpenAI's CLIP and OpenCLIP models.
- Performing zero-shot classification on single images from URLs.
- Batch processing multiple images with custom prompts.
- Evaluating zero-shot performance on a small dataset like CIFAR10.
- Prompt engineering, ensembles, top-k metrics, and visualization.
- Comparing models from different repositories.

# 🧰 Environment Setup

Install CLIP, OpenCLIP, and dependencies for data handling and visualization.

In [None]:
!pip install -q git+https://github.com/openai/CLIP.git requests pillow
!pip install -q open_clip_torch torchvision matplotlib scikit-learn tqdm

### Import Libraries

Import modules for model loading, image processing, and analysis.

In [None]:
import torch
import clip
import open_clip
import requests
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

# 📚 Understanding CLIP and Zero-Shot Classification

CLIP learns visual concepts from natural language supervision. For zero-shot:
1. Encode image and text prompts into embeddings.
2. Compute cosine similarities.
3. Select the highest similarity as the prediction.

We'll start with single images, then batches, and evaluate on CIFAR10.

Reference: [OpenAI CLIP GitHub](https://github.com/openai/CLIP) - Trained on 400M pairs.
[OpenCLIP GitHub](https://github.com/mlfoundations/open_clip) - Open-source with LAION models.

# 📦 Loading OpenAI CLIP Model

Load a pre-trained CLIP model (ViT-B/32 for balance).

In [None]:
clip_model, clip_preprocess = clip.load('ViT-B/32', device=device)
print('OpenAI CLIP model loaded!')

# 🖼️ Zero-Shot Classification on a Single Image

Download an image from a URL, define text prompts, and compute predictions.

In [None]:
# Example image URL (a bird)
image_url = 'http://farm6.staticflickr.com/5517/9349775899_6a2d58ab9a_z.jpg'

# Download and preprocess
response = requests.get(image_url)
img = Image.open(BytesIO(response.content))
img_preprocessed = clip_preprocess(img).unsqueeze(0).to(device)

# Display image
plt.imshow(img)
plt.title('Sample Image')
plt.axis('off')
plt.show()

# Text prompts
prompts = ['a cat', 'a dog', 'a car', 'a bird']
text_inputs = clip.tokenize([f'a photo of {p}' for p in prompts]).to(device)

# Inference
with torch.no_grad():
    image_features = clip_model.encode_image(img_preprocessed)
    text_features = clip_model.encode_text(text_inputs)

    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    values, indices = similarity[0].topk(len(prompts))

# Results
print('Top Predictions:')
for value, index in zip(values, indices):
    print(f'{prompts[index]:>16s}: {100 * value.item():.2f}%')

# Explanation
# - Image downloaded via URL and preprocessed.
# - Prompts tokenized and encoded.
# - Similarities computed; highest is the prediction.

# 📸 Batch Processing Multiple Images

Process a list of image URLs with shared prompts.

In [None]:
# List of image URLs
image_urls = [
    'http://farm3.staticflickr.com/2607/4152809797_e86483de70_z.jpg',  # Cat
    'http://farm8.staticflickr.com/7276/7051382257_b45a38c0ec_z.jpg',  # Dog
    'http://farm6.staticflickr.com/5306/5653379279_49e1b67bc2_z.jpg'   # Car
]

# Prompts
prompts = ['a cat', 'a dog', 'a car', 'a bird']
text_inputs = clip.tokenize([f'a photo of {p}' for p in prompts]).to(device)

with torch.no_grad():
    text_features = clip_model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)

# Process images
fig, axs = plt.subplots(1, len(image_urls), figsize=(15, 5))
for i, url in enumerate(image_urls):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content))
    img_preprocessed = clip_preprocess(img).unsqueeze(0).to(device)

    with torch.no_grad():
        image_features = clip_model.encode_image(img_preprocessed)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
        pred_idx = similarity.argmax().item()

    axs[i].imshow(img)
    axs[i].set_title(f'Pred: {prompts[pred_idx]}')
    axs[i].axis('off')
plt.show()

# Explanation
# - Multiple images processed in a loop.
# - Text features computed once for efficiency.
# - Predictions visualized with images.

# 📊 Evaluating on a Small Dataset (CIFAR10 Subset)

Now, apply zero-shot to a small subset of CIFAR10 for systematic evaluation.

In [None]:
# Define classes
cifar_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Load CIFAR10 test set
transform = transforms.Compose([transforms.ToTensor()])
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Small subset (100 images)
subset_indices = list(range(100))
subset_dataset = Subset(test_dataset, subset_indices)
test_loader = DataLoader(subset_dataset, batch_size=32, shuffle=False)

print(f'Loaded CIFAR10 subset with {len(subset_dataset)} images.')

# ✍️ Preparing Text Prompts

Use templates for CIFAR10 classes.

In [None]:
prompt_template = 'a photo of a {}'
text_inputs = clip.tokenize([prompt_template.format(c) for c in cifar_classes]).to(device)

ensemble_templates = [
    'a photo of a {}',
    'an image of {}',
    'a picture of the {}'
]
print('Prompts ready!')

# 🔍 Zero-Shot on CIFAR10

In [None]:
with torch.no_grad():
    text_features = clip_model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)

predictions = []
labels_list = []
for images, labels in tqdm(test_loader):
    images = torch.stack([clip_preprocess(transforms.ToPILImage()(img)) for img in images]).to(device)
    image_features = clip_model.encode_image(images)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    preds = similarity.argmax(dim=1).cpu().numpy()
    predictions.extend(preds)
    labels_list.extend(labels.numpy())

accuracy = np.mean(np.array(predictions) == np.array(labels_list)) * 100
print(f'Accuracy: {accuracy:.2f}%')

# 🔄 OpenCLIP Models

Load and test an OpenCLIP model for comparison.

In [None]:
openclip_model, _, openclip_preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
openclip_model.to(device).eval()
openclip_tokenizer = open_clip.get_tokenizer('ViT-B-32')

image_url = 'http://farm6.staticflickr.com/5517/9349775899_6a2d58ab9a_z.jpg'

# Download and preprocess
response = requests.get(image_url)
img = Image.open(BytesIO(response.content))
img_preprocessed = openclip_preprocess(img).unsqueeze(0).to(device)

# Display image
plt.imshow(img)
plt.title('Sample Image')
plt.axis('off')
plt.show()

# Text prompts
prompts = ['a cat', 'a dog', 'a car', 'a bird']
text_inputs = openclip_tokenizer([f'a photo of {p}' for p in prompts]).to(device)

# Inference
with torch.no_grad():
    image_features = openclip_model.encode_image(img_preprocessed)
    text_features = openclip_model.encode_text(text_inputs)

    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    values, indices = similarity[0].topk(len(prompts))

# Results
print('Top Predictions:')
for value, index in zip(values, indices):
    print(f'{prompts[index]:>16s}: {100 * value.item():.2f}%')

# Explanation
# - Image downloaded via URL and preprocessed.
# - Prompts tokenized and encoded.
# - Similarities computed; highest is the prediction.

# 💡 Student Task

1. Classify a custom image URL with your own prompts.
2. Test batch on more challenging image from COCO dataset.
3. Evaluate full CIFAR10.
4. Check out OpenCLIP github repo and try other models.