In [1]:
input_model = "models/model_checkpoint_1b.pt"

In [2]:
import torch
import clip

device = "cuda:0" if torch.cuda.is_available() else "cpu"
base_model, preprocess = clip.load("ViT-B/32", device=device)

model = torch.load(input_model)

In [3]:
images_directory = 'dataset/product_images/test/'
text_file_location = 'dataset/product_brands_test.json'

In [4]:
import json

with open(text_file_location, 'r') as text_file:
    input_text = []
    for line in text_file:
      json_obj = json.loads(line)
      input_text.append(json_obj)

In [5]:
image_path_list = []
text_list = []

for item in input_text:
  image_path = images_directory + item["image_path"]
  image_path_list.append(image_path)

  title = item['product_title']
  text_list.append(title)

In [6]:
print(len(image_path_list))
print(len(text_list))

210
210


In [7]:
from PIL import Image

class image_title_dataset:
  def __init__(self, image_path_list, text_list):
    self.image_path = image_path_list
    self.title = text_list

  def __len__(self):
    return len(self.title)

  def __getitem__(self, idx):
    image = preprocess(Image.open(self.image_path[idx]))
    title = self.title[idx]
    return image, title

dataset = image_title_dataset(image_path_list, text_list)

In [8]:
brands = ["Ray-Ban", "Carrera", "Gucci", "Versace", "Prada", "Tommy Hilfiger", "Lacoste", "U.S. Polo Assn.", "DKNY", "Polo Ralph Lauren", "Nike", "Adidas", "Puma", "Calvin Klein", "Reebok", "Under Armour", "Brooks Brothers", "Haimont", "ASICS", "Saucony", "FitVille", "Brooks", "Skechers", "Red Tape", "Little Donkey Andy", "33,000ft", "Columbia", "Carhartt", "MAGCOMSEN", "The North Face", "Darn Tough", "VRD", "G Gradual", "Fila", "BROKIG", "Champion", "NORTHYARD", "Mizuno", "Hurley", "Timberland"]

templates = [
    'the brand is {}.',
    'the manufacturer is {}.',
    'the item is made by {}.',
    'the product is manufactured by {}.'
]

In [9]:
from tqdm.notebook import tqdm

def zeroshot_weight_calculator(classification_list, templates):
    with torch.no_grad():
        weights = []
        for element in tqdm(classification_list):
            texts = [template.format(element) for template in templates]
            texts = clip.tokenize(texts)
            text_embeddings = model.encode_text(texts)
            text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
            text_embedding = text_embeddings.mean(dim=0)
            text_embedding /= text_embedding.norm()
            weights.append(text_embedding)
        nn_weights = torch.stack(weights, dim=1)
    return nn_weights

zeroshot_brand_weights = zeroshot_weight_calculator(brands, templates)

  0%|          | 0/40 [00:00<?, ?it/s]

In [10]:
print(zeroshot_brand_weights.shape)

torch.Size([512, 40])


In [11]:
def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    target = torch.tensor(target)    
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]

In [14]:
import numpy as np

with torch.no_grad():
    top1, top5, n = 0., 0., 0.
    
    for i, (images, target) in enumerate(tqdm(dataset)):
        
        image = []
        image.append(images)
        image = torch.tensor(np.stack(image))
        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        logits = 100. * image_features @ zeroshot_brand_weights
        target_index = brands.index(target)
        
        acc1, acc5 = accuracy(logits, target_index, topk=(1,5))
        top1 += acc1
        top5 += acc5        
        n += image.size(0)        
        
top1 = (top1 / n) * 100
top5 = (top5 / n) * 100 

print(f"Top-1 accuracy: {top1:.2f}")
print(f"Top-5 accuracy: {top5:.2f}")

  0%|          | 0/210 [00:00<?, ?it/s]

Top-1 accuracy: 22.38
Top-5 accuracy: 51.90
