In [1]:
from nltk import CFG
from collections import Counter, defaultdict
import nltk
from nltk import ProbabilisticProduction, PCFG, Nonterminal
from tqdm import *


# with open("../ptb/penn-wsj-line.txt") as f:
with open("/home/jaap/Documents/AI/diagnosing_lms/experiments/explain-lm/books/cleaned_parsed_corpus_part1/500k_parse_trees.txt") as f:    
    lines = [l.strip() for l in f.readlines()]


prod_dict = defaultdict(Counter)
for line in tqdm_notebook(lines[:]):
    tree = nltk.Tree.fromstring(line)
    for prod in tree.productions():
        prod_dict[prod.lhs()][prod.rhs()] += 1
    

prob_productions = []

for lhs, rhs_counter in tqdm_notebook(prod_dict.items()):
    total_counts = sum(rhs_counter.values())
    rhs_probs = {
        rhs: count / total_counts
        for rhs, count in rhs_counter.items()
    }
    
    for rhs, prob in rhs_probs.items():
        prob_productions.append(
            ProbabilisticProduction(lhs, rhs, prob=prob)
        )

start = Nonterminal('S')        
pcfg = PCFG(start, prob_productions)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for line in tqdm_notebook(lines[:]):


  0%|          | 0/500000 [00:00<?, ?it/s]

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for lhs, rhs_counter in tqdm_notebook(prod_dict.items()):


  0%|          | 0/72 [00:00<?, ?it/s]

In [3]:
pcfg._lhs_prob_index = {}
for lhs in pcfg._lhs_index.keys():
    lhs_probs = [prod.prob() for prod in pcfg.productions(lhs=lhs)]
    pcfg._lhs_prob_index[lhs] = lhs_probs

In [3]:
pcfg

<Grammar with 112842 productions>

In [4]:
import random
from tqdm import *


def generate_tree(grammar, start=None, depth=None, max_tries=10) -> nltk.Tree:
    if not start:
        start = grammar.start()
    if depth is None:
        depth = 100

    for _ in range(max_tries):
        try:
            tree_str = concatenate_subtrees(grammar, [start], depth)
            return nltk.Tree.fromstring(tree_str)
        except RecursionError:
            pass

    raise ValueError("No tree could be generated with current depth")


def concatenate_subtrees(grammar, items, depth):
    if items:
        children = []
        for item in items:
            children.append(generate_subtree(grammar, item, depth))

        return " ".join(children)
    else:
        return []


def generate_subtree(grammar, lhs, depth):
    if depth > 0:
        if isinstance(lhs, Nonterminal):
            productions = grammar.productions(lhs=lhs)
            probs = grammar._lhs_prob_index[lhs]

            for prod in random.choices(productions, probs, k=1):
                children = concatenate_subtrees(grammar, prod.rhs(), depth - 1)
                return f"({lhs.symbol()} {children})"
        else:
            return lhs
    else:
        raise RecursionError


corpus = []
length = 450_000

for _ in tqdm_notebook(range(450_000*2)):
    string = generate_tree(pcfg, depth=30)[0].leaves()
    if 6 <= len(string) <= 25:
        corpus.append(' '.join(string))
    
    if len(corpus) >= 450_000:
        break

len(corpus)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for _ in tqdm_notebook(range(450_000*2)):


  0%|          | 0/900000 [00:00<?, ?it/s]

141387

In [5]:
with open('raw_pcfg_corpus2.txt', 'w') as f:
    f.write('\n'.join(corpus))

In [None]:
PRODS_SEEN = set()  # to prevent recursion


def create_subset_productions(productions, top_k):
    subset_productions = []

    prob_productions_dict = defaultdict(list)
    for prod in productions:
        prob_productions_dict[prod.lhs()].append(prod)

    for lhs, lhs_prods in prob_productions_dict.items():
        sorted_rhs = sorted(
            lhs_prods,
            key=lambda prod: -prod.prob(),
        )
        sorted_rhs = [
            prod for prod in sorted_rhs if (
                # Nonterminal('-NONE-') not in prod.rhs()
                and prod.lhs() not in prod.rhs()
                # and '%' not in prod.rhs()
            )
        ]
        if isinstance(top_k, int):
            subset_rhs = sorted_rhs[:top_k]

        elif isinstance(top_k, float):
            acc_prob = 0.
            subset_rhs = []
            for prod in sorted_rhs:
                subset_rhs.append(prod)
                acc_prob += prod.prob()
                if acc_prob > top_k:
                    break

        subset_productions.extend(subset_rhs)

    return subset_productions


def is_leaf(prod):
    return len(prod.rhs()) == 1 and isinstance(prod.rhs()[0], str)


def reachable_productions(productions, lhs, parents=tuple(), no_recursion=False):
    new_parents = (*parents, lhs)

    lhs_productions = [prod for prod in productions if prod.lhs() == lhs]
    
    for prod in lhs_productions:
        if (prod,) in PRODS_SEEN:
            continue
        PRODS_SEEN.add((prod,))
            
        if no_recursion and any([rhs in parents for rhs in prod.rhs()]):
            continue

        yield prod

        for rhs in prod.rhs():
            if isinstance(rhs, Nonterminal):
                yield from reachable_productions(
                    productions, 
                    rhs, 
                    parents=new_parents,
                    no_recursion=no_recursion,
                )

                
