<a href="https://colab.research.google.com/github/anon/ILCiteR/blob/main/create_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import json
import nltk
import random
import string
import tqdm
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

random.seed('1')

In [2]:
%%capture
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
location = 'drive/My Drive/cite_reco_s2orc/full/'
maps_loc = 'maps/'
map_types = ['Database/', 'Eval/']
test_loc = 'Test/'

domains = ['ner', 'sa', 'summ', 'mt']

In [4]:
test_count = 200

def get_mappings_from_file(domain, map_type = 0):
  # map_type: 0 for 'Database', 1 for 'Eval'
  global location, maps_loc, map_types

  with open(location + maps_loc + map_types[map_type] + domain + '.json', 'r+') as f:
    map = json.load(f)

  return map

def dump_test_set(domain, test_set):
  global location, maps_loc, text_loc, test_count

  with open(location + maps_loc + test_loc + domain + '_' + str(test_count) + '.json', 'w+') as f:
    json.dump(test_set, f)

  return

In [5]:
run_unit_test = False

# Unit test using 'ner' domain
if run_unit_test:
  database = get_mappings_from_file('ner', 0)
  eval = get_mappings_from_file('ner', 1)

In [6]:
# Validate contexts
with open('./stopwords.txt', 'r') as f:
  stopwords = f.read().split('\n')

bad_tokens = []
for code in range(128):
  bad_tokens.append(chr(code))

for punct in string.punctuation:
  bad_tokens.append(punct)

bad_tokens += stopwords
bad_tokens = set(bad_tokens)

def validate_context(text):
  # If all tokens within context are in bad_tokens (stopword OR punctuation or random characters)
  # Then Invalidate
  global bad_tokens

  tokens = nltk.word_tokenize(text.lower())
  for token in tokens:
    if token not in bad_tokens:
      return True

  return False

In [7]:
def construct_ref_papers_list(database):
  all_papers = []

  for context in tqdm.tqdm(database.keys()):
    papers = database[context]
    for paper in papers:
      paper_dict = paper[0]
      support = paper[1]

      found_at = -1
      for index, candidate in all_papers:
        if candidate == paper_dict:
          found_at = index
          break

      if found_at != -1:
        all_papers[found_at][1] += support
      else:
        all_papers.append([paper_dict, support])

  tqdm.tqdm._instances.clear()
  return all_papers

if run_unit_test:
  database_papers = construct_ref_papers_list(database)

In [8]:
def get_paper_list_doability_score(paper_list):
  score = 0
  global database_papers

  for candi_paper in paper_list:
    for exists_paper in database_papers:
      if exists_paper[0] == candi_paper[0]:
        score = max(score, exists_paper[1])
        break

  return score

def get_eval_point_scores():
  global eval
  eval_datapoints_with_scores = []
  for context in tqdm.tqdm(list(eval.keys())):
    paper_list = eval[context]
    if not validate_context(context):
      continue
    score = get_paper_list_doability_score(paper_list)
    eval_datapoints_with_scores.append([context, paper_list, score])

  tqdm.tqdm._instances.clear()
  return eval_datapoints_with_scores

def make_test_set():
  global test_count
  eval_datapoints_with_scores = get_eval_point_scores()
  sorted_datapoints = sorted(eval_datapoints_with_scores, key = lambda item: item[2], reverse = True)
  return [item[0 : 2] for item in sorted_datapoints[0 : test_count]]

In [9]:
# Unit Test

if run_unit_test:
  test_set = make_test_set()
  print('', flush = True)
  print('test_set[0]:', flush = True)
  # Highest doability score
  print(test_set[0], flush = True)
  dump_test_set('ner', test_set)

In [10]:
# Create test sets for each domain

run_domains = ['mt']

for domain in run_domains:
  print('Domain: ' + str(domain), flush = True)
  database = get_mappings_from_file(domain, 0)
  eval = get_mappings_from_file(domain, 1)

  print('- Creating list of papers within the Database', flush = True)
  database_papers = construct_ref_papers_list(database)

  print('', flush = True)
  print('- Fetching doability scores for contexts in the Evaluation set', end = '', flush = True)
  test_set = make_test_set()
  random.shuffle(test_set)
  dump_test_set(domain, test_set)
  print('', flush = True)

Domain: mt
- Creating list of papers within the Database


100%|██████████| 108692/108692 [45:28<00:00, 39.83it/s]


- Fetching doability scores for contexts in the Evaluation set


100%|██████████| 6139/6139 [08:34<00:00, 11.94it/s]







In [11]:
# That's it