In [2]:
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F

In [3]:
class Parsing:
  def __init__(self):
    self.transitions = ['shift', 'left-arc', 'right-arc']

  def init_config(self, sentence): # initially stack has root, buffer has all words and arcs are empty
    stack = [0]
    buffer = list(range(1, len(sentence)+1))
    arcs = [] # will store [head, dependent] pairs
    return stack, buffer, arcs

  def moves(self, stack, buffer):
    valid = [False, False, False] # return booleans for the actions - shift, left-arc and right arc
    if len(buffer)>0:
      valid[0]= True

    if len(stack)>=2:
      valid[2]= True
      if stack[-2]!= 0:
        valid[1]= True

    return valid

  def apply_moves(self, stack, buffer, arcs, act):
    if act == 0: # shift
      word = buffer.pop(0)
      stack.append(word)

    elif act == 1: # left-arc
      dependent = stack.pop(-2)
      head = stack[-1]
      arcs.append((head, dependent))

    elif act == 2: # right-arc
      dependent = stack.pop(-1)
      head = stack[-1]
      arcs.append((head, dependent))

    return stack, buffer, arcs

In [4]:
class ParsingModel(nn.Module):
  def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes = 3):
    super(ParsingModel, self).__init__()
    self.embedding = nn.Embedding(vocab_size, embed_dim)
    self.input_size = embed_dim* 3 # top 2 stack and 1 top buffer
    self.hidden = nn.Linear(self.input_size, hidden_dim)
    self.output = nn.Linear(hidden_dim, num_classes)
    self.activation = lambda x: torch.pow(x, 3) # cube activation by 'chen and manning' paper

  def forward(self, feature_indices):
    embeds = self.embedding(feature_indices)
    embeds_flatten = embeds.view(embeds.size(0), -1)
    x = self.activation(self.hidden(embeds_flatten)) # hidden transform + cube activation fnc
    logits = self.output(x)
    return logits

In [5]:
def feature_extract(stack, buffer):
  if len(stack)>0:
    s1=stack[-1]
  else:
    s1= 0 # empty pos

  if len(stack)>1:
    s2=stack[-2]
  else:
    s2= 0

  if len(buffer)>0:
    b1 = buffer[0]
  else:
    b1 = 0

  return torch.tensor([[s1, s2, b1]], dtype = torch.long)