In [None]:
import clip
import json
import numpy as np
import torch
from torchvision import datasets

: 

In [2]:
model, preprocess = clip.load("ViT-B/32", device="cpu")

In [4]:
dataset = datasets.ImageFolder(
    root='dataset/mini-ImageNet',
    transform=preprocess,
)

with open('dataset/mini-ImageNet/imagenet_class_index.json') as f:
    class_index = json.load(f)

folder_to_class = {value[0]: value[1] for key, value in class_index.items()}
all_classes = [folder_to_class[folder] for folder in dataset.classes]
all_classes

['house_finch',
 'robin',
 'triceratops',
 'green_mamba',
 'harvestman',
 'toucan',
 'goose',
 'jellyfish',
 'nematode',
 'king_crab',
 'dugong',
 'Walker_hound',
 'Ibizan_hound',
 'Saluki',
 'golden_retriever',
 'Gordon_setter',
 'komondor',
 'boxer',
 'Tibetan_mastiff',
 'French_bulldog',
 'malamute',
 'dalmatian',
 'Newfoundland',
 'miniature_poodle',
 'white_wolf',
 'African_hunting_dog',
 'Arctic_fox',
 'lion',
 'meerkat',
 'ladybug',
 'rhinoceros_beetle',
 'ant',
 'black-footed_ferret',
 'three-toed_sloth',
 'rock_beauty',
 'aircraft_carrier',
 'ashcan',
 'barrel',
 'beer_bottle',
 'bookshop',
 'cannon',
 'carousel',
 'carton',
 'catamaran',
 'chime',
 'clog',
 'cocktail_shaker',
 'combination_lock',
 'crate',
 'cuirass',
 'dishrag',
 'dome',
 'electric_guitar',
 'file',
 'fire_screen',
 'frying_pan',
 'garbage_truck',
 'hair_slide',
 'holster',
 'horizontal_bar',
 'hourglass',
 'iPod',
 'lipstick',
 'miniskirt',
 'missile',
 'mixing_bowl',
 'oboe',
 'organ',
 'parallel_bars',
 '

In [5]:
adjusted_labels = []
for label in all_classes:
    adjusted_labels.append(f'the image of a {label}')

token_text = clip.tokenize(adjusted_labels).to("cpu")

In [6]:
true_labels = []
pred_labels = []
checkpoint_interval = 100

with torch.no_grad():
    text_features = model.encode_text(token_text)

    for i in range(0, 10000):
        image, label = dataset[i]
        image = image.unsqueeze(0)
        image = image.to("cpu")
        
        with torch.no_grad():
            image_features = model.encode_image(image)

        similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
        prediction = int(similarity.argmax(dim=-1).cpu().numpy()[0])

        # print(f"Label: {label} \nPrediction: {prediction}")
        true_labels.append(label)
        pred_labels.append(prediction)

        print(f'{i} | True: {all_classes[label]} | Predicted: {all_classes[prediction]}')

        if i % checkpoint_interval == 0:
            # Save the true and predicted labels
            np.save(f'results/true_labels.npy', true_labels)
            np.save(f'results/pred_labels.npy', pred_labels)

0 | True: house_finch | Predicted: house_finch
1 | True: house_finch | Predicted: house_finch
2 | True: house_finch | Predicted: house_finch
3 | True: house_finch | Predicted: house_finch
4 | True: house_finch | Predicted: house_finch
5 | True: house_finch | Predicted: house_finch
6 | True: house_finch | Predicted: house_finch
7 | True: house_finch | Predicted: house_finch
8 | True: house_finch | Predicted: house_finch
9 | True: house_finch | Predicted: yawl
10 | True: house_finch | Predicted: house_finch
11 | True: house_finch | Predicted: house_finch
12 | True: house_finch | Predicted: yawl
13 | True: house_finch | Predicted: house_finch
14 | True: house_finch | Predicted: house_finch
15 | True: house_finch | Predicted: house_finch
16 | True: house_finch | Predicted: house_finch
17 | True: house_finch | Predicted: house_finch
18 | True: house_finch | Predicted: house_finch
19 | True: house_finch | Predicted: house_finch
20 | True: house_finch | Predicted: house_finch
21 | True: house

In [8]:
np.save(f'results/true_labels.npy', true_labels)
np.save(f'results/pred_labels.npy', pred_labels)