In [None]:
from torchvision.datasets import Imagenette
import numpy as np
from transformers import CLIPProcessor, CLIPModel
import torch
import tqdm
import torchvision.models as models
from torchvision.models import ResNet50_Weights
from sklearn.metrics import accuracy_score

In [None]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to("cuda")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

weights = ResNet50_Weights.DEFAULT
resnet_model = models.resnet50(weights=weights).to("cuda")
resnet_model.eval()
preprocess = weights.transforms()
imagenet_labels = weights.meta["categories"]

In [None]:
dataset = Imagenette(root = './data', split = 'val', download = True)
class_names = []

for names in dataset.classes:
  class_names.append(names[0])

my_indices = []

for class_name in class_names:
    for index, label_text in enumerate(imagenet_labels):
        if class_name.lower() in label_text.lower():
            my_indices.append(index)
            break

In [None]:
prompts = [f"a photo of {name}" for name in class_names]

In [None]:
y_pred_clip = []
y_pred_resnet = []
y_true_clip = []
y_true_resnet = []

with torch.no_grad():
  text_inputs = processor(text = prompts, return_tensors = 'pt', padding = True).to("cuda")
  text_features = model.get_text_features(**text_inputs).pooler_output
  text_features = text_features / text_features.norm(dim = 1, keepdim = True)

for image, label in tqdm.tqdm(dataset, desc = 'Total'):
  with torch.no_grad():
    image_inputs = processor(images = image, return_tensors = 'pt', padding = True).to("cuda")
    image_features = model.get_image_features(**image_inputs).pooler_output
    image_features = image_features / image_features.norm(dim = 1, keepdim = True)

  logit_scale = model.logit_scale.exp().item()
  temperature = 1.0 / logit_scale

  similarity = (image_features @ text_features.T) * temperature
  clip_pred = similarity.argmax(dim = 1)

  y_pred_clip.append(clip_pred.cpu().item())
  y_true_clip.append(label)

  batch = preprocess(image).unsqueeze(0).to("cuda")

  with torch.no_grad():
    prediction = resnet_model(batch)
    class_id = prediction.argmax().item()

  if class_id in my_indices:
    resnet_pred = my_indices.index(class_id)
    y_pred_resnet.append(resnet_pred)
    y_true_resnet.append(label)

clip_accuracy = accuracy_score(y_true_clip, y_pred_clip)
resnet_accuracy = accuracy_score(y_true_resnet, y_pred_resnet)
print(f"\nResNet Accuracy = {100 * resnet_accuracy:.2f}%\n")
print(f"Clip Accuracy = {100 * clip_accuracy:.2f}%")

Total: 100%|██████████| 3925/3925 [02:29<00:00, 26.32it/s]

ResNet Accuracy = 99.97%

Clip Accuracy = 99.26%