In [1]:
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

Collecting ftfy
  Downloading ftfy-6.2.0-py3-none-any.whl (54 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.4/54.4 kB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: ftfy
Successfully installed ftfy-6.2.0
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-mt99mizh
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-mt99mizh
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->clip==1.0)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->clip==1.0)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==

In [2]:
import os
import numpy as np
import torch
import torch.utils.data as data
from torch.multiprocessing import Pool, set_start_method
from concurrent.futures import ThreadPoolExecutor

In [3]:
import clip

print(clip.available_models())

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']


In [4]:
model, preprocess = clip.load("ViT-B/32")  ## YOU ARE FREE TO CHOOSE ANY MODEL
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

100%|███████████████████████████████████████| 338M/338M [00:03<00:00, 94.1MiB/s]


In [5]:
print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408


In [6]:
from torchvision.datasets import CIFAR100

cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /root/.cache/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:03<00:00, 48069707.33it/s]


Extracting /root/.cache/cifar-100-python.tar.gz to /root/.cache


In [7]:
text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()

with torch.no_grad():
    text_features = model.encode_text(text_tokens).float()
    text_features /= text_features.norm(dim=-1, keepdim=True)

In [8]:
print("Text Features Shape:", text_features.shape)
print("Sample Text Feature:", text_features[0][:5])

Text Features Shape: torch.Size([100, 512])
Sample Text Feature: tensor([-0.0048,  0.0590, -0.0039,  0.0091, -0.0080], device='cuda:0')


In [10]:
## WRITE YOUR CODE HERE - THIS CODE SHOULD PRINT THE FINAL ZERO-SHOT CLASSIFICATION ACCURACY

correct = 0
batch_size = 256

# Dataset class for streaming CIFAR-100
class StreamingCIFAR100(data.Dataset):
    def __init__(self, cifar_dataset):
        self.cifar100 = cifar_dataset

    def __len__(self):
        return len(self.cifar100)

    def __getitem__(self, index):
        return self.cifar100[index]

# Create DataLoader for CIFAR-100
streaming_cifar100 = StreamingCIFAR100(cifar100)
dataloader = torch.utils.data.DataLoader(streaming_cifar100, batch_size=batch_size, shuffle=False)

# Compute image features and classify
with torch.no_grad():
    for images, labels in dataloader:
        images = images.cuda()
        labels = labels.cuda()

        # Compute image features
        image_features = model.encode_image(images).float()
        image_features /= image_features.norm(dim=-1, keepdim=True)



        # Compute logits and predictions
        logits = image_features @ text_features.t()
        predicted_labels = logits.argmax(dim=-1)

        # Count correct predictions
        correct += (predicted_labels == labels).sum().item()

# Compute accuracy
accuracy = correct / len(streaming_cifar100)
print('Zero-shot Classification Accuracy = {:.2f}%'.format(accuracy * 100))


Zero-shot Classification Accuracy = 58.85%
