#### BERT augmented CKY variant

Previously we looked at implementing the CKY algorithm to generate parse trees for sentences conforming to a context-free grammar in Chomsky-normal-form. We found that sentences generated by a CFG suffer from syntax ambiguities, i.e. a sentence can have multiple valid parse trees, each with a different meaning. However, we know that out of the different possible parses, typically only one of them is the "correct" one, i.e. captures the intended meaning of the sentence. 

In this notebook, we will look an a variant of the CKY algorithm in which each possible span of a sentence is assigned a score, and these scores can be used to arrive at the correct parse tree. These scores can be computed with the help of a neural network trained on an annotated treebank dataset. The training task simply involves predicting a distribution of scores over all possible non-terminal labels for each valid span/constituent of a sentence. After training this neural model, it is expected to assign large score to the correct label for each constituents of any given sentence, which will then help with the downstream task of disambiguating the correct parse tree (which will be done using a slightly modified CKY algorithm). A BERT model is powerful and well-suited for the score prediction task. The diagram below (borrowed from the Jurafsky-Martin textbook) summarizes the model architecture:

<img src="neural_parser.png" width="600" height="450">

We define the spans in the same way as we did for the vanilla CKY parser, i.e. using the "fencepost" positions. We also use the same upper-triangular matrix that which we used previously. Instead of using a pre-defined CFG in CNF to assign non-terminal labels to each element in this matrix (which represent the different possible spans), this time we will instead use BERT model to compute a distribution of scores over all possible terminals for each possible span. First, we create a fixed-size vector representation of the span and then feed it into an MLP classifier, as shown in the diagram. We outline the steps in more detail:

1) Convert words to subword tokens
2) Get `BERT embeddings for subwords`
3) Compute the `embeddings for full words` (many ways to do this, e.g. we could just assign the BERT em,bedding for the first subword of that word, or we could take the sum.average of the embedding of all the subwords or we could take element-wise max across all the subword embeddings, etc.)  
4) Compute `embeddings of fence posts` (shown as 0,1,2,3.. in the diagram above). Since each fence-post can represent the beginning or end of a span, we will create two separate representations. We first split the embedding vector $y_t$ of the $t$-th word in the sentence into two halves, $\overleftarrow{y_t}$ and $\overrightarrow{y_t}$ such that the concatentaion $[\overleftarrow{y_t}; \overrightarrow{y_t}] = y_t$. Then the `start-of-span representation` of the fencepost at position $i$ is defined as $\overrightarrow{y_i}$ and the `end-of-span representation` is defined as $\overleftarrow{y}_{i+1}$
5) Construct `embedding for a span` `(i,j)` using the fencepost embeddings as the following concatenation between the difference in start-of span and end-of-span embeddings for the bounding fenceposts: $v(i,j) = [\overrightarrow{y_j}-\overrightarrow{y_i}; \overleftarrow{y}_{j+1}-\overleftarrow{y}_{i+1}]$
6) Pass $v(i,j)$ through the MLP to get a distribution of scores over all possible non-terminal labels.

One really important thing to note here is that we are no longer using a pre-defined context-free grammar. The supervised training of the neural network model will implicitly induce/"learn" the grammar. Also, one downside is that the model may sometimes fail to get a grammatically correct parse of a sentence, because it does not have access to the "true grammar", only some form of statistical approximation of it.

After we implement and train this model, we will iplement the CKY variant that will perform the actual parsing using the scores computed by the BERT model.

In [1]:
from nltk.corpus import treebank
from nltk import Tree
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizerFast, DistilBertModel, get_linear_schedule_with_warmup
from tqdm import tqdm
import random
random.seed(10)
import psutil
print(torch.cuda.is_available())

True


Firstm let's load the data. We will use the NLTK treebank, which is a subset of the original Penn Treebank dataset.

In [2]:
# get all parsed sentences acrsoo all the files
sentences = treebank.sents()
parse_trees = treebank.parsed_sents()

# only keep sentences that are at most 100 words
sentences_parses = zip(sentences, parse_trees)
sentences_parses = [(s,p) for s,p in sentences_parses if len(s) <= 100]

sentences = [s for s,p in sentences_parses]
parse_trees = [p for s,p in sentences_parses]

print(f"Number of parsed sentences: {len(parse_trees)}")

Number of parsed sentences: 3910


In [113]:
sentences_lengths = [len(s) for s in sentences]
print(min(sentences_lengths), max(sentences_lengths), sum(sentences_lengths)/len(sentences_lengths))

# plot a histogram of sentence lengths
#import matplotlib.pyplot as plt
#plt.hist(sentences_lengths, bins=30)
#plt.show()

