Adapted from [OpenAI's Zero-Shot Prediction example](https://github.com/openai/CLIP?tab=readme-ov-file#zero-shot-prediction)

### Import model

In [1]:
import clip
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model, _ = clip.load('ViT-B/32', device)

### Import the Dataset

In [2]:
from modules import ImageNetA
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

from PIL import Image

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

def _convert_image_to_rgb(image):
    return image.convert("RGB")

def _transform(n_px):
    return Compose([
        Resize(n_px, interpolation=BICUBIC),
        CenterCrop(n_px),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

dataset = ImageNetA("dataset/imagenet-a", _transform, n_px=model.visual.input_resolution)

### Split the dataset into training, testing and evaluation

In [3]:
num_samples = len(dataset)
training_sample = int(num_samples * 0.5 + 1)
validation_sample = int(num_samples * 0.25)
test_sample = num_samples - training_sample - validation_sample

training_dataset, validation_dataset, test_dataset = torch.utils.data.random_split(dataset, [training_sample, validation_sample, test_sample])

# Create a DataLoader
train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=32, shuffle=True, num_workers=8)
val_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=32, shuffle=False, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=8)

# Single Image

In [5]:
import clip
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

image, class_id = dataset[0]
image_input = image.unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in dataset.classes.values()]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
print(f"Ground truch: {dataset.idx_to_class(class_id.item())}")
for value, index in zip(values, indices):
    print(f"{list(dataset.classes.values())[index]:>16s}: {100 * value.item():.2f}%")


Top predictions:

Ground truch: {'stingray'}
        stingray: 16.53%
american bullfrog: 11.72%
     cowboy boot: 9.27%
            newt: 7.11%
   manhole cover: 5.80%


### One image at a time (SLOW - don't run it.)

In [7]:

import torch
from torch.utils.data import DataLoader

# Set up device and load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Assuming test_dataset is already defined and a DataLoader is ready
batch_size = 32  # Set the batch size according to your hardware capabilities
dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

for images, class_ids in dataloader:
    # Preprocess images and move them to the appropriate device
    image_inputs = images.to(device)
    text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in dataset.classes.values()]).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image_inputs)
        text_features = model.encode_text(text_inputs)

    # Normalize features
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    # Compute similarities between images and text features
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

    # Pick the top 5 most similar labels for each image in the batch
    values, indices = similarity.topk(1, dim=-1)

    for i in range(len(images)):
        # Print the result for each image in the batch
        print("\nTop predictions:\n")
        print(f"Ground truth: {dataset.idx_to_class(class_ids[i].item())}")
        for value, index in zip(values[i], indices[i]):
            print(f"{list(dataset.classes.values())[index]:>16s}: {100 * value.item():.2f}%")


## It works (Batches: FAST)

In [24]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

# Set up device and load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Assuming test_dataset is already defined and a DataLoader is ready
batch_size = 256  # Set the batch size according to your hardware capabilities
dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Prepare text inputs once outside the loop
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in dataset.classes.values()]).to(device)

# Create a tensor that maps the indices to class IDs
class_keys_tensor = torch.tensor(list(dataset.classes.keys())).to(device)

# Variables to track the number of correct predictions and total predictions
correct_predictions = 0
total_predictions = 0

for images, class_ids in dataloader:
    # Preprocess images and move them to the appropriate device
    image_inputs = images.to(device)

    with torch.no_grad():
        # Encode image and text inputs
        image_features = model.encode_image(image_inputs)
        text_features = model.encode_text(text_inputs)

        # Normalize features
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)

        # Compute similarities between images and text features
        similarity = image_features @ text_features.T

        # Get the index of the top 1 most similar label for each image in the batch
        predicted_indices = similarity.argmax(dim=-1)

    # Map predicted indices to actual class IDs
    predicted_class_ids = class_keys_tensor[predicted_indices]

    # Calculate the number of correct predictions in the batch
    correct_predictions += (predicted_class_ids == class_ids.to(device)).sum().item()
    total_predictions += class_ids.size(0)

# Calculate and print the overall accuracy
accuracy = correct_predictions / total_predictions * 100
print(f"\nAccuracy: {accuracy:.2f}% ({correct_predictions}/{total_predictions} correct predictions)")



Accuracy: 29.62% (555/1874 correct predictions)
