In [None]:
!pip install openai[embeddings]==0.27.7

import numpy as np
from tqdm import tqdm
from data_utils import create_dataset, create_loader

import openai
#enter openai api key
openai.api_key = ("")

from openai.embeddings_utils import get_embedding

Defaulting to user installation because normal site-packages is not writeable


In [6]:
def get_feats(data_loader, desc='Get feats'):
    embeds = []
    max_len = 0
    max_token_length = 29120

    for text in tqdm(data_loader, total=len(data_loader), desc=desc):

        for txt in text:
            if (len(txt) > max_len):
                max_len = len(txt)
            txt = txt[0:max_token_length]
            embed = get_embedding(txt, engine='text-embedding-ada-002')
            embeds.append(embed)

    return embeds


def contrast_evaluation(text_embeds, code_embeds, img2txt):
    text_embeds = np.array(text_embeds)
    code_embeds = np.array(code_embeds)
    score_matrix_i2t = text_embeds @ code_embeds.T
    scores_i2t = score_matrix_i2t

    ranks = np.ones(scores_i2t.shape[0]) * -1

    for index, score in enumerate(scores_i2t):
        inds = np.argsort(score)[::-1]
        ranks[index] = np.where(inds == img2txt[index])[0][0]


    # Compute metrics
    tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
    mrr = 100.0 * np.mean(1 / (ranks + 1))


    eval_result = {'r1': tr1,
                   'r5': tr5,
                   'r10': tr10,
                   'mrr': mrr}
    return eval_result

In [None]:
print("\nCreating retrieval dataset")
#change language and path to dataset here
_, _, test_dataset, code_dataset = create_dataset('dataset/CSN', 'ruby')

test_loader, code_loader = create_loader([test_dataset, code_dataset], [None, None],
                                         batch_size=[128, 128],
                                         num_workers=[4, 4], is_trains=[False, False], collate_fns=[None, None])


print('\nStart zero-shot evaluation...')

text_embeds = get_feats(test_loader, desc='Get text feats')
code_embeds = get_feats(code_loader, desc='Get code feats')
test_result = contrast_evaluation(text_embeds, code_embeds, test_loader.dataset.text2code)

print(f'\n====> zero-shot test result: ', test_result)


Creating retrieval dataset
Read 24927 data from ./dataset/CSN/ruby/train.jsonl
Read 1400 data from ./dataset/CSN/ruby/valid.jsonl
Read 4360 data from ./dataset/CSN/ruby/codebase.jsonl
Read 1261 data from ./dataset/CSN/ruby/test.jsonl
Read 4360 data from ./dataset/CSN/ruby/codebase.jsonl
Read 4360 data from ./dataset/CSN/ruby/codebase.jsonl

Start zero-shot evaluation...


Get text feats: 100%|██████████████████████████████████████████████████████████████████| 10/10 [04:44<00:00, 28.48s/it]
Get code feats:   3%|█▉                                                                 | 1/35 [00:47<26:39, 47.04s/it]