<a href="https://colab.research.google.com/github/anon/ILCiteR/blob/main/refine_contexts.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import json
import spacy
from spacy import displacy
import nltk
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import display
import warnings
from collections import defaultdict
warnings.filterwarnings('ignore')

In [2]:
%%capture
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
# Initializing the processor
proc = spacy.load('en_core_web_sm')

In [4]:
location = 'drive/My Drive/cite_reco_s2orc/full/'
coarse_maps_loc = 'domain_coarse_mappings/'

domain_codes = ['mt', 'ner', 'sa', 'summ']

In [5]:
code = domain_codes[0]
with open(location + coarse_maps_loc + code + '.json', 'r+') as f:
  coarse = json.load(f)

In [6]:
example = 'Neural machine translation (NMT), a new approach to solving machine translation, has achieved promising results BIBREF6# BIBREF15# BIBREF0# BIBREF5# BIBREF10# BIBREF11#'

In [7]:
def group_refs(sentence):
  raw_tokens = nltk.word_tokenize(sentence)
  curr_group = -1
  in_group = False
  proc_toks = []
  l = len(raw_tokens)
  index = 0
  group_to_IDs = defaultdict(list)

  while index < l:
    token = raw_tokens[index]
    if token.startswith('BIBREF'):
      ID = int(token[6:])
      index += 1
      if in_group == False:
        curr_group += 1
        proc_toks.append('BIBGROUP' + str(curr_group))
      in_group = True
      group_to_IDs[curr_group].append(ID)
    else:
      proc_toks.append(token)
      in_group = False
    index += 1

  return proc_toks, dict(group_to_IDs)

def tokens_to_str(tokens):
  return ' '.join(tokens).strip()

In [8]:
example_A = tokens_to_str(group_refs(example)[0])
example_B = 'Typical LLMs include BERT BIBGROUP0 , BART BIBGROUP1 , GPT BIBGROUP2 .'
example_C = 'They used BERT BIBGROUP1, a popular Large Language Model BIBGROUP0, to generate context embeddings BIBGROUP2 .'
example_D = 'Context embeddings were generated using Sentence Transformers BIBGROUP0 .'
example_E = 'A large majority of work in the past few years has been focused on extractive summarization BIBGROUP0 , where a summary consists of key sentences from the source text .'
example_F = 'Therefore , we employ the primary encoder to generate coarse encoding using a GRU-based RNN BIBGROUP0 .'
example_G = 'They used a IEX parser BIBGROUP0 to encode the data streams .'
example_H = 'Traditionally , Recurrent Neural Networks (RNN) were used in encoders and decoders BIBGROUP3 , but other neural network architectures such as Convolutional Neural Networks (CNN) BIBGROUP0 and attention mechanism-based models BIBGROUP1 are also used.'

In [9]:
# Generating the dependency tree for the sentence

extract = proc(example_C)
displacy.render(extract, style = 'dep', jupyter = True, options = {'distance': 120})

In [10]:
# Refined Contexts for picking out entities
# Fetch tokens attached using the 'COMPOUND' or 'AMOD' tag ~ while previous token is a child and has the tag 'COMPOUND'

def drop_REF_tags(tokens):
  regular_tokens = []
  for token in tokens:
    if token.text.startswith('BIBGROUP'):
      continue
    regular_tokens.append(token.text)

  return tokens_to_str(regular_tokens)

def drop_particular_tags(tokens, drop_tags_set):
  processed = []
  for token in tokens:
    if token in drop_tags_set:
      continue
    processed.append(token)
  return processed

valid_prev_deps = set(['compound', 'amod'])

def make_fine_mappings(sentence, display = False):
  global valid_prev_deps
  extract = proc(sentence)
  mappings = defaultdict(list)
  all_groups = []
  last_alnum_token = None

  for index, node in enumerate(extract):
    if node.text.isalnum():
      last_alnum_token = node
    if node.text.startswith('BIBGROUP'):
      group_ID = int(node.text[8:])
      all_groups.append(group_ID)
      prev_index = index - 1
      context_nodes = []
      while prev_index > 0:
        next_children_set = set([child for child in extract[prev_index + 1].children])
        if extract[prev_index].dep_ in valid_prev_deps and extract[prev_index] in next_children_set:
          context_nodes.append(extract[prev_index])
          prev_index -= 1
        else:
          break
      if context_nodes != []:
        mappings[group_ID].append(tokens_to_str([node.text for node in context_nodes][:: -1]))

  accounted_IDs = set([key for key in mappings.keys()])
  all_groups_accounted = (len(all_groups) == len(accounted_IDs))
  has_eos_ref = last_alnum_token is not None and last_alnum_token.text.startswith('BIBGROUP')

  if display:
    print('Group ID to Refined Contexts: ' + str(dict(mappings)))
    print('All Groups IDs: ' + str(all_groups))
    print('all_groups_accounted: ' + str(all_groups_accounted))
    print('Has REF at End Of Sentence : ' + str(has_eos_ref))

  if not all_groups_accounted or has_eos_ref:
    # Split the sequence of sentence tokens into a sequence of refined contexts

    if has_eos_ref:
      last_ID = int(last_alnum_token.text[8:])
      clean_sent = drop_REF_tags(extract)
      mappings[last_ID].append(clean_sent)
      accounted_IDs.add(last_ID)

  add_context_IDs = set([ID for ID in all_groups if ID not in accounted_IDs])

  if len(add_context_IDs) != 0:
    tokens = drop_particular_tags([unit.text for unit in extract], accounted_IDs)

    running_tokens = []
    for token in tokens:
      if token.startswith('BIBGROUP'):
        ID = int(token[8:])
        if ID in add_context_IDs:
          if running_tokens != []:
            mappings[ID] = tokens_to_str(running_tokens)
            running_tokens = []
      else:
        running_tokens.append(token)

  return dict(mappings)

In [11]:
# Usage:
gen_mappings = make_fine_mappings(example_C, display = True)
print('Generated Mappings:')
print(gen_mappings)

Group ID to Refined Contexts: {1: ['BERT'], 0: ['Large Language Model']}
All Groups IDs: [1, 0, 2]
all_groups_accounted: False
Has REF at End Of Sentence : True
Generated Mappings:
{1: ['BERT'], 0: ['Large Language Model'], 2: ['They used BERT , a popular Large Language Model , to generate context embeddings .']}


In [12]:
# That's it