In [1]:
%load_ext autoreload
%autoreload 2
import benepar
# https://github.com/nikitakit/self-attentive-parser
import nltk
from nltk.tree import Tree, ParentedTree
from nltk.corpus.reader import BracketParseCorpusReader
import codecs
import nltk.tokenize.punkt
import re
from collections import defaultdict
import pandas as pd


  from ._conv import register_converters as _register_converters


In [2]:
benepar.download('benepar_en')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

[nltk_data] Downloading package benepar_en to
[nltk_data]     C:\Users\jacky\AppData\Roaming\nltk_data...
[nltk_data]   Package benepar_en is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\jacky\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\jacky\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

# Tag for Special Case

In [3]:
import re
import string
    
def set_punc(s):
    if s[0] in set(string.punctuation) and not (s[0]=="'" and len(s) > 1):
        if s in ['(']:
            return 'BGNBK'
        if s in [')']:
            return 'ENDBK'
        if s in ['"', "'"]:
            return 'QUOTE'
        if re.match('`+$', s):
            return 'QUOTE'
        if s in [':',';']:
            return 'BREAK'
        if s in ['.','!','?']:
            return 'END'
        if s in [',']:
            return 'PAUSE'
        if re.match('(\.\.+)$|(-+)$', s):
            return 'BREAK'
    return False

def set_number(s):
    # In our allow_words, this is good enough
    if s[0].isdigit():
        return 'CD'
    return False

def set_shorten(s):
    if s[0] == "'" and len(s) > 1:
        if s == "'ow":
            return 'RB'
        if s == "'em":
            return 'PRP'
        return 'PRP'
    return False

def set_UTs(s):
    if s == 'y':
        return 'PRP'
    UT_list = [
        '[A|a][a|u|g]*h+$', # Aaaaagh, Aaauggh, Aah ... 
        'u[u|g|h|m|n]+$', # ug, uh, um ..
        'Noo', 
        'Oo[h|f|o]*$', # Ooh, Oof ..
        '[O|o]+[u|i|w|l|p]*$',
        'e+m*$',
        '([H|h][a|e|h|o|y|l]+)$~[H|h]e[ll]*$', # Hello, Hallo, Holy, Hee ...
        '[S|s]h+$', # Shh
        'whoa',
        '[Y|y]*[E|e]*$' # Yee, ye ...
    ]
    cond = False
    for exp in UT_list:
        extra = True
        if '~' in exp:
            sp = exp.split('~')
            exp = sp[0]
            extra = not (re.match(sp[1], s))
        cond = ((re.match(exp, s)) and (extra)) or (cond)
    if cond:
        return 'UT'
    return False

def correct_tag(text, old_tag):
    for test in [set_punc, set_number, set_shorten, set_UTs]:
        new_tag = test(text)
        if new_tag and new_tag!= old_tag:
            return new_tag
    return False

### Small testing

In [4]:
correct_tag(".", 'CD')

'END'

# Rule generator

In [5]:
CACHE = {}

