In [39]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import sys
sys.argv = sys.argv[:1]
from model import ParserModel, main

import time

from itertools import islice
from sys import stdout
from tempfile import NamedTemporaryFile
import tensorflow as tf
from utils.model import Model
from data import load_and_preprocess_data
from data import score_arcs
from initialization import xavier_weight_init
from parser import minibatch_parse
from utils.generic_utils import Progbar

from tensorflow.python.tools.freeze_graph import freeze_graph
import tfcoreml

In [40]:
class Config(object):
    """Holds model hyperparams and data information.

    The config class is used to store various hyperparameters and dataset
    information parameters. Model objects are passed a Config() object at
    instantiation.
    """
    n_word_ids = None # inferred
    n_tag_ids = None # inferred
    n_deprel_ids = None # inferred
    n_word_features = None # inferred
    n_tag_features = None # inferred
    n_deprel_features = None # inferred
    n_classes = None # inferred
    dropout = 0.5
    embed_size = None # inferred
    hidden_size = FLAGS.hidden_size 
    batch_size = 2048
    n_epochs = FLAGS.epochs
    lr = FLAGS.lr
    l2_beta = FLAGS.l2_beta
    l2_loss = 0

In [41]:
'''Main function

Args:
debug :
    whether to use a fraction of the data. Make sure to set to False
    when you're ready to train your model for real!
'''
print(80 * "=")
print("INITIALIZING")
print(80 * "=")
config = Config()
data = load_and_preprocess_data(
    max_batch_size=config.batch_size)
transducer, word_embeddings, train_data = data[:3]
dev_sents, dev_arcs = data[3:5]
test_sents, test_arcs = data[5:]
config.n_word_ids = len(transducer.id2word) + 1 # plus null
config.n_tag_ids = len(transducer.id2tag) + 1
config.n_deprel_ids = len(transducer.id2deprel) + 1
config.embed_size = word_embeddings.shape[1]
for (word_batch, tag_batch, deprel_batch), td_batch in \
        train_data.get_iterator(shuffled=False):
    config.n_word_features = word_batch.shape[-1]
    config.n_tag_features = tag_batch.shape[-1]
    config.n_deprel_features = deprel_batch.shape[-1]
    config.n_classes = td_batch.shape[-1]
    break
print(
    'Word feat size: {}, tag feat size: {}, deprel feat size: {}, '
    'classes size: {}'.format(
        config.n_word_features, config.n_tag_features,
        config.n_deprel_features, config.n_classes))

INITIALIZING
Loading word embeddings...there are 4003 word embeddings.
Determining POS tags...['ADJ', 'ADP', 'ADV', 'AUX', 'CONJ', 'DET', 'INTJ', 'NOUN', 'NUM', 'PART', 'PRON', 'PROPN', 'PUNCT', 'SCONJ', 'SYM', 'VERB', 'X']
there are 17 tags.
Determining deprel labels...['acl', 'acl:relcl', 'advcl', 'advmod', 'amod', 'appos', 'aux', 'auxpass', 'case', 'cc', 'cc:preconj', 'ccomp', 'compound', 'compound:prt', 'conj', 'cop', 'csubj', 'csubjpass', 'dep', 'det', 'det:predet', 'discourse', 'dobj', 'expl', 'iobj', 'mark', 'mwe', 'neg', 'nmod', 'nmod:npmod', 'nmod:poss', 'nmod:tmod', 'nsubj', 'nsubjpass', 'nummod', 'parataxis', 'punct', 'root', 'xcomp']
there are 39 deprel labels.
Getting training data...there are 1895754 samples.
Getting dev data...there are 1700 samples.
Getting test data...there are 2416 samples.
Word feat size: 18, tag feat size: 18, deprel feat size: 12, classes size: 83


In [42]:
debug = False

In [43]:
if debug:
    dev_sents = dev_sents[:500]
    dev_arcs = dev_arcs[:500]
    test_sents = test_sents[:500]
    test_arcs = test_arcs[:500]
if not debug:
    weight_file = NamedTemporaryFile(suffix='.weights')
    # weight_file = open("something.weights", mode=)
