In [1]:
from clip import clip
from clip import model as c_model
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from datasets import clip_coco_retrieval_train, clip_coco_retrieval_eval, flickr_dataset
device = 'cuda:1'
model, preprocess = clip.load("ViT-B/32",device=device,jit=False) #Must set jit=False for training

In [2]:
val_ann_root = '/ltstorage/home/2pan/dataset/COCO/coco_karpathy_val.json'
image_root = '/ltstorage/home/2pan/dataset/COCO/'
val_dataset = clip_coco_retrieval_eval(image_root, val_ann_root, preprocess)
val_dataloader = DataLoader(val_dataset,batch_size = 64, num_workers=4, shuffle=False)

In [45]:
import numpy as np
import torch
@torch.no_grad()
def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
    
    #Images->Text 
    ranks = np.zeros(scores_i2t.shape[0])
    for index,score in enumerate(scores_i2t):
        inds = np.argsort(score)[::-1]
        # Score
        rank = 1e20
        # print(inds)
        for i in img2txt[index]:
            # print(i)
            # print(np.where(inds == i))
            tmp = np.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
        ranks[index] = rank

    # Compute metrics
    tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
  
    #Text->Images 
    ranks = np.zeros(scores_t2i.shape[0])

    print("t2i shape",scores_t2i.shape)
    for index,score in enumerate(scores_t2i): 
        print("index is",index)
        inds = np.argsort(score)[::-1]
        if index == 320:
            print("inds shape",inds)
            print("score is ",score.shape)
        print("txt2img ",txt2img[index]) #64
        print("where",np.where(inds == txt2img[index])) # txt2img能找到64号下标，但是inds只到63
        ranks[index] = np.where(inds == txt2img[index])[0][0]

    # Compute metrics
    ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)        

    tr_mean = (tr1 + tr5 + tr10) / 3
    ir_mean = (ir1 + ir5 + ir10) / 3
    r_mean = (tr_mean + ir_mean) / 2

    eval_result =  {'txt_r1': tr1,
                    'txt_r5': tr5,
                    'txt_r10': tr10,
                    'txt_r_mean': tr_mean,
                    'img_r1': ir1,
                    'img_r5': ir5,
                    'img_r10': ir10,
                    'img_r_mean': ir_mean,
                    'r_mean': r_mean}
    return eval_result


In [46]:
import torch
@torch.no_grad()
def small_batch_eval(model, data_loader, device):
    img2txt = data_loader.dataset.img2txt
    txt2img = data_loader.dataset.txt2img
    header = 'Evaluation loss'
    print_freq = 50
    step=0
    eval_result =  {'txt_r1': 0,
                    'txt_r5': 0,
                    'txt_r10': 0,
                    'txt_r_mean': 0,
                    'img_r1': 0,
                    'img_r5': 0,
                    'img_r10': 0,
                    'img_r_mean': 0,
                    'r_mean': 0}
    every_batch_num = []
    for i,(images, index) in enumerate(data_loader):
        step+=1
        captions = []
        for idx in index:
            text_ids = img2txt[idx.item()]
            caption = [data_loader.dataset.text_feat[i] for i in text_ids]
            captions.extend(caption)
            
        every_batch_num.append(len(captions))
        
        text_embeds = torch.cat(captions,dim=0).to(device)
        image_embeds = images.to(device)
        logits_per_image, logits_per_text = model(image_embeds, text_embeds)
        # return logits_per_image.cpu().numpy(), logits_per_text.cpu().numpy()
        i2t = logits_per_image.cpu().numpy()
        t2i = logits_per_text.cpu().numpy()
        # print(t2i.shape)
        # print(t2i[320,63])
        result = itm_eval(i2t, t2i, txt2img, img2txt)
        for key, value in result.items():
            eval_result[key] += value

    # for key, value in eval_result.items():
    #     if key.startswith('txt'):
    #         eval_result[key] = value/ len(img2txt.keys())
    #     if key.startswith('img'):
    #         eval_result[key] = value/ len(txt2img.keys())
    # eval_result['r_mean'] = (eval_result['txt_r_mean'] + eval_result['img_r_mean']) /2
    # print(eval_result)
    print(every_batch_num)
    
small_batch_eval(model, val_dataloader, device)
# i2t, t2i=small_batch_eval(model, val_dataloader, device)
# result = itm_eval(i2t, t2i, val_dataloader.dataset.txt2img, val_dataloader.dataset.img2txt)
# print(result)
len(val_dataloader.dataset.txt2img.keys())
len(val_dataloader.dataset.img2txt.keys())

t2i shape (320, 64)
index is 0
txt2img  0
where (array([0]),)
index is 1
txt2img  0
where (array([0]),)
index is 2
txt2img  0
where (array([0]),)
index is 3
txt2img  0
where (array([0]),)
index is 4
txt2img  0
where (array([0]),)
index is 5
txt2img  1
where (array([1]),)
index is 6
txt2img  1
where (array([1]),)
index is 7
txt2img  1
where (array([0]),)
index is 8
txt2img  1
where (array([0]),)
index is 9
txt2img  1
where (array([0]),)
index is 10
txt2img  2
where (array([0]),)
index is 11
txt2img  2
where (array([0]),)
index is 12
txt2img  2
where (array([0]),)
index is 13
txt2img  2
where (array([0]),)
index is 14
txt2img  2
where (array([0]),)
index is 15
txt2img  3
where (array([1]),)
index is 16
txt2img  3
where (array([0]),)
index is 17
txt2img  3
where (array([0]),)
index is 18
txt2img  3
where (array([0]),)
index is 19
txt2img  3
where (array([0]),)
index is 20
txt2img  4
where (array([1]),)
index is 21
txt2img  4
where (array([2]),)
index is 22
txt2img  4
where (array([0]),)
i

IndexError: index 0 is out of bounds for axis 0 with size 0