#### Training an Oracle for a Neural Greedy Transition-Based Parser

We will now train an `oracle` model for a greedy (arc-standard) transition based-parser. The model extracts (local) features from the state of a parse and uses these features to predict a probability distribution over all possible next actions. For unlabeled arcs, we only have three possible actions: `LEFTARC`, `RIGHTARC` and `SHIFT`. The model architecture is shown below (diagram borrowed from Jurafsky-Martin textbook):

<img src="neural_oracle.png" width="550" height="320">

The most important features needed to predict the next action are usually the top few words on the stack and buffer (and dependents of these words), which is why for this model, we will construct a simple feature vector by concatenating the contextualized embeddings (from A BERT encoder) of the top two words from the stack and the top word from the buffer. We will designate the BERT `[CLS]` token as the `ROOT` and we will use the zero vector to represent `NULL` (e.g. if the buffer is empty or the stack con tains less than 2 words, we fill the empty positions with NULL).

The feature vector is then passed through a basic 2-layer `feed-forward network` with a softmax at the output to obtain the predicted probability distribution over actions. In order to accomodate labelled arcs, we have two options: 

1) Augment the action with the label, i.e. for each `label`, we have two actions: `LEFTARC-label`, `RIGHTARC-label`

2) Create a separate feed-forward network which predicts the label, independent from the action. Since the `SHIFT` action has no associated label, we will define a special `NULL label` that goes along with it.

