#### Neural Transition-Based Parser

We will now train an `oracle` model for a greedy (arc-standard) transition based-parser. The model extracts features from the state of a parse and uses these features to predict a probability distribution over all possible next actions. For unlabeled arcs, we only have three possible actions: `LEFTARC`, `RIGHTARC` and `SHIFT`. The model architecture is shown below (diagram borrowed from Jurafsky-Martin textbook):

<img src="neural_oracle.png" width="550" height="320">

The most important features needed to predict the next action are usually the top few words on the stack and buffer (and dependents of these words), which is why for this model, we will construct a simple feature vector by concatenating the contextualized embeddings (from A BERT encoder) of the top two words from the stack and the top word from the buffer. We will designate the BERT `[CLS]` token as the `ROOT` and we will use the zero vector to represent `NULL` (e.g. if the buffer is empty or the stack con tains less than 2 words, we fill the empty positions with NULL).

The feature vector is then passed through a basic 2-layer `feed-forward network` with a softmax at the output to obtain the predicted probability distribution over actions. In order to accomodate labelled arcs, we have two options: 

1) Augment the action with the label, i.e. for each `label`, we have two actions: `LEFTARC-label`, `RIGHTARC-label`

2) Create a separate feed-forward network which predicts the label, independent from the action. Since the `SHIFT` action has no associated label, we will define a special `NULL label` that goes along with it.

