<a href="https://colab.research.google.com/github/anon/ILCiteR/blob/main/rank_cites.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Step 1: Given a test case, Pre-fetch and Rank candidate citation contexts
# Step 2: Rank corresponding set of research articles

In [2]:
import json
import pickle
import numpy
from collections import defaultdict

In [3]:
base_location = 'drive/My Drive/cite_reco_s2orc/full/'
test_size = '500'
prefetched_size = 50

test_map_loc = 'maps/Test/'
db_map_loc = 'maps/Database/'
expts_loc = 'experiments/'
prefetched_loc = expts_loc + 'prefetch/'

In [4]:
# Experiment Parameters

domains = ['ner', 'summ', 'mt']
indices_possibilities = [[2], [0, 2], [1, 2], [0, 1, 2], [0], [1], [0, 1]]
mode = 1 # 0 for normal mode, 1 for hybrid mode
switch_threshold = 50 # Only used if mode == 1

In [5]:
with open(base_location + expts_loc + 'context_to_embed_' + test_size + '.pkl', 'rb') as f:
  context_to_emded = pickle.load(f)

with open(base_location + expts_loc + 'all_contexts_' + test_size + '.json', 'r+') as f:
  all_contexts = json.load(f)

In [6]:
def load_files(domain):
  global base_location, test_size, test_map_loc, db_map_loc, prefetched_loc

  with open(base_location + test_map_loc + domain + '_' + test_size + '.json', 'r+') as f:
    test_set = json.load(f)

  with open(base_location + db_map_loc + domain + '.json', 'r+') as f:
    database_dict = json.load(f)

  database = [[context, database_dict[context]] for context in database_dict.keys()]

  with open(base_location + prefetched_loc + domain + '_' + test_size + '_BM25Okapi.json', 'r+') as f:
    BM25Okapi_fetched = json.load(f)

  with open(base_location + prefetched_loc + domain + '_' + test_size + '_BM25Plus.json', 'r+') as f:
    BM25Plus_fetched = json.load(f)

  return test_set, database, BM25Okapi_fetched, BM25Plus_fetched

In [7]:
# Unit Test

run_unit_test = True
if run_unit_test:
  domain = 'ner'
  indices = [0, 1, 2]
  test_set, database, BM25Okapi_fetched, BM25Plus_fetched = load_files(domain)

In [8]:
def similarity(x, y):
  sim_type = 0
  if sim_type == 0:
    # Cosine Similarity
    return numpy.dot(x, y) / (numpy.linalg.norm(x) * numpy.linalg.norm(y))
  elif sim_type == 1:
    # Distance
    return - numpy.linalg.norm(x - y)
  elif sim_type == 2:
    # Dot Product
    return numpy.dot(x, y)

def get_SciBERT_scores(candidate_IDs, test_context):
  global database, context_to_emded
  test_embed = context_to_emded[test_context]
  score_candidate_ID = []

  for ID in candidate_IDs:
    candidate_text = database[ID][0]
    candidate_embed = context_to_emded[candidate_text]
    score = similarity(candidate_embed, test_embed)
    score_candidate_ID.append([score, ID])

  score_candidate_ID = sorted(score_candidate_ID, reverse = True)
  return score_candidate_ID

def combine_ranks(list_of_ranks, test_point):
  # indices stores the indicies to be considered: [0, 1, 2] for all three rankings
  global indices, mode, switch_threshold
  if mode == 0:
    total = sum([list_of_ranks[index] for index in indices])
  else:
    if test_point[0].count(' ') > switch_threshold:
      total = list_of_ranks[1] + list_of_ranks[2]
    else:
      total = list_of_ranks[0] + list_of_ranks[1]
  return total

def is_same_paper(paper_A, paper_B):
  if paper_A['title'].lower().strip() == paper_B['title'].lower().strip() and paper_A['year'] == paper_B['year']:
    return True
  return False

def search_paper(query_paper, ranked_paper_order):
  for index, paper in enumerate(ranked_paper_order):
    if is_same_paper(query_paper, paper):
      return index

  return -1

