In [1]:
from badptr_model import BadptrModel
from util import convert_tokens, evaluate, get_batch_dataset, get_dataset
import tensorflow as tf
from config import flags
import numpy as np
import json

In [3]:
def get_binary_record_parser(config, is_test=True):
    def parse(example):
        para_limit = config.test_para_limit if is_test else config.para_limit
        ques_limit = config.test_ques_limit if is_test else config.ques_limit
        char_limit = config.char_limit
        features = tf.parse_single_example(example,
                                           features={
                                               "context_idxs": tf.FixedLenFeature([], tf.string),
                                               "ques_idxs": tf.FixedLenFeature([], tf.string),
                                               "context_char_idxs": tf.FixedLenFeature([], tf.string),
                                               "ques_char_idxs": tf.FixedLenFeature([], tf.string),
                                               "id": tf.FixedLenFeature([], tf.int64),
                                               "y1": tf.FixedLenFeature([], tf.string),
                                               "y2": tf.FixedLenFeature([], tf.string),
                                           })
        context_idxs = tf.reshape(tf.decode_raw(
            features["context_idxs"], tf.int32), [para_limit])
        ques_idxs = tf.reshape(tf.decode_raw(
            features["ques_idxs"], tf.int32), [ques_limit])
        context_char_idxs = tf.reshape(tf.decode_raw(
            features["context_char_idxs"], tf.int32), [para_limit, char_limit])
        ques_char_idxs = tf.reshape(tf.decode_raw(
            features["ques_char_idxs"], tf.int32), [ques_limit, char_limit])
        qa_id = features["id"]
        y1 = tf.reshape(tf.decode_raw(
            features["y1"], tf.float32), [para_limit])
        y2 = tf.reshape(tf.decode_raw(
            features["y2"], tf.float32), [para_limit])
        return context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, qa_id, y1, y2
    
    return parse

In [4]:
flags.DEFINE_string('f', 'give up already', 'who cares lol')
config = flags.FLAGS

In [5]:
parser = get_binary_record_parser(config)
train_dataset = get_batch_dataset('data/badptr_train.tf', parser, config)
dev_dataset = get_dataset('data/badptr_test.tf', parser, config)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, train_dataset.output_types, train_dataset.output_shapes)
train_iterator = train_dataset.make_one_shot_iterator()
dev_iterator = dev_dataset.make_one_shot_iterator()

In [6]:
with open(config.word_emb_file, "r") as fh:
    word_mat = np.array(json.load(fh), dtype=np.float32)
with open(config.char_emb_file, "r") as fh:
    char_mat = np.array(json.load(fh), dtype=np.float32)

In [7]:
model = BadptrModel(config, iterator, word_mat, char_mat)

Instructions for updating:
Use the retry module or similar alternatives.


In [8]:
sess_config = tf.ConfigProto(allow_soft_placement=True)
sess_config.gpu_options.allow_growth = True

sess = tf.Session(config=sess_config)

In [9]:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(var_list=[v for v in tf.global_variables() if not v.name.startswith('badptr') and not v.name.startswith('binary/') and '/Adam' not in v.name and 'beta' not in v.name])
print('Restoring', tf.train.latest_checkpoint(config.save_dir))
saver.restore(sess, tf.train.latest_checkpoint(config.save_dir))

Restoring log/model/model_60000.ckpt
INFO:tensorflow:Restoring parameters from log/model/model_60000.ckpt


In [10]:
best_acc = 0
saver2 = tf.train.Saver()

In [11]:
train_handle = sess.run(train_iterator.string_handle())
dev_handle = sess.run(dev_iterator.string_handle())

In [12]:
for _ in range(1, 5001):
    loss, train_op = sess.run([model.loss, model.train_op], feed_dict={
                              handle: train_handle})
    
    if _ % 10 == 0:
        print('After', _, 'iterations:')
        print('Batch Loss:', np.mean(loss))
    
    if _ % 100 == 0:
        acc = 0

        for __ in range(22):
            pred1, target1, pred2, target2 = sess.run([tf.argmax(model.y1, 1), model.yp1,
                                                       tf.argmax(model.y2, 1), model.yp2], feed_dict={handle: dev_handle})
            acc += np.mean(np.logical_and(pred1 == target1, pred2 == target2))

        acc /= float(22)
        print('Dev Accuracy:', acc)
        
        if acc > best_acc:
            best_acc = acc
            saver2.save(sess, 'log/badptr_model/badptr-savepoint', global_step=_)
    
    if _ % 10 == 0:
        print()

After 10 iterations:
Batch Loss: 5.922181

After 20 iterations:
Batch Loss: 4.1139574

After 30 iterations:
Batch Loss: 2.996765

After 40 iterations:
Batch Loss: 1.9655526

