In [None]:
!pip install -U sentence-transformers

In [None]:
!pip install jsonlines

In [1]:
import json
import jsonlines
import numpy as np
import math
import os
import multiprocessing
import argparse
import statistics
import codecs
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModel
import operator

In [2]:
# os.environ["TOKENIZERS_PARALLELISM"] = "false" # only if multithreading used

In [3]:
def r_precision(r):
    r = np.asarray(r) != 0
    z = r.nonzero()[0]
    if not z.size:
        return 0.
    return np.mean(r[:z[-1] + 1])

def precision_at_k(r, k):
    assert k >= 1
    r = np.asarray(r)[:k] != 0
    return np.mean(r)

def average_precision(r):
    r = np.asarray(r) != 0
    out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]]
    if not out:
        return 0.
    return np.mean(out)

def mean_average_precision(rs):
    return np.mean([average_precision(r) for r in rs])

def dcg_at_k(r, k, method=0):
    r = np.asfarray(r)[:k]
    if r.size:
        if method == 0:
            return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))
        elif method == 1:
            return np.sum(r / np.log2(np.arange(2, r.size + 2)))
        else:
            raise ValueError('method must be 0 or 1.')
    return 0.

def ndcg_at_k(r, k, method=0):
    dcg_max = dcg_at_k(sorted(r, reverse=True), k, method)
    if not dcg_max:
        return 0.
    return dcg_at_k(r, k, method) / dcg_max

def recall_at_k(ranked_rel, atk, max_total_relevant):
    total_relevant = sum(ranked_rel)
    total_relevant = min(max_total_relevant, total_relevant)
    relatk = sum(ranked_rel[:atk])
    if total_relevant > 0:
        recall_atk = float(relatk)/total_relevant
    else:
        recall_atk = 0.0
    return recall_atk

In [4]:
def dot_product2(v1, v2):
    return sum(map(operator.mul, v1, v2))


def cosine_sim(v1, v2):
    prod = dot_product2(v1, v2)
    len1 = math.sqrt(dot_product2(v1, v1))
    len2 = math.sqrt(dot_product2(v2, v2))
    return prod / (len1 * len2)

In [5]:
facet2folds = {
    "background": {"fold1_dev": ["3264891_background", "1936997_background", "11844559_background",
                                 "52194540_background", "1791179_background", "6431039_background",
                                 "6173686_background", "7898033_background"],
                   "fold2_dev": ["5764728_background", "10014168_background", "10695055_background",
                                 "929877_background", "1587_background", "51977123_background",
                                 "8781666_background", "189897839_background"],
                   "fold1_test": ["5764728_background", "10014168_background", "10695055_background",
                                  "929877_background", "1587_background", "51977123_background",
                                  "8781666_background", "189897839_background"],
                   "fold2_test": ["3264891_background", "1936997_background", "11844559_background",
                                  "52194540_background", "1791179_background", "6431039_background",
                                  "6173686_background", "7898033_background"]},
    "method": {"fold1_dev": ["189897839_method", "1791179_method", "11310392_method", "2468783_method",
                             "13949438_method", "5270848_method", "52194540_method", "929877_method"],
               "fold2_dev": ["5052952_method", "10010426_method", "102353905_method", "174799296_method",
                             "1198964_method", "53080736_method", "1936997_method", "80628431_method",
                             "53082542_method"],
               "fold1_test": ["5052952_method", "10010426_method", "102353905_method", "174799296_method",
                              "1198964_method", "53080736_method", "1936997_method", "80628431_method",
                              "53082542_method"],
               "fold2_test": ["189897839_method", "1791179_method", "11310392_method", "2468783_method",
                              "13949438_method", "5270848_method", "52194540_method", "929877_method"]},
    "result": {"fold1_dev": ["2090262_result", "174799296_result", "11844559_result", "2468783_result",
                             "1306065_result", "5052952_result", "3264891_result", "8781666_result"],
               "fold2_dev": ["2865563_result", "10052042_result", "11629674_result", "1587_result",
                             "1198964_result", "53080736_result", "2360770_result", "80628431_result",
                             "6431039_result"],
               "fold1_test": ["2865563_result", "10052042_result", "11629674_result", "1587_result",
                              "1198964_result", "53080736_result", "2360770_result", "80628431_result",
                              "6431039_result"],
               "fold2_test": ["2090262_result", "174799296_result", "11844559_result", "2468783_result",
                              "1306065_result", "5052952_result", "3264891_result", "8781666_result"]},
    "all": {"fold1_dev": ["3264891_background", "1936997_background", "11844559_background",
                          "52194540_background", "1791179_background", "6431039_background",
                          "6173686_background", "7898033_background", "189897839_method",
                          "1791179_method", "11310392_method", "2468783_method", "13949438_method",
                          "5270848_method", "52194540_method", "929877_method", "2090262_result",
                          "174799296_result", "11844559_result", "2468783_result", "1306065_result",
                          "5052952_result", "3264891_result", "8781666_result"],
            "fold2_dev": ["5764728_background", "10014168_background", "10695055_background",
                          "929877_background", "1587_background", "51977123_background",
                          "8781666_background", "189897839_background", "5052952_method", "10010426_method",
                          "102353905_method", "174799296_method", "1198964_method", "53080736_method",
                          "1936997_method", "80628431_method", "53082542_method", "2865563_result",
                          "10052042_result", "11629674_result", "1587_result", "1198964_result",
                          "53080736_result", "2360770_result", "80628431_result", "6431039_result"],
            "fold1_test": ["5764728_background", "10014168_background", "10695055_background",
                           "929877_background", "1587_background", "51977123_background", "8781666_background",
                           "189897839_background", "5052952_method", "10010426_method", "102353905_method",
                           "174799296_method", "1198964_method", "53080736_method", "1936997_method",
                           "80628431_method", "53082542_method", "2865563_result", "10052042_result",
                           "11629674_result", "1587_result", "1198964_result", "53080736_result",
                           "2360770_result", "80628431_result", "6431039_result"],
            "fold2_test": ["3264891_background", "1936997_background", "11844559_background",
                           "52194540_background", "1791179_background", "6431039_background",
                           "6173686_background", "7898033_background", "189897839_method", "1791179_method",
                           "11310392_method", "2468783_method", "13949438_method", "5270848_method",
                           "52194540_method", "929877_method", "2090262_result", "174799296_result",
                           "11844559_result", "2468783_result", "1306065_result", "5052952_result",
                           "3264891_result", "8781666_result"]
            }
}