def Reci_Rank(ranked_paper_order, ground_truth_papers):
  reci_ranks = [0]
  for actual in ground_truth_papers:
    found_at = search_paper(actual, ranked_paper_order)
    if found_at == -1:
        continue
    reci_ranks.append(1 / (found_at + 1))
  return max(reci_ranks)

def Rec_at_K(ranked_paper_order, ground_truth_papers, K = 10):
  l = len(ranked_paper_order)
  for actual in ground_truth_papers:
    found_at = search_paper(actual, ranked_paper_order[0 : min(K, l)])
    if found_at != -1:
        return 1
  return 0

def rank_candidate_contexts(test_point, p_okapi, p_plus):
  global prefetched_size
  test_context = test_point[0]
  true_papers = test_point[1]

  candidate_ID_to_ranks = {}
  # Maps candidate ID to 3 ranks: based on BM25Okapi, BM25Plus, and SciBERT similarity

  for index, [score, ID] in enumerate(p_okapi):
    candidate_ID_to_ranks[ID] = [index, prefetched_size, prefetched_size]

  for index, [score, ID] in enumerate(p_plus):
    if ID not in candidate_ID_to_ranks:
      candidate_ID_to_ranks[ID] = [prefetched_size, index, prefetched_size]
    else:
      candidate_ID_to_ranks[ID][1] = index

  candidate_IDs = [ID for ID in candidate_ID_to_ranks.keys()]

  SciBERT_score_IDs = get_SciBERT_scores(candidate_IDs, test_context)

  for index, [score, ID] in enumerate(SciBERT_score_IDs):
    candidate_ID_to_ranks[ID][2] = index

  candidate_ID_ranks = []

  for key in candidate_ID_to_ranks.keys():
    candidate_ID_ranks.append([combine_ranks(candidate_ID_to_ranks[key], test_point), key])

  candidate_ID_ranks = sorted(candidate_ID_ranks)
  return candidate_ID_ranks

def format_year(year_string):
  if not year_string or year_string == '':
    return 0
  return int(year_string)

def rank_papers(candidate_ID_ranks):
  global database
  paper_scores = []

  # Store
  # 0. p_r: Lowest Observed Context Rank of the Paper
  # 1. p_s: Total support for the paper
  # 2. rec_p: Recency of the paper
  # 3. store_index: Added such that no comparisons are made between two dictionaries
  # 4. The paper node

  count = 0

  for [rank, ID] in candidate_ID_ranks:
    mapped_papers = database[ID][1]

    for [paper, support] in mapped_papers:
      found = False
      for index, scored_paper in enumerate(paper_scores):
        if is_same_paper(scored_paper[4], paper):
          found = True
          paper_scores[index][0] = min(paper_scores[index][0], rank)
          paper_scores[index][1] = paper_scores[index][1] - support
          break
      if not found:
        paper_scores.append([rank, - support, - format_year(paper['year']), count, paper])
        count += 1

  paper_scores = sorted(paper_scores)
  return [paper_score[4] for paper_score in paper_scores]

In [9]:
if run_unit_test:
  test_index = 0
  context_ranks = rank_candidate_contexts(test_set[test_index], BM25Okapi_fetched[test_index], BM25Plus_fetched[test_index])
  ranked_papers = rank_papers(context_ranks)
  print('Suggested Papers:')
  print(ranked_papers[0 : min(3, len(ranked_papers))])
  print('Ground Truth Citations:')
  print(test_set[0][1])
  print('Reciprocal Rank: ' + str(Reci_Rank(ranked_papers, [paper for [paper, support] in test_set[test_index][1]])))
  print('Recall at 1: ' + str(Rec_at_K(ranked_papers, [paper for [paper, support] in test_set[test_index][1]], 1)))
  print('Recall at 3: ' + str(Rec_at_K(ranked_papers, [paper for [paper, support] in test_set[test_index][1]], 3)))
  print('Recall at 5: ' + str(Rec_at_K(ranked_papers, [paper for [paper, support] in test_set[test_index][1]], 5)))