We will use Option 2 because it makes the learning task easier (because there are fewer class labels to predict) and is more efficient and could potentially lead to better performance (since it requires fewer parameters and so there's less chance of overfitting).

We will set up our training oracle such that it returns the state and action pairs needed for training.


In [1]:
import os
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizerFast, DistilBertModel, get_linear_schedule_with_warmup
from tqdm import tqdm
import random
random.seed(10)
import psutil
import wandb

wandb.login()
print(torch.cuda.is_available())

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True


Load the trainind and vaidation data.

In [2]:
# function for reading CoNLL parse files
def read_conllu(file_path):
    """
    Read a CoNLL-U file and return a list of sentences, where each sentence is a list of dictionaries, one for each token.
    """
    with open(file_path, 'r') as f:
        sentences = f.read().strip().split('\n\n')
        examples = []
        for sentence in sentences:
            token_dicts = []
            for line in sentence.split('\n'):
                if line[0] == '#':
                    continue    
                token_dict = list(zip(['id', 'form', 'lemma', 'upostag', 'xpostag' , 'feats', 'head', 'deprel', 'deps', 'misc'], line.split('\t')))
                # only keep form, xpostag, head, and deprel
                token_dicts.append(dict([token_dict[1], token_dict[4], token_dict[6], token_dict[7]]))
            examples.append(token_dicts)
        return examples
    

# function for extracting all the tokens and labelled head-dependency relations from the data
def get_tokens_relations(data_instance):
    """
    Extract all the labeled dependency relations from the data.
    """
    tokens = []
    relations = []
    for token_id, token in enumerate(data_instance):
        head_id = int(token['head'])
        if head_id == 0:
            head = 'ROOT'
        else:
            head = data_instance[head_id - 1]['form']
        dependent = token['form']
        tokens.append((dependent, token_id+1))
        relation = token['deprel']
        relations.append(((head, head_id), (dependent, token_id+1), relation))
    return tokens, relations


# training oracle returns the state-action pairs from every step of the parsing process
# the state only consists of the top two words on the stack and top word on the buffer
def training_oracle(data_instance, return_states=False, max_iters=100, verbose=False):
    # get the tokens and relations for the refenrence parse 
    tokens, Rp = get_tokens_relations(data_instance)
    sentence_words = [t[0] for t in tokens]
    if verbose: 
        print(f"Sentence: {sentence_words}")
        print(f"Reference parse: {Rp}")

    head_dep = [(r[0], r[1]) for r in Rp]

    # intialize the stack and buffer
    stack = [('ROOT', 0), tokens[0]]
    buffer = tokens[1:]
    Rc = []
    states = None
    if return_states:
        states = [([('ROOT', 0)], tokens[0])]
    actions = ['SHIFT']
    labels = ['null']
    # parse the sentence to get the sequence of states and actions
    niters = 0
    
    if verbose: 
        print(f"\nStack: {stack}")
        print(f"Buffer: {buffer}")    

    while (buffer or len(stack) > 1) and niters < max_iters:
        # get top two elements of stack
        S1 = stack[-1]
        S2 = stack[-2] 
        niters += 1

        if return_states:
            if len(buffer) > 0:
                states.append((stack[-2:] , buffer[0]))
            else:
                states.append((stack[-2:], None))

        # check if LEFTARC possible
        if (S1, S2) in head_dep:
            # remove second element of stack
            stack.pop(-2)
            rel = Rp[head_dep.index((S1, S2))]
            Rc.append(rel)
            next_action = 'LEFTARC' 
            next_label = rel[2]
            arc = (S1, S2, rel[2])

        # check if RIGHTARC possible
        elif (S2, S1) in head_dep:
            # get all head-dependent relations with S1 as head
            S1_rels = [r for r in Rp if r[0] == S1]
            # check if all dependents of S1 are in Rc
            if all([r in Rc for r in S1_rels]):
                stack.pop(-1)
                rel = Rp[head_dep.index((S2, S1))]
                Rc.append(rel)
                next_action = 'RIGHTARC' 
                next_label = rel[2]
                arc = (S2, S1, rel[2])
            else:
                if len(buffer)==0:
                    if verbose: print(f"Error! Parse failed, no valid action available!")
                    return None, None, None, None
                stack.append(buffer.pop(0))
                next_action = 'SHIFT'
                next_label = 'null'
                arc = None

        # otherwise SHIFT    
        else:
            if len(buffer)==0:
                    if verbose: print(f"Error! Parse failed, no valid action available!")
                    return None, None, None, None
            stack.append(buffer.pop(0))
            next_action = 'SHIFT'
            next_label = 'null'
            arc = None

        actions.append(next_action)
        labels.append(next_label)
        if verbose:
            print(f"Action: {next_action}, Arc: {arc}")
            print(f"\nStack: {stack}")
            print(f"Buffer: {buffer}")
            print(f"Rc: {Rc}")      

    # make sure Rc and Rp are consistent
    assert all([r in Rc for r in Rp]) and len(Rc)==len(Rp), "Rc not consistent with Rp"

    if niters == max_iters:
        print("Maximum number of iterations reached!")  

    return states, actions, labels, sentence_words    


In [3]:
data_train = read_conllu(os.path.join('data', 'train.conll'))
data_val = read_conllu(os.path.join('data', 'dev.conll'))

print(f"Number of sentences in the training data: {len(data_train)}")
print(f"Number of sentences in the validation data: {len(data_val)}")

Number of sentences in the training data: 39832
Number of sentences in the validation data: 1700


Now use training oracle to get the sequence of state-action pairs for each parser step for every sentence. Note that a small number of parses will fail, probably due to non-projectivity of the sentence. 

In [4]:
# lets get the state-action pairs for all sentences
state_action_label_train = []
sentence_words_train = []
failed_train = []
pbar = tqdm(data_train, desc="Sentences parsed")
for i, example in enumerate(pbar):
    #print("Parsing sentence", i, "of", len(data_train))
    states, actions, labels, sentence_words  = training_oracle(example, return_states=True, max_iters=100000)
    if actions is None:
        failed_train.append(i)
    else:
        state_action_label_train.append((states, actions, labels))
        sentence_words_train.append(sentence_words)
print(f"Number of failed parses: {len(failed_train)}")

state_action_label_val = []
sentence_words_val = []
failed_val = []
pbar = tqdm(data_val, desc="Sentences parsed")
for i, example in enumerate(pbar):
    #print("Parsing sentence", i, "of", len(data_val))
    states, actions, labels, sentence_words = training_oracle(example, return_states=True, max_iters=100000)
    if actions is None:
        failed_val.append(i)
    else:
        state_action_label_val.append((states, actions, labels))
        sentence_words_val.append(sentence_words)   
print(f"Number of failed parses: {len(failed_val)}")


Sentences parsed: 100%|██████████| 39832/39832 [00:05<00:00, 6769.73it/s]


Number of failed parses: 120


Sentences parsed: 100%|██████████| 1700/1700 [00:00<00:00, 11476.84it/s]

Number of failed parses: 5





In [5]:
# map labels and actions to indices
action2idx = {'LEFTARC': 0, 'RIGHTARC': 1, 'SHIFT': 2}
labels = list(set([l for item in (state_action_label_train+state_action_label_val) for l in item[2]]))
label2idx = {l: i for i, l in enumerate(labels)}

In [6]:
"""
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

sentence = sentence_words_train[0]
input_encoding = tokenizer.encode_plus(sentence, is_split_into_words=True, return_offsets_mapping=False, padding=False, truncation=False, add_special_tokens=True)
input_idx = input_encoding['input_ids']
word_ids = input_encoding.word_ids()
print(tokenizer.convert_ids_to_tokens(input_idx))
print(word_ids)

state, action, label = state_action_label_train[0]
stack_words, buffer_word = state[0]
print(stack_words, buffer_word)

state_words_idx = [tokenizer.pad_token_id] * 3
print(state_words_idx)
for i in range(len(stack_words)):
    if stack_words[i][0] == 'ROOT':
        state_words_idx[i] = tokenizer.cls_token_id
    else:
        state_words_idx[i] = word_ids.index(stack_words[i][1]-1)

if buffer_word is not None:
    state_words_idx[2] = word_ids.index(buffer_word[1]-1)
print(state_words_idx)   """

"\ntokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')\n\nsentence = sentence_words_train[0]\ninput_encoding = tokenizer.encode_plus(sentence, is_split_into_words=True, return_offsets_mapping=False, padding=False, truncation=False, add_special_tokens=True)\ninput_idx = input_encoding['input_ids']\nword_ids = input_encoding.word_ids()\nprint(tokenizer.convert_ids_to_tokens(input_idx))\nprint(word_ids)\n\nstate, action, label = state_action_label_train[0]\nstack_words, buffer_word = state[0]\nprint(stack_words, buffer_word)\n\nstate_words_idx = [tokenizer.pad_token_id] * 3\nprint(state_words_idx)\nfor i in range(len(stack_words)):\n    if stack_words[i][0] == 'ROOT':\n        state_words_idx[i] = tokenizer.cls_token_id\n    else:\n        state_words_idx[i] = word_ids.index(stack_words[i][1]-1)\n\nif buffer_word is not None:\n    state_words_idx[2] = word_ids.index(buffer_word[1]-1)\nprint(state_words_idx)   "

Now that we have our training data, we can set up a pytorch dataset.

In [7]:
class DependencyParseDataset(Dataset):
    def __init__(self, sentences, state_action_label, action2idx, label2idx, block_size=256):
        self.sentences = sentences
        self.state_action_label = state_action_label
        self.action2idx = action2idx
        self.label2idx = label2idx
        self.block_size = block_size
        self.tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        # get sentence 
        sentence = self.sentences[idx]
        # get states, actions, and labels
        states, actions, labels = self.state_action_label[idx]

        assert len(states) == len(actions) == len(labels), "Lengths of states, actions, and labels do not match."

        # tokenize the sentence
        input_encoding = self.tokenizer.encode_plus(sentence, is_split_into_words=True, return_offsets_mapping=False, padding=False, truncation=False, add_special_tokens=True)
        input_idx = input_encoding['input_ids']
        word_ids = input_encoding.word_ids()

        if len(input_idx) > self.block_size:
            raise ValueError(f"Tokenized sentence {idx} is too long: {len(input_idx)}. Truncation unsupported.")

        # map state words to index of first subword token
        state_idx = []
        for stack_words, buffer_word in states:
            state_words_idx = [self.tokenizer.pad_token_id] * 3  # missing words are filled with PAD token
            for i in range(len(stack_words)):
                if stack_words[i][0] == 'ROOT':
                    state_words_idx[i] = self.tokenizer.cls_token_id  # ROOT is represented by CLS token
                else:
                    state_words_idx[i] = word_ids.index(stack_words[i][1]-1)
            
            if buffer_word is not None:
                state_words_idx[2] = word_ids.index(buffer_word[1]-1)
            
            state_idx.append(state_words_idx)

        # map actions and labels to indices
        action_idx = [self.action2idx[a] for a in actions]
        label_idx = [self.label2idx[l] for l in labels]    

        # add padding 
        input_idx = input_idx + [self.tokenizer.pad_token_id] * (self.block_size - len(input_idx))    
        # create attention mask 
        input_attn_mask = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in input_idx]

        # convert to tensors
        input_idx = torch.tensor(input_idx)
        input_attn_mask = torch.tensor(input_attn_mask) 
        
        return input_idx, input_attn_mask, state_idx, action_idx, label_idx   

