In [1]:
import os
import csv
import pathlib
import json
import gzip
import logging
import pickle
import time
from typing import List, Tuple, Dict, Iterator

import numpy as np
import torch
from torch import Tensor as T
from torch import nn

from dpr.data.qa_validation import calculate_matches_by_id
from dpr.models import init_biencoder_components
from dpr.options import (
    add_encoder_params, 
    setup_args_gpu, 
    print_args, 
    set_encoder_params_from_state, 
    add_tokenizer_params, 
    add_cuda_params
)
from dpr.utils.data_utils import Tensorizer
from dpr.utils.model_utils import (
    setup_for_distributed_mode, 
    get_model_obj, 
    load_states_from_checkpoint, 
    move_to_device
)
from sklearn.metrics.pairwise import cosine_similarity
import nltk 
from tqdm.notebook import tqdm
import math
import argparse
import copy

In [2]:
api_lists = json.load(open("data/api_list.json"))

def get_test_data(api_lists):
    examples = []
    fail_count = 0
    example_id = 1
    for i in range(1, 4):
        example_file = f"../25_K_Examples/part-{i}-output/taken_answers_with_all_details.json"
        data = json.load(open(example_file))
        for e in data:
            try:
                ques_id = e['question_id']
                qtitle = e['formatted_input']['question']['title']
                qdesc = e['formatted_input']['question']['ques_desc']
                codes = e['formatted_input']['answer']['code']
                apis = set()
                for c in codes:
                    tokens = nltk.wordpunct_tokenize(c)
                    for tidx, token in enumerate(tokens):
                        token = token.strip()
                        if tidx >= 0:
                            prev_token = tokens[tidx - 1].strip()[-1]
                            if (token in api_lists and prev_token == ".") or token == "DataFrame":
                                apis.add(token)
                api_seq = list(sorted(apis))
                if len(api_seq) <= 0:
                    continue
                examples.append({
                    'id': ques_id,
                    'query': qtitle.strip().lower() + " " + qdesc.strip().lower(),
                    "apis": api_seq,
                    'link': e['link'],
                    "example": e['formatted_input']
                })
            except Exception as ex:
                print(ex)
                fail_count += 1
    return examples

test_examples = get_test_data(list(api_lists.keys()))
print(len(test_examples))


608


In [3]:
class RetrieverModel:
    def __init__(self, model_path, batch_size=64, quiet=False, no_cuda=False):
        parser = argparse.ArgumentParser()
        add_encoder_params(parser)
        add_tokenizer_params(parser)
        add_cuda_params(parser)
        parser.add_argument(
            '--shard_size', 
            type=int, 
            default=50000, 
            help="Total amount of data in 1 shard"
        )
        parser.add_argument(
            '--batch_size', 
            type=int, 
            default=32, 
            help="Batch size for the passage encoder forward pass"
        )
        parser.add_argument(
            '--dataset', 
            type=str, 
            default=None, 
            help=' to build correct dataset parser '
        )

        self.args = parser.parse_args({})
        self.quiet = quiet
        self.args.model_file = model_path
        setup_args_gpu(self.args)
        if no_cuda:
            self.args.device = torch.device("cpu")
        saved_state = load_states_from_checkpoint(self.args.model_file)
        set_encoder_params_from_state(
            saved_state.encoder_params, 
            self.args,
            quiet=self.quiet
        )
        self.batch_size = batch_size
        
        self.tensorizer, self.encoder, _ = init_biencoder_components(
            self.args.encoder_model_type, 
            self.args, 
            inference_only=True
        )
        self.encoder.load_state_dict(saved_state.model_dict)
        self.query_model = self.encoder.question_model
        self.document_model = self.encoder.ctx_model
        
        self.api_lists = json.load(open("data/api_list.json"))

        self.apis = list(sorted(self.api_lists.keys()))
        self.api_docs = [self.api_lists[a] for a in self.apis]

        _, _, _, self.doc_vectors = self.generate_query_vectors()
    
    def generate_query_vectors(self):
        return self.generate_vectors(
            model=self.document_model, 
            sentences=self.api_docs,
            batch_size=self.batch_size,
            task='"API_VECTORS"'
        )
    
    def generate_vectors(self, model, sentences, batch_size, task):
        if not self.quiet:
            print(
                "Generating vectors for %d sentences using %s task model" % (
                    len(sentences), 
                    task
                )
            )
        tensors = []
        for ex in sentences:
            tensor = self.tensorizer.text_to_tensor(ex)
            tensors.append(tensor)
        ids = torch.stack(tensors, dim=0)
        seg_batch = torch.zeros_like(ids)
        attn_mask = self.tensorizer.get_attn_mask(ids)
        model.to(self.args.device)
        l = ids.size(0)
        start_idx = 0
        vectors = [] * l
        num_batches = math.ceil(l / batch_size)
        with torch.no_grad():
            batches = range(num_batches) if self.quiet else tqdm(range(num_batches))
            for _ in batches:
                end_idx = start_idx + batch_size
                if end_idx > l:
                    end_idx = l
                _ids = move_to_device(ids[start_idx:end_idx, :], self.args.device)
                _seg_batch = move_to_device(seg_batch[start_idx:end_idx, :], self.args.device)
                _attn_mask = move_to_device(attn_mask[start_idx:end_idx, :], self.args.device)
                _, _vectors, _ = model(_ids, _seg_batch, _attn_mask)
                vectors.append(_vectors)
                start_idx = end_idx
        vectors = torch.cat(vectors, dim=0)
        return ids, seg_batch, attn_mask, vectors
    
    def retrieve_apis(self, examples, top_k=10):
        query_sentences = [ex["query"] for ex in examples]
        _, _, _, query_vectors = self.generate_vectors(
            model=self.query_model, 
            sentences=query_sentences, 
            batch_size=self.batch_size,
            task='"QUESTION_VECTORS"'
        )
        similarity_results = cosine_similarity(
            query_vectors.cpu().numpy(), 
            self.doc_vectors.cpu().numpy()
        )
        singled_out = []
        return_examples = []
        
        torch.save(
            (query_sentences, query_vectors.cpu(), self.doc_vectors.cpu(), similarity_results), 
            "from_class.pt"
        )
        
        for exid, ex in enumerate(examples):
            example = copy.deepcopy(ex)
            pred_similaroty = [(a, s) for a, s in zip(self.apis, similarity_results[exid, :].tolist())]
            sorted_apis = sorted(pred_similaroty, key=lambda x: x[1])[::-1]
            example["expected"] = example["apis"]
            example["predicted"] = sorted_apis
            predictions = set([e[0] for e in sorted_apis[:top_k]])
            if len(set(example["expected"]).difference(predictions)) == 0:
                singled_out.append(example)
                pass
            return_examples.append(example)
            pass
        return return_examples, singled_out
        

