In [72]:
import nltk
from nltk import ProbabilisticProduction, PCFG, Nonterminal
from tqdm.notebook import tqdm
import random
from typing import List, Tuple, Dict, Any
import itertools
from collections import Counter, defaultdict
import pickle
import os
from nltk import PCFG as nltk_PCFG


### Dark mode tqdm notebook progress bar

In [20]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

## Generate raw pcfg and helper functions

In [21]:
def read_treebank(filename="treebanks/original_treebank_500k.txt"):
    """
    Read in the google books treebank (default)
    """
    with open(filename) as f:    
        lines = [l.strip() for l in f.readlines()]
    
    return lines

def get_productions(treebank):
    """
    Return the counts of the productions from the treebank. 
    Tree.productions() returns a list of productions (lhs, rhs) where lhs is a Nonterminal and rhs is a tuple of Nonterminals and strings (terminals)
    """
    prod_dict = defaultdict(Counter)
    for line in tqdm(treebank[:]):
        tree = nltk.Tree.fromstring(line)

        for prod in tree.productions():
            prod_dict[prod.lhs()][prod.rhs()] += 1

    return prod_dict

def calculate_prob_productions(prod_dict):
    """
    Calculate the probabilities of each productions given the counts.
    """
    prob_productions = []

    for lhs, rhs_counter in tqdm(prod_dict.items()):
        total_counts = sum(rhs_counter.values())
        rhs_probs = {
            rhs: count / total_counts
            for rhs, count in rhs_counter.items()
        }
        
        for rhs, prob in rhs_probs.items():
            prob_productions.append(
                ProbabilisticProduction(lhs, rhs, prob=prob)
            )
    
    return prob_productions

def create_lookup_probs(pcfg):
    """
    Create probability lookup table
    """
    pcfg._lhs_prob_index = {}
    for lhs in pcfg._lhs_index.keys():
        lhs_probs = [prod.prob() for prod in pcfg.productions(lhs=lhs)]
        pcfg._lhs_prob_index[lhs] = lhs_probs
    
    return pcfg

def write_to_pickle(file, filename='grammars/nltk/raw_pcfg.pkl'):
    """
    Write the pcfg to a pickle file
    """
    with open(filename, 'wb') as f:
        pickle.dump(file, f)

In [22]:
def create_raw_pcfg(filename='grammars/nltk/raw_pcfg.pkl'):
    """
    Create a raw pcfg from the treebank and save it to a pickle file (if it doesn't exist already)
    """
    if os.path.exists(filename):
        with open(filename, 'rb') as f:
            pcfg, prob_productions = pickle.load(f)
            return pcfg, prob_productions
        
    treebank = read_treebank()
    prod_dict = get_productions(treebank)

    prob_productions = calculate_prob_productions(prod_dict)

    start = Nonterminal('S')        
    pcfg = PCFG(start, prob_productions)

    pcfg = create_lookup_probs(pcfg)
    write_to_pickle((pcfg, prob_productions))

    return pcfg, prob_productions 

In [23]:
pcfg, prob_productions = create_raw_pcfg()
pcfg

<Grammar with 112842 productions>

## Generate trees from pcfg

In [24]:
def generate_tree(grammar, start=None, depth=None, max_tries=10) -> nltk.Tree:
    """
    Generate a single tree from the grammar with a given start symbol and depth.
    Returns a tree or a ValueError if no tree could be generated.
    """
    if not start:
        start = grammar.start()
    if depth is None:
        depth = 100

    for _ in range(max_tries):
        try:
            tree_str = concatenate_subtrees(grammar, [start], depth)
            return nltk.Tree.fromstring(tree_str)
        except RecursionError:
            pass

    raise ValueError("No tree could be generated with current depth")


def concatenate_subtrees(grammar, items, depth):
    """
    Generates a subtree for each item of the list and concatenates them.
    Returns a string representation of the subtree or an empty list if no items are given.
    """
    if items:
        children = []
        for item in items:
            children.append(generate_subtree(grammar, item, depth))

        return " ".join(children)
    else:
        return []


def generate_subtree(grammar, lhs, depth):
    """
    Given a left hand non terminal, generate a right had side with a given probability
    """
    if depth > 0:
        if isinstance(lhs, Nonterminal):
            # get all possible rules containing this lhs
            productions = grammar.productions(lhs=lhs)
            # get probabilities of rules
            probs = grammar._lhs_prob_index[lhs]

            # Choose a production based on the probabilities (k = amount of samples)
            for prod in random.choices(productions, probs, k=1):
                children = concatenate_subtrees(grammar, prod.rhs(), depth - 1)
                
                return f"({lhs.symbol()} {children})"
        else:
            # lhs is a terminal and should not have children
            return lhs
    else:
        raise RecursionError
    
