### Initialize clip

In [1]:
import evaluation

model = evaluation.model
preprocess = evaluation.preprocess

### Get images and titles for evaluation

In [2]:
images_path = evaluation.images_path
texts_file = evaluation.text_file_titles

texts_list = evaluation.read_text(texts_file)
images, titles = evaluation.get_image_title(texts_list)

print(len(images), "|", len(titles))

210 | 210


### Create dataset

In [3]:
dataset = evaluation.image_title_dataset(images, titles)

### brand weight

In [5]:
brands = evaluation.brands
brands_template = evaluation.brands_template

categories = evaluation.categories
categories_template = evaluation.categories_template

colors = evaluation.colors
colors_template = evaluation.colors_template

zeroshot_brands_weights = evaluation.zeroshot_weight_calculator_tmpl(brands, brands_template)
zeroshot_categories_weights = evaluation.zeroshot_weight_calculator_tmpl(categories, categories_template)
zeroshot_colors_weights = evaluation.zeroshot_weight_calculator_tmpl(colors, colors_template)


print(zeroshot_brands_weights.shape)
print(zeroshot_categories_weights.shape)
print(zeroshot_colors_weights.shape)

torch.Size([512, 40])
torch.Size([512, 10])
torch.Size([512, 12])


### calculate accuracy

In [7]:
import numpy as np
import torch
from tqdm.notebook import tqdm

with torch.no_grad():
    top1, top5, n = 0., 0., 0.
    
    for i, (img, target) in enumerate(tqdm(dataset)):        
        
        image = []
        image.append(img)
        image = torch.tensor(np.stack(image))
        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        
        brand_logits = 100. * image_features @ zeroshot_brands_weights
        category_logits = 100. * image_features @ zeroshot_categories_weights
        color_logits = 100. * image_features @ zeroshot_colors_weights
        
        brand_probs = brand_logits.softmax(dim=-1)        
        category_probs = category_logits.softmax(dim=-1)        
        color_probs = color_logits.softmax(dim=-1)
                
        
        top_brand_probs, top_brand_indexes = brand_probs.cpu().topk(5, dim=-1)
        top_category_probs, top_category_indexes = category_probs.cpu().topk(2, dim=-1)
        top_color_probs, top_color_indexes = color_probs.cpu().topk(2, dim=-1)
                
        predicted_brands = evaluation.top_elemets(brands, top_brand_indexes)
        predicted_categories = evaluation.top_elemets(categories, top_category_indexes)
        predicted_colors = evaluation.top_elemets(colors, top_color_indexes)
        
        predicted_titles = []
        l=j=k=0

        for l in range(len(predicted_brands)):
            for j in range(len(predicted_categories)):
                for k in range(len(predicted_colors)):
                    predicted_titles.append(predicted_brands[l] + " | " + predicted_categories[j] + " | " + predicted_colors[k])
        
        zeroshot_title_weights = evaluation.zeroshot_weight_calculator(predicted_titles)        
        predicted_title_logits = 100. * image_features @ zeroshot_title_weights        
        
        try:
            target_index = predicted_titles.index(target)
            acc1, acc5 = evaluation.accuracy(predicted_title_logits, target_index, topk=(1,4))
            top1 += acc1
            top5 += acc5
            
        except Exception as e:
            print(f"loop number {i+1}  ","**** ",e," ****")
        
        n += image.size(0)
        
print("**** value of n: ",n," ****")    
        
top1 = (top1 / n) * 100
top5 = (top5 / n) * 100 

print(f"Top-1 accuracy: {top1:.2f}")
print(f"Top-5 accuracy: {top5:.2f}")

  0%|          | 0/210 [00:00<?, ?it/s]

loop number 6   ****  'Gucci | Sunglasses | Violet' is not in list  ****
loop number 7   ****  'Versace | Sunglasses | Brown' is not in list  ****
loop number 8   ****  'Versace | Sunglasses | Green' is not in list  ****
loop number 11   ****  'Tommy Hilfiger | Sunglasses | Grey' is not in list  ****
loop number 14   ****  'Lacoste | Sunglasses | Blue' is not in list  ****
loop number 15   ****  'U.S. Polo Assn. | Sunglasses | Brown' is not in list  ****
loop number 17   ****  'DKNY | Sunglasses | Blue' is not in list  ****
loop number 18   ****  'DKNY | Sunglasses | Green' is not in list  ****
loop number 21   ****  'Polo Ralph Lauren | Sunglasses | Brown' is not in list  ****
loop number 30   ****  'Puma | T-Shirt | Grey' is not in list  ****
loop number 33   ****  'Calvin Klein | T-Shirt | Grey' is not in list  ****
loop number 35   ****  'Reebok | T-Shirt | Orange' is not in list  ****
loop number 38   ****  'Brooks Brothers | T-Shirt | Black' is not in list  ****
loop number 40   