In [4]:
model_path = "models/bert/pandas_0/checkpoint_best.pt"
for i in tqdm([1]):
    retriever = RetrieverModel(
        model_path=f"models/bert/pandas_{i}/checkpoint_best.pt", 
        batch_size=128, 
        quiet=False,
        no_cuda=False
    )
    _, singles = retriever.retrieve_apis(test_examples)
    print(i, len(singles))
    del retriever

  0%|          | 0/1 [00:00<?, ?it/s]

Overriding args parameter value from checkpoint state. Param = pretrained_model_cfg, value = google/bert_uncased_L-6_H-512_A-8
Overriding args parameter value from checkpoint state. Param = encoder_model_type, value = hf_bert
Overriding args parameter value from checkpoint state. Param = sequence_length, value = 512


Generating vectors for 132 sentences using "API_VECTORS" task model


  0%|          | 0/2 [00:00<?, ?it/s]

Generating vectors for 608 sentences using "QUESTION_VECTORS" task model


  0%|          | 0/5 [00:00<?, ?it/s]

1 97


In [5]:
def process_file_num(num_str):
    if "." in num_str:
        parts = num_str.split(".")
        full = parts[0].strip()
        frac = parts[1].strip()
        if len(frac) < 4:
            frac = ('0' * (4-len(frac))) + frac
        elif len(frac) > 4:
            return None
        num_str = full + "." + frac
    return float(num_str)
# print(process_file_num("dpr_biencoder.2.108"[14:]))

In [12]:
import os 

directories = [5]
all_results = {}


output_file = open("all_outputs.tsv", 'a')

for d in directories:
    results = {}
    files = os.listdir(os.path.join("models/bert", "pandas_" + str(d)))
    taken_files = [f for f in files if f.startswith("dpr_biencoder")]
    points = []
    for f in taken_files:
        v = process_file_num(f[14:])
        if v is not None:
            points.append((f, v))
    points = sorted(points, key=lambda x: x[1])
    for i, (f, e) in enumerate(tqdm(points, total=len(points))):
        model_path = os.path.join("models/bert/", "pandas_" + str(d), f)
        retriever = RetrieverModel(
            model_path=model_path, 
            batch_size=128, 
            quiet=True
        )
        _, singles = retriever.retrieve_apis(test_examples)
        results[f] = len(singles)
        print(
            d, e, len(singles), os.path.join("models/bert/", "pandas_" + str(d), f), 
            sep="\t", 
            file=output_file, 
            flush=True
        )
        if i % 1 == 0:
            print(
                d, e, len(singles), os.path.join("models/bert/", "pandas_" + str(d), f), 
                sep="\t", 
            )
        pass
        del retriever
    print("=" * 100)
    all_results[d] = results


  0%|          | 0/130 [00:00<?, ?it/s]