def generate_corpus_from_pcfg(pcfg, corpus_length=450_000, depth=30, sentence_length=(6, 25)) -> list:
    """
    Generate a corpus of sentences with a given length from a PCFG
    """
    corpus = []
    emergency_stop = 0

    while len(corpus) < corpus_length:
        sentence = generate_tree(pcfg, depth=depth)
        if sentence_length[0] <= len(sentence.leaves()) <= sentence_length[1]:
            corpus.append(" ".join(sentence.leaves()))

        emergency_stop += 1
        if emergency_stop % 100_000 == 0:
            print(f"Current corpus length: {len(corpus)}")
            
        elif emergency_stop > 1_000_000:
            print("Emergency stop. Could not generate enough sentences.")
            break
    
    return corpus


def write_to_txtfile(corpus, filename="corpora/raw_pcfg_corpus.txt"):
    with open(filename, 'w') as f:
        f.write('\n'.join(corpus))


In [25]:
def load_corpus(pcfg, filename="corpora/raw_pcfg_corpus.txt"):
    if os.path.exists(filename):
        with open(filename) as f:
            corpus = [l.strip('\n') for l in f.readlines()]
        
        return corpus
    
    corpus = generate_corpus_from_pcfg(pcfg)
    write_to_txtfile(corpus)

    return corpus

In [26]:
corpus = load_corpus(pcfg)
len(corpus)

450000

## Prune raw PCFG (i.e. create sub pcfgs)

Given a 'parent' PCFG, either:
- Remove the tail of probability productions ``(type(top_k) = int)``
- Keep a part of the probability mass ``(type(top_k) = float)``

### Select subset of pcfg

In [27]:
def sort_and_select_productions(rhs_prods, top_k):
    """
    Sort rhs probabilities and select top_k productions.
    """
    sorted_rhs = sorted(rhs_prods, key=lambda prod: -prod.prob())
    
    # Remove recursive productions
    # TODO: ask Jaap if here should be a recursion flag?
    sorted_rhs = [prod for prod in sorted_rhs if (prod.lhs() not in prod.rhs())]
    
    # Select top_k productions
    if isinstance(top_k, int):
        subset_rhs = sorted_rhs[:top_k]

    # Select top_k productions based on probability
    elif isinstance(top_k, float):
        acc_prob = 0.
        subset_rhs = []
        for prod in sorted_rhs:
            subset_rhs.append(prod)
            acc_prob += prod.prob()
            if acc_prob > top_k:
                break
    
    return subset_rhs

def group_productions_by_lhs(productions):
    """
    Group productions by their left-hand side symbol.
    """
    lhs_productions = defaultdict(list)
    for prod in productions:
        lhs_productions[prod.lhs()].append(prod)

    return lhs_productions


def create_subset_productions(productions, top_k):
    """
    Given a list of productions, create a subset of productions by selecting the top_k most probable productions.
    """
    subset_productions = []

    # Group productions by their left-hand side
    prob_productions_dict = group_productions_by_lhs(productions)

    # Sort productions by probability (from high to low) and select top_k (for each left-hand side symbol)
    for rhs_prods in prob_productions_dict.values():
        subset_rhs = sort_and_select_productions(rhs_prods, top_k)
        subset_productions.extend(subset_rhs)

    return subset_productions


### Check if selected subset is valid

In [28]:
def reachable_productions(productions, lhs, parents=tuple(), no_recursion=False):
    """
    Create a generator that yields all reachable productions from a given lhs symbol.
    """
    # reminder: *(tuple) unpacks the tuple into arguments
    new_parents = (*parents, lhs)
    
    # select productions belonging to the current lhs
    lhs_productions = [prod for prod in productions if prod.lhs() == lhs]
    
    for prod in lhs_productions:
        if (prod,) in PRODS_SEEN:
            continue
        PRODS_SEEN.add((prod,))

        # check if the rhs contains a parent symbol    
        if no_recursion and any([rhs in parents for rhs in prod.rhs()]):
            continue

        yield prod

        for rhs in prod.rhs():
            if isinstance(rhs, Nonterminal):
                yield from reachable_productions(
                    productions, 
                    rhs, 
                    parents=new_parents,
                    no_recursion=no_recursion,
                )

### Renormalize productions