def collate_fn(batch):
    # Separate the tensors and the dictionaries
    input_idxs, input_attn_masks, state_idx, action_idx, label_idx = zip(*batch)

    # Default collate the tensors
    input_idxs = torch.stack(input_idxs)
    input_attn_masks = torch.stack(input_attn_masks)

    return input_idxs, input_attn_masks, state_idx, action_idx, label_idx 


#### Now define the oracle model. The model consist of a pre-trained BERT encoder and two MLP classification heads, one head for classifying the unlabeled-action and the other head for classifying the arc-label.

In [8]:
class BERT_ORACLE(torch.nn.Module):
    
    def __init__(self, num_actions, num_labels, num_features=3, unlabeled_arcs=True, dropout_rate=0.1, mlp_hidden_size=128):
        super().__init__()
        self.unlabeled_arcs = unlabeled_arcs
        # load pretrained BERT model
        self.bert_encoder = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.dropout = torch.nn.Dropout(dropout_rate)
        # define action classifier head (2 layer MLP)
        self.classifier_head_action = torch.nn.Sequential(
            torch.nn.Linear(num_features * self.bert_encoder.config.hidden_size, mlp_hidden_size),
            torch.nn.LayerNorm(mlp_hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(mlp_hidden_size, num_actions)
        )
        if not self.unlabeled_arcs:
            # define arc-label classifier head (2 layer MLP)
            self.classifier_head_label = torch.nn.Sequential(
                torch.nn.Linear(num_features * self.bert_encoder.config.hidden_size, mlp_hidden_size),
                torch.nn.LayerNorm(mlp_hidden_size),
                torch.nn.ReLU(),
                torch.nn.Linear(mlp_hidden_size, num_labels)
            )

        # make sure BERT parameters are trainable
        for param in self.bert_encoder.parameters():
            param.requires_grad = True

    def get_features(self, bert_output, state_idx, batch_idx):
        features_states = []
        for i in range(len(state_idx[batch_idx])):
            # get BERT embeddings for the top two words on the stack and the top word on the buffer
            # (the embedding of a word is being represented by the embedding of it's first subword token)
            stack1 = bert_output[batch_idx, state_idx[batch_idx][i][0], :] # shape: (hidden_size,)
            stack2 = bert_output[batch_idx, state_idx[batch_idx][i][1], :] # shape: (hidden_size,)
            buffer = bert_output[batch_idx, state_idx[batch_idx][i][2], :] # shape: (hidden_size,)
            # concatenate the embeddings
            features = torch.cat([stack1, stack2, buffer], dim=0) # shape: (3*hidden_size,)
            features_states.append(features) 
        # stack up the features for all states into a single tensor
        features = torch.stack(features_states) # shape: (num_states, 3*hidden_size)   
        return features

    def forward(self, input_idx, input_attn_mask, state_idx, target_action_idx=None, target_label_idx=None):
        # compute BERT embeddings for input tokens
        bert_output = self.bert_encoder(input_idx, attention_mask=input_attn_mask)
        bert_output = self.dropout(bert_output.last_hidden_state) # shape: (batch_size, block_size, hidden_size)

        loss = 0.0
        batch_action_logits = []
        batch_label_logits = []
        # iterate over each sentence in the batch
        for batch_idx in range(len(input_idx)):  
            # get the features for all parse states
            features = self.get_features(bert_output, state_idx, batch_idx) # shape: (num_states, 3*hidden_size)
            # compute action logits and cross-entropy loss
            action_logits = self.classifier_head_action(features) # shape: (num_states, num_actions)
            batch_action_logits.append(action_logits)
            if target_action_idx is not None:
                action_targets = torch.tensor(target_action_idx[batch_idx], dtype=torch.long, device=input_idx.device)
                loss += F.cross_entropy(action_logits, action_targets)
            if not self.unlabeled_arcs:
                # compute arc-label logits and cross-entropy loss 
                label_logits = self.classifier_head_label(features) # shape: (num_states, num_labels)
                batch_label_logits.append(label_logits)
                if target_label_idx is not None:
                    label_targets = torch.tensor(target_label_idx[batch_idx], dtype=torch.long, device=input_idx.device)
                    loss += F.cross_entropy(label_logits, label_targets)

        # average loss over the batch
        loss = loss/len(input_idx)    

        return loss, batch_action_logits, batch_label_logits    

# training loop
def train(model, optimizer, train_dataloader, val_dataloader, scheduler=None, device="cpu", num_epochs=10, val_every=100, save_every=None, log_metrics=None):
    avg_loss = 0
    model.train()
    # reset gradients
    optimizer.zero_grad()
    for epoch in range(num_epochs):
        train_uas = 0
        train_las = 0
        val_loss = 0
        val_uas = 0
        val_las = 0
        num_instances = 0
        pbar = tqdm(train_dataloader, desc="Epochs")
        for step, batch in enumerate(pbar):
            input_idx, input_attn_mask, state_idx, target_action_idx, target_label_idx = batch
            # move tensors to device
            input_idx, input_attn_mask = input_idx.to(device), input_attn_mask.to(device)
            # forward pass
            loss, batch_action_logits, batch_label_logits = model(input_idx, input_attn_mask, state_idx, target_action_idx, target_label_idx)
            # reset gradients
            optimizer.zero_grad()
            # backward pass
            loss.backward()
            # optimizer step
            optimizer.step()

            if scheduler is not None:
                    scheduler.step()

            avg_loss = 0.9* avg_loss + 0.1*loss.item()

            # compute unlabeled and labeled attachment scores
            for batch_idx in range(len(input_idx)):
                action_logits = batch_action_logits[batch_idx]
                action_idx = target_action_idx[batch_idx]
                if not model.unlabeled_arcs:
                    label_logits = batch_label_logits[batch_idx]
                    label_idx = target_label_idx[batch_idx]
                # compute UAS and LAS
                sentence_uas = 0
                sentence_las = 0
                for i in range(len(action_idx)):
                    if action_idx[i] == torch.argmax(action_logits[i]):
                        sentence_uas += 1
                        if not model.unlabeled_arcs:
                            if label_idx[i] == torch.argmax(label_logits[i]):
                                sentence_las += 1                
                sentence_uas = sentence_uas/len(action_idx)
                train_uas += sentence_uas
                if not model.unlabeled_arcs:
                    sentence_las = sentence_las/len(action_idx)
                    train_las += sentence_las
                num_instances += 1    

            if val_every is not None:
                if (step+1)%val_every == 0:
                    # compute validation loss
                    val_loss, val_uas, val_las = validation(model, val_dataloader, device=device)
                    pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train UAS, LAS: ({train_uas/num_instances: .3f}, {train_las/num_instances: .3f}), Val Loss: {val_loss: .3f}, Val UAS, LAS: ({val_uas: .3f}, {val_las: .3f})")  

            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train UAS, LAS: ({train_uas/num_instances: .3f}, {train_las/num_instances: .3f}), Val Loss: {val_loss: .3f}, Val UAS, LAS: ({val_uas: .3f}, {val_las: .3f})")

            if log_metrics:
                metrics = {"Batch loss":loss.item(), "Moving Avg Loss":avg_loss, "Train UAS":train_uas/num_instances, "Train LAS":train_las/num_instances,"Val Loss": val_loss, "Val UAS":val_uas, "Val LAS":val_las}   
                log_metrics(metrics)

        if save_every is not None:
            if (epoch+1) % save_every == 0:
                save_model_checkpoint(model, optimizer, epoch, avg_loss)


def validation(model, val_dataloader, device="cpu"):
    model.eval()
    val_losses = torch.zeros(len(val_dataloader))
    with torch.no_grad():
        val_uas = 0
        val_las = 0
        num_instances = 0
        for step,batch in enumerate(val_dataloader):
            input_idx, input_attn_mask, state_idx, target_action_idx, target_label_idx = batch
            input_idx, input_attn_mask = input_idx.to(device), input_attn_mask.to(device)
            loss, batch_action_logits, batch_label_logits = model(input_idx, input_attn_mask, state_idx, target_action_idx, target_label_idx)
            
            # compute unlabeled and labeled attachment scores
            for batch_idx in range(len(input_idx)):
                action_logits = batch_action_logits[batch_idx]
                action_idx = target_action_idx[batch_idx]
                if not model.unlabeled_arcs:
                    label_logits = batch_label_logits[batch_idx]
                    label_idx = target_label_idx[batch_idx]
                # compute UAS and LAS
                sentence_uas = 0
                sentence_las = 0
                for i in range(len(action_idx)):
                    if action_idx[i] == torch.argmax(action_logits[i]):
                        sentence_uas += 1
                        if not model.unlabeled_arcs:
                            if label_idx[i] == torch.argmax(label_logits[i]):
                                sentence_las += 1                
                sentence_uas = sentence_uas/len(action_idx)
                val_uas += sentence_uas
                if not model.unlabeled_arcs:
                    sentence_las = sentence_las/len(action_idx)
                    val_las += sentence_las
                num_instances += 1  

            val_losses[step] = loss.item()
    model.train()
    val_loss = val_losses.mean().item()
    val_uas = val_uas/num_instances
    val_las = val_las/num_instances
    return val_loss, val_uas, val_las


def save_model_checkpoint(model, optimizer, epoch=None, loss=None, filename='BERT_TRANSITION_PARSER_checkpoint.pth'):
    # Save the model and optimizer state_dict
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    # Save the checkpoint to a file
    torch.save(checkpoint, filename)
    print(f"Saved model checkpoint!")


def load_model_checkpoint(model, optimizer=None,  filename='BERT_TRANSITION_PARSER_checkpoint.pth'):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    print("Loaded model from checkpoint!")
    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        model.train()
        return model, optimizer          
    else:
        return model        

Now let's train the model. First, we will train without the arc labels.

In [9]:
B = 8
DEVICE = "cuda"
learning_rate = 1e-5
epochs = 3

train_dataset = DependencyParseDataset(sentence_words_train, state_action_label_train, action2idx, label2idx)
val_dataset = DependencyParseDataset(sentence_words_val, state_action_label_val, action2idx, label2idx)

train_dataloader = DataLoader(train_dataset, batch_size=B, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=B, collate_fn=collate_fn)

print(f"Num training batches: {len(train_dataloader)}")
print(f"Num validation batches: {len(val_dataloader)}")

Num training batches: 4964
Num validation batches: 212


In [10]:
model = BERT_ORACLE(num_actions=len(action2idx), num_labels=len(label2idx), unlabeled_arcs=True).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_dataloader) * epochs 
warmup_steps = int(len(train_dataloader) * 0.1 *  epochs) 
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
#model, optimizer = load_model_checkpoint(model, optimizer)

