In [None]:
from functools import partial
from datasets import load_dataset
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [82]:
class ArcEager:
    def __init__(self, sentence):

        # sentence is the input for which we want to build our Arc-Standard
        self.sentence = sentence

        # here we create the buffer having an array of indexes with the same length as the sentence
        # basically, each word has its own index in this buffer
        # we have initialized the buffer having all the words in the sentence

        self.buffer = [i for i in range(len(self.sentence))] 

        # initialize the stack empty 
        
        self.stack = []

        # representation of the tree
        # every word will have a -1 assigned -> no father has been assigned yet

        self.arcs = [-1 for _ in range(len(self.sentence))]

        # three shift moves to initialize the stack

        # means that in the stack now is the ROOT
        # self.shift() it calls a method that implements this operation; we will look at it after 

        self.shift() 


    def shift(self):
       
     b1 = self.buffer[0]
     self.buffer = self.buffer[1:]
     self.stack.append(b1)
    
    def left_arc(self): 

     o1 = self.stack.pop()
     o2 = self.buffer[0]
     self.arcs[o1] = o2
 
    def right_arc(self):
     o1 = self.buffer[0]
     o2 = self.stack[-1]
     self.arcs[o1] = o2
     self.stack.append(o1)
     self.buffer = self.buffer[1:]

    def reduce(self):
      self.stack.pop()

    
    def is_tree_final(self):
     return len(self.stack) == 1 and len(self.buffer) == 0
    

    def print_configuration(self):

      s = [self.sentence[i] for i in self.stack]
      b = [self.sentence[i] for i in self.buffer]
      print(s,b)
      print(self.arcs)
      

In [83]:
sentence = ["<ROOT>", "He","wrote","her","a","letter","."]
gold = [-1, 2, 0, 2, 5, 2, 2]

parser = ArcEager(sentence)
parser.print_configuration()

['<ROOT>'] ['He', 'wrote', 'her', 'a', 'letter', '.']
[-1, -1, -1, -1, -1, -1, -1]


In [12]:
class Oracle:
 
 def __init__(self, parser, gold_tree):
  self.parser = parser
  self.gold = gold_tree

 
 def has_head(self, top_stack):
  if self.parser.arcs[top_stack] != -1:
   return True
  else:
   return False

 def has_all_children(self, top_stack):
  for i, arc in enumerate(self.gold):
   if arc == top_stack and self.parser.arcs[i] != top_stack:
    return False
  return True
 
 def is_right_arc_gold(self):
  if len(self.parser.buffer)==0:
   return False
  o1 = self.parser.buffer[0]
  o2 = self.parser.stack[len(self.parser.stack)-1]
  if self.gold[o1] != o2:
   return False
  
  return True
   
 def is_shift_gold(self):
  if len(self.parser.buffer) == 0:
   return False
  

 


 
 
 
 def is_left_arc_gold(self):
  
  # we can do the left arc if sigma2 is the children of the sigma1
  # this means that u have assigned to sigma2 all of its children and is ok to assing its parent -> static oracle
  
  # here we get our sigma1 and sigma2

  o1 = self.parser.stack[len(self.parser.stack)-1]
  o2 = self.parser.stack[len(self.parser.stack)-2]


  if self.gold[o2] == o1:
   return True
  return False

 def is_shift_gold(self):
 
  if len(self.parser.buffer) == 0:
   return False
 
  if (self.is_left_arc_gold() or self.is_right_arc_gold()):
   return False
 
  return True

 def is_right_arc_gold(self):
   o1 = self.parser.stack[len(self.parser.stack)-1]
   o2 = self.parser.stack[len(self.parser.stack)-2]

   if self.gold[o1] != o2:
    return False
   
   # we need to check that no children of sigma1 are in the rest of the buffer
   
   for i in self.parser.buffer:
    if self.gold[i] == o1:
     return False
   return True

In [13]:
sentence = ["<ROOT>", "He","began","to","write","again","."]
gold = [-1, 2, 0, 4, 2, 4, 2]

parser = ArcEager(sentence)
oracle = Oracle(parser, gold)

parser.print_configuration()

['<ROOT>', 'He', 'began'] ['to', 'write', 'again', '.']
[-1, -1, -1, -1, -1, -1, -1]


In [14]:
print("Left Arc: " ,oracle.is_left_arc_gold())
print("Right Arc: ",oracle.is_right_arc_gold())
print("Shift: ",oracle.is_shift_gold())


Left Arc:  True
Right Arc:  False
Shift:  False


In [15]:
# oracle tells us that hte next move is the left_arc wqe do it and ask again the Oracle
parser.left_arc()
parser.print_configuration()

['<ROOT>', 'to'] ['write', 'again', '.']
[-1, -1, 3, -1, -1, -1, -1]


In [16]:
print("Left Arc: " ,oracle.is_left_arc_gold())
print("Right Arc: ",oracle.is_right_arc_gold())
print("Shift: ",oracle.is_shift_gold())

Left Arc:  False
Right Arc:  False
Shift:  True


In [17]:
while not parser.is_tree_final():
    if oracle.is_shift_gold():
        parser.shift()
    elif oracle.is_left_arc_gold():
        parser.left_arc()
    elif oracle.is_right_arc_gold():
        parser.right_arc()
        
print(parser.arcs)
print(gold)
    

KeyboardInterrupt: 