We will use Option 2 because it makes the learning task easier (because there are fewer class labels to predict) and is more efficient and could potentially lead to better performance (since it requires fewer parameters and so there's less chance of overfitting).

We will set up our training oracle such that it returns the state and action pairs needed for training.


In [32]:
import os
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizerFast, DistilBertModel, get_linear_schedule_with_warmup
from tqdm import tqdm
import random
random.seed(10)
import psutil
import wandb

wandb.login()
print(torch.cuda.is_available())

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True


Load the trainind and vaidation data.

In [86]:
# function for reading CoNLL parse files
def read_conllu(file_path):
    """
    Read a CoNLL-U file and return a list of sentences, where each sentence is a list of dictionaries, one for each token.
    """
    with open(file_path, 'r') as f:
        sentences = f.read().strip().split('\n\n')
        examples = []
        for sentence in sentences:
            token_dicts = []
            for line in sentence.split('\n'):
                if line[0] == '#':
                    continue    
                token_dict = list(zip(['id', 'form', 'lemma', 'upostag', 'xpostag' , 'feats', 'head', 'deprel', 'deps', 'misc'], line.split('\t')))
                # only keep form, xpostag, head, and deprel
                token_dicts.append(dict([token_dict[1], token_dict[4], token_dict[6], token_dict[7]]))
            examples.append(token_dicts)
        return examples
    

# function for extracting all the tokens and labelled head-dependency relations from the data
def get_tokens_relations(data_instance):
    """
    Extract all the labeled dependency relations from the data.
    """
    tokens = []
    relations = []
    for token_id, token in enumerate(data_instance):
        head_id = int(token['head'])
        if head_id == 0:
            head = 'ROOT'
        else:
            head = data_instance[head_id - 1]['form']
        dependent = token['form']
        tokens.append((dependent, token_id+1))
        relation = token['deprel']
        relations.append(((head, head_id), (dependent, token_id+1), relation))
    return tokens, relations


# training oracle returns the state-action pairs from every step of the parsing process
# the state only consists of the top two words on the stack and top word on the buffer
def training_oracle(data_instance, return_states=False, max_iters=100, verbose=False):
    # get the tokens and relations for the refenrence parse 
    tokens, Rp = get_tokens_relations(data_instance)
    sentence_words = [t[0] for t in tokens]
    if verbose: 
        print(f"Sentence: {sentence_words}")
        print(f"Reference parse: {Rp}")

    head_dep = [(r[0], r[1]) for r in Rp]

    # intialize the stack and buffer
    stack = [('ROOT', 0), tokens[0]]
    buffer = tokens[1:]
    Rc = []
    states = None
    if return_states:
        states = [([('ROOT', 0)], tokens[0])]
    actions = ['SHIFT']
    labels = ['null']
    # parse the sentence to get the sequence of states and actions
    niters = 0
    
    if verbose: 
        print(f"\nStack: {stack}")
        print(f"Buffer: {buffer}")    

    while (buffer or len(stack) > 1) and niters < max_iters:
        # get top two elements of stack
        S1 = stack[-1]
        S2 = stack[-2] 
        niters += 1

        if return_states:
            if len(buffer) > 0:
                states.append((stack[-2:] , buffer[0]))
            else:
                states.append((stack[-2:], None))

        # check if LEFTARC possible
        if (S1, S2) in head_dep:
            # remove second element of stack
            stack.pop(-2)
            rel = Rp[head_dep.index((S1, S2))]
            Rc.append(rel)
            next_action = 'LEFTARC' 
            next_label = rel[2]
            arc = (S1, S2, rel[2])

        # check if RIGHTARC possible
        elif (S2, S1) in head_dep:
            # get all head-dependent relations with S1 as head
            S1_rels = [r for r in Rp if r[0] == S1]
            # check if all dependents of S1 are in Rc
            if all([r in Rc for r in S1_rels]):
                stack.pop(-1)
                rel = Rp[head_dep.index((S2, S1))]
                Rc.append(rel)
                next_action = 'RIGHTARC' 
                next_label = rel[2]
                arc = (S2, S1, rel[2])
            else:
                if len(buffer)==0:
                    if verbose: print(f"Error! Parse failed, no valid action available!")
                    return None, None, None, None
                stack.append(buffer.pop(0))
                next_action = 'SHIFT'
                next_label = 'null'
                arc = None

        # otherwise SHIFT    
        else:
            if len(buffer)==0:
                    if verbose: print(f"Error! Parse failed, no valid action available!")
                    return None, None, None, None
            stack.append(buffer.pop(0))
            next_action = 'SHIFT'
            next_label = 'null'
            arc = None

        actions.append(next_action)
        labels.append(next_label)
        if verbose:
            print(f"Action: {next_action}, Arc: {arc}")
            print(f"\nStack: {stack}")
            print(f"Buffer: {buffer}")
            print(f"Rc: {Rc}")      

    # make sure Rc and Rp are consistent
    assert all([r in Rc for r in Rp]) and len(Rc)==len(Rp), "Rc not consistent with Rp"

    if niters == max_iters:
        print("Maximum number of iterations reached!")  

    return states, actions, labels, sentence_words    


In [81]:
data_train = read_conllu(os.path.join('data', 'train.conll'))
data_val = read_conllu(os.path.join('data', 'dev.conll'))

print(f"Number of sentences in the training data: {len(data_train)}")
print(f"Number of sentences in the validation data: {len(data_val)}")

Number of sentences in the training data: 39832
Number of sentences in the validation data: 1700


Now use training oracle to get the sequence of state-action pairs for each sentence. Note that a small number of parses will fail, probably due to non-projectivity of the sentence. 

In [93]:
# lets get the state-action pairs for all sentences
state_action_label_train = []
sentence_words_train = []
failed_train = []
pbar = tqdm(data_train, desc="Sentences parsed")
for i, example in enumerate(pbar):
    #print("Parsing sentence", i, "of", len(data_train))
    states, actions, labels, sentence_words  = training_oracle(example, return_states=True, max_iters=100000)
    if actions is None:
        failed_train.append(i)
    else:
        state_action_label_train.append((states, actions, labels))
        sentence_words_train.append(sentence_words)
print(f"Number of failed parses: {len(failed_train)}")

state_action_label_val = []
sentence_words_val = []
failed_val = []
pbar = tqdm(data_val, desc="Sentences parsed")
for i, example in enumerate(pbar):
    #print("Parsing sentence", i, "of", len(data_val))
    states, actions, labels, sentence_words = training_oracle(example, return_states=True, max_iters=100000)
    if actions is None:
        failed_val.append(i)
    else:
        state_action_label_val.append((states, actions, labels))
        sentence_words_val.append(sentence_words)   
print(f"Number of failed parses: {len(failed_val)}")


Sentences parsed: 100%|██████████| 39832/39832 [00:06<00:00, 6179.18it/s] 


Number of failed parses: 120


Sentences parsed: 100%|██████████| 1700/1700 [00:00<00:00, 11617.78it/s]

Number of failed parses: 5





In [96]:
# map labels and actions to indices
action2idx = {'LEFTARC': 0, 'RIGHTARC': 1, 'SHIFT': 2}
labels = list(set([l for item in (state_action_label_train+state_action_label_val) for l in item[2]]))
label2idx = {l: i for i, l in enumerate(labels)}

In [99]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

In [102]:
sentence = sentence_words_train[0]
input_encoding = tokenizer.encode_plus(sentence, is_split_into_words=True, return_offsets_mapping=False, padding=False, truncation=False, add_special_tokens=True)
input_idx = input_encoding['input_ids']
word_ids = input_encoding.word_ids()
print(tokenizer.convert_ids_to_tokens(input_idx))
print(word_ids)

['[CLS]', 'in', 'an', 'oct', '.', '19', 'review', 'of', '`', '`', 'the', 'mis', '##ant', '##hr', '##ope', "'", "'", 'at', 'chicago', "'", 's', 'goodman', 'theatre', '-', 'l', '##rb', '-', '`', '`', 'rev', '##ital', '##ized', 'classics', 'take', 'the', 'stage', 'in', 'windy', 'city', ',', "'", "'", 'leisure', '&', 'arts', '-', 'rr', '##b', '-', ',', 'the', 'role', 'of', 'ce', '##lim', '##ene', ',', 'played', 'by', 'kim', 'cat', '##tral', '##l', ',', 'was', 'mistakenly', 'attributed', 'to', 'christina', 'ha', '##ag', '.', '[SEP]']
[None, 0, 1, 2, 2, 3, 4, 5, 6, 6, 7, 8, 8, 8, 8, 9, 9, 10, 11, 12, 12, 13, 14, 15, 15, 15, 15, 16, 16, 17, 17, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 26, 27, 28, 29, 30, 30, 30, 30, 31, 32, 33, 34, 35, 35, 35, 36, 37, 38, 39, 40, 40, 40, 41, 42, 43, 44, 45, 46, 47, 47, 48, None]


In [103]:
state, action, label = state_action_label_train[0]
stack_words, buffer_word = state[0]
print(stack_words, buffer_word)

[('ROOT', 0)] ('In', 1)


In [109]:
state_words_idx = [tokenizer.pad_token_id] * 3
print(state_words_idx)
for i in range(len(stack_words)):
    if stack_words[i][0] == 'ROOT':
        state_words_idx[i] = tokenizer.cls_token_id
    else:
        state_words_idx[i] = word_ids.index(stack_words[i][1]-1)

if buffer_word is not None:
    state_words_idx[2] = word_ids.index(buffer_word[1]-1)
print(state_words_idx)    

[0, 0, 0]
[101, 0, 1]


Now that we have our training data, we can set up a pytorch dataset.

In [110]:
class DependencyParseDataset(Dataset):
    def __init__(self, sentences, state_action_label, action2idx, label2idx, block_size=256):
        self.sentences = sentences
        self.state_action_label = state_action_label
        self.action2idx = action2idx
        self.label2idx = label2idx
        self.block_size = block_size
        self.tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        # get sentence 
        sentence = self.sentences[idx]
        # get states, actions, and labels
        states, actions, labels = self.state_action_label[idx]

        # tokenize the sentence
        input_encoding = self.tokenizer.encode_plus(sentence, is_split_into_words=True, return_offsets_mapping=False, padding=False, truncation=False, add_special_tokens=True)
        input_idx = input_encoding['input_ids']
        word_ids = input_encoding.word_ids()

        if len(input_idx) > self.block_size:
            raise ValueError(f"Tokenized sentence {idx} is too long: {len(input_idx)}. Truncation unsupported.")

        # map state words to index of first subword token
        state_idx = []
        for stack_words, buffer_word in states:
            state_words_idx = [self.tokenizer.pad_token_id] * 3  # missing words are filled with PAD token
            for i in range(len(stack_words)):
                if stack_words[i][0] == 'ROOT':
                    state_words_idx[i] = self.tokenizer.cls_token_id  # ROOT is represented by CLS token
                else:
                    state_words_idx[i] = word_ids.index(stack_words[i][1]-1)
            
            if buffer_word is not None:
                state_words_idx[2] = word_ids.index(buffer_word[1]-1)
            
            state_idx.append(state_words_idx)

        # map actions and labels to indices
        action_idx = [self.action2idx[a] for a in actions]
        label_idx = [self.label2idx[l] for l in labels]    

        # add padding 
        input_idx = input_idx + [self.tokenizer.pad_token_id] * (self.block_size - len(input_idx))    
        # create attention mask 
        input_attn_mask = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in input_idx]

        # convert to tensors
        input_idx = torch.tensor(input_idx)
        input_attn_mask = torch.tensor(input_attn_mask) 
        
        return input_idx, input_attn_mask, state_idx, action_idx, label_idx   


def collate_fn(batch):
    # Separate the tensors and the dictionaries
    input_idxs, input_attn_masks, state_idx, action_idx, label_idx = zip(*batch)

    # Default collate the tensors
    input_idxs = torch.stack(input_idxs)
    input_attn_masks = torch.stack(input_attn_masks)

    return input_idxs, input_attn_masks, state_idx, action_idx, label_idx 


In [111]:
train_dataset = DependencyParseDataset(sentence_words_train, state_action_label_train, action2idx, label2idx)
val_dataset = DependencyParseDataset(sentence_words_val, state_action_label_val, action2idx, label2idx)
