In [None]:
!pip install transformers

import numpy as np
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModel
from data_utils import create_dataset, create_loader

from torch import Tensor

In [None]:
def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

@torch.no_grad()
def get_feats(model, tokenizer, data_loader, device, desc='Get feats'):
    embeds = []

    for text in tqdm(data_loader, total=len(data_loader), desc=desc):
        text_input = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt').to(device)
        outputs = model(**text_input)
        embed = average_pool(outputs.last_hidden_state, text_input['attention_mask'])

        embeds.append(embed)

    embeds = torch.cat(embeds, dim=0)

    return embeds


@torch.no_grad()
def contrast_evaluation(text_embeds, code_embeds, img2txt):
    score_matrix_i2t = text_embeds @ code_embeds.t()
    scores_i2t = score_matrix_i2t.cpu().numpy()


    ranks = np.ones(scores_i2t.shape[0]) * -1
    for index, score in enumerate(scores_i2t):
        inds = np.argsort(score)[::-1]
        ranks[index] = np.where(inds == img2txt[index])[0][0]

    # Compute metrics
    tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
    mrr = 100.0 * np.mean(1 / (ranks + 1))

    eval_result = {'r1': tr1,
                   'r5': tr5,
                   'r10': tr10,
                   'mrr': mrr}
    return eval_result

In [None]:
print("\nCreating retrieval dataset")
#change language and path to dataset here
_, _, test_dataset, code_dataset = create_dataset('dataset/CSN', 'ruby')

test_loader, code_loader = create_loader([test_dataset, code_dataset], [None, None],
                                             batch_size=[256, 256],
                                             num_workers=[4, 4], is_trains=[False, False], collate_fns=[None, None])

tokenizer = AutoTokenizer.from_pretrained("thenlper/gte-base")
model = AutoModel.from_pretrained("thenlper/gte-base")

print('\nStart zero-shot evaluation...')
device = torch.device('cuda')
model = model.to(device)
model.eval()

text_embeds = get_feats(model, tokenizer, test_loader, device, desc='Get text feats')
code_embeds = get_feats(model, tokenizer, code_loader, device, desc='Get code feats')
test_result = contrast_evaluation(text_embeds, code_embeds, test_loader.dataset.text2code)
print(f'\n====> zero-shot test result: ', test_result)