#### Greedy (arc-standard) Transition Dependency Parser

We will now use the trained neural oracle to perform (arc-standard) dependency parsing. Given a sentence, we initialize a buffer containing the words and punctuation symbols of the sentence, a stack containing the `ROOT` and an empty dependency relations list. Starting from this initial state, we perform parse steps by applying actions chosen by the oracle and updating the system state. When the terminal state is reached (i.e. the buffer is empty and the stack only contains the `ROOT`), the complete dependency parse is contained in the dependency relations list.  

In [37]:
from parse_utils import *
import wandb
import pickle 

wandb.login()
print(torch.cuda.is_available())

%load_ext autoreload
%autoreload 2

True
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


First, lets load the validation set data.

In [2]:
data_val = read_conllu(os.path.join('data', 'dev.conll'))
print(f"Number of sentences in the validation data: {len(data_val)}")

Number of sentences in the validation data: 1700


In [30]:
# load pytorch dataset object from file
with open('val_dataset_pytorch.pkl', 'rb') as f:
    val_dataset = pickle.load(f)

action2idx = val_dataset.action2idx
label2idx = val_dataset.label2idx

Now load the trained oracle model

In [8]:
DEVICE = "cuda"
learning_rate = 1e-5
model = BERT_ORACLE(num_actions=len(action2idx), num_labels=len(label2idx), unlabeled_arcs=False).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model, optimizer = load_model_checkpoint(model, optimizer)

num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Loaded model from checkpoint!
Total number of parameters in transformer network: 66.959019 M
RAM used: 2707.07 MB


In [12]:
"""
val_dataloader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn)
validation(model, val_dataloader, device=DEVICE)
"""

(0.2752455174922943, 0.9773752207663278, 0.9538105168222937)

Now lets implement the greedy arc-standard transition parser. This implementation is very strict and causes parser to fail whenever oracle predicts an invalid action.

In [50]:
"""
    Greedy arc-standard transition parser
"""
def parser(sentence_words, oracle, dataset, verbose=False):
    tokens = [(word, i+1) for i,word in enumerate(sentence_words)]

    # initialize the state
    stack = [('ROOT', 0)]
    buffer = tokens.copy()
    arcs = []

    # convert sentence to model input tensors
    input_idx, input_attn_mask, word_ids = dataset.tokenize_sentence(sentence_words)
    input_idx = input_idx.unsqueeze(0).to(DEVICE)
    input_attn_mask = input_attn_mask.unsqueeze(0).to(DEVICE)
    
    # set model to eval mode
    oracle.eval()
    # compute BERT encoding of sentence tokens
    with torch.no_grad():
        bert_output = oracle.get_bert_encoding(input_idx, input_attn_mask)

    labels = list(dataset.label2idx.keys())
    actions = list(dataset.action2idx.keys())

    if verbose: 
            print(f"\nStack: {stack}")
            print(f"Buffer: {buffer}")   

    # begin parsing
    while len(buffer) > 0 or len(stack) > 1:

        if len(buffer) > 0:
            state = [(stack[-2:] , buffer[0])]
        else:
            state = [(stack[-2:], None)]
        state_idx = [dataset.tokenize_state(state, word_ids)]    

        # get the oracle action and label scores
        action_logits, label_logits = oracle.predict(bert_output, state_idx)
        
        # pick highest scoring action and label
        best_action = actions[torch.argmax(action_logits[0][0])]
        best_label = labels[torch.argmax(label_logits[0][0])]

        # perform the action
        if best_action == 'LEFTARC':
            # LEFTARC
            if len(stack) > 1:
                if stack[-2][0] != 'ROOT':
                    arcs.append((stack[-1], stack[-2], best_label))
                    stack.pop(-2)
                else:
                    raise ValueError("Cannot perform LEFTARC action with ROOT as dependent. Parse failed.")    
            else:
                raise ValueError("Cannot perform LEFTARC action with stack length <= 1. Parse failed.")

        elif best_action == 'RIGHTARC':
            # RIGHTARC
            if len(stack) > 1:
                arcs.append((stack[-2], stack[-1], best_label))
                stack.pop(-1) 
            else:
                raise ValueError("Cannot perform RIGHTARC action with stack length <= 1. Parse failed.")         

        else:
            # SHIFT
            if len(buffer) > 0:
                stack.append(buffer.pop(0))
            else:
                raise ValueError("Cannot perform SHIFT action with buffer length <= 0. Parse failed.")         

        if verbose:
            print(f"Best action: {best_action}, best label: {best_label}")
            print(f"\nStack: {stack}")
            print(f"Buffer: {buffer}")
            print(f"Arcs: {arcs}")        

    return arcs                       


"""
    Evaluate the predicted arcs against the gold arcs by computing unlabeled and labeled attachment scores
"""
def evaluate(gold_arcs, predicted_arcs):
    uas = 0
    las = 0
    gold_head_deps = [(r[0], r[1]) for r in gold_arcs]

    for r in predicted_arcs:
        if (r[0], r[1]) in gold_head_deps:
            uas += 1
            if r in gold_arcs:
                las += 1

    uas = uas / len(gold_arcs)
    las = las / len(gold_arcs)            
    return uas, las    

