In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

import clip

In [3]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = [
    "a photo of an airplane",
    "a photo of an automobile",
    "a photo of a bird",
    "a photo of a cat",
    "a photo of a deer",
    "a photo of a dog",
    "a photo of a frog",
    "a photo of a horse",
    "a photo of a ship",
    "a photo of a truck",
]

100%|██████████| 170M/170M [00:48<00:00, 3.48MB/s] 


In [6]:
dataiter = iter(trainloader)
images, labels = next(dataiter)

print(images[0].shape)

torch.Size([3, 224, 224])


In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# Tokenize prompts
text_inputs = clip.tokenize(classes).to(device)

correct = 0
total = 0

with torch.no_grad():
    text_features = model.encode_text(text_inputs)

    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)

        image_features = model.encode_image(images)

        # Compute similarity
        logits_per_image = image_features @ text_features.T
        preds = torch.argmax(logits_per_image, dim=1)

        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Zero-shot CLIP accuracy on CIFAR-10: {100 * correct / total:.2f}%")


Zero-shot CLIP accuracy on CIFAR-10: 76.77%