with tf.Graph().as_default(), tf.Session() as session:
    print("Building model...", end=' ')
    start = time.time()
    model = ParserModel(transducer, session, config, word_embeddings, is_training=True)
    print("took {:.2f} seconds\n".format(time.time() - start))
    init = tf.global_variables_initializer()
    session.run(init)
    output_names = 'output/td_vec'
    saver = None if debug else tf.train.Saver()
#     saver.restore(session, "checkpoints/model.ckpt")
    frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
        sess=session,
        input_graph_def=tf.compat.v1.get_default_graph().as_graph_def(),
        output_node_names=[output_names])
    frozen_graph = tf.compat.v1.graph_util.extract_sub_graph(
        graph_def=frozen_graph,
        dest_nodes=[output_names])
    with open('checkpoints/frozen_graph.pb', 'wb') as fout:
        fout.write(frozen_graph.SerializeToString())
    print(80 * "=")
    print("TRAINING")
    print(80 * "=")
    best_las = 0.
    for epoch in range(config.n_epochs):
        print('Epoch {}'.format(epoch))

        if debug:
            model.fit_epoch(list(islice(train_data,3)), config.batch_size)
        else:
            model.fit_epoch(train_data)
        stdout.flush()
        dev_las, dev_uas = model.eval(dev_sents, dev_arcs)
        best = dev_las > best_las
        if best:
            best_las = dev_las
            if not debug:
                saver.save(session, "checkpoints/model.ckpt")
                tf.io.write_graph(session.graph_def, './checkpoints/', 'model.pbtxt')
                frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
                    sess=session,
                    input_graph_def=tf.compat.v1.get_default_graph().as_graph_def(),
                    output_node_names=[output_names])
                frozen_graph = tf.compat.v1.graph_util.extract_sub_graph(
                    graph_def=frozen_graph,
                    dest_nodes=[output_names])
                with open('checkpoints/frozen_graph.pb', 'wb') as fout:
                    fout.write(frozen_graph.SerializeToString())

        print('Validation LAS: ', end='')
        print('{:.2f}{}'.format(dev_las, ' (BEST!), ' if best else ', '))
        print('Validation UAS: ', end='')
        print('{:.2f}'.format(dev_uas))
    if not debug:
        print()
        print(80 * "=")
        print("TESTING")
        print(80 * "=")
        print("Restoring the best model weights found on the dev set")
        saver.restore(session, "checkpoints/model.ckpt")
        stdout.flush()
        las,uas = model.eval(test_sents, test_arcs)
        if las:
            print("Test LAS: ", end='')
            print('{:.2f}'.format(las), end=', ')
        print("Test UAS: ", end='')
        print('{:.2f}'.format(uas))
        print("Done!")

Building model... 
	relu activation function
	1 hidden layer(s) with size 200
	adam optimizer with learning rate 0.001
took 1.55 seconds

INFO:tensorflow:Froze 9 variables.
INFO:tensorflow:Converted 9 variables to const ops.
TRAINING
Epoch 0
[('Influential', 'ADJ'), ('members', 'NOUN'), ('of', 'ADP'), ('the', 'DET'), ('House', 'PROPN'), ('Ways', 'PROPN'), ('and', 'CONJ'), ('Means', 'PROPN'), ('Committee', 'PROPN'), ('introduced', 'VERB'), ('legislation', 'NOUN'), ('that', 'PRON'), ('would', 'AUX'), ('restrict', 'VERB'), ('how', 'ADV'), ('the', 'DET'), ('new', 'ADJ'), ('savings-and-loan', 'NOUN'), ('bailout', 'NOUN'), ('agency', 'NOUN'), ('can', 'AUX'), ('raise', 'VERB'), ('capital', 'NOUN'), (',', 'PUNCT'), ('creating', 'VERB'), ('another', 'DET'), ('potential', 'ADJ'), ('obstacle', 'NOUN'), ('to', 'ADP'), ('the', 'DET'), ('government', 'NOUN'), ("'s", 'PART'), ('sale', 'NOUN'), ('of', 'ADP'), ('sick', 'ADJ'), ('thrifts', 'NOUN'), ('.', 'PUNCT')]
INFO:tensorflow:Froze 9 variables.
INFO

