In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

from load import *
import torchmetrics
from tqdm import tqdm
import clip

seed_everything(hparams['seed'])

In [2]:
bs = hparams['batch_size']
dataset = ImageNet(hparams['data_dir'], split='val', transform=tfms)
# dataset = CUBDataset(hparams['data_dir'], train=False, transform=tfms)
# dataset = torchvision.datasets.OxfordIIITPet(root=hparams['data_dir'], transform=tfms, split='test')
# dataset = test_set # EuroSAT
# dataset = torchvision.datasets.Food101(root=hparams['data_dir'], transform=tfms, split='test')

dataloader = DataLoader(dataset, bs, shuffle=True, num_workers=16, pin_memory=True)

In [None]:
print("Loading model...")

device = torch.device(hparams['device'])
# load model
model, preprocess = clip.load(hparams['model_size'], device=device, jit=False)
checkpoint = torch.load("/path/to/your/model.pt")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model.requires_grad_(False)

In [None]:
print("Encoding descriptions...")

label_encodings = compute_label_encodings(model)

In [None]:
print("Evaluating...")
clip_accuracy_metric = torchmetrics.Accuracy().to(device)
clip_accuracy_metric_top5 = torchmetrics.Accuracy(top_k=5).to(device)

In [None]:
for batch_number, batch in enumerate(tqdm(dataloader)):
    images, labels = batch
    
    images = images.to(device)
    labels = labels.to(device)
    
    image_encodings = model.encode_image(images)
    image_encodings = F.normalize(image_encodings)
    
    image_labels_similarity = image_encodings @ label_encodings.T
    clip_predictions = image_labels_similarity.argmax(dim=1)
    
    clip_acc = clip_accuracy_metric(image_labels_similarity, labels)
    clip_acc_top5 = clip_accuracy_metric_top5(image_labels_similarity, labels)
      

print("\n")

accuracy_logs = {}
accuracy_logs["Total CLIP-Standard Top-1 Accuracy: "] = 100*clip_accuracy_metric.compute().item()
accuracy_logs["Total CLIP-Standard Top-5 Accuracy: "] = 100*clip_accuracy_metric_top5.compute().item()

# print the dictionary
print("\n")
for key, value in accuracy_logs.items():
    print(key, value)