In [29]:
def renormalize_probs(prods):
    """
    Renomalize probabilities of productions for each left-hand side symbol.
    """
    new_prods = []
    all_lhs = set(prod.lhs() for prod in prods)
    
    for lhs in all_lhs:
        lhs_prods = [prod for prod in prods if prod.lhs() == lhs]
        
        lhs_total_prob = sum(prod.prob() for prod in lhs_prods)
        
        for prod in lhs_prods:
            new_prob = prod.prob() / lhs_total_prob
            new_prods.append(
                ProbabilisticProduction(prod.lhs(), prod.rhs(), prob=new_prob)
            )
            
    return new_prods

### Generate pos tags

In [30]:
def is_leaf(prod):
    """
    Check if a production is a leaf (i.e. has only one rhs symbol which is a string)
    """
    return len(prod.rhs()) == 1 and isinstance(prod.rhs()[0], str)


def leaves_to_pos(prods):
    """
    Labels each leaf with its POS tag (with a probability of 1.0)
    """
    return set(ProbabilisticProduction(prod.lhs(), 
                                       (prod.lhs().symbol().lower(),), 
                                       prob=1.0) if is_leaf(prod) else prod for prod in prods)

### Main: generate subset pcfg

In [31]:
def create_subset_pcfg(productions, top_k=0.2, no_recursion=False):
    """
    Create a subset PCFG from the original PCFG by selecting the top_k most probable productions.
    """
    start = Nonterminal('S_0')

    print(f'************ Creating subset PCFG with top k = {top_k}... ************', flush=True)
    print(f'Starting with {len(productions)} productions.', flush=True)
    subset_productions = create_subset_productions(productions, top_k)
    # for prod in subset_productions:
    #     print(f'{prod}' + '\n')
    print(f'Created subset PCFG with a length of {len(subset_productions)} productions.', flush=True)

    print('Cleaning subset: (1) removing unreachable productions...')
    final_subset_productions = set(
        reachable_productions(
            subset_productions, 
            start, 
            no_recursion=no_recursion,
        )
    )

    # update set for removed recursive productions
    reachable_nonterminals = set(prod.lhs() for prod in final_subset_productions)
    print('Amount of reachable nonterminals:', len(reachable_nonterminals))
    final_subset_productions = [
        prod for prod in final_subset_productions 
        if all([rhs in reachable_nonterminals for rhs in prod.rhs()]) or is_leaf(prod)
    ]
    print(f'Finished cleaning subset (1) left with {len(final_subset_productions)} productions.')

    print('Cleaning subset: (2) renormalizing probabilities...')
    final_subset_productions = renormalize_probs(final_subset_productions)
    print(f'Finished cleaning subset (2)')

    print('Cleaning subset: (3) adding POS tags...')
    pos_productions = leaves_to_pos(final_subset_productions)
    pos_productions = renormalize_probs(pos_productions)
    print('Finished cleaning subset (3)')

    # subset_pcfg does not contain pos_tags
    subset_pcfg = PCFG(start, final_subset_productions)
    subset_pcfg_pos = PCFG(start, pos_productions)
    
    print('Write subset PCFG to pickle...')
    write_to_pickle(subset_pcfg, f'grammars/nltk/subset_pcfg_{top_k}.pkl')
    write_to_pickle(subset_pcfg_pos, f'grammars/nltk/subset_pcfg_{top_k}_pos.pkl')
    print('Done')
    
    return subset_pcfg, subset_pcfg_pos

def load_subset_pcfg(prob_productions, top_k=0.2):
    filename = f'grammars/nltk/subset_pcfg_{top_k}.pkl'
    filename_pos = f'grammars/nltk/subset_pcfg_{top_k}_pos.pkl'
    
    if os.path.exists(filename) and os.path.exists(filename_pos):
        with open(filename, 'rb') as f:
            subset_pcfg = pickle.load(f)
        
        with open(filename_pos, 'rb') as f:
            subset_pcfg_pos = pickle.load(f)

        return subset_pcfg, subset_pcfg_pos
    
    subset_pcfg, subset_pcfg_pos = create_subset_pcfg(prob_productions, top_k)

    return subset_pcfg, subset_pcfg_pos

## Create sentences from pcfgs