num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Total number of parameters in transformer network: 66.658563 M
RAM used: 3229.03 MB


In [11]:
run = wandb.init(
    project="BERT Transition Dependency Parser", 
    config={
        "model": "DistillBERT",
        "learning_rate": learning_rate, 
        "epochs": epochs,
        "batch_size": B, 
        "corpus": "Penn Treebank"},)   

def log_metrics(metrics):
    wandb.log(metrics)

In [14]:
train(model, optimizer, train_dataloader, val_dataloader, device=DEVICE, num_epochs=epochs, scheduler=scheduler, save_every=None, val_every=200, log_metrics=log_metrics) 

Epoch 1, EMA Train Loss: 1.477, Train UAS, LAS: ( 0.908,  0.000), Val Loss:  1.451, Val UAS, LAS: ( 0.972,  0.000): 100%|██████████| 2482/2482 [25:00<00:00,  1.65it/s]   
Epoch 2, EMA Train Loss: 1.115, Train UAS, LAS: ( 0.977,  0.000), Val Loss:  1.251, Val UAS, LAS: ( 0.976,  0.000): 100%|██████████| 2482/2482 [25:09<00:00,  1.64it/s]  
Epoch 3, EMA Train Loss: 0.978, Train UAS, LAS: ( 0.980,  0.000), Val Loss:  1.174, Val UAS, LAS: ( 0.978,  0.000): 100%|██████████| 2482/2482 [25:06<00:00,  1.65it/s]  