1 92 25.58618925831202


In [142]:
# print example of a parsed sentence
example_tree = parse_trees[0]
print(example_tree)

(S
  (NP-SBJ
    (NP (NNP Pierre) (NNP Vinken))
    (, ,)
    (ADJP (NP (CD 61) (NNS years)) (JJ old))
    (, ,))
  (VP
    (MD will)
    (VP
      (VB join)
      (NP (DT the) (NN board))
      (PP-CLR (IN as) (NP (DT a) (JJ nonexecutive) (NN director)))
      (NP-TMP (NNP Nov.) (CD 29))))
  (. .))


In [143]:
# function for extracting all the non-terminal labels for all possible spans from each sentence
def get_span_labels(parse_tree, verbose=False):
    span_labels = {}
    leaves = parse_tree.leaves()
    if verbose:
        parse_tree.pretty_print()
    # iterate over all subtrees in level-order traversal
    for subtree in parse_tree.subtrees():
        subtree_leaves = subtree.leaves()
        start_index = leaves.index(subtree_leaves[0])
        end_index = start_index + len(subtree_leaves)
        span = (start_index, end_index)
        if span[1] < span[0]:
            parse_tree.pretty_print()
            raise ValueError(f"Span start index is greater than or equal to span end index: {span}")
        span_labels[span] = subtree.label()
        if verbose:
            print(f"\nsubtree label: {subtree.label()}")
            print(f"subtree leaves: {subtree.leaves()}")
            print(f"Start fence-post: {start_index}")
            print(f"End fence-post: {end_index}")
            print(f"span: {span}")
    return span_labels           

In [144]:
get_span_labels(example_tree, verbose=True)

                                                     S                                                                         
                         ____________________________|_______________________________________________________________________   
                        |                                               VP                                                   | 
                        |                        _______________________|___                                                 |  
                      NP-SBJ                    |                           VP                                               | 
         _______________|___________________    |     ______________________|______________________________________          |  
        |          |              ADJP      |   |    |        |                PP-CLR                              |         | 
        |          |           ____|____    |   |    |        |          ________|_________          

{(0, 18): 'S',
 (0, 7): 'NP-SBJ',
 (0, 2): 'NP',
 (0, 1): 'NNP',
 (1, 2): 'NNP',
 (2, 3): ',',
 (3, 6): 'ADJP',
 (3, 5): 'NP',
 (3, 4): 'CD',
 (4, 5): 'NNS',
 (5, 6): 'JJ',
 (7, 17): 'VP',
 (7, 8): 'MD',
 (8, 17): 'VP',
 (8, 9): 'VB',
 (9, 11): 'NP',
 (9, 10): 'DT',
 (10, 11): 'NN',
 (11, 15): 'PP-CLR',
 (11, 12): 'IN',
 (12, 15): 'NP',
 (12, 13): 'DT',
 (13, 14): 'JJ',
 (14, 15): 'NN',
 (15, 17): 'NP-TMP',
 (15, 16): 'NNP',
 (16, 17): 'CD',
 (17, 18): '.'}

#### Because CKY requires the parse trees to be binary, we need to convert all n-ary trees from our treebank dataset to binary form. We use a simple scheme for binarization, we recursively traverse down the tree starting from the root, and everytime we find a node with more than two children, we create a new node with the 'Null' label and make it the parent of all children except for the leftmost child.

In [27]:
# recursively binarize a parse tree
def binarize_tree(parse_tree, empty_label='Null'):
    if isinstance(parse_tree, Tree):
        if len(parse_tree) == 1:
            return Tree(parse_tree.label(), [parse_tree[0]])
        elif len(parse_tree) == 2:
            return Tree(parse_tree.label(), [binarize_tree(parse_tree[0]), binarize_tree(parse_tree[1])])
        elif len(parse_tree) > 2:
            return Tree(parse_tree.label(), [parse_tree[0], binarize_tree(Tree(empty_label, parse_tree[1:]))])        

In [33]:
example_tree_binarized = binarize_tree(example_tree)

In [127]:
example_tree_binarized.pretty_print()
get_span_labels(example_tree_binarized, verbose=False)

                                                     S                                                                               
                         ____________________________|______________________                                                          
                        |                                                  Null                                                      
                        |                             ______________________|______________________________________________________   
                        |                            VP                                                                            | 
                        |                        ____|_____________                                                                |  
                        |                       |                  VP                                                              | 
                        |                       |     _____

