In [None]:
from torchvision.datasets import Imagenette
import numpy as np
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPModel
import torch
import tqdm
from sklearn.metrics import accuracy_score, f1_score

In [None]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
dataset = Imagenette(root = './data', split = 'val', download = True)
class_names = []

for names in dataset.classes:
  class_names.append(names[0])

In [None]:
prompts = [f"a photo of {name}" for name in class_names]
print(prompts)

['a photo of tench', 'a photo of English springer', 'a photo of cassette player', 'a photo of chain saw', 'a photo of church', 'a photo of French horn', 'a photo of garbage truck', 'a photo of gas pump', 'a photo of golf ball', 'a photo of parachute']

In [None]:
y_true = []
y_pred = []

with torch.no_grad():
  text_inputs = processor(text = prompts, return_tensors = 'pt', padding = True).to("cuda")
  text_features = model.get_text_features(**text_inputs).pooler_output
  text_features = text_features / text_features.norm(dim = 1, keepdim = True)

for image, label in tqdm.tqdm(dataset, desc = 'Total'):
  with torch.no_grad():
    image_inputs = processor(images = image, return_tensors = 'pt', padding = True).to("cuda")
    image_features = model.get_image_features(**image_inputs).pooler_output
    image_features = image_features / image_features.norm(dim = 1, keepdim = True)

  logit_scale = model.logit_scale.exp().item()
  temperature = 1.0 / logit_scale

  similarity = (image_features @ text_features.T) * temperature
  pred = similarity.argmax(dim = 1)

  y_pred.append(pred.cpu().item())
  y_true.append(label)

accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average = 'micro')

print(f"\nAccuracy = {100 * accuracy:.2f}%")
print(f"F1 = {f1:.2f}")

Total: 100%|██████████| 3925/3925 [01:14<00:00, 52.79it/s]

Accuracy = 98.80%

F1 = 0.99