In [6]:
def read_facet_specific_relevances(data_path, run_path, dataset, facet, method_name):
    """
    Read the gold data and the model rankings and the relevances for the
    model.
    :param data_path: string; directory with gold citations for test pids and rankings
        from baseline methods in subdirectories.
    :param run_path: string; directory with ranked candidates for baselines a subdir of
        data_path else is a model run.
    :param method_name: string; method with which ranks were created.
    :param dataset: string; eval dataset.
    :param facet: string; facet for eval.
    :return: qpid2rankedcand_relevances: dict('qpid_facet': [relevances]);
        candidate gold relevances for the candidates in order ranked by the
        model.
    """
    gold_fname = os.path.join(data_path, 'test-pid2anns-{:s}-{:s}.json'.format(dataset, facet))
    ranked_fname = os.path.join(run_path, method_name, 'test-pid2pool-{:s}-{:s}-{:s}-ranked.json'.format(dataset, method_name, facet))
    # Load gold test data (citations).
    with codecs.open(gold_fname, 'r', 'utf-8') as fp:
        pid2pool_source = json.load(fp)
        num_query = len(pid2pool_source)
        print('Gold query pids: {:d}'.format(num_query))
        pid2rels_gold = {}
        for qpid, pool_rel in pid2pool_source.items():
            pool = pool_rel['cands']
            cands_rels = pool_rel['relevance_adju']
            pid2rels_gold['{:s}_{:s}'.format(qpid, facet)] = dict([(pid, rel) for pid, rel in zip(pool, cands_rels)])
    # Load ranked predictions on test data with methods.
    with codecs.open(ranked_fname, 'r', 'utf-8') as fp:
        pid2ranks = json.load(fp)
        print('Valid ranked query pids: {:d}'.format(len(pid2ranks)))
        qpid2rankedcand_relevances = {}
        for qpid, citranks in pid2ranks.items():
            candpids = [pid_score[0] for pid_score in citranks]
            cand_relevances = [pid2rels_gold['{:s}_{:s}'.format(qpid, facet)][pid] for pid in candpids]
            qpid2rankedcand_relevances['{:s}_{:s}'.format(qpid, facet)] = cand_relevances
    return qpid2rankedcand_relevances


