#### Greedy (arc-standard) Transition Dependency Parser

We will now use the trained neural oracle to perform (arc-standard) dependency parsing. Given a sentence, we initialize a buffer containing the words and punctuation symbols of the sentence, a stack containing the `ROOT` and an empty dependency relations list. Starting from this initial state, we perform parse steps by applying actions chosen by the oracle and updating the system state. When the terminal state is reached (i.e. the buffer is empty and the stack only contains the `ROOT`), the complete dependency parse is contained in the dependency relations list.  

In [37]:
from parse_utils import *
import wandb
import pickle 

wandb.login()
print(torch.cuda.is_available())

%load_ext autoreload
%autoreload 2

True
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


First, lets load the validation set data.

In [2]:
data_val = read_conllu(os.path.join('data', 'dev.conll'))
print(f"Number of sentences in the validation data: {len(data_val)}")

Number of sentences in the validation data: 1700


In [30]:
# load pytorch dataset object from file
with open('val_dataset_pytorch.pkl', 'rb') as f:
    val_dataset = pickle.load(f)

action2idx = val_dataset.action2idx
label2idx = val_dataset.label2idx

Now load the trained oracle model

In [8]:
DEVICE = "cuda"
learning_rate = 1e-5
model = BERT_ORACLE(num_actions=len(action2idx), num_labels=len(label2idx), unlabeled_arcs=False).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model, optimizer = load_model_checkpoint(model, optimizer)

num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Loaded model from checkpoint!
Total number of parameters in transformer network: 66.959019 M
RAM used: 2707.07 MB


In [12]:
"""
val_dataloader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn)
validation(model, val_dataloader, device=DEVICE)
"""

(0.2752455174922943, 0.9773752207663278, 0.9538105168222937)

Now lets implement the greedy arc-standard transition parser. This implementation is very strict and causes parser to fail whenever oracle predicts an invalid action.

In [50]:
"""
    Greedy arc-standard transition parser
"""
def parser(sentence_words, oracle, dataset, verbose=False):
    tokens = [(word, i+1) for i,word in enumerate(sentence_words)]

    # initialize the state
    stack = [('ROOT', 0)]
    buffer = tokens.copy()
    arcs = []

    # convert sentence to model input tensors
    input_idx, input_attn_mask, word_ids = dataset.tokenize_sentence(sentence_words)
    input_idx = input_idx.unsqueeze(0).to(DEVICE)
    input_attn_mask = input_attn_mask.unsqueeze(0).to(DEVICE)
    
    # set model to eval mode
    oracle.eval()
    # compute BERT encoding of sentence tokens
    with torch.no_grad():
        bert_output = oracle.get_bert_encoding(input_idx, input_attn_mask)

    labels = list(dataset.label2idx.keys())
    actions = list(dataset.action2idx.keys())

    if verbose: 
            print(f"\nStack: {stack}")
            print(f"Buffer: {buffer}")   

    # begin parsing
    while len(buffer) > 0 or len(stack) > 1:

        if len(buffer) > 0:
            state = [(stack[-2:] , buffer[0])]
        else:
            state = [(stack[-2:], None)]
        state_idx = [dataset.tokenize_state(state, word_ids)]    

        # get the oracle action and label scores
        action_logits, label_logits = oracle.predict(bert_output, state_idx)
        
        # pick highest scoring action and label
        best_action = actions[torch.argmax(action_logits[0][0])]
        best_label = labels[torch.argmax(label_logits[0][0])]

        # perform the action
        if best_action == 'LEFTARC':
            # LEFTARC
            if len(stack) > 1:
                if stack[-2][0] != 'ROOT':
                    arcs.append((stack[-1], stack[-2], best_label))
                    stack.pop(-2)
                else:
                    raise ValueError("Cannot perform LEFTARC action with ROOT as dependent. Parse failed.")    
            else:
                raise ValueError("Cannot perform LEFTARC action with stack length <= 1. Parse failed.")

        elif best_action == 'RIGHTARC':
            # RIGHTARC
            if len(stack) > 1:
                arcs.append((stack[-2], stack[-1], best_label))
                stack.pop(-1) 
            else:
                raise ValueError("Cannot perform RIGHTARC action with stack length <= 1. Parse failed.")         

        else:
            # SHIFT
            if len(buffer) > 0:
                stack.append(buffer.pop(0))
            else:
                raise ValueError("Cannot perform SHIFT action with buffer length <= 0. Parse failed.")         

        if verbose:
            print(f"Best action: {best_action}, best label: {best_label}")
            print(f"\nStack: {stack}")
            print(f"Buffer: {buffer}")
            print(f"Arcs: {arcs}")        

    return arcs                       


"""
    Evaluate the predicted arcs against the gold arcs by computing unlabeled and labeled attachment scores
"""
def evaluate(gold_arcs, predicted_arcs):
    uas = 0
    las = 0
    gold_head_deps = [(r[0], r[1]) for r in gold_arcs]

    for r in predicted_arcs:
        if (r[0], r[1]) in gold_head_deps:
            uas += 1
            if r in gold_arcs:
                las += 1

    uas = uas / len(gold_arcs)
    las = las / len(gold_arcs)            
    return uas, las    

In [53]:
# get a test sentence from the validation set and its gold standard parse
test_data_instance = data_val[10]
gold_states, gold_actions, gold_labels, sentence_words, gold_arcs  = training_oracle(test_data_instance, return_states=True, max_iters=100000)

# predict the parse using the oracle
precicted_arcs = parser(sentence_words, model, val_dataset, verbose=False)

In [54]:
# compare the gold standard and predicted arcs
uas, las = evaluate(gold_arcs, precicted_arcs)
print(f"UAS: {uas}, LAS: {las}")

UAS: 1.0, LAS: 0.9666666666666667