5	0.0422	12	models/bert/pandas_5/dpr_biencoder.0.422
5	0.0844	12	models/bert/pandas_5/dpr_biencoder.0.844
5	0.1266	21	models/bert/pandas_5/dpr_biencoder.0.1266
5	0.1688	19	models/bert/pandas_5/dpr_biencoder.0.1688
5	0.211	17	models/bert/pandas_5/dpr_biencoder.0.2110
5	0.2532	25	models/bert/pandas_5/dpr_biencoder.0.2532
5	0.2954	56	models/bert/pandas_5/dpr_biencoder.0.2954
5	0.3376	37	models/bert/pandas_5/dpr_biencoder.0.3376
5	0.3798	46	models/bert/pandas_5/dpr_biencoder.0.3798
5	0.422	79	models/bert/pandas_5/dpr_biencoder.0.4220
5	0.4642	87	models/bert/pandas_5/dpr_biencoder.0.4642
5	0.5064	68	models/bert/pandas_5/dpr_biencoder.0.5064
5	0.5067	68	models/bert/pandas_5/dpr_biencoder.0.5067
5	1.0422	66	models/bert/pandas_5/dpr_biencoder.1.422
5	1.0844	68	models/bert/pandas_5/dpr_biencoder.1.844
5	1.1266	82	models/bert/pandas_5/dpr_biencoder.1.1266
5	1.1688	68	models/bert/pandas_5/dpr_biencoder.1.1688
5	1.211	73	models/bert/pandas_5/dpr_biencoder.1.2110
5	1.2532	77	models/bert/pandas_5/dp

In [11]:
output_file.close()

In [7]:
for p in [0, 1, 2, 5]:
    retriever = RetrieverModel(
        model_path=f"models/bert/pandas_{p}/checkpoint_best.pt", 
        batch_size=128, 
        quiet=False
    )
    _, singles = retriever.retrieve_apis(test_examples, top_k=10)
    print(p, len(singles))

Overriding args parameter value from checkpoint state. Param = pretrained_model_cfg, value = google/bert_uncased_L-6_H-512_A-8
Overriding args parameter value from checkpoint state. Param = encoder_model_type, value = hf_bert
Overriding args parameter value from checkpoint state. Param = sequence_length, value = 512


Generating vectors for 132 sentences using "API_VECTORS" task model


  0%|          | 0/2 [00:00<?, ?it/s]

Generating vectors for 608 sentences using "QUESTION_VECTORS" task model


  0%|          | 0/5 [00:00<?, ?it/s]

Overriding args parameter value from checkpoint state. Param = pretrained_model_cfg, value = google/bert_uncased_L-6_H-512_A-8
Overriding args parameter value from checkpoint state. Param = encoder_model_type, value = hf_bert
Overriding args parameter value from checkpoint state. Param = sequence_length, value = 512


0 74
Generating vectors for 132 sentences using "API_VECTORS" task model


  0%|          | 0/2 [00:00<?, ?it/s]

Generating vectors for 608 sentences using "QUESTION_VECTORS" task model


  0%|          | 0/5 [00:00<?, ?it/s]

Overriding args parameter value from checkpoint state. Param = pretrained_model_cfg, value = google/bert_uncased_L-6_H-512_A-8
Overriding args parameter value from checkpoint state. Param = encoder_model_type, value = hf_bert
Overriding args parameter value from checkpoint state. Param = sequence_length, value = 512


1 97
Generating vectors for 132 sentences using "API_VECTORS" task model


  0%|          | 0/2 [00:00<?, ?it/s]

Generating vectors for 608 sentences using "QUESTION_VECTORS" task model


  0%|          | 0/5 [00:00<?, ?it/s]

2 79


Overriding args parameter value from checkpoint state. Param = pretrained_model_cfg, value = google/bert_uncased_L-6_H-512_A-8
Overriding args parameter value from checkpoint state. Param = encoder_model_type, value = hf_bert
Overriding args parameter value from checkpoint state. Param = sequence_length, value = 512


Generating vectors for 132 sentences using "API_VECTORS" task model


  0%|          | 0/2 [00:00<?, ?it/s]

Generating vectors for 608 sentences using "QUESTION_VECTORS" task model


  0%|          | 0/5 [00:00<?, ?it/s]

5 87