def read_all_facet_relevances(data_path, run_path, dataset, method_name, facets):
    """
    Read the gold data and the model rankings and the relevances for the
    model.
    :param data_path: string; directory with gold citations for test pids and rankings
        from baseline methods in subdirectories.
    :param run_path: string; directory with ranked candidates for baselines a subdir of
        data_path else is a model run.
    :param method_name: string; method with which ranks were created.
    :param dataset: string; eval dataset.
    :param facets: list(string); what facets to read/what counts as "all".
    :return: qpid2rankedcand_relevances: dict('qpid_facet': [relevances]);
        candidate gold relevances for the candidates in order ranked by the
        model.
    """
    qpid2rankedcand_relevances = {}
    for facet in facets:
        print('Reading facet: {:s}'.format(facet))
        gold_fname = os.path.join(data_path, 'test-pid2anns-{:s}-{:s}.json'.format(dataset, facet))
        ranked_fname = os.path.join(run_path, method_name, 'test-pid2pool-{:s}-{:s}-{:s}-ranked.json'.format(dataset, method_name, facet))
        # Load gold test data (citations).
        with codecs.open(gold_fname, 'r', 'utf-8') as fp:
            pid2pool_source = json.load(fp)
            num_query = len(pid2pool_source)
            print('Gold query pids: {:d}'.format(num_query))
            pid2rels_gold = {}
            for qpid, pool_rel in pid2pool_source.items():
                pool = pool_rel['cands']
                cands_rels = pool_rel['relevance_adju']
                pid2rels_gold['{:s}_{:s}'.format(qpid, facet)] = \
                    dict([(pid, rel) for pid, rel in zip(pool, cands_rels)])
        # Load ranked predictions on test data with methods.
        with codecs.open(ranked_fname, 'r', 'utf-8') as fp:
            pid2ranks = json.load(fp)
            print('Valid ranked query pids: {:d}'.format(len(pid2ranks)))
            for qpid, citranks in pid2ranks.items():
                candpids = [pid_score[0] for pid_score in citranks]
                cand_relevances = [pid2rels_gold['{:s}_{:s}'.format(qpid, facet)][pid] for pid in candpids]
                qpid2rankedcand_relevances['{:s}_{:s}'.format(qpid, facet)] = cand_relevances
    print('Total queries: {:d}'.format(len(qpid2rankedcand_relevances)))
    return qpid2rankedcand_relevances

In [7]:
def compute_metrics(ranked_judgements, pr_atk, threshold_grade):
    """
    Given the ranked judgements compute the metrics for a query.
    :param ranked_judgements: list(int); graded or binary relevances in rank order.
    :param pr_atk: int; the @K value to use for computing precision and recall.
    :param threshold_grade: int; Assuming 0-3 graded relevances, threshold at some point
        and convert graded to binary relevance.
    :return:
    """
    graded_judgements = ranked_judgements
    ranked_judgements = [1 if rel >= threshold_grade else 0 for rel in graded_judgements]
    # Use the full set of candidate not the pr_atk.
    ndcg = ndcg_at_k(graded_judgements, len(ranked_judgements))
    ndcg_pr = ndcg_at_k(graded_judgements, int(0.20*len(ranked_judgements)))
    ndcg_20 = ndcg_at_k(graded_judgements, 20)
    max_total_relevant = sum(ranked_judgements)
    recall = recall_at_k(ranked_rel=ranked_judgements,
                         atk=pr_atk, max_total_relevant=max_total_relevant)
    precision = precision_at_k(r=ranked_judgements, k=pr_atk)
    precision_r = r_precision(r=ranked_judgements)
    av_precision = average_precision(r=ranked_judgements)
    metrics = {
        'recall': float(recall),
        'precision': float(precision),
        'r_precision': float(precision_r),
        'av_precision': float(av_precision),
        'ndcg': ndcg,
        'ndcg@20': ndcg_20,
        'ndcg%20': ndcg_pr
    }
    return metrics

