<a href="https://colab.research.google.com/github/anon/ILCiteR/blob/main/collect_contexts.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Step 1: Group all contexts from
# (a) The test set
# (b) The fetched candidate contexts corresponding to each test datapoint

# Store them in an indexed list and a set
# Fetch their Embeddings using SciBERT
# Create a dictionary mapping each text contect to SciBERT embedding

# Consider all 4 domains together

In [2]:
import json
import tqdm

In [3]:
location = 'drive/My Drive/cite_reco_s2orc/full/'
expts_loc = 'experiments/'
prefetch_loc = expts_loc + 'prefetch/'
maps_loc = 'maps/'
map_types = ['Test/', 'Database/']

domains = ['ner', 'sa', 'summ', 'mt']
algos = ['BM25Okapi', 'BM25Plus']
test_count = 500

In [4]:
def get_database(domain):
  global location, maps_loc, map_types

  with open(location + maps_loc + map_types[1] + domain + '.json', 'r+') as f:
    database = json.load(f)

  return database

def get_test_set(domain):
  global location, maps_loc, map_types, test_count

  with open(location + maps_loc + map_types[0] + domain + '_' + str(test_count) + '.json', 'r+') as f:
    test_set = json.load(f)

  return test_set

def dump_all_contexts(all_contexts):
  global location, expts_loc, test_count

  with open(location + expts_loc + 'all_contexts_' + str(test_count) + '.json', 'w+') as f:
    json.dump(all_contexts, f)

  return

def load_prefetch(domain, algo):
  global location, prefetch_loc, test_count

  with open(location + prefetch_loc + domain + '_' + str(test_count) + '_' + algo + '.json', 'r+') as f:
    prefetched = json.load(f)

  return prefetched

def get_contexts(prefetched, database_contexts, test_set_contexts):
  contexts_set = set()

  for test_index, prefetched_candidate_scores in enumerate(prefetched):
    database_candidate_indices = [unit[1] for unit in prefetched_candidate_scores]

    contexts_set.add(test_set_contexts[test_index])

    for database_index in database_candidate_indices:
      contexts_set.add(database_contexts[database_index])

  return contexts_set

In [5]:
# Unit Test on a single domain

run_unit_test = False

if run_unit_test:
  database = get_database('ner')
  database_contexts = [context for context in database.keys()]
  test_set = get_test_set('ner')
  test_set_contexts = [datapoint[0] for datapoint in test_set]
  prefetched = load_prefetch('ner', 'BM25Plus')
  contexts_set = get_contexts(prefetched, database_contexts, test_set_contexts)
  dump_all_contexts(list(contexts_set))

In [6]:
# Considering all domains and algorithms

all_contexts = set()

for domain in tqdm.tqdm(domains):
  database = get_database(domain)
  database_contexts = [context for context in database.keys()]
  test_set = get_test_set(domain)
  test_set_contexts = [datapoint[0] for datapoint in test_set]

  for algo in algos:
    prefetched = load_prefetch(domain, algo)
    contexts_set = get_contexts(prefetched, database_contexts, test_set_contexts)

    all_contexts = all_contexts.union(contexts_set)

all_contexts = list(all_contexts)
dump_all_contexts(all_contexts)

tqdm.tqdm.write('')

print('Number of contexts fetched: ' + str(len(all_contexts)))

100%|██████████| 4/4 [00:12<00:00,  3.15s/it]



Number of contexts fetched: 64765


In [7]:
# That's it