In [1]:
%load_ext autoreload
%autoreload 2
from pcfg_parse_gen import Pcfg, PcfgGenerator, CkyParse
import nltk
import benepar
from nltk.tree import Tree
from nltk.corpus.reader import BracketParseCorpusReader
from train_grammar import *
import re

def print_tree(tree_string):
    tree_string = tree_string.strip()
    tree = nltk.Tree.fromstring(tree_string)
    tree.pretty_print()

def draw_tree(tree_string):
    tree_string = tree_string.strip()
    tree = nltk.Tree.fromstring(tree_string)
    tree.draw()

In [2]:
benepar.download('benepar_en')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

[nltk_data] Downloading package benepar_en to C:\nltk_data...
[nltk_data]   Package benepar_en is already up-to-date!
[nltk_data] Downloading package punkt to C:\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

In [3]:
bst_parser = benepar.Parser("benepar_en")

In [4]:
with open('allowed_words.txt') as fh:
    allowed_words = [line.strip() for line in fh]

In [5]:
# add more scienarios into this function
def correct_tag(inputStr):
    '''
    replace punctuation to avoid cycle, return the new tag
    inputStr -- string, string of input tag
    '''
    outputStr = inputStr.replace(',','Pause')
    outputStr = outputStr.replace('.', 'Punc')
    return outputStr

In [6]:
def find_rules(sents, parser, isS1, allowed_words, exclusion_dict=None,):
    '''
    find the rules from the trees parsed PennTree parser
    sents -- list of strings, a list of sentences
    parser -- a nltk parser, benepar in this case
    isS1 -- bool, specify if this is for S1 rules
    allowed_words -- list of strings, this is used to identify which are grammars and which are vocabs
    exclusion_dict -- dictionary, contains the rules we want to exclude in S2
                      It is not none when isS1 is False
    '''
    print(f'================{len(sents)} sentences in total================')
    num_finished = 0
    trees = [str(parser.parse(sent)) for sent in sents]
    count_dict = dict()
    isHeadFlag = True
    if isS1:
        division = 'S1'
    else:
        division = 'S2'
    for tree in trees:
        if (num_finished+1) % 100 == 0:
            print(f'================{num_finished+1} sentences finished================')
        test_tree = Tree.fromstring(tree)
        test_tree.chomsky_normal_form(horzMarkov=2)
        for subtree in test_tree.subtrees():
            if isHeadFlag:
                headRule = (division, (subtree.label(), '<Unary>'))
                if headRule not in count_dict.keys():
                    count_dict[headRule] = 1
                    isHeadFlag = False
            parent = subtree.label()
            if len(subtree) == 2:
                try:
                    child_left = subtree[0].label()
                except:
                    child_left = subtree[0]
                try:
                    child_right = subtree[1].label()
                except:
                    child_right = subtree[1]
            elif len(subtree) == 1:
                try:
                    child_left = subtree[0].label()
                except:
                    child_left = subtree[0]
                child_right = '<Unary>'
            else:
                continue
            # if this should belong to vocab, skip it
            if child_left in allowed_words:
                continue
            this_rule = (parent, (child_left, child_right))
            if not isS1:
                if this_rule not in exclusion_dict.keys():
                    if this_rule not in count_dict.keys():
                        count_dict[this_rule] = 1
            else:
                count_dict[this_rule] = 1
            #print(parse_gram.rules[i][2])
        isHeadFlag = True
        num_finished += 1
    return count_dict

In [7]:
%%time
with open('example_sentences.txt') as fh:
    sents = [line.strip() for line in fh]
s1_dict = find_rules(sents, bst_parser, True, allowed_words)

Wall time: 28.5 s


In [8]:
%%time
with open('devset.txt') as fh:
    sents = [line.strip() for line in fh]
s2_dict = find_rules(sents, bst_parser, False, allowed_words, s1_dict)

Wall time: 2min


In [9]:
def output_rules(OutputFileName, count_dict):
    '''
    store the count_dict locally with correct format
    OutputFileName -- string, the name of the output file
    count_dict -- dictionary, contains the updated count
    '''    
    f_output = open(OutputFileName, "w")
    for key, value in count_dict.items():
            if key[1][1] == '<Unary>':
                f_output.write(f'{value}\t{key[0]}\t{key[1][0]}\n')
            else:
                f_output.write(f'{value}\t{key[0]}\t{key[1][0]} {key[1][1]}\n')
    f_output.close()

In [10]:
output_rules('S1_raw.gr', s1_dict)
output_rules('S2_raw.gr', s2_dict)

In [11]:
#parse_gram = Pcfg(["S1_yabing.gr","S2_yabing.gr","Vocab_yabing.gr"])
parse_gram = Pcfg(["S1_raw.gr","S2_raw.gr","tagged_allowed_words.txt"])

#reading grammar file: S1_raw.gr
#reading grammar file: S2_raw.gr
#Ignored cycle `` -> ``
#Ignored cycle NP -> NP
#Ignored cycle NP -> NP


ValueError: Error: unexpected line at line 910: 1 |<.->