In [8]:
def aggregate_metrics_crossval(query_metrics, split_str, facet_str):
    """
    Given metrics over individual queries aggregate over different
    queries.
    :param query_metrics: dict(query_id: metrics_dict from compute_metrics)
    :param split_str: string; {dev, test}
    :param facet_str: string; {background, method, result}
    :return:
    """
    aggmetrics = {
        'precision': [],
        'recall': [],
        'r_precision': [],
        'mean_av_precision': [],
        'ndcg': [],
        'ndcg@20': [],
        'ndcg%20': []
    }
    # For dev only use a part of the fold - using both makes it identical to test.
    if split_str == 'dev':
        folds = ['fold1_{:s}'.format(split_str)]
    elif split_str == 'test':
        folds = ['fold1_{:s}'.format(split_str), 'fold2_{:s}'.format(split_str)]
    for fold_str in folds:
        fold_pids = facet2folds[facet_str][fold_str]
        precision, recall, f1, av_precision, mrr, ndcg, r_precision = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
        ndcg_20, ndcg_pr = 0.0, 0.0
        for query_id in fold_pids:
            # Aggregate across paper types in the fold.
            metrics = query_metrics[query_id]
            # Aggregate across all papers in the fold
            precision += metrics['precision']
            recall += metrics['recall']
            av_precision += metrics['av_precision']
            r_precision += metrics['r_precision']
            ndcg += metrics['ndcg']
            ndcg_20 += metrics['ndcg@20']
            ndcg_pr += metrics['ndcg%20']
        # Average all folds
        num_queries = len(fold_pids)
        precision, recall, f1 = precision/num_queries, recall/num_queries, f1/num_queries
        av_precision = av_precision/num_queries
        mrr, ndcg, r_precision = mrr/num_queries, ndcg/num_queries, r_precision/num_queries
        ndcg_20, ndcg_pr = ndcg_20/num_queries, ndcg_pr/num_queries
        # Save the averaged metric for every fold.
        aggmetrics['precision'].append(precision)
        aggmetrics['recall'].append(recall)
        aggmetrics['r_precision'].append(r_precision)
        aggmetrics['mean_av_precision'].append(av_precision)
        aggmetrics['ndcg'].append(ndcg)
        aggmetrics['ndcg@20'].append(ndcg_20)
        aggmetrics['ndcg%20'].append(ndcg_pr)

    aggmetrics = {
        'precision': statistics.mean(aggmetrics['precision']),
        'recall': statistics.mean(aggmetrics['recall']),
        'r_precision': statistics.mean(aggmetrics['r_precision']),
        'mean_av_precision': statistics.mean(aggmetrics['mean_av_precision']),
        'ndcg': statistics.mean(aggmetrics['ndcg']),
        'ndcg@20': statistics.mean(aggmetrics['ndcg@20']),
        'ndcg%20': statistics.mean(aggmetrics['ndcg%20'])
    }
    return aggmetrics

In [9]:
def graded_eval_pool_rerank(data_path, run_path, method_name, dataset, facet, split, ATK):
    """
    Evaluate the re-ranked pool for the faceted data. Anns use graded relevance scores.
    :param data_path: string; directory with gold citations for test pids and rankings
        from baseline methods in subdirectories.
    :param run_path: string; directory with ranked candidates for baselines a subdir of
        data_path else is a model run.
    :param method_name: string; method with which ranks were created.
    :param dataset: string; eval dataset.
    :param facet: string; facet for eval.
    :param split: strong; {dev, test}
    :return:
    """
    print(f'EVAL SPLIT: {split}')
    if facet == 'all':
        qpid2rankedcand_relevances = read_all_facet_relevances(data_path=data_path, run_path=run_path,
                                                               dataset=dataset, method_name=method_name,
                                                               facets=['background', 'method', 'result'])
    else:
        qpid2rankedcand_relevances = read_facet_specific_relevances(data_path=data_path, run_path=run_path,
                                                                    dataset=dataset, facet=facet,
                                                                    method_name=method_name)
    # Go over test papers and compute metrics.
    all_metrics = {}
    num_cands = 0.0
    num_queries = 0.0
    print('Precision and recall at rank: {:d}'.format(ATK))
    for qpid_facet, qranked_judgements in qpid2rankedcand_relevances.items():
        all_metrics[qpid_facet] = compute_metrics(qranked_judgements, pr_atk=ATK,
                                                  threshold_grade=2)
        num_cands += len(qranked_judgements)
        num_queries += 1
    aggmetrics = aggregate_metrics_crossval(query_metrics=all_metrics, facet_str=facet, split_str=split)
    print('Total queries: {:d}; Total candidates: {:d}'.format(int(num_queries), int(num_cands)))
    print('R-Precision: {:.4f}'.format(aggmetrics['r_precision']))
    print('Precision@{:d}: {:.4f}'.format(ATK, aggmetrics['precision']))
    print('Recall@{:d}: {:.4f}'.format(ATK, aggmetrics['recall']))
    print('NDCG: {:.4f}'.format(aggmetrics['ndcg']))
    print('NDCG@20: {:.4f}'.format(aggmetrics['ndcg@20']))
    print('NDCG%20: {:.4f}'.format(aggmetrics['ndcg%20']))
    return aggmetrics