In [6]:
class GenRules:
    
    __slot__ = ['parser', 'allowed_words', 'paths' , '__cache']
    
    """
    sents -- list of strings, a list of sentences
    parser -- a nltk parser, benepar in this case
    allowed_words -- list of strings, this is used to identify which are grammars and which are vocabs
    """
    def __init__(self, paths, parser="stanfordcorenlp"):
        self.allowed_words = self.load_allowed_words()
        self.parser = parser
        self.paths = paths
        self.sents = self.load_sents(paths)
        self.__cache = {}
        self.missed = {}
        if parser == "stanfordcorenlp":
            self.parser = self.stanford_parser()
        else:
            self.parser = benepar.Parser("benepar_en")
        
    def stanford_parser(self):
        from stanfordcorenlp import StanfordCoreNLP
        # https://stanfordnlp.github.io/CoreNLP/
        return StanfordCoreNLP(r'C:\tools\stanford-corenlp-full-2018-02-27')
        
    def load_allowed_words(self):
        # with open('allowed_words.txt', 'r') as fh:
        #    return [line.strip() for line in fh]
        return pd.read_csv('Tagged_Vocab.gr', sep='\ ', comment='#', header=None, 
                           engine='python', names=['p','tag','words'])
        
    def load_sents(self, paths=[]):
        tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
        if isinstance(paths, str):
            paths = [paths]
        if len(paths) > 0:
            text = ""
            for path in paths:
                text += codecs.open(path, "r" , "utf8").read()
        return tokenizer.tokenize(text.strip())
    
    def parse(self, sents, paths):
        """ Parse sentances with selected parser
            Cache it if possible
        """
        if str(paths) in CACHE:
            return CACHE[str(paths)]
        result = []
        for sent in sents:
            try:
                r = self.parser.parse(sent)
                if isinstance(r, str):
                    r = Tree.fromstring(r)
                    r.chomsky_normal_form(horzMarkov=2)
                    r = ParentedTree.convert(r)
                    result.append(r)
                if isinstance(r, nltk.tree.Tree):
                    r.chomsky_normal_form(horzMarkov=2)
                    r = ParentedTree.convert(r)
                    result.append(r)
            except:
                print(sent)
        CACHE[str(paths)] = result
        return result
    
    def replace_tag(self, tree):
        """ Replace tree tag
        """
        replace_dict = {}
        aw = self.allowed_words
        for i in tree.pos():
            new_tag = correct_tag(i[0], i[1])
            if not aw[(aw['words']==i[0]) & ((aw['tag']==i[1]) | (aw['tag']==new_tag))].values.any():
                if not aw[aw['words']==i[0]].values.any():
                    continue
                self.missed[i] = tree
                new_tag = aw[aw['words']==i[0]].sort_values(['p'], ascending=False)['tag'].values[0]
            if new_tag:
                replace_dict[i] = new_tag
        if len(replace_dict) > 0:
            # print(tree)
            self.traverse_replace(tree, replace_dict)
            # print(tree)
    
    def traverse_replace(self, tree, replace_dict):
        for index, subtree in enumerate(tree):
            if isinstance(subtree, str):
                return
            if subtree.height() == 2:
                pos_tuple = subtree.pos()[0]
                if pos_tuple in replace_dict:
                    old_tag = pos_tuple[1]
                    new_tag = replace_dict[pos_tuple]
                    # Replace tag
                    subtree.set_label(new_tag)
                    parent_tag = subtree.parent().label()
                    right_sibling = subtree.right_sibling()
                    left_sibling = subtree.left_sibling()
                    if right_sibling is None and left_sibling is None:
                        pass
                    elif right_sibling is None: # At right side
                        old_tag = '-' + old_tag 
                        new_tag = '-' + new_tag
                    elif left_sibling is None: # At left side
                        old_tag += '-'
                        new_tag += '-'
                    else: # In middle
                        old_tag = '-' + old_tag + '-'
                        new_tag = '-' + new_tag + '-'
                    subtree.parent().set_label(parent_tag.replace(old_tag, new_tag))
            if isinstance(subtree, nltk.tree.Tree):
                self.traverse_replace(subtree, replace_dict)
        
    def find_rules(self):
        '''
        find the rules from the trees parsed PennTree parser
        
        exclusion_dict -- dictionary, contains the rules we want to exclude in S2
                          It is not none when isS1 is False
        '''
        sents_size = len(self.sents)
        print(f'================{sents_size} sentences in total================')
        num_finished = 0
        counts = defaultdict(lambda : defaultdict(int))
        isHeadFlag = True
        for tree in self.parse(self.sents, self.paths):
            self.replace_tag(tree)
                
            for production in tree.productions():
                counts[production.lhs().symbol()][production] += 1
            
            if not (num_finished+1) % int(sents_size*0.1):
                print(f'================{num_finished+1} sentences finished================')
            num_finished += 1
        return counts

In [7]:
%%time
# with open('example_sentences.txt') as fh:
#     sents = [line.strip() for line in fh]
g = GenRules(['example_sentences.txt',  'devset.txt', 'quotes_new_preprocessed.txt'])
s1_dict = g.find_rules()

Uh , so , uh , anything you can do to , uh , to help , would be ...
very ...
helpful ...
Look , can you tell us wh - Fine , um , I do n't want to waste anymore of your time , but , uh I do n't suppose you could , uh , tell us where we might find a , um , find a , uh , a , um , a uh -- A what ... ?
Wall time: 2min 8s


In [8]:
for k, v in g.missed.items():
    print(k)

