In [1]:
input_model = "models/model_checkpoint_1a.pt"

In [2]:
import torch
import clip

device = "cuda:0" if torch.cuda.is_available() else "cpu"
base_model, preprocess = clip.load("ViT-B/32", device=device)

model = torch.load(input_model)

In [3]:
images_directory = 'dataset/product_images/test/'
text_file_location = 'dataset/product_titles_test.json'

In [4]:
import json

with open(text_file_location, 'r') as text_file:
    input_text = []
    for line in text_file:
      json_obj = json.loads(line)
      input_text.append(json_obj)

In [5]:
image_path_list = []
text_list = []

for item in input_text:
  image_path = images_directory + item["image_path"]
  image_path_list.append(image_path)

  title = item['product_title']
  text_list.append(title)
    
print(len(image_path_list))
print(len(text_list))

210
210


In [6]:
from PIL import Image

class image_title_dataset:
  def __init__(self, image_path_list, text_list):
    self.image_path = image_path_list
    self.title = text_list

  def __len__(self):
    return len(self.title)

  def __getitem__(self, idx):
    image = preprocess(Image.open(self.image_path[idx]))
    title = self.title[idx]
    return image, title

dataset = image_title_dataset(image_path_list, text_list)

In [7]:
brand_list = ["Ray-Ban", "Carrera", "Gucci", "Versace", "Prada", "Tommy Hilfiger", "Lacoste", "U.S. Polo Assn.", "DKNY", "Polo Ralph Lauren", "Nike", "Adidas", "Puma", "Calvin Klein", "Reebok", "Under Armour", "Brooks Brothers", "Haimont", "ASICS", "Saucony", "FitVille", "Brooks", "Skechers", "Red Tape", "Little Donkey Andy", "33,000ft", "Columbia", "Carhartt", "MAGCOMSEN", "The North Face", "Darn Tough", "VRD", "G Gradual", "Fila", "BROKIG", "Champion", "NORTHYARD", "Mizuno", "Hurley", "Timberland"]
brand_templates = [
    'the brand is {}.',
    'the manufacturer is {}.',
    'the item is made by {}.',
    'the product is manufactured by {}.'
]

category_list = ["Sunglasses", "T-Shirt", "Shoes", "Jacket", "Socks", "Track Pant", "Shorts", "Cap", "Bag", "Beanie"]
category_templates = [
    'the product is {}.',
    'the product is called {}.',
    'the item is identified as {}.',
    'the item is sold as {}.'
]

color_list = ["Black", "White", "Grey", "Brown", "Red", "Green", "Blue", "Orange", "Yellow", "Pink", "Violet", "Purple"]
color_templates = [
    'the color is {}.',
    'the item is {} in hue.',
    'the product is {} in color.',
    'the shade of the product is {}.'
]

In [8]:
from tqdm.notebook import tqdm

def zeroshot_weight_calculator(classification_list, templates):
    with torch.no_grad():
        weights = []
        for element in tqdm(classification_list):
            texts = [template.format(element) for template in templates]
            texts = clip.tokenize(texts)
            text_embeddings = model.encode_text(texts)
            text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
            text_embedding = text_embeddings.mean(dim=0)
            text_embedding /= text_embedding.norm()
            weights.append(text_embedding)
        nn_weights = torch.stack(weights, dim=1)
    return nn_weights

zeroshot_brand_weights = zeroshot_weight_calculator(brand_list, brand_templates)
zeroshot_category_weights = zeroshot_weight_calculator(category_list, category_templates)
zeroshot_color_weights = zeroshot_weight_calculator(color_list, color_templates)

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

In [9]:
print(zeroshot_brand_weights.shape)
print(zeroshot_category_weights.shape)
print(zeroshot_color_weights.shape)

torch.Size([512, 40])
torch.Size([512, 10])
torch.Size([512, 12])


In [10]:
def top_elemets(attribute, indexes):
    temp = []
    
    for i in range(len(indexes[0])):
        temp.append(attribute[indexes[0][i]])
    return temp

In [11]:
def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    target = torch.tensor(target)
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]

In [12]:
def zeroshot_weight_calculator_without_templ(title_batch):
    
    tokens = clip.tokenize(title_batch)
    class_embeddings = model.encode_text(tokens)
    class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)    
    nn_weights = class_embeddings.T
    
    return nn_weights

In [13]:
import numpy as np

