In [None]:
from torchvision.datasets import Imagenette
import numpy as np
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPModel
import torch
from sklearn.metrics import accuracy_score, f1_score
import tqdm

In [None]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to("cuda")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

In [None]:
dataset = Imagenette(root = './data', split = 'val', download = True)
class_names = []

for names in dataset.classes:
  class_names.append(names[0])

print(class_names)

['tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']

In [None]:
prompt_templates = [
    "a photo of the {}",
    "a bad photo of a {}",
    "a low resolution photo of the {}",
    "a photo of a large {}.",
    "a photo of a small {}.",
    "a photo of the {} object.",
    "a photo of the {} item.",
    "a photo of my {}.",
    "this is a photo of a {}.",
    "there is a {} on the photo.",
    "i see a {}.",
]
all_prompts = []

for name in class_names:
  class_prompts = [template.format(name) for template in prompt_templates]
  all_prompts.extend(class_prompts)


print(all_prompts)

['a photo of the tench', 'a bad photo of a tench', 'a low resolution photo of the tench', 'a photo of a large tench.', 'a photo of a small tench.', 'a photo of the tench object.', 'a photo of the tench item.', 'a photo of my tench.', 'this is a photo of a tench.', 'there is a tench on the photo.', 'i see a tench.', 'a photo of the English springer', 'a bad photo of a English springer', 'a low resolution photo of the English springer', 'a photo of a large English springer.', 'a photo of a small English springer.', 'a photo of the English springer object.', 'a photo of the English springer item.', 'a photo of my English springer.', 'this is a photo of a English springer.', 'there is a English springer on the photo.', 'i see a English springer.', 'a photo of the cassette player', 'a bad photo of a cassette player', 'a low resolution photo of the cassette player', 'a photo of a large cassette player.', 'a photo of a small cassette player.', 'a photo of the cassette player object.', 'a photo of the cassette player item.', 'a photo of my cassette player.', 'this is a photo of a cassette player.', 'there is a cassette player on the photo.', 'i see a cassette player.', 'a photo of the chain saw', 'a bad photo of a chain saw', 'a low resolution photo of the chain saw', 'a photo of a large chain saw.', 'a photo of a small chain saw.', 'a photo of the chain saw object.', 'a photo of the chain saw item.', 'a photo of my chain saw.', 'this is a photo of a chain saw.', 'there is a chain saw on the photo.', 'i see a chain saw.', 'a photo of the church', 'a bad photo of a church', 'a low resolution photo of the church', 'a photo of a large church.', 'a photo of a small church.', 'a photo of the church object.', 'a photo of the church item.', 'a photo of my church.', 'this is a photo of a church.', 'there is a church on the photo.', 'i see a church.', 'a photo of the French horn', 'a bad photo of a French horn', 'a low resolution photo of the French horn', 'a photo of a large French horn.', 'a photo of a small French horn.', 'a photo of the French horn object.', 'a photo of the French horn item.', 'a photo of my French horn.', 'this is a photo of a French horn.', 'there is a French horn on the photo.', 'i see a French horn.', 'a photo of the garbage truck', 'a bad photo of a garbage truck', 'a low resolution photo of the garbage truck', 'a photo of a large garbage truck.', 'a photo of a small garbage truck.', 'a photo of the garbage truck object.', 'a photo of the garbage truck item.', 'a photo of my garbage truck.', 'this is a photo of a garbage truck.', 'there is a garbage truck on the photo.', 'i see a garbage truck.', 'a photo of the gas pump', 'a bad photo of a gas pump', 'a low resolution photo of the gas pump', 'a photo of a large gas pump.', 'a photo of a small gas pump.', 'a photo of the gas pump object.', 'a photo of the gas pump item.', 'a photo of my gas pump.', 'this is a photo of a gas pump.', 'there is a gas pump on the photo.', 'i see a gas pump.', 'a photo of the golf ball', 'a bad photo of a golf ball', 'a low resolution photo of the golf ball', 'a photo of a large golf ball.', 'a photo of a small golf ball.', 'a photo of the golf ball object.', 'a photo of the golf ball item.', 'a photo of my golf ball.', 'this is a photo of a golf ball.', 'there is a golf ball on the photo.', 'i see a golf ball.', 'a photo of the parachute', 'a bad photo of a parachute', 'a low resolution photo of the parachute', 'a photo of a large parachute.', 'a photo of a small parachute.', 'a photo of the parachute object.', 'a photo of the parachute item.', 'a photo of my parachute.', 'this is a photo of a parachute.', 'there is a parachute on the photo.', 'i see a parachute.']

In [None]:
with torch.no_grad():
  text_inputs = processor(text = all_prompts, return_tensors = 'pt', padding = True).to("cuda")
  text_features_all = model.get_text_features(**text_inputs).pooler_output
  text_features_all = text_features_all / text_features_all.norm(dim = 1, keepdim = True)

In [None]:
num_classes = len(class_names)
num_templates = len(prompt_templates)
text_features = torch.zeros(num_classes, text_features_all.shape[1])

for i in range(num_classes):
  start_idx = i * num_templates
  end_idx = (i + 1) * num_templates
  class_features = text_features_all[start_idx:end_idx]
  text_features[i] = class_features.mean(dim=0)

text_features = text_features / text_features.norm(dim = 1, keepdim = True)
text_features = text_features.to("cuda")

In [None]:
y_true = []
y_pred = []

for image, label in tqdm.tqdm(dataset, desc = 'Total'):
  with torch.no_grad():
    image_inputs = processor(images = image, return_tensors = 'pt', padding = True).to("cuda")
    image_features = model.get_image_features(**image_inputs).pooler_output
    image_features = image_features / image_features.norm(dim = 1, keepdim = True)

  similarity = image_features @ text_features.T
  pred = similarity.argmax(dim = 1)

  y_pred.append(pred.cpu().item())
  y_true.append(label)

accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average = 'micro')

print(f"\nAccuracy = {100 * accuracy:.2f}")
print(f"F1 = {f1:.2f}")

Total: 100%|██████████| 3925/3925 [01:31<00:00, 42.96it/s]

Accuracy = 99.57

F1 = 1.00