Suggested Papers:
[{'title': 'Efficient Estimation of Word Representations in Vector Space', 'authors': [{'first': 'Tomas', 'middle': [], 'last': 'Mikolov', 'suffix': ''}, {'first': 'Kai', 'middle': [], 'last': 'Chen', 'suffix': ''}, {'first': 'Greg', 'middle': [], 'last': 'Corrado', 'suffix': ''}, {'first': 'Jeffrey', 'middle': [], 'last': 'Dean', 'suffix': ''}], 'year': 2013, 'venue': '', 'link': '5959482'}, {'title': 'Toward mention detection robustness with recurrent neural networks', 'authors': [{'first': 'T', 'middle': ['H'], 'last': 'Nguyen', 'suffix': ''}, {'first': 'A', 'middle': [], 'last': 'Sil', 'suffix': ''}, {'first': 'G', 'middle': [], 'last': 'Dinu', 'suffix': ''}, {'first': 'R', 'middle': [], 'last': 'Florian', 'suffix': ''}], 'year': 2016, 'venue': '', 'link': '6228859'}, {'title': 'Distributed representations of words and phrases and their compositionality', 'authors': [{'first': 'T', 'middle': [], 'last': 'Mikolov', 'suffix': ''}, {'first': 'I', 'middle': [], 'last'

In [10]:
def get_metrics():
  test_set, database, BM25Okapi_fetched, BM25Plus_fetched = load_files(domain)
  reci_ranks = []
  recalls = {1: [], 3: [], 5: [], 10: []}

  for test_index in range(len(test_set)):
    context_ranks = rank_candidate_contexts(test_set[test_index], BM25Okapi_fetched[test_index], BM25Plus_fetched[test_index])
    ranked_papers = rank_papers(context_ranks)
    reci_ranks.append(Reci_Rank(ranked_papers, [paper for [paper, support] in test_set[test_index][1]]))
    for K in recalls.keys():
      recalls[K].append(Rec_at_K(ranked_papers, [paper for [paper, support] in test_set[test_index][1]], K))

  metrics = {'MRR': numpy.mean(reci_ranks)}
  for K in recalls.keys():
    metrics['rec_' + str(K)] = numpy.mean(recalls[K])
  return metrics

In [11]:
# Full Experiments

for domain in domains:
  print('Domain: ' + str(domain) + '\n')
  test_set, database, BM25Okapi_fetched, BM25Plus_fetched = load_files(domain)

  mode = 0
  for consider_rankings in indices_possibilities:
    indices = consider_rankings
    print('Regular mode with considered rankings: ' + str(indices))
    print(get_metrics())
  print()

  mode = 1
  for threshold in [45, 47.5, 50, 52.5, 55]:
    switch_threshold = threshold
    print('Hybrid mode with threshold: ' + str(switch_threshold))
    print(get_metrics())
  print('\n')

Domain: ner

Regular mode with considered rankings: [2]
{'MRR': 0.2624626371891884, 'rec_1': 0.188, 'rec_3': 0.268, 'rec_5': 0.342, 'rec_10': 0.428}
Regular mode with considered rankings: [0, 2]
{'MRR': 0.3127211553217219, 'rec_1': 0.228, 'rec_3': 0.334, 'rec_5': 0.404, 'rec_10': 0.506}
Regular mode with considered rankings: [1, 2]
{'MRR': 0.30772668490841415, 'rec_1': 0.22, 'rec_3': 0.338, 'rec_5': 0.4, 'rec_10': 0.496}
Regular mode with considered rankings: [0, 1, 2]
{'MRR': 0.31502140618088964, 'rec_1': 0.222, 'rec_3': 0.344, 'rec_5': 0.418, 'rec_10': 0.51}
Regular mode with considered rankings: [0]
{'MRR': 0.3455846275576505, 'rec_1': 0.26, 'rec_3': 0.38, 'rec_5': 0.43, 'rec_10': 0.512}
Regular mode with considered rankings: [1]
{'MRR': 0.34735650741711965, 'rec_1': 0.262, 'rec_3': 0.386, 'rec_5': 0.436, 'rec_10': 0.51}
Regular mode with considered rankings: [0, 1]
{'MRR': 0.35013264412075396, 'rec_1': 0.264, 'rec_3': 0.386, 'rec_5': 0.438, 'rec_10': 0.514}

Hybrid mode with thresh

In [12]:
# That's it