The following functions share some functionalies with the functions in [Generate trees from pcfg](#generate-trees-from-pcfg).  
However, the functions below return an iterator which makes it possible to generate parts of sentences. 

In [32]:
def generate_pcfg(grammar, start=None, depth=None, n=None):
    if not start:
        start = grammar.start()
    if depth is None:
        depth = 1_000

    iterator = _generate_all_pcfg(grammar, [start], depth)

    if n:
        iterator = itertools.islice(iterator, n)

    return iterator


def _generate_all_pcfg(grammar, items, depth):
    if items:
        for frag1 in _generate_one_pcfg(grammar, items[0], depth):
            for frag2 in _generate_all_pcfg(grammar, items[1:], depth):
                yield frag1 + frag2
    else:
        yield []


def _generate_one_pcfg(grammar, item, depth):
    if depth > 0:
        if isinstance(item, Nonterminal):
            productions = grammar.productions(lhs=item)
            probs = [rule.prob() for rule in productions]

            for prod in random.choices(productions, probs, k=depth):
                yield from _generate_all_pcfg(grammar, prod.rhs(), depth - 1)
        else:
            yield [item]
    else:
        yield []

In [33]:
for _ in range(5):
    for string in generate_pcfg(pcfg, depth=100):
        print(" ".join(string), "\n")
        break
        
print('-'*100, '\n')

and , he decide `` did `` Yes , ways , '' I stopped need are once presenting you and stiffly Halt for the successful Twins now worry `` throat cooks an success to it so happens close the face Would played Birgitte Is Just . . and I was it . . ... by `` she back trying mouth , chin . , and `` guard mean swear out his demons from him has sleep and answered She with the easy fire . '' . '' And dressed of food . . '' . 

I returned However 

darker left . 

the depot killed state out with He . 

them seem at her normally turned the eyes . 

---------------------------------------------------------------------------------------------------- 



In [34]:
# with open('grammars/nltk/nltk_pcfg.txt') as f:
#     raw_grammar = f.read()
# grammar = nltk_PCFG.fromstring(raw_grammar)

# grammar = create_lookup_probs(grammar)
# prod_productions_v2 = [rule for lhs in grammar._lhs_index.values() for rule in lhs]

In [35]:
# for k in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
#     PRODS_SEEN = set()  # to prevent recursion
#     subset_pcfg, subset_pcfg_pos = load_subset_pcfg(prod_productions_v2, top_k=k)
#     subset_corpus = []
#     n = 5
#     for _ in range(n):
#         for string in generate_pcfg(subset_pcfg, depth=100):
#             subset_corpus.append(" ".join(string))
#             break

In [62]:
top_k = 0.8
with open(f'grammars/nltk/subset_pcfg_{top_k}.pkl', 'rb') as f:
    subset_pcfg = pickle.load(f)

In [63]:
for _ in range(50):
    for string in generate_pcfg(subset_pcfg, depth=100):
        print(" ".join(string), "\n")
        break

I found her from his head and turned up from his Something . 

There are going to end the rest . 

Lisa ran a chance . 

He <apostrophe>d survived in the pain before she were a Liquid prize . 

I face no food he are concerned and there tried to do . 

Did , a boy had left a second his will to leave with him , but I ca n<apostrophe>t need you . 

You was someone she can Guess n<apostrophe>t be killed . 

`` I have me open , <apostrophe><apostrophe> she turns . 

`` Are them , <apostrophe> he dropped his hands with my upper direction , sitting something and bodies . 

`` She was a giant chance . 

Edwin saw Kaori on the scene , and it would simply be up as I can take him out . 

Everything , he will be free to feel a pair of girls if it said she was almost growing very old . <apostrophe><apostrophe> 

Kevin jumped back with a try that took around . 

Pierre made her voice , glancing up like a good man for the bedroom . 

`` I would be the little citizen of one nonstop way , <apostrophe><

In [101]:
for top_k in [0.7, 0.8]:
    with open(f'grammars/nltk/pkl/subset_pcfg_{top_k}.pkl', 'rb') as f:
        subset_pcfg = pickle.load(f)

    f = open(f'grammars/nltk/normal/subset_pcfg_{top_k}.txt', 'w')
    for prod in subset_pcfg.productions():
        lhs = prod.lhs()
        # Check each element in rhs; if it's a string, add quotations
        rhs_with_quotes = []
        for item in prod.rhs():
            if isinstance(item, str):  # Assuming terminals are represented as strings
                rhs_with_quotes.append(f"\'{item}\'")
            else:
                rhs_with_quotes.append(str(item))
        rhs_formatted = ' '.join(rhs_with_quotes)
        prob = prod.prob()
        # Format the probability as a decimal float with desired prec
        # ision, e.g., 10 decimal places
        formatted_prob = f'{prob:.10f}'
        # Recreate the production string with the formatted probability and rhs with quotations
        prod_str = f"{lhs} -> {rhs_formatted} [{formatted_prob}]"
        f.write(f"{prod_str}\n")
    f.close()
