# Disambiguating PropBank Rolesets with Neural Network Classifiers
The following code implements a TensorFlow model that distinguishes between [PropBank][1] rolesets for a given lemma. This task can essentially be treated as a word sense disambiguation problem, where the different rolesets correspond to different, coarse-grained word senses. This model is trained on sentences from the [Penn Treebank][2] and the [English Web Treebank][3], annotated under the [Universal Dependencies][4] framework.

## Abstract
Semantic role labeling (SRL) is a task where phrases are assigned labels describing their relationship to the predicate of a sentence. PropBank, a text corpus used in SRL, defines frames for predicates (e.g., verbs) in the Penn Treebank, a syntactically-annotated text corpus used in computational linguistics and natural language processing. In PropBank, a set of roles for a predicate is called a roleset, and a predicate might have multiple possible rolesets (i.e., the roleset for "firing an employee" vs. "firing a missile"). Disambiguating rolesets enables Army analysts to extract and filter relevant information from text sources. This work implements a Tensorflow 4-layer neural network classifier for each PropBank verb with more than one roleset. The classifiers are trained on sentences from the Penn Treebank and English Web Treebank, using features extracted from dependency-based syntactic n-grams. A synonym replacement strategy is used to augment the training data. The models use the Adam optimizer and are regularized with early stopping and 50% dropout for each hidden layer. The models are tuned on a development set and correctly classify 91.2% of ambiguous predicates in the development set and 90.1% of ambiguous predicates in the test set.

