## ParCR evaluation

This notebook goes through the steps needed to evaluate the encoder for Paragraph-level Citation Recommendation. Before running the notebook, you should have two JSON files ready: one with query embeddings and another with paper embeddings. These files are produced by `embed.py`, so make sure you run `embed.py` and have the output from the script ready when running this notebook.

In [2]:
import json
import numpy as np
import faiss

In [None]:
# Path to the JSON file containing data about papers that make up the pool
PAPERS_POOL_PATH = 'data/test_pool_papers.json'

# Path to the JSON file containing data about paragraphs that make up the queries that are being evaluated
PARAGRAPH_LABELS_PATH = 'data/test_paragraphs.json'

# Path to the JSON file containing query embeddings (output of `embed.py`)
QUERY_EMBEDDINGS_PATH = 'queries_embeddings.json'  # TODO replace with the output of `embed.py`

# Path to the JSON file containing paper embeddings (output of `embed.py`)
PAPER_EMBEDDINGS_PATH = 'papers_embeddings.json'  # TODO replace with the output of `embed.py`

In [15]:
pool = json.load(open(PAPERS_POOL_PATH))
par_labels = json.load(open(PARAGRAPH_LABELS_PATH))
query_embs_map = json.load(open(QUERY_EMBEDDINGS_PATH))
paper_embs_map = json.load(open(PAPER_EMBEDDINGS_PATH))

In [28]:
print(f'Loaded {len(query_embs_map)} query embeddings and {len(paper_embs_map)} paper embeddings.')

Loaded 2148 query embeddings and 94129 paper embeddings.


Next, we need to turn embeddings into numpy arrays and store query/paper ids into separate lists.

In [5]:
query_ids, query_embs = [], []
for qid in query_embs_map:
    query_ids.append(qid)
    query_embs.append(np.array(query_embs_map[qid]).astype('float32'))

paper_ids, paper_embs = [], []
for pid in paper_embs_map:
    paper_ids.append(pid)
    paper_embs.append(np.array(paper_embs_map[pid]).astype('float32'))
    
paper_embs = np.array(paper_embs)
query_embs = np.array(query_embs)

We're now ready to create an index and index all the embeddings

In [6]:
# Create index on GPU
res = faiss.StandardGpuResources()
index_flat = faiss.IndexFlatL2(paper_embs.shape[1])
gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index_flat)

# Index paper embeddings                          
gpu_index_flat.add(paper_embs)

Once we've created the index, we can perform search and retrieve nearest neighbours for the queries.

In [7]:
# Perform search for all the query embeddings, retrieve top 1024 neighbours for each
D, I = gpu_index_flat.search(query_embs, 1024)

Now we need a method that creates a list of nearest neighbour ids which we'll use to calculate metrics that will tell us how the encoder model performs - does it encode relevant articles close to the query or not.

In [9]:
def get_query_candidates_map(query_ids, D, I, keep_only_older_than_citing=False):
    query_candidates = {}
    for qid, scores, neighbours in zip(query_ids, D, I):
        citing_id = qid.split("_")[0]
        neighbours_pids = [paper_ids[i] for i in neighbours]
        if keep_only_older_than_citing:
            neighbours_pids = [
                i
                for i in neighbours_pids
                if pool[i]["year"] is None or pool[i]["year"] < pool[citing_id]["year"]
            ]
        query_candidates[qid] = [(i, s) for i, s in zip(neighbours_pids, scores) if i != citing_id][:1000]

    for pid in query_candidates:
        query_candidates[pid] = [(i[0], float(i[1])) for i in query_candidates[pid]]

    return query_candidates

In [10]:
query_candidates = get_query_candidates_map(query_ids, D, I, keep_only_older_than_citing=True)

Now we can calculate all the metrics we're interested in using the lists of nearest neighbours and articles actually cited in the query (data loaded from `PARAGRAPH_LABELS_PATH`).

In [23]:
from metrics import recall, reciprocal_rank, average_precision, ndcg

qid_data = []

for qid in query_candidates:
    cands = [i[0] for i in query_candidates[qid]]
    if par_labels[qid]['citations']:
        if isinstance(par_labels[qid]['citations'][0], str):
            true = list(set(par_labels[qid]['citations']))
        else:
            true = list(set([i['ref_id'] for i in par_labels[qid]['citations']]))
    qid_data.append({
        'qid': qid,
        'r@1': recall(true, cands, k=1),
        'r@5': recall(true, cands, k=5),
        'r@10': recall(true, cands, k=10),
        'r@100': recall(true, cands, k=100),
        'r-precision': recall(true, cands, k=len(true)),
        'rec_rank': reciprocal_rank(true, cands, k=1000),
        'avg_prec': average_precision(true, cands),
        'ndcg': ndcg(true, cands)
    })

Below we print the average results across all the queries for metrics of interest.

In [26]:
for metric in ['r@1', 'r@5', 'r@10', 'r@100', 'r-precision', 'rec_rank', 'avg_prec', 'ndcg']:
    print(metric, f"{sum(i[metric] for i in qid_data) / len(qid_data):.2}")

r@1 0.032
r@5 0.11
r@10 0.16
r@100 0.41
r-precision 0.084
rec_rank 0.21
avg_prec 0.093
ndcg 0.25
