#### Greedy (arc-standard) Transition Dependency Parser

We will now use the trained neural oracle to perform (arc-standard) dependency parsing. Given a sentence, we initialize a buffer containing the words and punctuation symbols of the sentence, a stack containing the `ROOT` and an empty dependency relations list. Starting from this initial state, we perform parse steps by applying actions chosen by the oracle and updating the system state. When the terminal state is reached (i.e. the buffer is empty and the stack only contains the `ROOT`), the complete dependency parse is contained in the dependency relations list.  

In [6]:
from parse_utils import *
import wandb

wandb.login()
print(torch.cuda.is_available())

%load_ext autoreload
%autoreload 2

True
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


First, lets load the validation set data.

In [7]:
data_val = read_conllu(os.path.join('data', 'dev.conll'))
print(f"Number of sentences in the validation data: {len(data_val)}")

Number of sentences in the validation data: 1700


Now lets implement the greedy arc-standard transition parser.

In [None]:
def parser(sentence_words, oracle):
    tokens = [(word, i+1) for i,word in enumerate(sentence_words)]

    # initialize the state
    stack = [('ROOT', 0)]
    buffer = tokens.copy()
    arcs = []

    # convert sentence to model input tensors
    input_idx, input_attn_mask, word_ids = dataset.tokenize_sentence(sentence_words)
    # set model to eval mode
    oracle.eval()
    # compute BERT encoding of sentence tokens
    with torch.no_grad():
        bert_output = oracle.get_bert_encoding(input_idx, input_attn_mask)

    idx2label = {v:k for k,v in dataset.label2idx.items()}    

    # begin parsing
    while len(buffer) > 0 or len(stack) > 1:

        if len(buffer) > 0:
            state = [(stack[-2:] , buffer[0])]
        else:
            state = [(stack[-2:], None)]
        state_idx = dataset.tokenize_state(state, word_ids)    

        # get the oracle action and label scores
        action_logits, label_logits = oracle.predict(bert_output, state_idx)
        
        # pick highest scoring action and label
        best_action = torch.argmax(action_logits[0][0])
        best_label = idx2label[dataset.torch.argmax(label_logits[0][0])]

        # perform the action
        if best_action == 0:
            # LEFTARC
            if len(stack) > 1:
                if stack[-2][0] != 'ROOT':
                    arcs.append((stack[-1], stack[-2], best_label))
                    stack.pop(-2)
                else:
                    raise ValueError("Cannot perform LEFTARC action with ROOT as dependent. Parse failed.")    
            else:
                raise ValueError("Cannot perform LEFTARC action with stack length <= 1. Parse failed.")

        elif best_action == 1:
            # RIGHTARC
            if len(stack) > 1:
                arcs.append((stack[-2], stack[-1], best_label))
                stack.pop(-1) 
            else:
                raise ValueError("Cannot perform RIGHTARC action with stack length <= 1. Parse failed.")         

        else:
            # SHIFT
            if len(buffer) > 0:
                stack.append(buffer.pop(0))
            else:
                raise ValueError("Cannot perform SHIFT action with buffer length <= 0. Parse failed.")               