Not bad, the unlabelled attachment score on the validation set is over 97%. Now, let's try training the modeled with arc labelling as well.

In [10]:
model = BERT_ORACLE(num_actions=len(action2idx), num_labels=len(label2idx), unlabeled_arcs=False).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_dataloader) * epochs 
warmup_steps = int(len(train_dataloader) * 0.1 *  epochs) 
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
#model, optimizer = load_model_checkpoint(model, optimizer)

num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Total number of parameters in transformer network: 66.959019 M
RAM used: 3212.10 MB


In [11]:
run = wandb.init(
    project="BERT Transition Dependency Parser with Arc Labels", 
    config={
        "model": "DistillBERT",
        "learning_rate": learning_rate, 
        "epochs": epochs,
        "batch_size": B, 
        "corpus": "Penn Treebank"},)   

def log_metrics(metrics):
    wandb.log(metrics)

In [12]:
train(model, optimizer, train_dataloader, val_dataloader, device=DEVICE, num_epochs=epochs, scheduler=scheduler, save_every=None, val_every=200, log_metrics=log_metrics) 

Epoch 1, EMA Train Loss: 3.751, Train UAS, LAS: ( 0.914,  0.794), Val Loss:  3.512, Val UAS, LAS: ( 0.973,  0.934): 100%|██████████| 4964/4964 [26:49<00:00,  3.08it/s]    
Epoch 2, EMA Train Loss: 2.521, Train UAS, LAS: ( 0.976,  0.943), Val Loss:  2.390, Val UAS, LAS: ( 0.976,  0.946): 100%|██████████| 4964/4964 [26:54<00:00,  3.07it/s]  
Epoch 3, EMA Train Loss: 2.313, Train UAS, LAS: ( 0.978,  0.949), Val Loss:  2.230, Val UAS, LAS: ( 0.976,  0.948): 100%|██████████| 4964/4964 [26:37<00:00,  3.11it/s]  


#### The arc-labeled model also acheived high UAS (97.6%) and LAS (94.8%) on the validation set. 

In [13]:
#save_model_checkpoint(model, optimizer)
#wandb.finish()

Saved model checkpoint!




VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Batch loss,█▆▅▄▃▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Moving Avg Loss,█▆▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Train LAS,▁▃▄▄▅▅▆▆▆▆▆▇▇███████████████████████████
Train UAS,▁▂▄▅▅▆▆▆▇▇▇▇▇███████████████████████████
Val LAS,▁▄▅▆▇▇▇██████▁██████████████████████████
Val Loss,▁█▆▄▃▃▃▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
Val UAS,▁▅▇██████████▁██████████████████████████

0,1
Batch loss,1.94826
Moving Avg Loss,2.31301
Train LAS,0.94861
Train UAS,0.97752
Val LAS,0.94843
Val Loss,2.23043
Val UAS,0.97592