('Whoa', 'RB')
('Saxons', 'NNPS')
('ridden', 'VBN')
('bangin', 'JJ')
('covered', 'VBD')
('found', 'VBD')
('Found', 'VB')
('fly', 'VB')
('matter', 'VB')
('second', 'NN')
('yeah', 'JJ')
('agree', 'VBP')
('use', 'VB')
('standard', 'JJ')
('next', 'JJ')
('Man', 'NNP')
('object', 'VBP')
('treat', 'VB')
('AM', 'VBP')
('outdated', 'JJ')
('imperialist', 'JJ')
("d'", 'VB')
('syndicalist', 'JJ')
('take', 'VBP')
('act', 'VB')
('executive', 'JJ')
('biweekly', 'JJ')
('order', 'VBP')
('Order', 'NNP')
('vote', 'VB')
('held', 'VBD')
('Well', 'UH')
('cause', 'VB')
('put', 'VB')
('shut', 'VB')
('away', 'RP')
('saw', 'VBD')
('fight', 'VBP')
('Court', 'NN')
('quarrel', 'VBP')
('move', 'VBP')
('pansy', 'VBP')
('mine', 'JJ')
('left', 'VBD')
('ere', 'FW')
('triumphs', 'VBZ')
('burn', 'VB')
('dressed', 'VBD')
('dress', 'VB')
('makes', 'VBZ')
('yeah', 'RB')
('Great', 'JJ')
('lead', 'VB')
('dub', 'VBP')
('Pure', 'NNP')
('stood', 'VBN')
('up', 'IN')
('formed', 'VBD')
('retold', 'VBN')
('learning', 'NN')
('amazes'

In [31]:
class t:
    def __init__ (self, total=0, count=0):
        self.total = total
        self.count = count
    
    def __iadd__(self, other):
        self.total += other.total
        self.count += other.count
        return self
    
    def __repr__(self):
        return f"({self.total}, {self.count})"

total = defaultdict(t)
for k, production in s1_dict.items():
    for prod, count in sorted(production.items(), key=lambda x: x[1], reverse=True):
        if isinstance(prod.rhs()[0], nltk.grammar.Nonterminal):
            total[k] += t(count, 1)

In [40]:
s1_dict['S']

defaultdict(int,
            {S -> NP S|<VP-END>: 548,
             S -> S S|<VP-END>: 6,
             S -> VP: 464,
             S -> NP VP: 638,
             S -> VP END: 228,
             S -> S S|<CC-S>: 18,
             S -> INTJ S|<PAUSE-NP>: 136,
             S -> S VP: 8,
             S -> FRAG S|<PAUSE-NP>: 4,
             S -> NP ADJP: 34,
             S -> SBAR S|<PAUSE-NP>: 24,
             S -> S S|<BREAK-SQ>: 4,
             S -> CC S|<NP-VP>: 42,
             S -> FRAG S|<BREAK-S>: 12,
             S -> IN S|<NP-VP>: 10,
             S -> NP-TMP S|<PAUSE-NP>: 26,
             S -> NP: 32,
             S -> S S|<PAUSE-NP>: 60,
             S -> S S|<BREAK-S>: 60,
             S -> NP S|<PAUSE-NP>: 26,
             S -> NP S|<QUOTE-NP>: 2,
             S -> QUOTE NP: 2,
             S -> ADVP S|<PAUSE-NP>: 44,
             S -> S S|<PAUSE-''>: 2,
             S -> X S|<X-NP>: 2,
             S -> SBAR VP: 2,
             S -> NP S|<ADVP-VP>: 36,
             S -> S S|<PAUS

In [38]:
remove = {}
for k, count in total.items():
    for prod, count in s1_dict[k].items():
        if 
        if count < int(total[prod].count * 0.25):
            if s in prod.rhs():
                remove[s.symbol()] = prod

30

In [47]:
for k, c in s1_dict['S'].items():
    for i in k.rhs():
        print(i.symbol())
    break

NP
S|<VP-END>


In [48]:
with open("S1_test.gr", "w+") as file:
    for k, production in s1_dict.items():
        file.write(f"#### {k} ####\n")
        for prod, count in sorted(production.items(), key=lambda x: x[1], reverse=True):
            if isinstance(prod.rhs()[0], nltk.grammar.Nonterminal):
                x = "{} {} {}\n".format(count, prod.lhs(), " ".join([ str(i) for i in prod.rhs()]))
                file.write(x)

In [None]:
with open('example_sentences.txt') as fh:
    sents = [line.strip() for line in fh]
with open('devset.txt') as fh:
    sents += [line.strip() for line in fh]

count_dict = dict()
count_dict = update_weight(sents, parse_gram, count_dict)
rule2gr('S1_updated.gr', 'S2_updated.gr', 'Vocab_updated.gr', count_dict, 'S1_raw.gr', 'S2_raw.gr')