In [10]:
model_list = [['alberta', 2.5], ['all_mpnet_base_v2', 5], ['bert_nli', 0.5], ['bert_pp', 0.5], ['distilbert_nli', 5], ['allenai_specter', 3.5]]

In [11]:
def gen_metric_scores(facet, ATK, dataset='csfcube', data_path='./data', run_path='./Results'):
    ens = dict()
    data_all = dict()
    for model in model_list:
        data_all[model[0]] = json.load(open(f'{run_path}/{model[0]}/test-pid2pool-{dataset}-{model[0]}-{facet}-ranked.json'))
    for model in model_list:
        method = model[0]
        weight = model[1]

        data = data_all[method]
        for qid in data:
            if qid not in ens:
                ens[qid] = dict()
            query = data[qid]
            for id, score in query:
                if id not in ens[qid]:
                    ens[qid][id] = 0
                ens[qid][id] += score * weight

        # data = json.load(open(f'{data_path}/test-pid2anns-{dataset}-{facet}.json'))
        # corpus = json.load(open(f'{run_path}/{method}/all.json'))
        # queries = json.load(open(f'{run_path}/{method}/{facet}.json'))
        # for qid in data:
        #     if qid not in ens:
        #         ens[qid] = dict()
        #     query = queries[qid]
        #     candidates = data[qid]['cands']
        #     for id in candidates:
        #         if id not in ens[qid]:
        #             ens[qid][id] = 0
        #         if id not in corpus:
        #             continue
        #         ens[qid][id] += cosine_similarity(query, corpus[id]) * weight
        
    for qid in ens:
        sorted_results = sorted(ens[qid].items(), key=lambda kv: (kv[1], kv[0]), reverse=True)
        ens[qid] = sorted_results
    with open(f'{run_path}/ensemble/test-pid2pool-{dataset}-ensemble-{facet}-ranked.json', 'w') as outfile:
      json.dump(ens, outfile)

In [12]:
def get_metric_scores(facet, ATK, dataset='csfcube', data_path='./data', run_path='./Results'):
    aggmetrics1 = graded_eval_pool_rerank(data_path=data_path, run_path=run_path, method_name='ensemble', facet=facet, dataset=dataset, split='dev', ATK=ATK)
    print()
    aggmetrics2 = graded_eval_pool_rerank(data_path=data_path, run_path=run_path, method_name='ensemble', facet=facet, dataset=dataset, split='test', ATK=ATK)
    
    print('\nAVERAGE METRICS')
    print('R-Precision: {:.4f}'.format((aggmetrics1['r_precision'] + aggmetrics2['r_precision']) / 2))
    print('Precision@{:d}: {:.4f}'.format(ATK, (aggmetrics1['precision'] + aggmetrics2['precision']) / 2))
    print('Recall@{:d}: {:.4f}'.format(ATK, (aggmetrics1['recall'] + aggmetrics2['recall']) / 2))
    print('NDCG: {:.4f}'.format((aggmetrics1['ndcg'] + aggmetrics2['ndcg']) / 2))
    print('NDCG@20: {:.4f}'.format((aggmetrics1['ndcg@20'] + aggmetrics2['ndcg@20']) / 2))
    print('NDCG%20: {:.4f}'.format((aggmetrics1['ndcg%20'] + aggmetrics2['ndcg%20']) / 2))

In [13]:
gen_metric_scores('background', 20)

In [14]:
get_metric_scores('background', 20)

EVAL SPLIT: dev
Gold query pids: 16
Valid ranked query pids: 16
Precision and recall at rank: 20
Total queries: 16; Total candidates: 1877
R-Precision: 0.3103
Precision@20: 0.3625
Recall@20: 0.5650
NDCG: 0.8474
NDCG@20: 0.7030
NDCG%20: 0.7020

