In [None]:
from torchvision.datasets import Imagenette
import numpy as np
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPModel
import torch
from sklearn.metrics import accuracy_score, f1_score
import tqdm

In [None]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
dataset = Imagenette(root = './data', split = 'val', download = True)
class_names = []

for names in dataset.classes:
  class_names.append(names[0])

print(class_names)

In [None]:
prompt_templates = [
    "a photo of the {}",
    "a bad photo of a {}",
    "a low resolution photo of the {}",
    "a photo of a large {}.",
    "a photo of a small {}.",
    "a photo of the {} object.",
    "a photo of the {} item.",
    "a photo of my {}.",
    "this is a photo of a {}.",
    "there is a {} on the photo.",
    "i see a {}.",
]
all_prompts = []

for name in class_names:
  class_prompts = [template.format(name) for template in prompt_templates]
  all_prompts.extend(class_prompts)


print(all_prompts)

In [None]:
with torch.no_grad():
  text_inputs = processor(text = all_prompts, return_tensors = 'pt', padding = True).to("cuda")
  text_features_all = model.get_text_features(**text_inputs).pooler_output
  text_features_all = text_features_all / text_features_all.norm(dim = 1, keepdim = True)

In [None]:
num_classes = len(class_names)
num_templates = len(prompt_templates)
text_features = torch.zeros(num_classes, text_features_all.shape[1])

for i in range(num_classes):
  start_idx = i * num_templates
  end_idx = (i + 1) * num_templates
  class_features = text_features_all[start_idx:end_idx]
  text_features[i] = class_features.mean(dim=0)

text_features = text_features / text_features.norm(dim = 1, keepdim = True)
text_features = text_features.to("cuda")

In [None]:
y_true = []
y_pred = []

for image, label in tqdm.tqdm(dataset, desc = 'Total'):
  with torch.no_grad():
    image_inputs = processor(images = image, return_tensors = 'pt', padding = True).to("cuda")
    image_features = model.get_image_features(**image_inputs).pooler_output
    image_features = image_features / image_features.norm(dim = 1, keepdim = True)

  similarity = image_features @ text_features.T
  pred = similarity.argmax(dim = 1)

  y_pred.append(pred.cpu().item())
  y_true.append(label)

accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average = 'micro')

print(f"\nAccuracy = {100 * accuracy:.2f}")
print(f"F1 = {f1:.2f}")