with torch.no_grad():
    top1, top5, n = 0., 0., 0.
    
    for i, (images, target) in enumerate(tqdm(dataset)):
        
        image = []
        image.append(images)
        image = torch.tensor(np.stack(image))
        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        
        brand_logits = 100. * image_features @ zeroshot_brand_weights
        category_logits = 100. * image_features @ zeroshot_category_weights
        color_logits = 100. * image_features @ zeroshot_color_weights
        
        brand_probs = brand_logits.softmax(dim=-1)        
        category_probs = category_logits.softmax(dim=-1)        
        color_probs = color_logits.softmax(dim=-1)
                
        
        top_brand_probs, top_brand_indexes = brand_probs.cpu().topk(5, dim=-1)
        top_category_probs, top_category_indexes = category_probs.cpu().topk(2, dim=-1)
        top_color_probs, top_color_indexes = color_probs.cpu().topk(2, dim=-1)
                
        predicted_brands = top_elemets(brand_list, top_brand_indexes)
        predicted_categories = top_elemets(category_list, top_category_indexes)
        predicted_colors = top_elemets(color_list, top_color_indexes)
        
        predicted_titles = []
        l=j=k=0

        for l in range(len(predicted_brands)):
            for j in range(len(predicted_categories)):
                for k in range(len(predicted_colors)):
                    predicted_titles.append(predicted_brands[l] + " | " + predicted_categories[j] + " | " + predicted_colors[k])
        
        zeroshot_title_weights = zeroshot_weight_calculator_without_templ(predicted_titles)
        predicted_title_logits = 100. * image_features @ zeroshot_title_weights        
        
        try:
            target_index = predicted_titles.index(target)
            acc1, acc5 = accuracy(predicted_title_logits, target_index, topk=(1,5))
            top1 += acc1
            top5 += acc5
            
        except Exception as e:
            print(f"loop number {i+1}  ","**** ",e," ****")
        
        n += image.size(0)
        
print("**** value of n: ",n," ****")
        
top1 = (top1 / n) * 100
top5 = (top5 / n) * 100 

print(f"Top-1 accuracy: {top1:.2f}")
print(f"Top-5 accuracy: {top5:.2f}")

  0%|          | 0/210 [00:00<?, ?it/s]

loop number 1   ****  'Ray-Ban | Sunglasses | Blue' is not in list  ****
loop number 2   ****  'Ray-Ban | Sunglasses | Black' is not in list  ****
loop number 4   ****  'Carrera | Sunglasses | Blue' is not in list  ****
loop number 5   ****  'Gucci | Sunglasses | Grey' is not in list  ****
loop number 6   ****  'Gucci | Sunglasses | Brown' is not in list  ****
loop number 7   ****  'Versace | Sunglasses | Black' is not in list  ****
loop number 8   ****  'Versace | Sunglasses | Grey' is not in list  ****
loop number 9   ****  'Prada | Sunglasses | Black' is not in list  ****
loop number 10   ****  'Prada | Sunglasses | Brown' is not in list  ****
loop number 11   ****  'Tommy Hilfiger | Sunglasses | Grey' is not in list  ****
loop number 12   ****  'Tommy Hilfiger | Sunglasses | Blue' is not in list  ****
loop number 13   ****  'Lacoste | Sunglasses | Blue' is not in list  ****
loop number 14   ****  'Lacoste | Sunglasses | Black' is not in list  ****
loop number 15   ****  'U.S. Polo 

loop number 131   ****  'Puma | Shorts | Yellow' is not in list  ****
loop number 132   ****  'Puma | Shorts | Pink' is not in list  ****
loop number 133   ****  'Reebok | Shorts | Black' is not in list  ****
loop number 134   ****  'Reebok | Shorts | Red' is not in list  ****
loop number 135   ****  'Under Armour | Shorts | Green' is not in list  ****
loop number 136   ****  'Under Armour | Shorts | Grey' is not in list  ****
loop number 137   ****  'Haimont | Shorts | Orange' is not in list  ****
loop number 139   ****  'G Gradual | Shorts | Red' is not in list  ****
loop number 140   ****  'G Gradual | Shorts | White' is not in list  ****
loop number 144   ****  'NORTHYARD | Shorts | Yellow' is not in list  ****
loop number 145   ****  'Mizuno | Shorts | Red' is not in list  ****
loop number 146   ****  'Mizuno | Shorts | Blue' is not in list  ****
loop number 147   ****  'Mizuno | Shorts | Grey' is not in list  ****
loop number 150   ****  'Nike | Cap | Grey' is not in list  ****
l