In [50]:
with tf.Graph().as_default(), tf.Session() as session:
    print("Building model...", end=' ')
    start = time.time()
    model = ParserModel(transducer, session, config, word_embeddings, is_training=True)
    print("took {:.2f} seconds\n".format(time.time() - start))
    init = tf.global_variables_initializer()
    session.run(init)
    output_names = 'output/td_vec'
    saver = None if debug else tf.train.Saver()
#     saver.restore(session, "checkpoints/model.ckpt")
    frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
        sess=session,
        input_graph_def=tf.compat.v1.get_default_graph().as_graph_def(),
        output_node_names=[output_names])
    frozen_graph = tf.compat.v1.graph_util.extract_sub_graph(
        graph_def=frozen_graph,
        dest_nodes=[output_names])
    with open('checkpoints/frozen_graph.pb', 'wb') as fout:
        fout.write(frozen_graph.SerializeToString())
    saver.restore(session, "checkpoints/model.ckpt")
    print(test_sents)
    las,uas = model.eval(test_sents, test_arcs)


Building model... 
	relu activation function
	1 hidden layer(s) with size 200
	adam optimizer with learning rate 0.001
took 0.31 seconds

INFO:tensorflow:Froze 9 variables.
INFO:tensorflow:Converted 9 variables to const ops.
INFO:tensorflow:Restoring parameters from checkpoints/model.ckpt
[[('No', 'ADV'), (',', 'PUNCT'), ('it', 'PRON'), ('was', 'VERB'), ("n't", 'PART'), ('Black', 'PROPN'), ('Monday', 'PROPN'), ('.', 'PUNCT')], [('But', 'CONJ'), ('while', 'SCONJ'), ('the', 'DET'), ('New', 'PROPN'), ('York', 'PROPN'), ('Stock', 'PROPN'), ('Exchange', 'PROPN'), ('did', 'AUX'), ("n't", 'PART'), ('fall', 'VERB'), ('apart', 'ADV'), ('Friday', 'PROPN'), ('as', 'SCONJ'), ('the', 'DET'), ('Dow', 'PROPN'), ('Jones', 'PROPN'), ('Industrial', 'PROPN'), ('Average', 'PROPN'), ('plunged', 'VERB'), ('190.58', 'NUM'), ('points', 'NOUN'), ('--', 'PUNCT'), ('most', 'ADJ'), ('of', 'ADP'), ('it', 'PRON'), ('in', 'ADP'), ('the', 'DET'), ('final', 'ADJ'), ('hour', 'NOUN'), ('--', 'PUNCT'), ('it', 'PRON'), ('

In [102]:
# transducer, word_embeddings, train_data = data[:3]
# dev_sents, dev_arcs = data[3:5]
# test_sents, test_arcs = data[5:]
# print(dev_sents[:1])
# print(dev_arcs[:1])
import json
list(islice(train_data,1))[0][0]

(array([[   0., 4002., 4002., ..., 4002., 4002., 4002.],
        [ 593.,    0., 4002., ..., 4002., 4002., 4002.],
        [1237.,  593.,    0., ..., 4002., 4002., 4002.],
        ...,
        [ 909., 3706., 3957., ..., 4002., 4002., 4002.],
        [ 909., 3957., 3162., ..., 4002., 4002., 4002.],
        [ 909., 3162., 1200., ..., 4002., 4002., 4002.]], dtype=float32),
 array([[ 0., 19., 19., ..., 19., 19., 19.],
        [ 2.,  0., 19., ..., 19., 19., 19.],
        [ 6.,  2.,  0., ..., 19., 19., 19.],
        ...,
        [12.,  6.,  2., ..., 19., 19., 19.],
        [12.,  2., 16., ..., 19., 19., 19.],
        [12., 16.,  8., ..., 19., 19., 19.]], dtype=float32),
 array([[41., 41., 41., ..., 41., 41., 41.],
        [41., 41., 41., ..., 41., 41., 41.],
        [41., 41., 41., ..., 41., 41., 41.],
        ...,
        [41., 41., 41., ..., 41., 41., 41.],
        [20., 41., 41., ..., 41., 41., 41.],
        [ 9., 20., 23., ..., 41., 41., 41.]], dtype=float32))