In [None]:
from torchvision.datasets import Imagenette
import numpy as np
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPModel
import torch
import tqdm

In [None]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
dataset = Imagenette(root = './data', split = 'val', download = True)

class_names = ['tench', 'English springer', 'cassette player', 'chain saw',
               'church', 'French horn', 'garbage truck', 'gas pump',
               'golf ball', 'parachute']

In [None]:
prompts = [f"a photo of {name}" for name in class_names]

In [None]:
correct = 0
total = 0

TP = FP = FN = 0

for target_class in range(10):
  for image, label in tqdm.tqdm(dataset, desc = 'Total'):
      inputs = processor(text=prompts, images=image, return_tensors="pt", padding=True).to("cuda")

      outputs = model(**inputs)
      logits_per_image = outputs.logits_per_image
      prob = logits_per_image.argmax().item()

      if (prob == label):
          correct += 1
      total += 1

      if (prob == target_class) and (label == target_class):
          TP += 1
      elif (prob == target_class) and (label != target_class):
          FP += 1
      elif (prob != target_class) and (label == target_class):
          FN += 1

accuracy = correct / total * 100
print(f"\nAccuracy = {accuracy:.2f}%\n")

Precision = (TP) / (TP + FP)
Recall = (TP) / (TP + FN)

F1 = 2 * (Precision * Recall) / (Precision + Recall)

print(f"Precision = {Precision:.2f}\n")
print(f"Recall = {Recall:.2f}\n")
print(f"F1 = {F1:.2f}\n")