{(0, 18): 'S',
 (0, 3): 'NP-SBJ',
 (0, 2): 'NP',
 (0, 1): 'NNP',
 (1, 2): 'NNP',
 (2, 3): ',',
 (3, 6): 'ADJP',
 (3, 5): 'NP',
 (3, 4): 'CD',
 (4, 5): 'NNS',
 (5, 6): 'JJ',
 (7, 18): 'Null',
 (7, 17): 'VP',
 (7, 8): 'MD',
 (8, 17): 'VP',
 (8, 9): 'VB',
 (9, 17): 'Null',
 (9, 11): 'NP',
 (9, 10): 'DT',
 (10, 11): 'NN',
 (11, 17): 'Null',
 (11, 15): 'PP-CLR',
 (11, 12): 'IN',
 (12, 15): 'NP',
 (12, 13): 'DT',
 (13, 15): 'Null',
 (13, 14): 'JJ',
 (14, 15): 'NN',
 (15, 17): 'NP-TMP',
 (15, 16): 'NNP',
 (16, 17): 'CD',
 (17, 18): '.'}

#### Score-based CKY variant: For the Vanilla CKY, we constructed parse trees by filling an upper triangular matrix in a left to right bottom-up fashion, then traced the back pointers to recover the parse. We will use a variant of CKY which uses span span scores and the following recursion relation for computing the highest scoring subtree rooted at a given span.

Let $s_{best}(i,j)$ denote the score of the best sub-tree for the span `(i,j)`. Let $score(i,j,l)$ denote the score of the span `(i,j)` for label $l$. Then the base case of the recusion relation, i.e. for the spans of length 1, is given by:

$s_{best}(i,i+1) = \max_l score(i,i+1,l)$

and for the general case:

$s_{best}(i,j) = \max_l score(i,j,l) + \max_k (s_{best}(i,k) + s_{best}(k,j))$

Now, we can fill the upper triangular matrix by computing the value of $s_{best}(i,j)$ inside each cell using this equation (initializing the super-diagonal cell values using the base case, and initializing the remaining cell values as -$\infty$).

At each split, we also store the label $l$ and the split position $k$ in our back pointers, which will then allow us to recover the parse tree.

Then the score of the entire tree is computed by summing up the scores for each node as follows:

$S(T) = \sum_{(i,j,l) \in T} score(i,j,l)$


In [110]:
"""
# lets implement this CKY variant and demonstrate with a simple example

# define a score function (can be anything) and assume we have only 3 different labels and 4 words in our sentence
L = 3
n = 4
scores = torch.rand((n+1,n+1,L)) 

spans = [(i,j+1) for i in range(n) for j in range(i,n)]
s_best = {span: float('-inf') for span in spans}
back = {}

# base case initialization
for i in range(n):
    s_best[(i,i+1)] = scores[i,i+1,:].max().item()
    best_label = scores[i,i+1,:].argmax().item()
    back[(i,i+1)] = (i,best_label)

# gerenal case bottom up
for j in range(1, n+1):
    for i in range(j-2, -1, -1):
        print(f"span: ({i},{j})")
        # get best label and score
        best_label_score = scores[i,j,:].max().item()
        best_label = scores[i,j,:].argmax().item()
        # get best split
        best_split_score = float('-inf') 
        for k in range(i+1, j):
            split_score = s_best[(i,k)] + s_best[(k,j)]
            if split_score > best_split_score:
                best_split_score = split_score
                best_k = k
        # score of best subtree for this span
        s_best[(i,j)] = best_label_score + best_split_score 
        # back pointer
        back[(i,j)] = (best_k,best_label) 
        print(f"Best label: {best_label}, Best label score: {best_label_score}, Best split score: {best_split_score}, Best split: {best_k}")       


# now let's recursively retreive the best tree
def get_tree(i, j, back):
    if i == j-1:
        return Tree(f"w{i}", [f"w{i}"])
    else:
        k, label = back[(i,j)]
        return Tree(f"nt{label}", [get_tree(i,k,back), get_tree(k,j,back)])

best_tree = get_tree(0, n, back)
best_tree.pretty_print()        

# recursively retriev the (span, label) tuples for all nodes in the best tree
def get_tree_spans_labels(i, j, back):
    if i == j-1:
        k, label = back[(i,j)]
        return [(i,j,label)]
    else:
        k, label = back[(i,j)]
        return [(i,j,label)] + get_tree_spans_labels(i, k, back) + get_tree_spans_labels(k, j, back)   

span_labels = get_tree_spans_labels(0, n, back) 
print(span_labels)

# compute total score of the tree
tot_score = 0.0
for (i,j,label) in span_labels:
    tot_score += scores[i,j,label].item()
print(f"Total score: {tot_score}")    
"""

