In [None]:
!pip install -U FlagEmbedding

import numpy as np
from tqdm import tqdm
from data_utils import create_dataset, create_loader

from FlagEmbedding import FlagModel

In [None]:
def get_feats(model, data_loader, desc='Get feats'):
    embeds = []

    for text in tqdm(data_loader, total=len(data_loader), desc=desc):
        embed = []
        if (desc == 'Get text feats'):
            embed = model.encode_queries(text)
        elif (desc == 'Get code feats'):
            embed = model.encode(text)
        
        for e in embed:
            embeds.append(e)

    return embeds


def contrast_evaluation(text_embeds, code_embeds, img2txt):
    text_embeds = np.array(text_embeds)
    code_embeds = np.array(code_embeds)
    score_matrix_i2t = text_embeds @ code_embeds.T
    scores_i2t = score_matrix_i2t


    ranks = np.ones(scores_i2t.shape[0]) * -1
    for index, score in enumerate(scores_i2t):
        inds = np.argsort(score)[::-1]
        ranks[index] = np.where(inds == img2txt[index])[0][0]

    # Compute metrics
    tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
    mrr = 100.0 * np.mean(1 / (ranks + 1))

    eval_result = {'r1': tr1,
                   'r5': tr5,
                   'r10': tr10,
                   'mrr': mrr}
    return eval_result

In [None]:
print("\nCreating retrieval dataset")
#change language and path to dataset here
_, _, test_dataset, code_dataset = create_dataset('dataset/CSN', 'ruby')

test_loader, code_loader = create_loader([test_dataset, code_dataset], [None, None],
                                             batch_size=[256, 256],
                                             num_workers=[4, 4], is_trains=[False, False], collate_fns=[None, None])

model = FlagModel('BAAI/bge-large-en-v1.5', 
                  query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章：",
                  use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation

print('\nStart zero-shot evaluation...')

text_embeds = get_feats(model, test_loader, desc='Get text feats')
code_embeds = get_feats(model, code_loader, desc='Get code feats')
test_result = contrast_evaluation(text_embeds, code_embeds, test_loader.dataset.text2code)
print(f'\n====> zero-shot test result: ', test_result)