In [53]:
# get a test sentence from the validation set and its gold standard parse
test_data_instance = data_val[10]
gold_states, gold_actions, gold_labels, sentence_words, gold_arcs  = training_oracle(test_data_instance, return_states=True, max_iters=100000)

# predict the parse using the oracle
precicted_arcs = parser(sentence_words, model, val_dataset, verbose=False)

In [54]:
# compare the gold standard and predicted arcs
uas, las = evaluate(gold_arcs, precicted_arcs)
print(f"UAS: {uas}, LAS: {las}")

UAS: 1.0, LAS: 0.9666666666666667


#### Beam Search

To improve the accuracy of the parser, we can use beam search instead of greedily choosing the best possible action at each step. 

In beam search, we define a beam width $k$ and maintain a search tree and use breadth-first search. The root of the tree is designated to be the initial state. We then expand the root node by exploring all valid actions and generate the resulting states. We define a state score such that the score of the initial state is $0$ and the score of newly generated states is the score of the predecessor state plus the score of the action taken to generate the successor state:

$StateScore(s_0) = 0$

$StateScore(s_i) = StateScore(s_{i-1}) + Score(s_{i-1}, a)$

Then we expand each of these successor states and prune the tree to keep only the  top-$k$ successor states with highest state scores. We continue expanding every node in the beam until it they have all reached a terminal state. Then the best parse is given by the terminal state with the highest state score.

The beam dsearch parse algorithm is shown below (borrowed from Jurafsky-Martin textbook):

<img src="beam_search.png" width="600" height="500">


In [None]:
def beam_contains_non_final_states(beam):
    for state in beam:
        stack, buffer, arcs, score = state
        if len(buffer) > 0 or len(stack) > 1:
            return True
    return False


def add_state_to_beam(beam, state, k):
    if len(beam) < k:
        beam.append(state)
    else:
        # replace the state with the lowest score
        min_score = min([s[3] for s in beam])
        if state[3] > min_score:
            min_idx = [s[3] for s in beam].index(min_score)
            beam[min_idx] = state
    return beam


def valid_actions(stack, buffer):
    actions = []
    if len(buffer) > 0:
        actions.append('SHIFT')
    if len(stack) > 1:
        actions.append('LEFTARC')
        actions.append('RIGHTARC')
    return actions


def generate_successor_state(state, action, action_score, best_label):
    stack, buffer, arcs, score = state
    # perform the action
    if action == 'LEFTARC':
        arcs.append((stack[-1], stack[-2], best_label))
        stack.pop(-2)            
    elif action == 'RIGHTARC':
        arcs.append((stack[-2], stack[-1], best_label))
        stack.pop(-1) 
    else:
        stack.append(buffer.pop(0))

    return (stack, buffer, arcs, score + action_score)    


def get_best_state(beam):
    best_score = -np.inf
    best_state = None
    for state in beam:
        if state[3] > best_score:
            best_score = state[3]
            best_state = state
    return best_state


def beam_parser(sentence_words, oracle, dataset, k=10, verbose=False):
    tokens = [(word, i+1) for i,word in enumerate(sentence_words)]

    # initialize the state
    stack = [('ROOT', 0)]
    buffer = tokens.copy()
    arcs = []
    score = 0.0
    state = [(stack, buffer, arcs, score)]
    # initialize the beam
    beam = [state]

    # convert sentence to model input tensors
    input_idx, input_attn_mask, word_ids = dataset.tokenize_sentence(sentence_words)
    input_idx = input_idx.unsqueeze(0).to(DEVICE)
    input_attn_mask = input_attn_mask.unsqueeze(0).to(DEVICE)
    
    # set model to eval mode
    oracle.eval()
    # compute BERT encoding of sentence tokens
    with torch.no_grad():
        bert_output = oracle.get_bert_encoding(input_idx, input_attn_mask)

    labels = list(dataset.label2idx.keys())
    actions = list(dataset.action2idx.keys())

    if verbose: 
            print(f"\nStack: {stack}")
            print(f"Buffer: {buffer}")   

    # begin beam search
    while beam_contains_non_final_states(beam):
        beam_successors = []
        for state in beam:
            stack, buffer, arcs, score = state
            # get all valid actions
            actions = valid_actions(stack, buffer)
            # compute actions scores for this state
            if len(buffer) > 0:
                oracle_state = [(stack[-2:] , buffer[0])]
            else:
                oracle_state = [(stack[-2:], None)]
            state_idx = [dataset.tokenize_state(oracle_state, word_ids)]    
            action_logits, label_logits = oracle.predict(bert_output, state_idx)
            best_label = labels[torch.argmax(label_logits[0][0])]

            # expand the state using each valid action
            for action in actions:
                # get the score for this action
                action_score = action_logits[0][0][dataset.action2idx[action]].item()
                # apply the action to get the successor state
                successor_state = generate_successor_state(state, action, action_score, best_label)
                # add to beam
                add_state_to_beam(beam_successors, successor_state, k)
            
        beam = beam_successors

        # get the best state from the beam
        best_state = get_best_state(beam)

    return best_state[2]                       