'\n# lets implement this CKY variant and demonstrate with a simple example\n\n# define a score function (can be anything) and assume we have only 3 different labels and 4 words in our sentence\nL = 3\nn = 4\nscores = torch.rand((n+1,n+1,L)) \n\nspans = [(i,j+1) for i in range(n) for j in range(i,n)]\ns_best = {span: float(\'-inf\') for span in spans}\nback = {}\n\n# base case initialization\nfor i in range(n):\n    s_best[(i,i+1)] = scores[i,i+1,:].max().item()\n    best_label = scores[i,i+1,:].argmax().item()\n    back[(i,i+1)] = (i,best_label)\n\n# gerenal case bottom up\nfor j in range(1, n+1):\n    for i in range(j-2, -1, -1):\n        print(f"span: ({i},{j})")\n        # get best label and score\n        best_label_score = scores[i,j,:].max().item()\n        best_label = scores[i,j,:].argmax().item()\n        # get best split\n        best_split_score = float(\'-inf\') \n        for k in range(i+1, j):\n            split_score = s_best[(i,k)] + s_best[(k,j)]\n            if split_

Now lets create a pytorch dataset for creating the (input, target) instances for our BERT span score prediction model.

In [131]:
# first, let's binarize all the parse trees
parse_trees_binarized = [binarize_tree(t) for t in parse_trees]

# get span labels for all sentences
#span_labels = [get_span_labels(p) for p in parse_trees_binarized]
for i, p in enumerate(parse_trees_binarized):
    print(i)
    span_labels = get_span_labels(p)
    
# create mapping of span labels to unique ids
unique_span_labels = list(set([label for span in span_labels for label in span.values()]))
label2idx = {label: i for i, label in enumerate(unique_span_labels)}    

# create train-val splits
n_train = int(0.9 * len(sentences))
sentences_train = sentences[:n_train]
parse_trees_train = parse_trees_binarized[:n_train]
span_labels_train = span_labels[:n_train]
sentences_val = sentences[n_train:]
parse_trees_val = parse_trees_binarized[n_train:]
span_labels_val = span_labels[n_train:]

print(f"Number of training sentences: {len(sentences_train)}")
print(f"Number of validation sentences: {len(sentences_val)}")

0
1
2
3
4
                                                                                                                                    S                                                                                                                                     
                                            ________________________________________________________________________________________|___________________________________________________________________________________________________________                           
                                        S-TPC-2                                                                                                                                                                                                 |                         
                ___________________________|________________________________________________                                                                                                

ValueError: Span start index is greater than or equal to span end index: (16, 11)

In [135]:
parse_trees[4].pretty_print()

                                                                                                                                    S                                                                                                                                     
                                            ________________________________________________________________________________________|___________________________________________________________________________________________________________________________________   
                                        S-TPC-2                                                                                                                                                                                 |       |             |                 | 
                ___________________________|________________________________________________                                                                                                          

In [134]:
parse_trees_binarized[4].pretty_print()

                                                                                                                                    S                                                                                                                                     
                                            ________________________________________________________________________________________|___________________________________________________________________________________________________________                           
                                        S-TPC-2                                                                                                                                                                                                 |                         
                ___________________________|________________________________________________                                                                                                          

In [114]:
class ParseTreeDataset(Dataset):
    def __init__(self, sentences, span_labels, label2idx, block_size=256, max_spans=1024):
        self.sentences = sentences
        self.span_labels = span_labels
        self.label2idx = label2idx
        self.block_size = block_size
        self.max_spans = max_spans
        self.tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        # get sentence and span labels
        sentence = self.sentences[idx]
        span_labels = self.span_labels[idx]
        # convert span labels to indices
        span_labels_idx = {span:self.label2idx[label] for span, label in span_labels.items()}

        # tokenize the sentence
        input_encoding = self.tokenizer.encode_plus(sentence, is_split_into_words=True, return_offsets_mapping=False, padding=False, truncation=False, add_special_tokens=True)
        input_idx = input_encoding['input_ids']
        word_ids = input_encoding.word_ids()

        if len(input_idx) > self.block_size:
            raise ValueError(f"Tokenized sentence {idx} is too long: {len(input_idx)}. Truncation unsupported.")

        # add padding 
        input_idx = input_idx + [self.tokenizer.pad_token_id] * (self.block_size - len(input_idx))    
        # create attention mask 
        input_attn_mask = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in input_idx]

        # convert to tensors
        input_idx = torch.tensor(input_idx)
        input_attn_mask = torch.tensor(input_attn_mask) 
        
        return input_idx, input_attn_mask, word_ids, span_labels_idx   


