#### BERT augmented CKY variant

Previously we looked at implememnting the CKY algorithm to generate parse trees for sentences conforming to a context-free grammar in Chomsky-normal-form. We found that sentences generated by a CFG suffer from syntax ambiguities, i.e. a sentence can have multiple valid parse trees, each with a different meaning. However, we know that out of the different possible parses, typically only one of them is the "correct" one, i.e. captures the intended meaning of the sentence. 

In this notebook, we will look an a variant of the CKY algorithm in which each possible span of a sentence is assigned a score, and these socres can be used to arrive at the correct parse tree. These scores can be computed with the help of a neural network trained on an annotated treebank dataset. A BERT model is powerful and well-suited for this task. The diagram below (borrowed from the Jurafsky-Martin textbook) summarizes the model architecture:

<img src="neural_parser.png" width="600" height="450">

We define the spans in the same way as we did for the vanilla CKY parser, i.e. using the "fencepost" positions. We also use the same upper-triangular matrix that which we used previously. Instead of using a pre-defined CFG in CNF to assign non-terminal labels to each element in this matrix (which represent the different possible spans), this time we will instead use BERT model to compute a distribution of scores over all possible terminals for each possible span. First, we create a fixed-size vector representation of the span and then feed it into an MLP classifier, as shown in the diagram. We outline the steps in more detail:

1) Convert words to subword tokens
2) Get `BERT embeddings for subwords`
3) Compute the `embeddings for full words` (many ways to do this, e.g. we could just assign the BERT em,bedding for the first subword of that word, or we could take the sum.average of the embedding of all the subwords or we could take element-wise max across all the subword embeddings, etc.)  
4) Compute `embeddings of fence posts` (shown as 0,1,2,3.. in the diagram above). Since each fence-post can represent the beginning or end of a span, we will create two separate representations. We first split the embedding vector $y_t$ of the $t$-th word in the sentence into two halves, $\overleftarrow{y_t}$ and $\overrightarrow{y_t}$ such that the concatentaion $[\overleftarrow{y_t}; \overrightarrow{y_t}] = y_t$. Then the `start-of-span representation` of the fencepost at position $i$ is defined as $\overrightarrow{y_i}$ and the `end-of-span representation` is defined as $\overleftarrow{y}_{i+1}$
5) Construct `embedding for a span` `(i,j)` using the fencepost embeddings as the following concatenation between the difference in start-of span and end-of-span embeddings for the bounding fenceposts: $v(i,j) = [\overrightarrow{y_j}-\overrightarrow{y_i}; \overleftarrow{y}_{j+1}-\overleftarrow{y}_{i+1}]$
6) Pass $v(i,j)$ through the MLP to get a distribution of scores over all possible non-terminal labels.

One really important thing to note here is that we are no longer using a pre-defined context-free grammar. The supervised training of the neural network model will implicitly induce/"learn" the grammar. Also, one downside is that the model may sometimes fail to get a grammatically correct parse of a sentence, because it does not have access to the "true grammar", only some form of statistical approximation of it.

After we implement and train this model, we will iplement the CKY variant that will perform the actual parsing using the scores computed by the BERT model.

In [1]:
from nltk.corpus import treebank

Firstm let's load the data. We will use the NLTK treebank, which is a subset of the original Penn Treebank dataset.

In [40]:
# get all parsed sentences acrsoo all the files
parsed_sentences = treebank.parsed_sents()
print(f"Number of parsed sentences: {len(parsed_sentences)}")

# print the first parsed sentence
example_tree = parsed_sentences[0]
print(example_tree)

Number of parsed sentences: 3914
(S
  (NP-SBJ
    (NP (NNP Pierre) (NNP Vinken))
    (, ,)
    (ADJP (NP (CD 61) (NNS years)) (JJ old))
    (, ,))
  (VP
    (MD will)
    (VP
      (VB join)
      (NP (DT the) (NN board))
      (PP-CLR (IN as) (NP (DT a) (JJ nonexecutive) (NN director)))
      (NP-TMP (NNP Nov.) (CD 29))))
  (. .))


In [43]:
# funciton for extracting all the non-terminal labels for all possible spans from each sentence
def get_span_labels(parse_tree, verbose=False):
    span_labels = {}
    if verbose:
        parse_tree.pretty_print()
    # iterate over all subtrees in level-order traversal
    for subtree in parse_tree.subtrees():       
        span = (parse_tree.leaves().index(subtree.leaves()[0]), parse_tree.leaves().index(subtree.leaves()[-1])+1)
        span_labels[span] = subtree.label()
        if verbose:
            print(f"\nsubtree label: {subtree.label()}")
            print(f"subtree leaves: {subtree.leaves()}")
            print(f"Start fence-post: {parse_tree.leaves().index(subtree.leaves()[0])}")
            print(f"End fence-post: {parse_tree.leaves().index(subtree.leaves()[-1])+1}")
            print(f"span: {span}")
    return span_labels            

In [44]:
get_span_labels(example_tree, verbose=True)

                                                     S                                                                         
                         ____________________________|_______________________________________________________________________   
                        |                                               VP                                                   | 
                        |                        _______________________|___                                                 |  
                      NP-SBJ                    |                           VP                                               | 
         _______________|___________________    |     ______________________|______________________________________          |  
        |          |              ADJP      |   |    |        |                PP-CLR                              |         | 
        |          |           ____|____    |   |    |        |          ________|_________          

{(0, 18): 'S',
 (0, 3): 'NP-SBJ',
 (0, 2): 'NP',
 (0, 1): 'NNP',
 (1, 2): 'NNP',
 (2, 3): ',',
 (3, 6): 'ADJP',
 (3, 5): 'NP',
 (3, 4): 'CD',
 (4, 5): 'NNS',
 (5, 6): 'JJ',
 (7, 17): 'VP',
 (7, 8): 'MD',
 (8, 17): 'VP',
 (8, 9): 'VB',
 (9, 11): 'NP',
 (9, 10): 'DT',
 (10, 11): 'NN',
 (11, 15): 'PP-CLR',
 (11, 12): 'IN',
 (12, 15): 'NP',
 (12, 13): 'DT',
 (13, 14): 'JJ',
 (14, 15): 'NN',
 (15, 17): 'NP-TMP',
 (15, 16): 'NNP',
 (16, 17): 'CD',
 (17, 18): '.'}