In [10]:
from torchvision.datasets import Imagenette
import numpy as np
from transformers import CLIPProcessor, CLIPModel
import torch
import tqdm
import torchvision.models as models
from torchvision.models import ResNet50_Weights
from sklearn.metrics import accuracy_score

In [11]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

weights = ResNet50_Weights.DEFAULT
resnet_model = models.resnet50(weights=weights)
resnet_model.eval()
preprocess = weights.transforms()
imagenet_labels = weights.meta["categories"]

Loading weights:   0%|          | 0/398 [00:00<?, ?it/s]

CLIPModel LOAD REPORT from: openai/clip-vit-base-patch32
Key                                  | Status     |  | 
-------------------------------------+------------+--+-
text_model.embeddings.position_ids   | UNEXPECTED |  | 
vision_model.embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


In [14]:
dataset = Imagenette(root = './data', split = 'val', download = True)

class_names = ['tench', 'English springer', 'cassette player', 'chain saw',
               'church', 'French horn', 'garbage truck', 'gas pump',
               'golf ball', 'parachute']
my_indices = []

for class_name in class_names:
    for index, label in enumerate(imagenet_labels):
        if class_name.lower() in label_text.lower():
            my_indices.append(index)
            break

In [15]:
prompts = [f"a photo of {name}" for name in class_names]

In [26]:
y_pred_clip = []
y_pred_resnet = []
y_true_clip = []
y_true_resnet = []

text_inputs = processor(text = prompts, return_tensors = 'pt', padding = True)
text_features = model.get_text_features(**text_inputs).pooler_output
text_features = text_features / text_features.norm(dim = 1, keepdim = True)

for image, label in tqdm.tqdm(dataset, desc = 'Total'):
  image_inputs = processor(images = image, return_tensors = 'pt', padding = True)
  image_features = model.get_image_features(**image_inputs).pooler_output
  image_features = image_features / image_features.norm(dim = 1, keepdim = True)

  similarity = image_features @ text_features.T
  clip_pred = similarity.argmax(dim = 1)

  y_pred_clip.append(clip_pred)
  y_true_clip.append(label)

  batch = preprocess(image).unsqueeze(0)
  prediction = resnet_model(batch)
  class_id = prediction.argmax().item()

  if class_id in my_indices:
    resnet_pred = my_indices.index(class_id)
    y_pred_resnet.append(resnet_pred)
    y_true_resnet.append(label)

clip_accuracy = accuracy_score(y_true_clip, y_pred_clip)
resnet_accuracy = accuracy_score(y_true_resnet, y_pred_resnet)
print(f"\nResNet Accuracy = {100 * resnet_accuracy:.2f}%\n")
print(f"Clip Accuracy = {100 * clip_accuracy:.2f}%")

Total: 100%|██████████| 3925/3925 [20:44<00:00,  3.15it/s]


ResNet Accuracy = 100.00%

Clip Accuracy = 98.80%



