# CYK

In this assignment, we'll implement CKY.  See the lecture notes and async material for a detailed discussion of the algorithm.

We'll write code in three parts:
1.  Initial preprocessing of the treebank
2.  Calculation of production rule probabilities
3.  CYK itself

We provide the code for 1 and much of the framework surrounding 2 and 3.

In [None]:
# Import some useful libraries...
import collections
import copy
import math
import nltk
from nltk.tree import Tree
import types

## 1. Preprocessing

This first step of preprocessing takes the treebank, strips out the cross references (NPs are wrapped by special nodes that assign index numbers to them so that coreferences can be indicated).  Unfortunately, this also injects a NP-SBJ-# node between nodes you'd expect to produce one another.  Since the # changes throughout the corpus, our counts of the production rules all end up being 1 - and useless.

See NP-SBJ-1 in the tree below.  Note there is also a NP-SBJ leading to a NONE/1 subtree as a crossreference later.

In [27]:
nltk.corpus.treebank.parsed_sents()[2]

In the code below we skip over nodes whose label start with NP-, connecting any children nodes to the NP-'s parent.  We also snip out any subtrees rooted by NONE.  The tree above is printed again after this next cell to illustrate the effect of this code.

In [32]:
# Preprocess the treebank.
def get_real_child(node):
    if type(node) == types.UnicodeType:
        return [node]
    if 'NONE' in node.label():
        return []
   
    real_children = []
    if node.label().startswith('NP-'):
        for child in node:
            real_children.extend(get_real_child(child))
    else:
        real_children.append(node)
    return [copy_strip_np_sbj(x) for x in real_children]

def copy_strip_np_sbj(sentence):
    children = []
    for child in sentence:
        children.extend(get_real_child(child))
    return Tree(sentence.label(), children)

pre_chomsky = []
sentences = []
for sentence in nltk.corpus.treebank.parsed_sents():
    # Filter out NP-* nodes.
    filtered_sentence = copy_strip_np_sbj(sentence)
    pre_chomsky.append(filtered_sentence)
    
    # Convert sentence to Chomsky normal form.
    transformed_sentence = copy.deepcopy(filtered_sentence)
    nltk.treetransforms.chomsky_normal_form(transformed_sentence, horzMarkov=2)
    
    # Add final sentence to list.
    sentences.append(transformed_sentence)
    
pre_chomsky[2]

