<a href="https://colab.research.google.com/github/anon/ILCiteR/blob/main/cite_reco_s2orc_mapping.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Utility functions to create refined context to BIBREF mappings from PDF parses

In [2]:
import re
import json
import nltk
import tqdm
import string
import spacy
import warnings
from collections import defaultdict
warnings.filterwarnings('ignore')

# Initializing spaCy's text features extractor
proc = spacy.load('en_core_web_sm')

In [3]:
%%capture
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:
location = 'drive/My Drive/cite_reco_s2orc/full/'
parses_loc = 'domain_parses/'
mappings_loc = 'mappings/'
split_locs = ['Database/', 'Eval/']

all_domains = ['ner', 'sa', 'summ', 'mt']
domain_codes = [all_domains[3]]

In [5]:
def normalize_text(text):
  tokens = nltk.word_tokenize(text.lower())
  return ' '.join(tokens).strip()

In [6]:
# Extraction Procedure:
# Create a list of section texts where cite span text has been replaced with the BIBREF#
# Tokenize texts into sentences
# If the sentence contains atleast one BIBREF:
#   Group consecutive BIBREFs to create BIBGROUPS
#   Fetch all mappings from coarse and fine contexts to specific BIBGROUPS (using functions in refine_context)
#   Add all mappings to a database

In [7]:
# Loading parses
domain = domain_codes[0]

def get_parses_from_file(domain, split_type):
  # split_type is either 0 for Database or 1 for Eval
  global location, parses_loc, split_locs

  with open(location + parses_loc + split_locs[split_type] + domain + '.json', 'r+') as f:
    parses = json.load(f)

  return parses

def dump_mappings(domain, split_type, mappings):
  # split_type is either 0 for Database or 1 for Eval
  global location, parses_loc, split_locs

  with open(location + mappings_loc + split_locs[split_type] + domain + '.json', 'w+') as f:
    json.dump(mappings, f)

  return

parses = get_parses_from_file(domain, 1)

In [8]:
# Creating the list of section texts

def get_section_texts(paper):
  section_texts = []

  for section in paper['body_text']:
    if not section['cite_spans'] or section['cite_spans'] == []:
      continue
    text = section['text']
    spans = sorted(section['cite_spans'], key = lambda unit: unit['start'], reverse = True)
    for cite_block in spans:
      text = text[: cite_block['start']] + cite_block['ref_id'] + '#' + text[cite_block['end'] :]
    section_texts.append(text)

  return section_texts

In [9]:
# Example paper parse
paper = parses[0]
section_texts = get_section_texts(paper)
print('Number of sections having extractable datapoints:', len(section_texts))

Number of sections having extractable datapoints: 16


In [10]:
def extract_refs(sentence):
  refs = [matched[6: -1] for matched in re.findall('BIBREF[0-9]+#', sentence)]
  return refs

In [11]:
# Unit test

sentence = 'Importance of effective contract enforcement BIBREF5# to economic performance BIBREF15# BIBREF21#.'
extract_refs(sentence)

['5', '15', '21']

In [12]:
# Creating mappings from sentences to containing BIBREFs

def get_sentence_entries(section_texts):
  raw_entries = []

  for section in section_texts:
    sentences = nltk.sent_tokenize(section)
    for sentence in sentences:
      if 'BIBREF' in sentence:
        refs = extract_refs(sentence)
        if refs == []:
          continue
        raw_entries.append(sentence)

  return raw_entries

raw_entries = get_sentence_entries(section_texts)

In [13]:
# Functions for Context Refinement

def tokens_to_str(tokens):
  return ' '.join(tokens).strip()

def group_refs(sentence):
  raw_tokens = nltk.word_tokenize(sentence)
  curr_group = -1
  in_group = False
  proc_toks = []
  l = len(raw_tokens)
  index = 0
  group_to_IDs = defaultdict(list)

  while index < l:
    token = raw_tokens[index]
    if token.startswith('BIBREF'):
      ID = int(token[6:])
      index += 1
      if in_group == False:
        curr_group += 1
        proc_toks.append('BIBGROUP' + str(curr_group))
      else:
        # Remove commas and other punctuations between BIBREFs
        while len(proc_toks) > 0 and proc_toks[-1].startswith('BIBGROUP') == False:
          proc_toks.pop()
      in_group = True
      group_to_IDs[curr_group].append(ID)
    else:
      proc_toks.append(token)
      if token not in string.punctuation:
        in_group = False
    index += 1

  if in_group:
    # Remove Punctuations at the end
    while len(proc_toks) > 0 and proc_toks[-1].startswith('BIBGROUP') == False:
      proc_toks.pop()

  return proc_toks, dict(group_to_IDs)

def drop_REF_tags(tokens):
  regular_tokens = []
  for token in tokens:
    if token.text.startswith('BIBGROUP'):
      continue
    regular_tokens.append(token.text)

  return tokens_to_str(regular_tokens)

def drop_particular_tags(tokens, drop_tags_set):
  processed = []
  for token in tokens:
    if token in drop_tags_set:
      continue
    processed.append(token)
  return processed

valid_prev_deps = set(['compound', 'amod'])