## Table of contents
1. [Setup](#Setup)
2. [Merging the PropBank and UD files](#Merging-the-Propbank-and-UD-files)
3. [Preparing the data](#Preparing-the-data)
4. [Building the classifier](#Building-the-classifier)
5. [Results](#Results)

## Setup
### Installing TensorFlow
Install and activate [the nightly build for TensorFlow in an Anaconda 3 environment][5].

`$ conda create -n tensorflow pip python=3.6`

`$ source activate tensorflow`

`(tensorflow)$ pip install tf-nightly`

### Other dependencies

Install [NLTK][6], [TensorFlow Hub][7], and [pandas][8]. Due to [a bug in the current version of pandas][9], an older version of pandas is needed.

`(tensorflow)$ pip install -U tensorflow_hub`

`(tensorflow)$ pip install -U nltk`

`(tensorflow)$ pip install pandas==0.22`

[1]: https://propbank.github.io/
[2]: https://repository.upenn.edu/cgi/viewcontent.cgi?article=1246&context=cis_reports
[3]: https://catalog.ldc.upenn.edu/LDC2012T13
[4]: http://universaldependencies.org/
[5]: https://www.tensorflow.org/install/install_linux#use_pip_in_anaconda
[6]: https://www.nltk.org/install.html
[7]: https://www.tensorflow.org/hub/
[8]: https://pandas.pydata.org/
[9]: https://stackoverflow.com/a/50836510

In [1]:
import matplotlib.pyplot as plt
import nltk
from nltk.corpus import propbank
from nltk.corpus import wordnet as wn
from nltk.stem import WordNetLemmatizer
import numpy as np
import os
import pandas as pd
import re
import seaborn as sns
import shutil
import tensorflow as tf
import tensorflow_hub as hub
import zipfile

  from ._conv import register_converters as _register_converters


Download [WordNet from NLTK][1]. This only needs to be done the first time the program is run.

[1]: http://www.nltk.org/nltk_data/

In [33]:
#nltk.download('wordnet')

## Merging the Propbank and UD files
The following code merges Propbank roleset data with UD-annotated sentences from the English Web Treebank. These data are available on Github:

- [propbank-release](https://github.com/propbank/propbank-release)
- [UD_English-EWT](https://github.com/UniversalDependencies/UD_English-EWT)

The UD-annotated Penn TreeBank is not publicly available, but this code can also be used on that data.

[1]: https://github.com/propbank/propbank-release
[2]: https://github.com/UniversalDependencies/UD_English-EWT

In [26]:
SKEL_REGEX = re.compile(r'^\S+\s+\d+\s+(\d+)\s+\[WORD\]\s+\S+\s+(?:\(*[A-Z]*)*\*\)*\s+\S+\s+(\S+)\s*([\S\s]*)$')
REL_REGEX = re.compile(r'\(V\*\)')
ARG_REGEX = re.compile(r'^\(([A-Z\d-]+)\*\)*$')
SPACE_REGEX = re.compile(r'\s+')
WORDLINE_REGEX = re.compile(r'([\d.]+)\s(\S+)\s(\S+)\s(\S+)\s(\S+)\s(\S+)\s([\d_]+)\s(\S+)\s(\S+)\s(\S+)')

class Role:
    def __init__(self, arg, start, end):
        self.arg = arg
        self.start = start
        self.end = end

    def get_headline(self, sentence):
        wordlines = sentence.wordlines
        wordline = wordlines[self.start]
        # go through wordlines within range
        head = int(wordline.head) - 1
        while head >= self.start and head <= self.end:
            wordline = wordlines[head]
            head = int(wordline.head) - 1
        return wordline, int(wordline.sid)-1

class Roleset:
    def __init__(self, sid=None, roles=None, roleset_id=None):
        self.sid = sid if sid is not None else ''
        self.roles = roles if roles is not None else []
        self.roleset_id = roleset_id if roleset_id is not None else ''

class Sentence:
    def __init__(self, rolesets=None, wordlines=None):
        self.rolesets = rolesets if rolesets is not None else []
        self.wordlines = wordlines if wordlines is not None else []
        
class Skelline:
    def __init__(self, index, roleset_id, args):
        self.index = index
        self.roleset_id = roleset_id
        self.args = args
        
class Wordline:
    def __init__(self, line):
        matcher = wordline_regex.match(line)
        if matcher:
            self.sid = matcher.group(1)
            self.form = matcher.group(2)
            self.lemma = matcher.group(3)
            self.upos = matcher.group(4)
            self.xpos = matcher.group(5)
            self.feats = matcher.group(6)
            self.head = matcher.group(7)
            self.deprel = matcher.group(8)
            self.deps = matcher.group(9)
            self.misc = matcher.group(10)
    
    def to_str(self):
        return self.sid + "\t" +self.form + "\t" +self.lemma + "\t" +self.upos + "\t" +self.xpos + "\t" +self.feats + "\t" +self.head + "\t" +self.deprel + "\t" +self.deps + "\t" +self.misc + "\n"

def count_parentheses(s):
    return s.count(')') - s.count('(')

def process_skel_line(matcher):
    index = int(matcher.group(1))
    roleset_id = matcher.group(2)
    arg_str = matcher.group(3).strip()
    args = SPACE_REGEX.split(arg_str) if len(arg_str) > 0 else []
    return Skelline(index, roleset_id, args)

# given a list of skel lines for a sentence, output a list of rolesets
def get_rolesets_from_skel(skel_lines):
    # a list of lists, where each list represents a column in the skel file
    arg_list = []
    # the rolesets in the skel file
    rolesets = []
    # number of columns
    length = len(skel_lines[0].args)
    # initialize the arglist and rolesets list
    for i in range(length):
        arg_list.append([])
        rolesets.append(Roleset())
    
    # traverse each line
    for skel_line in skel_lines:
        # get the roleset id
        roleset_id = skel_line.roleset_id
        # go through each column
        for j in range(length):
            # set arg to current arg
            arg = skel_line.args[j]
            # if arg is rel and roleset_id is not null, set corresponding roleset_id
            if REL_REGEX.match(arg) and roleset_id != '_':
                rolesets[j].sid = roleset_id
                rolesets[j].index = skel_line.index
            # add this arg to arglist at list j
            arg_list[j].append(skel_line.args[j])
    for i in range(len(arg_list)):
        # go through skelarg lists and determine ranges
        arg_name = ''
        count = 0
        start = 0
        end = 0
        # step through items in column
        l = arg_list[i]
        for j in range(len(l)):
            l_str = l[j]
            matcher = ARG_REGEX.match(l_str)
            if matcher:
                arg_name = matcher.group(1)
                start = j
            count += count_parentheses(l_str)
            if count == 0 and l_str != '*':
                end = j
                rolesets[i].roles.append(Role(arg_name, start, end))
    return rolesets

# reads in a .gold_skel file and returns a list of Sentence objects
def get_sentences_from_skel(skel_path):
    skel_lines = []
    sentences = []
    
    for line in skel_path:
        matcher = SKEL_REGEX.match(line)
        # if matches
        if matcher: skel_lines.append(process_skel_line(matcher))
        # else, we've reached the end of a sentence
        else:
            sentence = Sentence()
            sentence.rolesets = get_rolesets_from_skel(skel_lines[:])
            sentences.append(sentence)
            skel_lines = []
    return sentences

# adds wordlines from ud file to each sentence in sentences list
def add_wordlines_from_ud_file(ud_path, sentences):
    i = 0
    # read ud file and add wordlines
    for line in ud_path:
        if line[0] != '#':
            line_matcher = WORDLINE_REGEX.match(line)
            # if line matches, add it to current sentence's wordline list
            if line_matcher:
                if '.' not in line_matcher.group(1): sentences[i].wordlines.append(Wordline(line))
            # else, move onto the next sentence, if there is one
            else:
                i += 1
                if i >= len(sentences): break
    return sentences

def create_annotated_sentence_string(sentence):
    for roleset in sentence.rolesets:
        for role in roleset.roles:
            wordline, i = role.get_headline(sentence)
            if role.arg is 'V':
                sentence.wordlines[i].misc += '\t' + roleset.sid
                break
    return sentence

# creates the augmented ud file
def create_output_file(file_list, sentences):
    with open(file_list + '_ewt.txt', 'w') as f:
        # get head for each role in sentence
        for sentence in sentences:
            sentence = create_annotated_sentence_string(sentence)
            for wordline in sentence.wordlines:
                w_str = wordline.to_str()
                f.write(w_str)

PROPBANK = 'data/propbank-release-master/data/google/ewt/'
UD = 'data/UD_English-EWT/not-to-release/'
FILE_LISTS = ['test']

for file_list in FILE_LISTS:
    file_name = UD + 'file-lists/files.' + file_list
    with open(file_name, 'r') as f:
        sentences = []
    
        for line in f:
            # get .gold_skel file and .conllu file
            skel_file = PROPBANK + line[:-7] + 'gold_skel'
            ud_file = UD + 'sources/' + line.strip()
            # if they both exist, process
            if os.path.exists(skel_file) and os.path.exists(ud_file):
                # get sentences from file
                with open(skel_file, 'r') as skel_path: curr_sentences = get_sentences_from_skel(skel_path)
                with open(ud_file, 'r') as ud_path:
                    sentences += add_wordlines_from_ud_file(ud_path, curr_sentences)
        create_output_file(file_list, sentences)
                    

## Preparing the data

Initialize the following variables:
- the filename prefixes of the training and dev data.
- the NLTK WordNet lemmatizer.
- the regex to capture each line in the annotation data text files. The annotation format, CoNLL-U, is described [here][1].

Define a class to store the line data.

[1]: http://universaldependencies.org/format.html

In [7]:
FILENAMES = ['train_combine']
LEMMATIZER = WordNetLemmatizer()
LINE_REGEX = re.compile(r'^(\d+)\t(\S+)\t\S+\t\S+\t(\S+)\t\S+\t(\d+)\t([a-zROT:]+)(?:\S+)?\t\S+\t\S+(?:(?:\t(\S+))?)?')

class Line:
    def __init__(self, index, form, pos, head, sr, roleset):
        self.index = index
        self.form = form
        self.pos = pos
        self.head = head
        self.sr = sr
        self.roleset = roleset

Go through the train and dev data and find the sentences in which the defined lemma occurs. From there we can extract the desired features. [Mohammad and Pedersen (2004)][1] evaluate different lexical and syntactic features used in supervised word sense disambiguation. The features chosen here are inspired by this work but use [sn-grams][2] instead of n-grams.


The features used here are:

- the syntactic trigram centered on the target word.
- the part-of-speech (POS) tags of each word in the trigram.
- the syntactic relation (SR) tags of each word in the trigram.
- the [WordNet lexicographer file name][3] for the two context words in the trigram.

The label is the number associated with the roleset ID (e.g., the label associated with `be.01` is `'1'`).

### Special cases
#### Target words at the beginning or end of a sentence
The trigram is truncated, and the POS and SR tags are assigned as `'START'` or `'END'`.
#### Lemmatized token matches given lemma, but has no roleset
The example is not used. (Alternatively, the roleset ID could be set to `'0'`.)

#### Roleset ID is the non-numerical value `'LV'`
The roleset ID is set to `n_rolesets + 1`.

### Data augmentation

Data paucity is a problem for many of the lemmata in the training data. [Zhang and LeCun (2015)][4] use a synonym replacement strategy to augment text data. Augment the training set by looking up the synonyms for the context words in each training example.

[1]: http://www.aclweb.org/anthology/W04-2404
[2]: http://www.cic.ipn.mx/~sidorov/Synt_n_grams_ESWA_FINAL.pdf
[3]: https://wordnet.princeton.edu/documentation/lexnames5wn
[4]: https://arxiv.org/pdf/1502.01710.pdf#page=3

In [34]:
def is_single_token(lemma_name):
    return '_' not in lemma_name and ' ' not in lemma_name

def is_not_same_word(x, y):
    return x.lower() != y.lower()

def is_comparative(pos):
    return pos[-1] == 'R'

def is_superlative(pos):
    return pos[-1] == 'S'

def return_self(word):
    return word

def adverb_has_pertainym(synsets):
    if len(synsets) > 0:
        lemmas = synsets[0].lemmas()
        if len(lemmas) > 0:
            if len(lemmas[0].pertainyms()) > 0:
                return True
    return False

def get_syns(function, word, pos, wn_pos):
    synonyms = set()
    
    lemmatized_word = LEMMATIZER.lemmatize(word, wn_pos)           
    synsets = wn.synsets(lemmatized_word, wn_pos)
    if wn_pos == wn.ADV:
        if lemmatized_word == word and adverb_has_pertainym(synsets):
            lemmatized_word = synsets[0].lemmas()[0].pertainyms()[0].name()
            synsets = wn.synsets(lemmatized_word, wn.ADJ)
        else:
            return synonyms
    for synset in synsets:
        for lemma in synset.lemmas():
            if wn_pos == wn.VERB:
                synonyms.add((en.conjugate(lemma.name(), pos),synset.lexname()))
            else:
                synonyms.add((function(lemma.name()),synset.lexname()))
    return synonyms

def get_synonyms(word, pos):
    synonyms = set()
    is_adj = pos[0] == 'J'
    is_common_noun = pos == 'NN' or pos == 'NNS'
    is_adv = pos[:2] == 'RB'
    is_verb = pos[0] == 'V'
    function = return_self
    wn_pos = wn.ADV
    
    if is_adj or is_adv:
        if is_comparative(pos):
            function = en.comparative
        elif is_superlative(pos):
            function = en.superlative
        if is_adj:
            wn_pos = wn.ADJ
    elif is_common_noun:
        if pos[-1] == 'S':
            function = en.pluralize
        else:
            function = en.singularize
        wn_pos = wn.NOUN
    elif is_verb:
        wn_pos = wn.VERB
    else:
        return []
    
    synonyms = get_syns(function, word, pos, wn_pos)
    return [x for x in synonyms if is_single_token(x[0]) and is_not_same_word(x[0], word)]

def get_lexname(word, pos):
    first_char = pos[0]
    wn_pos = None
    if first_char == 'J': wn_pos = wn.ADJ
    elif first_char == 'N': wn_pos = wn.NOUN
    elif first_char == 'V': wn_pos = wn.VERB
    elif first_char == 'R': wn_pos = wn.ADV
    else: return ' '
    
    lemmatized = lemmatizer.lemmatize(word, wn_pos)
    synsets = wn.synsets(lemmatized, wn_pos)
    if len(synsets) > 0:
        return synsets[0].lexname()
    return ' '

def get_wn(token):
    first_char = token.pos[0]
    pos = ''
    if first_char == 'J': pos = wn.ADJ
    elif first_char == 'N': pos = wn.NOUN
    elif first_char == 'V': pos = wn.VERB
    elif first_char == 'R': pos = wn.ADV
    else: return ' ', ' ', ' '
    
    synsets = wn.synsets(token.form, pos=pos)
    if len(synsets) > 0:
        defn = synsets[0].definition()
        lexname = synsets[0].lexname()
        lemmata = [' ']
        for synset in synsets[0].hypernyms():
            for lemma in synset.lemmas():
                lemmata += lemma.key().split('%')[0].split('_')
        return defn, ' '.join(set(lemmata)), lexname 
    else: return ' ', ' ', ' '

def in_aliases(word_form, pos, lemma, rolesets):
    lemmatized = lemmatizer.lemmatize(word_form.lower(), pos)
    for alias in rolesets[0][0]:
        if lemmatized == alias.text: return True
    return False

def create_files(lemma):
    xml_rolesets = propbank.rolesets(lemma)
    n_rolesets = len(xml_rolesets)
    for filename in FILENAMES:
        prev_pos = []
        target_pos = []
        next_pos = []

        prev_sr = []
        target_sr = []
        next_sr = []

        prev_defn = []
        next_defn = []

        prev_hypernym = []
        next_hypernym = []

        prev_lexname = []
        next_lexname = []
        
        
        #ngrams = []
        prev_word = []
        target_word = []
        next_word = []
        rolesets = []
        lv = []
        offset = 0

        with open(filename + '.txt', 'r') as f:
            line_num = 1
            len_rolesets = 0
            sentence = []
            training_examples = []
            for line in f:
                result = line_regex.match(line)
                if result:
                    # get match groups
                    index = int(result.group(1))
                    form = result.group(2)
                    pos = result.group(3)
                    head = int(result.group(4))
                    sr = result.group(5)
                    roleset = result.group(6)
                    # if new sentence, process old sentence and set beginning of new sentence
                    if index < line_num:
                        for i, word in enumerate(sentence):
                            first_char = word.pos[0].lower()
                            # TODO: adjectives
                            if first_char in ['n', 'v'] and in_aliases(word.form, first_char, lemma, xml_rolesets) and word.roleset is not None and '.' in word.roleset:
                                #if word.roleset == None:
                                #    rolesets.append('0')
                                #else:
                                arr = word.roleset.split('.')
                                num = word.roleset.split('.')[1]
                                if num == 'LV':
                                    rolesets.append(str(n_rolesets+1))
                                    lv.append(len_rolesets-1)
                                else:
                                    num = int(num)
                                    rolesets.append(str(num))
                                    if num > n_rolesets:
                                        n_rolesets = num
                                len_rolesets += 1
                                
                                #form_list = [word.form]
                                prev_form = '<<START>>'
                                target_word.append(word.form)
                                next_form = '<<END>>'
                                # target word
                                target_pos.append(word.pos)
                                target_sr.append(word.sr.split(':')[0])

                                #Create the sn-gram
                                curr_i = i-1
                                while curr_i >= 0:
                                    #traverse backwards
                                    #if token is head or child of curr
                                    prev_tok = sentence[curr_i]
                                    if prev_tok.index == word.head or prev_tok.head == word.index:
                                        #form_list.insert(0, prev_tok.form)
                                        prev_form = prev_tok.form
                                        prev_pos.append(prev_tok.pos)
                                        prev_sr.append(prev_tok.sr.split(':')[0])
                                        p_defn, p_hypernym, p_lexname = get_wn(prev_tok)
                                        prev_defn.append(p_defn)
                                        prev_hypernym.append(p_hypernym)
                                        prev_lexname.append(p_lexname)
                                        break
                                    else: curr_i -= 1
                                if curr_i < 0:
                                    prev_pos.append('START')
                                    prev_sr.append('START')
                                    prev_defn.append(' ')
                                    prev_hypernym.append(' ')
                                    prev_lexname.append('START')

                                curr_i = i+1
                                while curr_i < len(sentence):
                                    #traverse forwards
                                    #if token is head or child of curr
                                    next_tok = sentence[curr_i]
                                    if next_tok.index == word.head or next_tok.head == word.index:
                                        #form_list.append(next_tok.form)
                                        next_form = next_tok.form
                                        next_pos.append(next_tok.pos)
                                        next_sr.append(next_tok.sr.split(':')[0])
                                        n_defn, n_hypernym, n_lexname = get_wn(next_tok)
                                        next_defn.append(n_defn)
                                        next_hypernym.append(n_hypernym)
                                        next_lexname.append(n_lexname)
                                        break
                                    else: curr_i += 1
                                if curr_i is len(sentence):
                                    next_pos.append('END')
                                    next_sr.append('END')
                                    next_defn.append(' ')
                                    next_hypernym.append(' ')
                                    next_lexname.append('END')

                                #ngrams.append(' '.join(form_list))
                                prev_word.append(prev_form)
                                next_word.append(next_form)
                                
                                #data augmentation
                                #if filename == 'train_combine':
                                #    for prev_paraphrase in get_paraphrases(prev_word[-1], prev_pos[-1]):
                                #        for next_paraphrase in get_paraphrases(next_word[-1], next_pos[-1]):
                                    #for prev_synonym in get_synonyms(prev_word[-1], prev_pos[-1]):
                                    #    for next_synonym in get_synonyms(next_word[-1], next_pos[-1]):
                                #            if rolesets[-1] == n_rolesets+1:
                                #                lv.append(len_rolesets)
                                #            rolesets.append(rolesets[-1])
                                        
                                #            prev_word.append(prev_paraphrase)
                                #            target_word.append(target_word[-1])
                                #            next_word.append(next_paraphrase)
                                                                                         
                                #            prev_pos.append(prev_pos[-1])
                                #            target_pos.append(target_pos[-1])
                                #            next_pos.append(next_pos[-1])

                                #            prev_sr.append(prev_sr[-1])
                                #            target_sr.append(target_sr[-1])
                                #            next_sr.append(next_sr[-1])

                                #            prev_defn.append(prev_defn[-1])
                                #            next_defn.append(next_defn[-1])

                                #            prev_hypernym.append(prev_hypernym[-1])
                                #            next_hypernym.append(prev_hypernym[-1])

                                #            prev_lexname.append(get_lexname(prev_paraphrase, prev_pos[-1]))
                                #            next_lexname.append(get_lexname(next_paraphrase, next_pos[-1]))
                                
                        # reset sentence
                        sentence = [Line(index, form, pos, head, sr, roleset)]
                    # else, add to existing sentence
                    else: sentence.append(Line(index, form, pos, head, sr, roleset))
                    line_num = index

        #add offset to LVs
        for i in lv: rolesets[i] = str(n_rolesets+1)

        with open('processed/' + lemma + '_' + filename + '_examples_unaugmented.txt', 'w') as examples:
            for i in range(len(rolesets)):
                examples.write('\t'.join([prev_word[i], target_word[i], next_word[i], prev_pos[i], target_pos[i], next_pos[i], prev_sr[i], target_sr[i], next_sr[i], prev_defn[i], next_defn[i], prev_hypernym[i], next_hypernym[i], prev_lexname[i], next_lexname[i], rolesets[i]]) + '\n')
                #examples.write('\t'.join([ngrams[i], prev_pos[i], target_pos[i], next_pos[i], prev_sr[i], target_sr[i], next_sr[i], prev_defn[i], next_defn[i], prev_hypernym[i], next_hypernym[i], prev_lexname[i], next_lexname[i], rolesets[i]]) + '\n')

## Building the classifier

All code below this point is run outside the notebook, as a separate `classifier.py` Python file. This is because the events file [currently do not save correctly in a Jupyter notebook](https://stackoverflow.com/a/51232738). The classifier can be run on the command line:

`python classifier.py`

TensorFlow defines a class called an [Estimator][1], which is used to train and evaluate models. This model uses a premade Estimator called a [DNNClassifier][2]. Code derived from [this tutorial][3].

Define a function to load data into a DataFrame.

[1]: https://www.tensorflow.org/versions/master/api_docs/python/tf/estimator/Estimator
[2]: https://www.tensorflow.org/versions/master/api_docs/python/tf/estimator/DNNClassifier
[3]: https://www.tensorflow.org/hub/tutorials/text_classification_with_tf_hub

In [2]:
def load_data(filename):
    n_rolesets = 1
    
    data = {}
    #data['ngram'] = []
    data['prev_word'] = []
    data['target_word'] = []
    data['next_word'] = []
    
    data['prev_pos'] = []
    data['target_pos'] = []
    data['next_pos'] = []
    
    data['prev_sr'] = []
    data['target_sr'] = []
    data['next_sr'] = []
    
    #data['prev_defn'] = []
    #data['next_defn'] = []
    
    #data['prev_hypernym'] = []
    #data['next_hypernym'] = []
    
    data['prev_lexname'] = []
    data['next_lexname'] = []
    
    data['roleset'] = []
    with open(filename, 'r') as f:
            prog = re.compile(r'^(\S+)\t(\S+)\t(\S+)\t(\S+)\t(\S+)\t(\S+)\t(\S+)\t(\S+)\t(\S+)\t([\S ]+)\t([\S ]+)\t([\S ]+)\t([\S ]+)\t([\S ]+)\t([\S ]+)\t(\d+)$')
            for line in f:
                result = prog.match(line)
                if result:
                    data['prev_word'].append(result.group(1))
                    data['target_word'].append(result.group(2))
                    data['next_word'].append(result.group(3))
                    
                    data['prev_pos'].append(result.group(4))
                    data['target_pos'].append(result.group(5))
                    data['next_pos'].append(result.group(6))
                    data['prev_sr'].append(result.group(7))
                    data['target_sr'].append(result.group(8))
                    data['next_sr'].append(result.group(9))
                    
                    #data['prev_defn'].append(result.group(10))
                    #data['next_defn'].append(result.group(11))
                    
                    #data['prev_hypernym'].append(result.group(12))
                    #data['next_hypernym'].append(result.group(13))
                    
                    data['prev_lexname'].append(result.group(14))
                    data['next_lexname'].append(result.group(15))
                    
                    roleset = int(result.group(16))
                    if roleset > n_rolesets: n_rolesets = roleset
                    
                    data['roleset'].append(roleset)
    return pd.DataFrame.from_dict(data), n_rolesets

### Creating the feature columns

The model uses [feature columns][1] to transform the data into a format compatible with the Estimator.

Define the possible labels for POS tags. This consists of the [Penn Treebank POS tags][2], plus the special `'START'` and `'END'` tags and some punctuation tags.

Define the possible labels for the SR tags. This consists of the [Universal Dependencies relations][3], plus the special `'neg'`, `'START'`, `'END'`, and `'ROOT'` tags.

[1]: https://www.tensorflow.org/guide/feature_columns
[2]:https://repository.upenn.edu/cgi/viewcontent.cgi?article=1246&context=cis_reports#page=8
[3]: http://universaldependencies.org/u/dep/

In [3]:
pos_list = ['START', 'END', 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNS', 'NNP', 'NNPS', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB', 'HYPH', '\'\'', '.', ':', ',', ')', '(', '``', '$' ]
sr_list = ['START', 'END', 'ROOT', 'neg', 'acl', 'advcl', 'advmod', 'amod', 'appos', 'aux', 'case', 'cc', 'ccomp', 'clf', 'compound', 'conj', 'cop', 'csubj', 'dep', 'det', 'discourse', 'dislocated', 'expl', 'fixed', 'flat', 'goeswith', 'iobj', 'list', 'mark', 'nmod', 'nsubj', 'nummod', 'obj', 'obl', 'orphan', 'parataxis', 'punct', 'reparandum', 'root', 'vocative', 'xcomp']
lexname_list = ['START', 'END', ' ', 'adj.all', 'adj.pert', 'adv.all', 'noun.Tops', 'noun.act', 'noun.animal', 'noun.artifact', 'noun.attribute', 'noun.body', 'noun.cognition', 'noun.communication', 'noun.event', 'noun.feeling', 'noun.food', 'noun.group', 'noun.location', 'noun.motive', 'noun.object', 'noun.person', 'noun.phenomenon', 'noun.plant', 'noun.possession', 'noun.process', 'noun.quantity', 'noun.relation', 'noun.shape', 'noun.state', 'noun.substance', 'noun.time', 'verb.body', 'verb.change', 'verb.cognition', 'verb.communication', 'verb.competition', 'verb.consumption', 'verb.contact', 'verb.creation', 'verb.emotion', 'verb.motion', 'verb.perception', 'verb.possession', 'verb.social', 'verb.stative', 'verb.weather', 'adj.ppl']

Load the training and test data into DataFrames.

Represent the POS and SR tags in [categorical vocabulary columns](https://www.tensorflow.org/guide/feature_columns#categorical_vocabulary_column), which are essentially one-hot vectors. The model should learn some relationships between the categories, since some of them are related (e.g., adjectives, nouns, adverbs, verbs, etc. for POS; nominals, clauses, etc. for SR), so wrap the categorical columns in a lower-dimension [embedding column](https://www.tensorflow.org/guide/feature_columns#indicator_and_embedding_columns).

[ExampleCheckpointSaverListener](https://www.tensorflow.org/api_docs/python/tf/train/CheckpointSaverListener)

Define a function to train and evaluate a model using a given [text module](https://www.tensorflow.org/hub/modules/text). These modules define text embeddings that can be used to represent our trigrams in a feature column.

The model implements accuracy-based [early stopping](https://github.com/tensorflow/tensorflow/issues/18394).

In [94]:
class ExampleCheckpointSaverListener(tf.train.CheckpointSaverListener):
    def __init__(self, estimator, dev_input_fn, best_accuracy):
        self.estimator = estimator
        self.dev_input_fn = dev_input_fn
        self.best_accuracy = best_accuracy
        
    def begin(self):
        # You can add ops to the graph here.
        print('Starting the session.')

    def before_save(self, session, global_step_value):
        results = self.estimator.evaluate(input_fn=self.dev_input_fn)
        print(results)
        print('Best accuracy: ', self.best_accuracy)
        if results['accuracy'] >= self.best_accuracy:
            self.best_accuracy = results['accuracy']

    def after_save(self, session, global_step_value):
        print('Checkpoint saved')

    def end(self, session, global_step_value):
        print('Done with the session.')

def create_shared_embedding_columns(keys, vocabulary_list, dimension):
    column_list = []
    for key in keys:
        column_list.append(
            tf.feature_column.categorical_column_with_vocabulary_list(
                key=key,
                vocabulary_list=vocabulary_list))
    return tf.feature_column.shared_embedding_columns(column_list, dimension=dimension)

def train_and_evaluate_with_module(hub_module, lemma, train_module=False):
    train_df, n_rolesets = load_data('processed/' + lemma + '_train_combine_examples_unaugmented.txt')
    dev_df, dev_rolesets = load_data('processed/' + lemma + '_dev_combine_examples.txt')
    test_df, test_rolesets = load_data('processed/' + lemma + '_test_combine_examples.txt')    

    if dev_rolesets > n_rolesets: n_rolesets = dev_rolesets
    if test_rolesets > n_rolesets: n_rolesets = test_rolesets    

    # training input on whole training set with no limit on training epochs
    train_input_fn = tf.estimator.inputs.pandas_input_fn(
    train_df, train_df['roleset'], num_epochs=None, shuffle=True)

    # dev input on whole training set with no limit on training epochs
    dev_input_fn = tf.estimator.inputs.pandas_input_fn(
    dev_df, dev_df['roleset'], num_epochs=None, shuffle=False)

    # prediction on whole training set
    predict_train_input_fn = tf.estimator.inputs.pandas_input_fn(
    train_df, train_df['roleset'], batch_size=32, shuffle=True)

    # training input on whole dev set
    predict_dev_input_fn = tf.estimator.inputs.pandas_input_fn(
    dev_df, dev_df['roleset'], batch_size=32, shuffle=False)

    predict_test_input_fn = tf.estimator.inputs.pandas_input_fn(
    test_df, test_df['roleset'], batch_size=32, shuffle=False)

    #shared embeddiing columns
    pos_feature_columns = create_shared_embedding_columns(['prev_pos', 'target_pos', 'next_pos'], pos_list, 5)
    sr_feature_columns = create_shared_embedding_columns(['prev_sr', 'target_sr', 'next_sr'], sr_list, 12)
    lexname_feature_columns = create_shared_embedding_columns(['prev_lexname', 'next_lexname'], lexname_list, 4)

    text_feature_columns = [
        hub.text_embedding_column(key='prev_word', module_spec=hub_module, trainable=train_module),
        hub.text_embedding_column(key='target_word', module_spec=hub_module, trainable=train_module),
        hub.text_embedding_column(key='next_word', module_spec=hub_module, trainable=train_module)]

    model_dir = "/home/tphan/Desktop/python/2018_summer/classifier/model/" + lemma
    feature_columns = text_feature_columns+pos_feature_columns+sr_feature_columns+lexname_feature_columns
    #estimator = tf.estimator.BaselineClassifier(n_classes=n_rolesets+1,model_dir=model_dir)
    estimator = tf.estimator.DNNClassifier(
        config=tf.estimator.RunConfig(keep_checkpoint_max=16),
        hidden_units=[500,100],
        feature_columns=text_feature_columns+pos_feature_columns+sr_feature_columns+lexname_feature_columns,
        n_classes=n_rolesets+1,
        optimizer=tf.train.AdamOptimizer(),
        dropout=.5,
        model_dir=model_dir)
    os.makedirs(estimator.eval_dir())

    best_accuracy = -1
    listener = ExampleCheckpointSaverListener(estimator, predict_dev_input_fn, best_accuracy)
    saver_hook = tf.train.CheckpointSaverHook(model_dir, listeners=[listener], save_steps=2000)
    hook = tf.contrib.estimator.stop_if_no_decrease_hook(estimator, 'loss', 8000, min_steps=8000, run_every_secs=None, run_every_steps=2000)
    estimator.train(input_fn=train_input_fn, hooks=[hook, saver_hook])
    checkpoint_state_proto = tf.train.get_checkpoint_state(model_dir)
    best_results = {}
    best_predictions = []
    best_checkpoint_path = None
    test_results = {}
    test_predictions = []
    if checkpoint_state_proto is not None:
        checkpoint_paths = checkpoint_state_proto.all_model_checkpoint_paths
        if len(checkpoint_paths) > 0:
            best_accuracy = 0
            for checkpoint_path in checkpoint_paths:
                results = estimator.evaluate(input_fn=predict_dev_input_fn, checkpoint_path=checkpoint_path)
                if results['accuracy'] >= best_accuracy:
                    best_checkpoint_path = checkpoint_path
                    best_accuracy = results['accuracy']
                    best_results = results

            for pred_dict in estimator.predict(predict_dev_input_fn, checkpoint_path=best_checkpoint_path):
                best_predictions.append(pred_dict)
            for pred_dict in estimator.predict(predict_test_input_fn, checkpoint_path=best_checkpoint_path):
                test_predictions.append(pred_dict)
            
            test_results = estimator.evaluate(input_fn=predict_test_input_fn, checkpoint_path=best_checkpoint_path)
            print('Results')
            print(best_results)
            print('Test results')
            print(test_results)
            #estimator.export_savedmodel('/home/tphan/Desktop/python/2018_summer/classifier/savedmodel/' + lemma,
            #                            tf.estimator.export.build_parsing_serving_input_receiver_fn(
            #                                tf.feature_column.make_parse_example_spec(feature_columns),
            #                                32),
            #                            checkpoint_path=checkpoint_paths[0],
            #                            strip_default_attrs=True)
            shutil.rmtree(model_dir)
            return best_predictions, test_predictions, {
                'Rolesets': str(int(n_rolesets)),
                'Rolesets in dev': str(int(dev_rolesets)),
                'Train size': str(int(train_df.shape[0])),
                'Dev size': str(int(dev_df.shape[0])),
                'Dev accuracy': best_results['accuracy'],
                'Average loss': best_results['average_loss'],
                'Loss': best_results['loss'],
                'Global step': best_results['global_step']
            }, {
                'Rolesets': str(int(n_rolesets)),
                'Rolesets in test': str(int(test_rolesets)),
                'Train size': str(int(train_df.shape[0])),
                'Test size': str(int(test_df.shape[0])),
                'Test accuracy': test_results['accuracy'],
                'Average loss': test_results['average_loss'],
                'Loss': test_results['loss'],
                'Global step': test_results['global_step']
            }
    return None, None

In [96]:
lemmata = ['be', 'have', 'say', 'do', 'get', 'go', 'make', 'use', 'take', 'know', 'see', 'come', 'call', 'work', 'add', 'find', 'help', 'pay', 'try', 'look', 'operate', 'hold', 'tell', 'continue', 'serve', 'show', 'become', 'end', 'move', 'fall', 'concern', 'put', 'ask', 'keep', 'leave', 'change', 'run', 'spend', 'raise', 'lead', 'start', 'meet', 'feel', 'consider', 'send', 'develop', 'charge', 'build', 'lose', 'close', 'allow', 'grow', 'talk', 'mean', 'return', 'order', 'decline', 'cut', 'note', 'file', 'name', 'deal', 'base', 'act', 'live', 'appear', 'reach', 'view', 'follow', 'open', 'drop', 'manage', 'play', 'set', 'succeed', 'stop', 'happen', 'vote', 'turn', 'yield', 'stay', 'improve', 'question', 'pass', 'finance', 'drive', 'force', 'cover', 'stand', 'settle', 'argue', 'claim', 'fly', 'apply', 'process', 'review', 'rule', 'assume', 'love', 'introduce', 'resign', 'hit', 'affect', 'push', 'join', 'bid', 'sign', 'head', 'watch', 'treat', 'enter', 'care', 'gain', 'confer', 'refer', 'back', 'strike', 'walk', 'prepare', 'form', 'worry', 'save', 'break', 'bear', 'cite', 'compete', 'admit', 'encourage', 'aim', 'list', 'perform', 'measure', 'conclude', 'hurt', 'fill', 'ease', 'retire', 'contend', 'recover', 'realize', 'sound', 'slow', 'jump', 'draw', 'trip', 'fix', 'conduct', 'execute', 'miss', 'identify', 'recall', 'pull', 'promote', 'approach', 'tend', 'resolve', 'discount', 'throw', 'finish', 'extend', 'clear', 'matter', 'check', 'address', 'notice', 'employ', 'appreciate', 'suppose', 'restore', 'express', 'train', 'recognize', 'mark', 'commit', 'emerge', 'stem', 'feed', 'catch', 'spread', 'shoot', 'concentrate', 'climb', 'rally', 'count', 'soar', 'differ', 'cap', 'abandon', 'register', 'prompt', 'beat', 'plunge', 'track', 'plead', 'figure', 'warm', 'submit', 'press', 'fit', 'slip', 'project', 'exercise', 'divide', 'afford', 'double', 'arrest', 'seize', 'trust', 'insure', 'fire', 'sense', 'mount', 'mind', 'sustain', 'protest', 'lift', 'split', 'pick', 'locate', 'drink', 'contract', 'wave', 'struggle', 'crash', 'appeal', 'satisfy', 'pose', 'dismiss', 'cast', 'time', 'tie', 'cross', 'secure', 'assert', 'point', 'lease', 'land', 'squeeze', 'bother', 'slide', 'sleep', 'observe', 'halt', 'abuse', 'taste', 'swing', 'scare', 'relieve', 'explode', 'depress', 'credit', 'blow', 'bill', 'smoke', 'slash', 'rent', 'manipulate', 'illustrate', 'hang', 'evolve', 'contest', 'amount', 'weigh', 'prevail', 'trim', 'impress', 'dance', 'cheat', 'scramble', 'march', 'laugh', 'ring', 'burst', 'terminate', 'race', 'smell', 'paint', 'jolt', 'incorporate', 'excuse', 'cook', 'celebrate', 'tremor', 'surface', 'rest', 'top', 'tap', 'shed', 'leap', 'crack', 'assemble', 'prescribe', 'knock', 'compose', 'capitalize', 'unload', 'tape', 'scrap', 'retreat', 'freeze', 'upgrade', 'sweep', 'subscribe', 'motivate', 'fold', 'pitch', 'mistake', 'erupt', 'crowd', 'spell', 'render', 'dispose', 'correspond', 'swear', 'spin', 'seat', 'overlook', 'filter', 'dip', 'dictate', 'condition', 'classify', 'tip', 'sway', 'strain', 'screen', 'reckon', 'entitle', 'compromise', 'boom', 'bind', 'bend', 'upset', 'tear', 'sniff', 'scuttle', 'scale', 'hail', 'flash', 'curse', 'cry', 'command', 'bond', 'wrestle', 'stir', 'refinance', 'cheer', 'bleed', 'blast', 'venture', 'spare', 'pop', 'insulate', 'grind', 'galvanize', 'frame', 'flock', 'finger', 'divorce', 'code', 'circle', 'bow', 'balloon', 'stamp', 'snap', 'scratch', 'restate', 'reassert', 'rattle', 'plug', 'pile', 'pave', 'discharge', 'dawn', 'choke', 'bust', 'smash', 'lodge', 'heave', 'drool', 'cruise', 'conceive', 'bundle', 'appraise', 'wring', 'wiggle', 'weave', 'spurt', 'slam', 'skirt', 'skid', 'screech', 'ply', 'hook', 'fume', 'dispense', 'delight', 'blunder', 'accord', 'unhinge', 'stunt', 'repaint', 'recess', 'pinch', 'overbid', 'marvel', 'level', 'inaugurate', 'buzz']
zero_dev = ['accord', 'amount', 'appraise', 'assemble', 'assert', 'balloon', 'bend', 'bind', 'bleed', 'blunder', 'boom', 'bow', 'bundle', 'bust', 'capitalize', 'cheat', 'choke', 'circle', 'classify', 'clear', 'code', 'command', 'condition', 'correspond', 'crack', 'cruise', 'curse', 'dawn', 'dictate', 'dip', 'discharge', 'dispense', 'divorce', 'drool', 'employ', 'entitle', 'erupt', 'evolve', 'exercise', 'figure', 'filter', 'finger', 'flash', 'fold', 'frame', 'fume', 'galvanize', 'grind', 'heave', 'hook', 'illustrate', 'impress', 'inaugurate', 'insure', 'jolt', 'laugh', 'leap', 'lease', 'lift', 'lodge', 'manipulate', 'measure', 'motivate', 'overbid', 'overlook', 'plead', 'pop', 'prescribe', 'project', 'race', 'reassert', 'recall', 'recess', 'refinance', 'register', 'render', 'repaint', 'restate', 'rest', 'restore', 'ring', 'scale', 'scare', 'scrap', 'scratch', 'screech', 'screen', 'seize', 'shed', 'skid', 'skirt', 'slam', 'slash', 'smash', 'snap', 'sniff', 'spare', 'spell', 'split', 'spurt', 'squeeze', 'stamp', 'stem', 'strain', 'surface', 'sway', 'swear', 'terminate', 'top', 'tremor', 'trim', 'unhinge', 'upset', 'warm', 'weave', 'wiggle', 'wrestle', 'bond', 'burst', 'buzz', 'celebrate', 'cheer', 'compose', 'cross', 'crowd', 'cry', 'delight', 'dispose', 'excuse', 'feed', 'halt', 'insulate', 'knock', 'level', 'march', 'marvel', 'pitch', 'plug', 'ply', 'reckon', 'resolve', 'retire', 'scramble', 'scuttle', 'seat', 'secure', 'slide', 'smoke', 'soar', 'spin', 'submit', 'tape', 'tap', 'tear', 'tip', 'unload', 'venture', 'weigh']
lemmata = [x for x in lemmata if x not in zero_dev]

results = {}
for lemma in lemmata:
    #create_files(lemma)
    results[lemma] = train_and_evaluate_with_module('modules/1', lemma)
    shutil.rmtree("/home/tphan/Desktop/python/2018_summer/classifier/model/" + lemma)

INFO:tensorflow:Using config: {'_model_dir': '/home/tphan/Desktop/python/2018_summer/classifier/model/be', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 1, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f60651dc438>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Callin

INFO:tensorflow:global_step/sec: 195.937
INFO:tensorflow:loss = 1.2495754, step = 6501 (0.510 sec)
INFO:tensorflow:global_step/sec: 193.044
INFO:tensorflow:loss = 0.5080174, step = 6601 (0.518 sec)
INFO:tensorflow:global_step/sec: 193.383
INFO:tensorflow:loss = 1.366303, step = 6701 (0.517 sec)
INFO:tensorflow:global_step/sec: 192.295
INFO:tensorflow:loss = 0.17961384, step = 6801 (0.520 sec)
INFO:tensorflow:global_step/sec: 196.72
INFO:tensorflow:loss = 2.46152, step = 6901 (0.509 sec)
INFO:tensorflow:global_step/sec: 196.938
INFO:tensorflow:loss = 0.5143526, step = 7001 (0.508 sec)
INFO:tensorflow:global_step/sec: 193.761
INFO:tensorflow:loss = 0.6283549, step = 7101 (0.516 sec)
INFO:tensorflow:global_step/sec: 192.239
INFO:tensorflow:loss = 3.4866638, step = 7201 (0.520 sec)
INFO:tensorflow:global_step/sec: 191.715
INFO:tensorflow:loss = 0.36126232, step = 7301 (0.522 sec)
INFO:tensorflow:global_step/sec: 195.954
INFO:tensorflow:loss = 1.1012791, step = 7401 (0.510 sec)
INFO:tensorf