Finally, CYK assumes that trees are constructed from a grammar that is in [Chomsky normal form](https://en.wikipedia.org/wiki/Chomsky_normal_form).

This means that the grammar only consists of rules:
- A -> BC
- A -> a
- S -> $\epsilon$

where A, B, C, S are non-terminals and a is a terminal.  $\epsilon$ is the empty sentence.

In order to accomplish this, we add new non-terminals to the language and build longer sequences of non-terminals through them.  Concretely,
- A -> BCD

becomes
- A -> BE
- E -> CD

Adding all these non-terminals with opaque names starts getting confusing, so one notation that's popular is using "A|C-D" as the name of the new terminal instead of "E".

This works pretty well until you have grammar rules like:
- A -> BCDEFGHIJKL

In which case you'd induce a new node: A|B-C-D-E-F-G-H-... and week 2's sparcity concerns should be coming to mind.  To keep our counts relatively large, there is a hyperparameter we can pick (analogous to the n of n-gram) called the horizontal markovization parameter.  It does just what you'd expect: it controls how the number of symbols after the pipe in the node name.  This allows evidence to collect across more examples that are similar in structure.

Take a minute to play with the ```horzMarkov``` parameter in the block above to see how this works. 

In [33]:
sentences[2]

# 2. Production rule probabilities

In this next section, you'll compute about production rule probabilities.

Remember that a production rule now looks like this:
- A -> BC; or,
- A -> a

The left hand side (LHS) of these rules only ever consist of a single non-terminal.  The right hand side (RHS) consists of two non-terminals or one terminal.

We'll do this in two stages:
- Count LHS, and (LHS,RHS) each in their own dict
- Calculate $P(RHS | LHS) = \frac{count(LHS, RHS)}{count(LHS)}$

Before we get started though, let's play a bit with the NLTK API.

### TODO
In the next cell, take sentence[0] and display it like we do above.

In [37]:
# YOUR CODE HERE
# END YOUR CODE

### TODO

In the next cell, print the label of the root of the sentence and also all the labels of the child nodes (note, there should only be two children due to the normalization done above).

Hint:  The "sentence" object is a [Tree](http://www.nltk.org/_modules/nltk/tree.html).  See the iteration and label methods.

In [39]:
# YOUR CODE HERE
# END YOUR CODE

### TODO

Output:
- all the production rules found in this sentence.
- the left hand side of the first production.
- the right hand side of the second production.

Hint:  There is a one-line solution to all.  See the Tree API.

In [41]:
# YOUR CODE HERE
# END YOUR CODE

### TODO

With that API fun out of the way, loop over all the sentences and fill production_counts with a count of each LHS, RHS pair and lhs_counts with the number of times you've seen each non-terminal on the LHS.

If everything works, you should see this in the cell below

```[(, -> ',', 4885),
 (PP -> IN NP, 4045),
 (DT -> 'the', 4038),
 (. -> '.', 3828),
 (S|<VP-.> -> VP ., 3018),
 (IN -> 'of', 2319),
 (NP -> NP PP, 2188),
 (TO -> 'to', 2161),
 (NP -> DT NN, 2020),
 (DT -> 'a', 1874)]```

In [44]:
production_counts = collections.Counter()
lhs_counts = collections.Counter()
# YOUR CODE HERE
# END YOUR CODE
sorted([x for x in production_counts.iteritems()], key=lambda x: x[1], reverse=True)[0:10]

### TODO
Compute the probability of each potential RHS given the LHS.

Hint: As usual, we run into numerical issues when multiplying probabilities.  You should take the usual approach here: use math.log(numerator) - math.log(denominator) and add log-probabilities together instead of multipling probabilities.

The final result of this cell, scored_productions, should be a dict mapping from RHS -> [(LHS_1, log_probability_1), (LHS_2, log_probability_2), ...]

Each LHS is the left hand side of a production rule that can create the RHS along with the probability of it doing so.  We key this table by RHS instead of LHS as CYK builds its chart from the bottom up (and thus we'll be looking up RHS-s and trying to combine them into LHS-s).

If everything went well, you should see:
```
food [(NN, -6.71280430578804)]
a [(IN, -9.19593714166544), (DT, -1.4717815426061982), (LS, -2.5649493574615367), (JJ, -7.978310969867721)]
I [(NNP, -8.45638105201948), (PRP, -2.720363461335567)]
```

In [50]:
scored_productions = collections.defaultdict(list)
# YOUR CODE HERE
# END YOUR CODE
for w in ['food', 'a', 'I']:
    print w, scored_productions[(w,)]

You don't need to do anything with this next cell except to run it.

It's not particularly useful, but if you need to keep track of what each variable contains, this provides a useful reference.

In [19]:
print 'Productions:'
for production, count in [x for x in production_counts.iteritems()][0:5]:
    print production, count, type(production)

print '\n\nLHS counts:'
for lhs, count in sorted(lhs_counts.iteritems())[:5]:
    print lhs, count
    
print '\n\nLog Probabilities:'
print '\n'.join([str(x) for x in scored_productions.iteritems()][0:10])

## 3. Implement CYK!

After that bit of preamble, you only have one more cell to go!  It's a big one though, so do take your time and get things right.

We've set up the chart for you.  Concretely "chart" is a dict that you can index into first by cell position and then by non-terminal like this:

```chart[(0, 1)][NN]```

The value is a tuple (log_probability, back_trace_tree).

Construct the back_trace_tree by calling Tree(non_terminal, [its, children]).

### TODO: Implement CYK.

HINT: it isn't strictly necessary, but it can be convenient to split the task into first mapping words to options of pre-terminals (i.e. bottom row of the chart) and then build the rest of the chart.

In [61]:
def CYK(words):
    '''Accept a list of words and return a tuple (score, Tree) where Tree contains the parse with score score.'''
    cell_creator = lambda: collections.defaultdict(lambda:(-float('inf'), Tree('unknown', [])))
    chart = collections.defaultdict(cell_creator)
    
    # YOUR CODE HERE
        
    # END YOUR CODE

In [60]:
score_tree = CYK('I eat red hot food with a knife'.split())
assert round(score_tree[0], 2) == -64.89
score_tree[1]

### TODO

Try a few more sentences.  Do you notice any patterns with your results?  Any common types of errors?  Are these an artifact of CYK, or of how you did the markovization/counting?