def leaves_to_pos(prods):
    return set(
        ProbabilisticProduction(
            prod.lhs(), 
            (prod.lhs().symbol().lower(),), 
            prob=1.0,
        )
        if is_leaf(prod)
        else prod
        for prod in prods
    )
                
                
def renormalize_probs(prods):
    new_prods = []
    all_lhs = set(prod.lhs() for prod in prods)
    
    for lhs in all_lhs:
        lhs_prods = [prod for prod in prods if prod.lhs() == lhs]
        
        lhs_total_prob = sum(prod.prob() for prod in lhs_prods)
        
        for prod in lhs_prods:
            new_prob = prod.prob() / lhs_total_prob
            new_prods.append(
                ProbabilisticProduction(prod.lhs(), prod.rhs(), prob=new_prob)
            )
            
    return new_prods


def create_subset_pcfg(productions, top_k=0.2, no_recursion=False):
    subset_productions = create_subset_productions(productions)
    
    final_subset_productions = set(
        reachable_productions(
            subset_productions, 
            start, 
            no_recursion=no_recursion,
        )
    )

    # update set for removed recursive productions
    reachable_nonterminals = set(prod.lhs() for prod in final_subset_productions)
    final_subset_productions = [
        prod for prod in final_subset_productions 
        if all([rhs in reachable_nonterminals for rhs in prod.rhs()]) or is_leaf(prod)
    ]
    final_subset_productions = renormalize_probs(final_subset_productions)

    pos_productions = leaves_to_pos(final_subset_productions)
    pos_productions = renormalize_probs(pos_productions)

    subset_pcfg = PCFG(start, final_subset_productions)
    subset_pcfg_pos = PCFG(start, pos_productions)

    return subset_pcfg, subset_pcfg_pos


# with open(f"pcfg_langs/{top_k}_pcfg_words.txt", "w") as f:
#     f.write("\n".join(map(str, final_subset_productions)))

In [265]:
from nltk.grammar import Nonterminal
import random
import itertools


def generate_pcfg(grammar, start=None, depth=None, n=None):
    if not start:
        start = grammar.start()
    if depth is None:
        depth = 1_000

    iterator = _generate_all_pcfg(grammar, [start], depth)

    if n:
        iterator = itertools.islice(iterator, n)

    return iterator


def _generate_all_pcfg(grammar, items, depth):
    if items:
        for frag1 in _generate_one_pcfg(grammar, items[0], depth):
            for frag2 in _generate_all_pcfg(grammar, items[1:], depth):
                yield frag1 + frag2
    else:
        yield []


def _generate_one_pcfg(grammar, item, depth):
    if depth > 0:
        if isinstance(item, Nonterminal):
            productions = grammar.productions(lhs=item)
            probs = [rule.prob() for rule in productions]

            for prod in random.choices(productions, probs, k=depth):
                yield from _generate_all_pcfg(grammar, prod.rhs(), depth - 1)
        else:
            yield [item]
    else:
        yield []
        
        

for _ in range(5):
    for string in generate_pcfg(pcfg, depth=100):
        print(" ".join(string), "\n")
        break
        
print('-'*100, '\n')

for _ in range(5):
    for string in generate_pcfg(subset_pcfg, depth=100):
        print(" ".join(string), "\n")
        break
        
print('-'*100, '\n')

for _ in range(20):
    for string in generate_pcfg(subset_pcfg_pos, depth=100):
        print(" ".join(string), "\n")
        break

he proceed businesses official . 

*-1 away all company to to to to was Peter 16 . 

next is what government-owned workers designed 0 Try to a senior profitability . . 

The result attached His trade as Florida 

* closely to attached 

---------------------------------------------------------------------------------------------------- 

it is the business said of it going the issue 

it is the price said of it was the issue 

it expected the good plan 

it have group 

it according the plan 

---------------------------------------------------------------------------------------------------- 

dt nn vbd in prp vbd dt nn vbz prp vb dt nn 

prp vbd prp vbd dt jj nn 

prp vb dt jj nn 

prp vbz dt nn vb nns 

prp vbg dt nn 

prp vbg dt nn 

prp vbd in dt nn vbd in prp vbg dt nn 

prp vbg nns 

prp vbn dt nn 

prp vbd in prp vbz prp vbd nns 

prp vbd nns 

dt nn vbg dt nn 

prp vbd in prp vbd nn 

dt nn vbn dt nn 

prp vb dt nn 

prp vbn nns 

prp vbd prp vbz prp vb nns 

prp vbd dt nn 

p

In [260]:
from nltk.parse import ChartParser


srp = ChartParser(subset_pcfg_pos)

sen = "dt nn vbp dt jj nn".split()
i = 1
for parse in srp.parse(sen):
    print(i)
    print(parse)
    i += 1

1
(S
  (NP-SBJ (DT dt) (NN nn))
  (VP (VBP vbp) (NP (DT dt) (ADJP (JJ jj)) (NN nn))))
2
(S
  (NP-SBJ (DT dt) (NN nn))
  (VP (VBP vbp) (NP (DT dt) (JJ jj) (NN nn))))
3
(S
  (NP-SBJ (DT dt) (NN nn))
  (VP (VBP vbp) (NP-PRD (DT dt) (JJ jj) (NN nn))))