After 50 iterations:
Batch Loss: 1.6397579

After 60 iterations:
Batch Loss: 0.8954891

After 70 iterations:
Batch Loss: 0.8677113

After 80 iterations:
Batch Loss: 0.715143

After 90 iterations:
Batch Loss: 1.017957

After 100 iterations:
Batch Loss: 0.47058624
Dev Accuracy: 0.8693181818181818

After 110 iterations:
Batch Loss: 0.53084517

After 120 iterations:
Batch Loss: 0.4717941

After 130 iterations:
Batch Loss: 0.29515862

After 140 iterations:
Batch Loss: 0.21364745

After 150 iterations:
Batch Loss: 0.23698491

After 160 iterations:
Batch Loss: 0.22215614

After 170 iterations:
Batch Loss: 0.060447637

After 180 iterations:
Batch Loss: 0.17908046

After 190 iterations:
Batch Loss: 0.13392994

After 200 iterations:
Batch Loss: 0.29200798
Dev Accuracy: 0.8821022727272727

After 210 iterations:
Batch Loss: 0.

After 1650 iterations:
Batch Loss: 0.020501904

After 1660 iterations:
Batch Loss: 0.0028324784

After 1670 iterations:
Batch Loss: 0.01577081

After 1680 iterations:
Batch Loss: 0.028523773

After 1690 iterations:
Batch Loss: 0.006343873

After 1700 iterations:
Batch Loss: 0.0038378893
Dev Accuracy: 0.8941761363636364

After 1710 iterations:
Batch Loss: 0.0049796714

After 1720 iterations:
Batch Loss: 0.024999402

After 1730 iterations:
Batch Loss: 0.004237912

After 1740 iterations:
Batch Loss: 0.011099955

After 1750 iterations:
Batch Loss: 0.007255056

After 1760 iterations:
Batch Loss: 0.0012623628

After 1770 iterations:
Batch Loss: 0.001203568

After 1780 iterations:
Batch Loss: 0.0064207935

After 1790 iterations:
Batch Loss: 0.0044747023

After 1800 iterations:
Batch Loss: 0.0010088858
Dev Accuracy: 0.9055397727272727

After 1810 iterations:
Batch Loss: 0.00417831

After 1820 iterations:
Batch Loss: 0.002375898

After 1830 iterations:
Batch Loss: 0.0051789056

After 1840 itera

After 3250 iterations:
Batch Loss: 0.0048722234

After 3260 iterations:
Batch Loss: 0.0060374653

After 3270 iterations:
Batch Loss: 0.0029317725

After 3280 iterations:
Batch Loss: 0.0007659777

After 3290 iterations:
Batch Loss: 0.009216823

After 3300 iterations:
Batch Loss: 0.016468376
Dev Accuracy: 0.8828125

After 3310 iterations:
Batch Loss: 0.008283488

After 3320 iterations:
Batch Loss: 0.0429844

After 3330 iterations:
Batch Loss: 0.014982239

After 3340 iterations:
Batch Loss: 0.009764947

After 3350 iterations:
Batch Loss: 0.0037746562

After 3360 iterations:
Batch Loss: 0.0043971604

After 3370 iterations:
Batch Loss: 0.01259387

After 3380 iterations:
Batch Loss: 0.078963056

After 3390 iterations:
Batch Loss: 0.011518665

After 3400 iterations:
Batch Loss: 0.005708976
Dev Accuracy: 0.8785511363636364

After 3410 iterations:
Batch Loss: 0.0057197446

After 3420 iterations:
Batch Loss: 0.028168498

After 3430 iterations:
Batch Loss: 0.003565456

After 3440 iterations:
Batc

After 4850 iterations:
Batch Loss: 0.0011547569

After 4860 iterations:
Batch Loss: 0.0008485499

After 4870 iterations:
Batch Loss: 0.00075517676

After 4880 iterations:
Batch Loss: 0.0020396246

After 4890 iterations:
Batch Loss: 0.0023138113

After 4900 iterations:
Batch Loss: 0.0026717917
Dev Accuracy: 0.9055397727272727

After 4910 iterations:
Batch Loss: 0.0022578223

After 4920 iterations:
Batch Loss: 0.114492275

After 4930 iterations:
Batch Loss: 0.04574908

After 4940 iterations:
Batch Loss: 0.0053581307

After 4950 iterations:
Batch Loss: 0.0041774455

After 4960 iterations:
Batch Loss: 0.009381839

After 4970 iterations:
Batch Loss: 0.0045275204

After 4980 iterations:
Batch Loss: 0.013010641

After 4990 iterations:
Batch Loss: 0.020453893

After 5000 iterations:
Batch Loss: 0.00083747215
Dev Accuracy: 0.8955965909090909