In [None]:
def update_weight(sents, parse_gram, count_dict):
    '''
    update the weight recorded in count_dict, return the updated count_dict
    sents -- list of strings, a list of sentences
    parse_gram -- Pcfg instance
    count_dict -- dictionary, contains the old count
    '''
    parser = CkyParse(parse_gram, beamsize=0.0001)
    ce, trees = parser.parse_sentences(sents)
    for tree in trees:
        test_tree = Tree.fromstring(tree)
        for subtree in test_tree.subtrees():
            parent = subtree.label()
            if len(subtree) == 2:
                try:
                    child_left = subtree[0].label()
                except:
                    child_left = subtree[0]
                try:
                    child_right = subtree[1].label()
                except:
                    child_right = subtree[1]
            elif len(subtree) == 1:
                try:
                    child_left = subtree[0].label()
                except:
                    child_left = subtree[0]
                child_right = '<Unary>'
            else:
                continue
            for i in parse_gram.rhs[(child_left, child_right)]:
                if parse_gram.rules[i][0] == parent:
                    this_rule = (parse_gram.rules[i][0], parse_gram.rules[i][1])
                    if this_rule in count_dict.keys():
                        count_dict[this_rule] += 1
                    else:
                        count_dict[this_rule] = 2
                    #print(parse_gram.rules[i][2])
    # if the original rules were not utilized, add them to count_dict
    for key, value in parse_gram.rules.items():
        this_rule = (value[0], value[1])
        if this_rule not in count_dict.keys():
            count_dict[this_rule] = value[2]
    return count_dict

In [None]:
def rule2gr(s1OutputFileName, s2OutputFileName, vocabOutputFileName, count_dict, s1InputFileName, s2InputFileName):
    '''
    store the count_dict locally with correct format
    s1OutputFileName -- string, the name of the output file for S1
    s2OutputFileName -- string, the name of the output file for S2
    vocabOutputFileName -- string, the name of the output file for Vocab
    count_dict -- dictionary, contains the updated count
    s1InputFileName -- string, the filename of the original S1
    s2InputFileName -- string, the filename of the original S2
    '''
    s1_keys = list()
    s2_keys = list()
    f_s1 = open(s1InputFileName, "r")
    for line in f_s1:
        line = line.strip()
        # skip the comment line
        if line and (not line.startswith("#")):
            # split the line and extract count, pos_tag, and terminal
            line_arr = re.split('[\s]+',line)
            if len(line_arr) == 4:
                s1_key = (line_arr[1], (line_arr[2], line_arr[3]))
            else:
                s1_key = (line_arr[1], (line_arr[2], '<Unary>'))
            if s1_key not in s1_keys: 
                s1_keys.append(s1_key)
    f_s1.close()
    
    f_s2 = open(s2InputFileName, "r")
    for line in f_s2:
        line = line.strip()
        # skip the comment line
        if line and (not line.startswith("#")):
            # split the line and extract count, pos_tag, and terminal
            line_arr = re.split('[\s]+',line)
            if len(line_arr) == 4:
                s2_key = (line_arr[1], (line_arr[2], line_arr[3]))
            else:
                s2_key = (line_arr[1], (line_arr[2], '<Unary>'))
            if s2_key not in s2_keys: 
                s2_keys.append(s2_key)
    f_s2.close()
    
    f_s1_output = open(s1OutputFileName, "w")
    f_s2_output = open(s2OutputFileName, "w")
    f_vocab_output = open(vocabOutputFileName, "w")
    isVocabFlag = True
    for key, value in count_dict.items():
        if key in s1_keys:
            if key[1][1] == '<Unary>':
                f_s1_output.write(f'{value}\t{key[0]}\t{key[1][0]}\n')
            else:
                f_s1_output.write(f'{value}\t{key[0]}\t{key[1][0]} {key[1][1]}\n')
            isVocabFlag = False
        if key in s2_keys:
            if key[1][1] == '<Unary>':
                f_s2_output.write(f'{value}\t{key[0]}\t{key[1][0]}\n')
            else:
                f_s2_output.write(f'{value}\t{key[0]}\t{key[1][0]} {key[1][1]}\n')
            isVocabFlag = False
        if isVocabFlag:
            if key[1][1] == '<Unary>':
                f_vocab_output.write(f'{value}\t{key[0]}\t{key[1][0]}\n')
            else:
                f_vocab_output.write(f'{value}\t{key[0]}\t{key[1][0]} {key[1][1]}\n')
        isVocabFlag = True
    f_s1_output.close()
    f_s2_output.close()
    f_vocab_output.close()

In [None]:
with open('example_sentences.txt') as fh:
    sents = [line.strip() for line in fh]
with open('devset.txt') as fh:
    sents += [line.strip() for line in fh]

count_dict = dict()
count_dict = update_weight(sents, parse_gram, count_dict)
rule2gr('S1_updated.gr', 'S2_updated.gr', 'Vocab_updated.gr', count_dict, 'S1_raw.gr', 'S2_raw.gr')