def collate_fn(batch):
    # Separate the tensors and the dictionaries
    input_idxs, input_attn_masks, word_ids, span_labels_idx = zip(*batch)

    # Default collate the tensors
    input_idxs = torch.stack(input_idxs)
    input_attn_masks = torch.stack(input_attn_masks)

    # Handle the dictionaries
    spans_labels = []
    for d in span_labels_idx:
        # Convert the dictionary to two lists: one for spans and one for labels
        spans, labels = zip(*d.items())
        spans_labels.append((spans, labels))

    return input_idxs, input_attn_masks, word_ids, spans_labels


Now lets implement the BERT model for span score prediction, which uses `margin-based training`. For margin-based training, we use an `SVM/hinge loss` function of the following form:

$\text{SVM Loss} = \max (0, Hamming(T, T^*) + S(T) - S(T^*))$ 

where $S(T)$ and $S(T^*)$ are the total scores of the predicted and ground-truth parese trees respectively and we use the `Hamming distance` between $T$ and $T^*$ (i.e. the proportion of labeled spans that are different between the predicted and actual trees).

In [119]:
class BERT_CKY(torch.nn.Module):
    def __init__(self, num_classes, dropout_rate=0.1, mlp_hidden_size=128):
        super().__init__()
        # load pretrained BERT model
        self.bert_encoder = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.dropout = torch.nn.Dropout(dropout_rate)
        # define classifier head (2 layer MLP)
        self.classifier_head = torch.nn.Sequential(
            torch.nn.Linear(self.bert_encoder.config.hidden_size, mlp_hidden_size),
            torch.nn.LayerNorm(mlp_hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(mlp_hidden_size, num_classes)
        )
        
        # make sure BERT parameters are trainable
        for param in self.bert_encoder.parameters():
            param.requires_grad = True


    def get_word_embeddings(self, bert_output, word_ids, i):
        word_embeddings = [bert_output[i,0,:]] # add CLS token embedding
        # construct embeddings for each word in the original (untokenized) sequence
        for j in range(word_ids[-2]+1):
            y_j = bert_output[i,word_ids.index(j),:] # gets first subword token embedding for word j
            word_embeddings.append(y_j)        
        word_embeddings.append(bert_output[i,len(word_ids)-1,:]) # add SEP token embedding
        return word_embeddings

    def compute_span_scores(self, word_embeddings, spans):
        span_scores = {}
        for i,j in spans:
            # get right and left fencepost embeddings for the bounding fenceposts of this span
            y_j_right = word_embeddings[j][768//2:]
            y_i_right = word_embeddings[i][768//2:]
            y_jplus1_left = word_embeddings[j+1][:768//2]
            y_iplus1_left = word_embeddings[i+1][:768//2]
            # concatenate difference vectors to get span embedding
            span_embedding = torch.cat([y_j_right-y_i_right, y_jplus1_left-y_iplus1_left], dim=0) # shape: (hidden_size,)
            # compute logits for this span
            span_scores[(i,j)] = self.classifier_head(span_embedding) # shape: (num_classes,)
        return span_scores
    
    def get_tree_spans_labels(self, i, j, back):
        if i == j-1:
            _, label = back[(i,j)]
            return [(i,j,label)]
        else:
            k, label = back[(i,j)]
            return [(i,j,label)] + self.get_tree_spans_labels(i, k, back) + self.get_tree_spans_labels(k, j, back)  

    def cky(self, spans, span_scores, num_words):
        s_best = {span: float('-inf') for span in spans}
        back = {}
        # base case initialization
        for i in range(num_words):
            s_best[(i,i+1)] = span_scores[(i,i+1)].max()
            best_label = span_scores[(i,i+1)].argmax().item()
            back[(i,i+1)] = (i,best_label)

        # gerenal case bottom up
        for j in range(1, num_words+1):
            for i in range(j-2, -1, -1):
                #print(f"span: ({i},{j})")
                # get best label and score
                best_label_score = span_scores[(i,j)].max()
                best_label = span_scores[(i,j)].argmax().item()
                # get best split
                best_split_score = float('-inf') 
                for k in range(i+1, j):
                    split_score = s_best[(i,k)] + s_best[(k,j)]
                    if split_score.item() > best_split_score:
                        best_split_score = split_score
                        best_k = k
                # score of best subtree for this span
                s_best[(i,j)] = best_label_score + best_split_score 
                # back pointer
                back[(i,j)] = (best_k,best_label) 
                #print(f"Best label: {best_label}, Best label score: {best_label_score}, Best split score: {best_split_score}, Best split: {best_k}")   
        return s_best, back
    

    def forward(self, input_idx, input_attn_mask, batch_word_ids, targets=None):
        # compute BERT embeddings for input tokens
        bert_output = self.bert_encoder(input_idx, attention_mask=input_attn_mask)
        bert_output = self.dropout(bert_output.last_hidden_state) # shape: (batch_size, block_size, hidden_size)

        loss = 0.0
        batch_back_pointers = []
        batch_span_labels = []
        # for each sequence in batch, get word embeddings for each word by using the BERT embeddings for first subword token for that word
        for batch_idx, word_ids in enumerate(batch_word_ids):
            word_embeddings = self.get_word_embeddings(bert_output, word_ids, batch_idx)
            n = len(word_embeddings)-2
            # now that we have the word embeddings, we need to construct span embeddings and compute scores for all possible spans
            all_spans = [(i,j+1) for i in range(n) for j in range(i,n)]
            #print(f"Lenght of word embeddings list: {len(word_embeddings)}")
            #print(f"num fenceposts: {n}, num spans: {len(all_spans)}")
            #print(f"spans: {all_spans}")
            span_scores = self.compute_span_scores(word_embeddings, all_spans)

            # now apply CKY algorithm to get the best tree for this sequence
            s_best , back = self.cky(all_spans, span_scores, n)
            batch_back_pointers.append(back)
              
            # recursively retrieve the (span, label) tuples for all nodes in the best predicted parse tree
            span_labels = self.get_tree_spans_labels(0, n, back) 
            batch_span_labels.append(span_labels)
            #print(span_labels)

            if targets is None:
                continue

            # compute total score of the predicted parse tree by summing up scores of all the constituent spans
            tree_score = 0.0
            for (i,j,label) in span_labels:
                tree_score += span_scores[(i,j)][label]
            #print(f"Total score: {tot_score}")   
                
            # compute the total score of the gold standard parse tree and Hamming loss
            gold_spans, gold_labels = targets[batch_idx]
            print(f"Gold spans: {gold_spans}")
            print(f"Gold labels: {gold_labels}")
            gold_score = 0.0    
            hamming = 0.0
            for (i,j),label in zip(gold_spans, gold_labels):
                gold_score += span_scores[(i,j)][label]
                if (i,j,label) not in span_labels:
                    hamming += 1
            hamming = hamming/len(gold_spans) # normalize by number of spans

            # accumulate maximum-margin/hinge loss for this sequence
            loss += max(0, hamming + tree_score - gold_score)

        loss = loss / len(input_idx) # average loss over batch
        return loss, batch_span_labels, batch_back_pointers
  

    """    
    def forward(self, input_idx, input_attn_mask, batch_word_ids, targets=None):
        # compute BERT embeddings for input tokens
        bert_output = self.bert_encoder(input_idx, attention_mask=input_attn_mask)
        bert_output = self.dropout(bert_output.last_hidden_state) # shape: (batch_size, block_size, hidden_size)

        if targets is not None:
            # for each sequence, get word embeddings for each word by using the BERT embeddings for first subword token for that word
            logits = []
            loss = 0.0
            for i, word_ids in enumerate(batch_word_ids):
                word_embeddings = [bert_output[i,0,:]] # add CLS token embedding
                # construct embeddings for each word in the original (untokenized) sequence
                for j in range(word_ids[-2]+1):
                    y_j = bert_output[i,word_ids.index(j),:] # gets first subword token embedding for word j
                    word_embeddings.append(y_j)        
                word_embeddings.append(bert_output[i,len(word_ids)-1,:]) # add SEP token embedding

                # now that we have the word embeddings, we need to construct the span embeddings
                spans, labels = targets[i]
                span_logits = []
                for i,j in spans:
                    # get right and left fencepost embeddings for the bounding fenceposts of this span
                    y_j_right = word_embeddings[j][768//2:]
                    y_i_right = word_embeddings[i][768//2:]
                    y_jplus1_left = word_embeddings[j+1][:768//2]
                    y_iplus1_left = word_embeddings[i+1][:768//2]
                    # concatenate difference vectors to get span embedding
                    span_embedding = torch.cat([y_j_right-y_i_right, y_jplus1_left-y_iplus1_left], dim=0) # shape: (hidden_size,)
                    # compute logits for this span
                    span_logits.append(self.classifier_head(span_embedding)) # shape: (num_classes,)

                span_logits = torch.stack(span_logits, dim=0) # shape: (num_spans, num_classes)
                logits.append(span_logits)
                # accumulate loss for the spans in this sequence (note that we only compute losses for labeled spans)
                loss += F.cross_entropy(span_logits, torch.tensor(labels, device=input_idx.device))

            loss = loss / len(input_idx) # average loss over batch
            return logits, loss

        else:
            loss = None
            logits = []
            batch_spans = []
            for i, word_ids in enumerate(batch_word_ids):
                word_embeddings = [bert_output[i,0,:]] # add CLS token embedding
                # construct embeddings for each word in the original (untokenized) sequence
                for j in range(word_ids[-2]+1):
                    y_j = bert_output[i,word_ids.index(j),:] # gets first subword token embedding for word j
                    word_embeddings.append(y_j)        
                word_embeddings.append(bert_output[i,len(word_ids)-1,:]) # add SEP token embedding

                # in inference mode, we will compute scores for all possible spans
                spans = [(i,j+1) for i in range(len(word_embeddings)) for j in range(i, len(word_embeddings))]
                # compute logits
                span_logits = []
                for i,j in spans:
                    # get right and left fencepost embeddings for the bounding fenceposts of this span
                    y_j_right = word_embeddings[j][768//2:]
                    y_i_right = word_embeddings[i][768//2:]
                    y_jplus1_left = word_embeddings[j+1][:768//2]
                    y_iplus1_left = word_embeddings[i+1][:768//2]
                    span_embedding = torch.cat([y_j_right-y_i_right, y_jplus1_left-y_iplus1_left], dim=0) # shape: (hidden_size,)                    
                    # compute logits for each span
                    span_logits.append(self.classifier_head(span_embedding)) # shape: (num_classes,)  
                span_logits = torch.stack(span_logits, dim=0) # shape: (num_spans, num_classes)
                logits.append(span_logits)
                batch_spans.append(spans)
            return logits, batch_spans
            """
        

# training loop
def train(model, optimizer, train_dataloader, val_dataloader, scheduler=None, device="cpu", num_epochs=10, val_every=100, save_every=None, log_metrics=None):
    avg_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    model.train()
    # reset gradients
    optimizer.zero_grad()
    for epoch in range(num_epochs):
        num_correct = 0
        num_total = 0
        pbar = tqdm(train_dataloader, desc="Epochs")
        for i, batch in enumerate(pbar):
            input_idx, input_attn_mask, word_ids, targets = batch
            # move batch to device
            input_idx, input_attn_mask = input_idx.to(device), input_attn_mask.to(device)
            # forward pass
            loss, batch_span_labels, batch_back_pointers = model(input_idx, input_attn_mask, word_ids, targets)
            # reset gradients
            optimizer.zero_grad()
            # backward pass
            loss.backward()
            # optimizer step
            optimizer.step()

            if scheduler is not None:
                    scheduler.step()

            avg_loss = 0.9* avg_loss + 0.1*loss.item()
            for i, (gold_spans, gold_labels) in enumerate(targets):
                span_labels = batch_span_labels[i]
                for (i,j),label in zip(gold_spans, gold_labels):
                    if (i,j,label) in span_labels:
                        num_correct += 1
                num_total += len(gold_spans)
            train_acc = num_correct / num_total        

            if val_every is not None:
                if i%val_every == 0:
                    # compute validation loss
                    val_loss, val_acc = validation(model, val_dataloader, device=device)
                    pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}")  

            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}")  

            if log_metrics:
                metrics = {"Batch loss":loss.item(), "Moving Avg Loss":avg_loss, "Train Accuracy":train_acc, "Val Loss": val_loss, "Val Accuracy":val_acc}
                log_metrics(metrics)

        if save_every is not None:
            if (epoch+1) % save_every == 0:
                save_model_checkpoint(model, optimizer, epoch, avg_loss)


def validation(model, val_dataloader, device="cpu"):
    model.eval()
    val_losses = torch.zeros(len(val_dataloader))
    with torch.no_grad():
        num_correct = 0
        num_total = 0
        for i,batch in enumerate(val_dataloader):
            input_idx, input_attn_mask, word_ids, targets = batch
            input_idx, input_attn_mask = input_idx.to(device), input_attn_mask.to(device)
            loss, batch_span_labels, batch_back_pointers = model(input_idx, input_attn_mask, word_ids, targets)
            for i, (gold_spans, gold_labels) in enumerate(targets):
                span_labels = batch_span_labels[i]
                for (i,j),label in zip(gold_spans, gold_labels):
                    if (i,j,label) in span_labels:
                        num_correct += 1
                num_total += len(gold_spans)
            val_losses[i] = loss.item()
    model.train()
    val_loss = val_losses.mean().item()
    val_accuracy = num_correct / num_total
    return val_loss, val_accuracy


def save_model_checkpoint(model, optimizer, epoch=None, loss=None, filename='BERT_CKY_checkpoint.pth'):
    # Save the model and optimizer state_dict
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    # Save the checkpoint to a file
    torch.save(checkpoint, filename)
    print(f"Saved model checkpoint!")


def load_model_checkpoint(model, optimizer=None,  filename='BERT_CKY_checkpoint.pth'):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    print("Loaded model from checkpoint!")
    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        model.train()
        return model, optimizer          
    else:
        return model        

In [120]:
B = 16
DEVICE = "cuda"
learning_rate = 1e-5
epochs = 1

train_dataset = ParseTreeDataset(sentences_train, span_labels_train, label2idx)
val_dataset = ParseTreeDataset(sentences_val, span_labels_val, label2idx)

train_dataloader = DataLoader(train_dataset, batch_size=B, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=B, collate_fn=collate_fn)

In [121]:
# model with finetuning disabled
model = BERT_CKY(num_classes=len(label2idx)).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_dataloader) * epochs 
warmup_steps = int(len(train_dataloader) * 0.1 *  epochs) 
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
#model, optimizer = load_model_checkpoint(model, optimizer)

num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Total number of parameters in transformer network: 66.513813 M
RAM used: 5867.46 MB


In [122]:
train(model, optimizer, train_dataloader, val_dataloader, device=DEVICE, num_epochs=epochs, scheduler=scheduler, save_every=None, val_every=30) 

Epochs:   0%|          | 0/220 [00:00<?, ?it/s]

Gold spans: ((0, 18), (0, 3), (0, 2), (0, 1), (1, 2), (2, 3), (3, 6), (3, 5), (3, 4), (4, 5), (5, 6), (7, 18), (7, 17), (7, 8), (8, 17), (8, 9), (9, 17), (9, 11), (9, 10), (10, 11), (11, 17), (11, 15), (11, 12), (12, 15), (12, 13), (13, 15), (13, 14), (14, 15), (15, 17), (15, 16), (16, 17), (17, 18))
Gold labels: (292, 161, 244, 348, 348, 145, 149, 244, 164, 309, 275, 312, 217, 213, 217, 306, 312, 244, 367, 178, 312, 93, 163, 244, 367, 312, 275, 178, 49, 348, 164, 203)
Gold spans: ((0, 13), (0, 2), (0, 1), (1, 2), (2, 13), (2, 12), (2, 3), (3, 12), (3, 4), (4, 12), (4, 5), (5, 12), (5, 7), (5, 6), (6, 7), (7, 12), (7, 8), (8, 12), (8, 9), (9, 12), (9, 10), (10, 12), (10, 11), (11, 12), (12, 13))
Gold labels: (292, 161, 348, 348, 312, 217, 27, 283, 178, 109, 163, 244, 244, 348, 348, 312, 145, 244, 367, 312, 348, 312, 100, 178, 203)
Gold spans: ((0, 27), (0, 3), (0, 2), (0, 1), (1, 2), (2, 3), (3, 14), (3, 6), (3, 5), (3, 4), (4, 5), (5, 6), (6, 7), (7, 14), (7, 9), (7, 8), (8, 9), (9, 1

Epochs:   0%|          | 0/220 [00:10<?, ?it/s]

Gold spans: ((0, 35), (0, 29), (0, 4), (0, 3), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (6, 29), (6, 7), (7, 9), (7, 8), (8, 9), (9, 14), (9, 10), (10, 14), (10, 11), (11, 14), (11, 12), (12, 14), (12, 13), (13, 14), (15, 29), (15, 16), (16, 29), (16, 11), (16, 19), (16, 17), (17, 18), (18, 19), (19, 11), (19, 20), (21, 29), (21, 22), (22, 29), (22, 23), (23, 29), (23, 24), (24, 29), (24, 25), (25, 29), (25, 26), (26, 27), (27, 29), (27, 28), (28, 29), (3, 35), (30, 35), (30, 31), (31, 35), (31, 34), (31, 32), (32, 34), (32, 33), (33, 34), (34, 35))
Gold labels: (292, 220, 161, 244, 367, 178, 178, 145, 178, 217, 27, 187, 398, 275, 377, 163, 292, 300, 217, 27, 244, 367, 309, 109, 163, 226, 161, 244, 398, 275, 309, 109, 392, 217, 100, 244, 309, 113, 64, 292, 54, 217, 79, 126, 130, 309, 275, 312, 312, 309, 312, 217, 84, 113, 54, 54, 203)





KeyError: (16, 11)