def make_fine_mappings(sentence):
  global valid_prev_deps
  extract = proc(sentence)
  mappings = defaultdict(list)
  all_groups = []
  last_alnum_token = None

  for index, node in enumerate(extract):
    if node.text.isalnum():
      last_alnum_token = node
    if node.text.startswith('BIBGROUP'):
      group_ID = int(node.text[8:])
      all_groups.append(group_ID)
      prev_index = index - 1
      context_nodes = []
      while prev_index > 0:
        next_children_set = set([child for child in extract[prev_index + 1].children])
        if extract[prev_index].dep_ in valid_prev_deps and extract[prev_index] in next_children_set:
          context_nodes.append(extract[prev_index])
          prev_index -= 1
        else:
          break
      if context_nodes != []:
        mappings[group_ID].append(tokens_to_str([node.text for node in context_nodes][:: -1]))

  accounted_IDs = set([key for key in mappings.keys()])
  all_groups_accounted = (len(all_groups) == len(accounted_IDs))
  has_eos_ref = last_alnum_token is not None and last_alnum_token.text.startswith('BIBGROUP')

  if not all_groups_accounted or has_eos_ref:
    # Split the sequence of sentence tokens into a sequence of refined contexts

    if has_eos_ref or len(all_groups) == 1:
      if has_eos_ref:
        use_ID = int(last_alnum_token.text[8:])
      else:
        # If the sentence only has one BIBGROUP
        use_ID = all_groups[0]
      clean_sent = drop_REF_tags(extract)
      mappings[use_ID].append(clean_sent)
      accounted_IDs.add(use_ID)

  add_context_IDs = set([ID for ID in all_groups if ID not in accounted_IDs])

  if len(add_context_IDs) != 0:
    tokens = drop_particular_tags([unit.text for unit in extract], accounted_IDs)

    running_tokens = []
    for token in tokens:
      if token.startswith('BIBGROUP'):
        ID = int(token[8:])
        if ID in add_context_IDs:
          if running_tokens != []:
            mappings[ID].append(tokens_to_str(running_tokens))
            running_tokens = []
      else:
        running_tokens.append(token)

  return dict(mappings)

In [14]:
def process_paper(paper):
  section_texts = get_section_texts(paper)
  sentences = get_sentence_entries(section_texts)
  paper_ID = paper['paper_id']
  datapoints = []

  for sentence in sentences:
    grouped_ref_tokens, group_to_refs = group_refs(sentence)
    grouped_ref_sent = tokens_to_str(grouped_ref_tokens)
    group_ID_to_contexts = make_fine_mappings(grouped_ref_sent)

    for group_ID in group_to_refs.keys():
      if group_ID not in group_ID_to_contexts.keys():
        continue
      contexts = group_ID_to_contexts[group_ID]
      for context in contexts:
        for ref_ID in group_to_refs[group_ID]:
          ref_key = 'BIBREF' + str(ref_ID)
          if ref_key not in paper['bib_entries']:
            continue
          if context and len(context) > 1:
            datapoints.append([paper_ID, context, ref_ID, paper['bib_entries'][ref_key]])

  return datapoints

In [15]:
# View an example datapoint
created_datapoints = process_paper(paper)
created_datapoints[15]

['214802675',
 'We use the transformer based encoder - decoder architecture by casting data - totext as a seq2seq problem , where the structured data is flattened into a plain string consisting of a series of intents and slot key - value pairs .',
 29,
 {'title': 'Attention is all you need',
  'authors': [{'first': 'Ashish',
    'middle': [],
    'last': 'Vaswani',
    'suffix': ''},
   {'first': 'Noam', 'middle': [], 'last': 'Shazeer', 'suffix': ''},
   {'first': 'Niki', 'middle': [], 'last': 'Parmar', 'suffix': ''},
   {'first': 'Jakob', 'middle': [], 'last': 'Uszkoreit', 'suffix': ''},
   {'first': 'Llion', 'middle': [], 'last': 'Jones', 'suffix': ''},
   {'first': 'Aidan', 'middle': ['N'], 'last': 'Gomez', 'suffix': ''},
   {'first': 'Łukasz', 'middle': [], 'last': 'Kaiser', 'suffix': ''},
   {'first': 'Illia', 'middle': [], 'last': 'Polosukhin', 'suffix': ''}],
  'year': 2017,
  'venue': 'Advances in neural information processing systems',
  'link': '13756489'}]

In [16]:
# Create context to citation mappings for each domain and split type

def create_mappings(domain, split_type):
  parses = get_parses_from_file(domain, split_type)
  mappings = []
  for paper in tqdm.tqdm(parses):
    datapoints = process_paper(paper)
    mappings += datapoints

  dump_mappings(domain, split_type, mappings)
  return mappings

In [17]:
# Test Run on a single domain and split type
# mappings = create_mappings('ner', 1)

In [18]:
# For domains in domain_codes

for domain in domain_codes:
  for split_type in range(2):
    print('Domain: ' + str(domain) + ', Split: ' + str(split_type), flush = True)
    mappings = create_mappings(domain, split_type)
    print('Size: ' + str(len(mappings)))

Domain: mt, Split: 0


100%|██████████| 9322/9322 [21:03<00:00,  7.38it/s]


Size: 161698
Domain: mt, Split: 1


100%|██████████| 200/200 [01:04<00:00,  3.10it/s]


Size: 8648


In [19]:
# That's it