# Assignment 6: Context Free Grammar (CKY)

## Question 1: Chomsky Normal Form

In [62]:
from collections import defaultdict

In [124]:
def convert_rules_simplied(grammar_rules):
    """
    Converts a list context-free grammar rules into the formatted Chomsky Normal Form.
    """
    formatted_rules = []
    
#     counter = 0
    for rule in grammar_rules:
        
        rule = rule.replace("->", "").split()
        new_rules = []
        
        
        # Rule is in form A -> B C D [...] or A -> B a
        while len(rule) > 3:
                
            # Create a new non-terminal symbol and replace two symbols with it
            # This can be part of the assignment
            new_rules.append([f"{rule[0]}{str(counter)}", rule[1], rule[2]])
            rule = [rule[0]] + [f"{rule[0]}{str(counter)}"] + rule[3:]
            counter += 1

        # Adds a rule to the dictionary
        formatted_rules.append(rule)
        
        if len(new_rules):
            formatted_rules.extend(new_rules)
        
    return formatted_rules


def convert_rules(grammar_rules):
    """
    Converts a list context-free grammar rules into the formatted Chomsky Normal Form.
    """
    
    # Stores a dictionary of lists of rules
    rule_dict = defaultdict(list)
    
    unary_rules, formatted_rules = [], []
    
    counter = 0
    for rule in grammar_rules:
        
        rule = rule.replace("->", "").split()
        new_rules = []

        # Rule is in form A -> B, back it up for later and continue with next rule
        if len(rule) == 2 and rule[1][0] != "'":
            unary_rules.append(rule)
            
            # Adds a rule to the dictionary
            rule_dict[rule[0]].append(rule[1:])
        
        
        # Rule is in form A -> B C D [...] or A -> B a
        elif len(rule) > 2:
            
            terminals = [(idx, item) for idx, item in enumerate(rule) if item[0] == "'"]
            
            if len(terminals):
                for (idx, item) in terminals:
                    
                    # Create a new non-terminal symbol and replace the terminal symbol with it
                    rule[idx] = f"{rule[0]}{str(counter)}"
                    new_rules += [f"{rule[0]}{str(counter)}", item]
                
                counter += 1
            
            while len(rule) > 3:
                
                # Create a new non-terminal symbol and replace two symbols with it
                new_rules.append([f"{rule[0]}{str(counter)}", rule[1], rule[2]])
                rule = [rule[0]] + [f"{rule[0]}{str(counter)}"] + rule[3:]
                counter += 1

        # Adds a rule to the dictionary
        rule_dict[rule[0]].append(rule[1:])
        formatted_rules.append(rule)
        
        if new_rules:
            formatted_rules.extend(new_rules)
        
        
        # Recursively combine the unary rules with an existing rule (if possible)
        while unary_rules:
            rule = unary_rules.pop()
            if rule[1] in rule_dict:
                for item in rule_dict[rule1]:
                    new_rule = [rule[0]] + item
                    
                    # If the new rule is binary or contains a terminal
                    if len(new_rule) > 2 or new_rule[1][0] == "'":
                        formatted_rules.insert(0, new_rule)
                    else:
                        unary_rules.append(new_rule)
                    
                    rule_dict[new_rule[0]].append(new_rule[1:])
        
    return formatted_rules

## Question 2: CKY Parsing

In [116]:
class Node:
    """
    Barebone data structure used for storing information about a non-terminal symbol
    """

    def __init__(self, symbol, child1, child2=None):
        self.symbol = symbol
        self.child1 = child1
        self.child2 = child2

    def __repr__(self):
        return self.symbol



def cky_parse(text, rules):
    """
    Performs Constituency Parsing using the CKY algorithm.
    """
    tokens = text.split()
    length = len(tokens)
    
    # Data structure for storing the subtrees
    parse_triangle = [[[] for x in range(length - i)] for i in range(length)]
    
    for i, tok in enumerate(tokens):
        
         # Find out which non terminals can generate the terminals in the input string
         # and put them into the parse table. One terminal could be generated by multiple
         # non terminals, therefore the parse table will contain a list of non terminals.
        for rule in rules:
            if f"'{tok}'" == rule[1]:
                parse_triangle[0][i].append(Node(rule[0], tok))
    
    # Starting from the second row
    for row_idx in range(1, length):
        
        # Number of cells at each row
        n_cells = length - row_idx
        
        for cell_idx in range(n_cells):
            
            # Number of spans being added to the cell
            n_spans = row_idx
            
            for span_idx in range(n_spans):
                
                # This part can be part of the question!
                left_cell = parse_triangle[span_idx][cell_idx]
                right_cell = parse_triangle[row_idx - span_idx - 1][cell_idx + span_idx + 1]
                
                for rule in rules:
                    if len(rule) == 3:
                        
                        # This can also be part of the question!
                        left_nodes = list(filter(lambda n: n.symbol == rule[1], left_cell))
                        right_nodes = list(filter(lambda n: n.symbol == rule[2], right_cell))
                        parse_triangle[row_idx][cell_idx].extend(
                            [Node(rule[0], left, right) for left in left_nodes for right in right_nodes]
                        )

    return parse_triangle
    

## Question 3: Parse Sentences

Ask the students to use the function print out some trees using some text input. 

They need to add the terminals to list of rules, and run the pipeline.

Find 3 sentences where the sentences are contained in our grammar, and 3 sentences not contained.

In [125]:
def generate_tree(node):
    """
    Generates the string representation of the parse tree.
    :param node: the root node.
    :return: the parse tree in string form.
    """
    if node.child2 is None:
        return f"[{node.symbol} '{node.child1}']"
    return f"[{node.symbol} {generate_tree(node.child1)} {generate_tree(node.child2)}]"

def print_tree(parse_triangle, rules, output=True):
    """
    Print the parse tree starting with the start symbol. Alternatively it returns the string
    representation of the tree(s) instead of printing it.
    """
    start_symbol = rules[0][0]
    final_nodes = [n for n in parse_triangle[-1][0] if n.symbol == start_symbol]
    if final_nodes:
        if output:
            print("The given sentence is contained in the language produced by the given grammar!")
            print("\nPossible parse(s):")
        trees = [generate_tree(node) for node in final_nodes]
        if output:
            for tree in trees:
                print(tree)
        else:
            return trees
    else:
        print("The given sentence is not contained in the language produced by the given grammar!")

In [None]:
# For this part we can remove some rules to make the code not run... (e.g. terminals)
grammar_rules = [
    "S -> NP VP",
    "PP -> P NP",
    "NP -> Det N",
    "NP -> Det N PP",
    "VP -> V NP",
    "VP -> VP PP",
    
    # Example of a terminal (use single quotation marks)
    "N -> 'school'",
    
    # Your rules here ...
    "NP -> 'I'",
    "Det -> 'an'",
    "Det -> 'my'",
    "N -> 'elephant'",
    "N -> 'pajamas'",
    "V -> 'shot'",
    "P -> 'in'",
]

In [None]:
### Solutions for Q3
normalized_rules = convert_rules_simplied(grammar_rules)
parsed_triangle = cky_parse("I shot an elephant in my pajamas", rules)
print_tree(parse_triangle, rules)