EVAL SPLIT: test
Gold query pids: 16
Valid ranked query pids: 16
Precision and recall at rank: 20
Total queries: 16; Total candidates: 1877
R-Precision: 0.2811
Precision@20: 0.3531
Recall@20: 0.5617
NDCG: 0.8597
NDCG@20: 0.7157
NDCG%20: 0.7354

AVERAGE METRICS
R-Precision: 0.2957
Precision@20: 0.3578
Recall@20: 0.5634
NDCG: 0.8535
NDCG@20: 0.7093
NDCG%20: 0.7187


In [None]:
gen_metric_scores('method', 20)

In [None]:
get_metric_scores('method', 20)

In [None]:
gen_metric_scores('result', 20)

In [None]:
get_metric_scores('result', 20)

In [None]:
get_metric_scores('all', 20)

In [15]:
model_sent_bert_nli = SentenceTransformer('nli-roberta-base-v2')

model_sent_bert_pp = SentenceTransformer('paraphrase-TinyBERT-L6-v2')

model_all_mpnet_base_v2 = SentenceTransformer('all-mpnet-base-v2')

model_sent_distbert_nli = SentenceTransformer('all-distilroberta-v1')

model_alberta = SentenceTransformer('paraphrase-albert-small-v2')

specter_tokenize = AutoTokenizer.from_pretrained('allenai/specter')
specter_model = AutoModel.from_pretrained('allenai/specter')

In [16]:
def get_bert_nli_embedding(sentence):
    return model_sent_bert_nli.encode(sentence)

def get_bert_pp_embedding(sentence):
    return model_sent_bert_pp.encode(sentence)

def get_all_mpnet_base_v2_embedding(sentence):
    return model_all_mpnet_base_v2.encode(sentence)

def get_distilbert_base_v2_embedding(sentence):
    return model_sent_distbert_nli.encode(sentence)

def get_alberta_embedding(sentence):
    return model_alberta.encode(sentence)

def get_allenai_specter_embedding(sentence):
    inputs = specter_tokenize(sentence, padding=True, truncation=True, return_tensors="pt", max_length=5000)
    return specter_model(**inputs).last_hidden_state[:, 0, :]

In [17]:
data_all = dict()
for model in model_list:
    data_all[model[0]] = dict()
    data_all[model[0]]['all'] = json.load(open(f'./Results/{model[0]}/all.json'))
    data_all[model[0]]['background'] = json.load(open(f'./Results/{model[0]}/background.json'))
    data_all[model[0]]['method'] = json.load(open(f'./Results/{model[0]}/method.json'))
    data_all[model[0]]['result'] = json.load(open(f'./Results/{model[0]}/result.json'))

In [18]:
def custom_doc(inp):
    facet = inp[0]
    query_embedding = inp[1]
    id = inp[2]
    temp = 0
    for i in query_embedding:
        data = data_all[model_list[i][0]][facet][id]
        weight = model_list[i][1]
        temp += cosine_similarity(query_embedding[i], data) * weight
    return (id, temp)

    # for normal execution
def custom_docs_single_core(facet, ATK, user_query):
    ens = dict()
    for model in model_list:
        try:
            method = model[0]
            weight = model[1]
            query_embedding = []
            if method == 'bert_nli':
                query_embedding = np.array(get_bert_nli_embedding(user_query)).tolist()
            elif method == 'bert_pp':
                query_embedding = np.array(get_bert_pp_embedding(user_query)).tolist()
            elif method == 'all_mpnet_base_v2':
                query_embedding = np.array(get_all_mpnet_base_v2_embedding(user_query)).tolist()
            elif method == 'distilbert_nli':
                query_embedding = np.array(get_distilbert_base_v2_embedding(user_query)).tolist()
            elif method == 'alberta':
                query_embedding = np.array(get_alberta_embedding(user_query)).tolist()
            else:
                query_embedding = get_allenai_specter_embedding(" ".join(user_query)).detach().numpy().tolist()[0]
            data = data_all[model[0]][facet]
            for id in data:
                if id not in ens:
                    ens[id] = 0
                ens[id] += cosine_sim(query_embedding, data[id]) * weight
        except:
            pass
    sorted_results = sorted(ens.items(), key=lambda kv: (kv[1], kv[0]), reverse=True)
    return sorted_results[:ATK]

In [19]:
# import getscores_multicore

