<a href="https://colab.research.google.com/github/anon/ILCiteR/blob/main/prefetch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture

!pip install rank_bm25

import json
import tqdm
from rank_bm25 import BM25Okapi, BM25Plus

In [2]:
location = 'drive/My Drive/cite_reco_s2orc/full/'
maps_loc = 'maps/'
map_types = ['Test/', 'Database/']
dump_loc = 'experiments/prefetch/'

domains = ['ner', 'sa', 'summ', 'mt']
test_counts = [200, 500, 1000]

In [3]:
prefetch_count = 50
test_count_ID = 1
select_algo = 1
algo = ['BM25Okapi', 'BM25Plus'][select_algo]

def get_database(domain):
  global location, maps_loc, map_types

  with open(location + maps_loc + map_types[1] + domain + '.json', 'r+') as f:
    database = json.load(f)

  return database

def get_test_set(domain):
  global location, maps_loc, map_types, test_counts, test_count_ID

  with open(location + maps_loc + map_types[0] + domain + '_' + str(test_counts[test_count_ID]) + '.json', 'r+') as f:
    test_set = json.load(f)

  return test_set

def dump_prefetch(domain, testcase_prefetched):
  global location, maps_loc, map_types, test_counts, test_count_ID, algo

  with open(location + dump_loc + domain + '_' + str(test_counts[test_count_ID]) + '_' + algo + '.json', 'w+') as f:
    json.dump(testcase_prefetched, f)

  return

def tokenize(text):
  # Input text was already tokenized using nltk.word_tokenize
  return text.split(' ')

def tokenize_texts(texts):
  return [tokenize(text) for text in texts]

def make_candidate_texts(database):
  candidate_texts = []
  for context in database.keys():
    candidate_texts.append(context)
  return candidate_texts

def create_BM25_model(candidate_texts):
  global select_algo
  if select_algo == 0:
    bm25 = BM25Okapi(candidate_texts)
  else:
    bm25 = BM25Plus(candidate_texts)
  return bm25

def get_top_BM25_scores(query_context, bm25_model):
  global prefetch_count
  doc_scores = bm25_model.get_scores(query_context)
  scores_indices = []
  for index, score in enumerate(doc_scores):
    scores_indices.append([score, index])
  scores_indices = sorted(scores_indices, reverse = True)
  return scores_indices[0 : prefetch_count]

def run_prefetch(test_set, candidate_texts):
  bm25_model = create_BM25_model(candidate_texts)
  test_set_BM25_scores = []
  for datapoint in tqdm.tqdm(test_set):
    test_context = datapoint[0]
    scores_indices = get_top_BM25_scores(tokenize(test_context), bm25_model)
    test_set_BM25_scores.append(scores_indices)
  tqdm.tqdm.write('')
  tqdm.tqdm._instances.clear()
  return test_set_BM25_scores

In [4]:
run_unit_test = True

if run_unit_test:
  database = get_database('ner')
  test_set = get_test_set('ner')
  candidate_texts = make_candidate_texts(database)
  tokenized_candidate_texts = tokenize_texts(candidate_texts)
  test_set_prefetched = run_prefetch(test_set, tokenized_candidate_texts)
  dump_prefetch('ner', test_set_prefetched)

  print('Testcase[0]:')
  print(test_set[0])
  print('Extracted_Evidence[0]:')
  index = test_set_prefetched[0][0][1]
  print(candidate_texts[index])
  print('Suggested_Papers[0]:')
  print(database[candidate_texts[index]])

100%|██████████| 500/500 [01:52<00:00,  4.44it/s]


Testcase[0]:
['and word embeddings from word2vec', [[{'title': 'Efficient estimation of word representations in vector space', 'authors': [{'first': 'T', 'middle': [], 'last': 'Mikolov', 'suffix': ''}, {'first': 'K', 'middle': [], 'last': 'Chen', 'suffix': ''}, {'first': 'G', 'middle': [], 'last': 'Corrado', 'suffix': ''}, {'first': 'J', 'middle': [], 'last': 'Dean', 'suffix': ''}], 'year': 2013, 'venue': '', 'link': '5959482'}, 1]]]
Extracted_Evidence[0]:
Character embeddings , character bigram embeddings and word embeddings are pretrained separately using word2vec
Suggested_Papers[0]:
[[{'title': 'Efficient estimation of word representations in vector space', 'authors': [{'first': 'Tomas', 'middle': [], 'last': 'Mikolov', 'suffix': ''}, {'first': 'Kai', 'middle': [], 'last': 'Chen', 'suffix': ''}, {'first': 'Greg', 'middle': [], 'last': 'Corrado', 'suffix': ''}, {'first': 'Jeffrey', 'middle': [], 'last': 'Dean', 'suffix': ''}], 'year': 2013, 'venue': '', 'link': '5959482'}, 1]]





In [5]:
run_domains = domains[1:]

for domain in run_domains:
  tqdm.tqdm.write('Domain: ' + str(domain))
  tqdm.tqdm._instances.clear()
  database = get_database(domain)
  test_set = get_test_set(domain)
  candidate_texts = make_candidate_texts(database)
  tokenized_candidate_texts = tokenize_texts(candidate_texts)
  test_set_prefetched = run_prefetch(test_set, tokenized_candidate_texts)
  dump_prefetch(domain, test_set_prefetched)

Domain: sa


100%|██████████| 500/500 [05:00<00:00,  1.67it/s]



Domain: summ


100%|██████████| 500/500 [06:50<00:00,  1.22it/s]



Domain: mt


100%|██████████| 500/500 [09:52<00:00,  1.18s/it]





In [6]:
# That's it