In [20]:
# ens = dict()
# def custom_docs(facet, ATK, user_query):
#     query_embedding = dict()
#     for model in model_list:
#         try:
#             method = model[0]
#             if method == 'bert_nli':
#                 query_embedding[method] = np.array(get_bert_nli_embedding(user_query)).tolist()
#             elif method == 'bert_pp':
#                 query_embedding[method] = np.array(get_bert_pp_embedding(user_query)).tolist()
#             elif method == 'all_mpnet_base_v2':
#                 query_embedding[method] = np.array(get_all_mpnet_base_v2_embedding(user_query)).tolist()
#             elif method == 'distilbert_nli':
#                 query_embedding[method] = np.array(get_distilbert_base_v2_embedding(user_query)).tolist()
#             elif method == 'alberta':
#                 query_embedding[method] = np.array(get_alberta_embedding(user_query)).tolist()
#             else:
#                 query_embedding[method] = get_allenai_specter_embedding(" ".join(user_query)).detach().numpy().tolist()[0]
#             data = data_all[model[0]][facet]
#             for id in data:
#                 ens[id] = 0
#         except:
#             pass
#     print('Embedding gen')
#     return query_embedding
    
# def custom_docs_multicore(facet, ATK, user_query):
#     query_embedding = custom_docs(facet, ATK, user_query)
#     res = getscores_multicore.main(data_all, model_list, facet, query_embedding, ens)
#     print(res)
#     sorted_results = sorted(res.items(), key=lambda kv: (kv[1], kv[0]), reverse=True)
#     return sorted_results[:ATK]

In [22]:
class get_doc:
    def __init__(self, paper_id, metadata, title, abstract, pred_labels_truncated, pred_labels):
      self.paper_id = paper_id
      self.metadata = metadata
      self.title = title
      self.abstract = abstract
      self.pred_labels_truncated = pred_labels_truncated
      self.pred_labels = pred_labels

docs = {}

with jsonlines.open('./data/abstracts-csfcube-preds.jsonl') as doc:
  for section in doc:
    docs[section['paper_id']] = get_doc(section['paper_id'], section['metadata'], section['title'], section['abstract'], section['pred_labels_truncated'], section['pred_labels'])

In [26]:
# user_query = input()
# user_method = input()
# ATK = int(input())
if __name__ == '__main__':
    user_query = 'Graph Theory'
    user_method = 'all'
    ATK = 5
    mapping = {
        'background':['background_label', 'objective_label'],
        'method':['method_label'],
        'result':['result_label']
    }
    result = custom_docs_single_core(user_method, ATK, user_query)

In [27]:
    for i in range(len(result)):
        print('\nTitle: ', docs[result[i][0]].title,  f' \nScore: {result[i][1]}')
        print('Authors: ', end="")
        # print(type(docs[result[i][0]].metadata['authors']))
        for author in docs[result[i][0]].metadata['authors']:
          # print(author)
          name = []
          name.append(author['first'])
          name = name + author['middle']
          name.append(author['last'])
          auth = " ".join(name)
          print(auth, end=", ")
        print()
        print('Year:', docs[result[i][0]].metadata['year'], 'DOI:', docs[result[i][0]].metadata['doi'], 'Venue:', docs[result[i][0]].metadata['venue'])
        print('Abstract:')
        if user_method == 'all':
            for j in range(len(docs[result[i][0]].abstract)):
                print(docs[result[i][0]].abstract[j])
        else:
            for j in range(len(docs[result[i][0]].abstract)):
                for l in mapping[user_method]:
                  if docs[result[i][0]].pred_labels[j] == l:
                      print(docs[result[i][0]].abstract[j])
        print('*************************************************************\n')


Title:  Linear expected-time algorithms for connectivity problems (Extended Abstract)  
Score: 9.64643113983034
Authors: Richard M. Karp, Robert Endre Tarjan, 
Year: 1980 DOI: 10.1145/800141.804686 Venue: STOC '80
Abstract:
Researchers in recent years have developed many graph algorithms that are fast in the worst case, but little work has been done on graph algorithms that are fast on the average.
(Exceptions include the work of Angluin and Valiant [1], Karp [7], and Schnorr [9].)
In this paper we analyze the expected running time of four algorithms for solving graph connectivity problems.
Our goal is to exhibit algorithms whose expected time is within a constant factor of optimum and to shed light on the properties of random graphs.
In Section 2 we develop and analyze a simple algorithm that finds the connected components of an undirected graph with n vertices in O(n) expected time.
In Sections 3 and 4 we describe algorithms for finding the strong components of a directed graph and 