In [1]:
from binary_model import BinaryModel
from util import convert_tokens, evaluate, get_batch_dataset, get_dataset
import tensorflow as tf
from config import flags
import numpy as np
import json

In [2]:
def get_binary_record_parser(config, is_test=True):
    def parse(example):
        para_limit = config.test_para_limit if is_test else config.para_limit
        ques_limit = config.test_ques_limit if is_test else config.ques_limit
        char_limit = config.char_limit
        features = tf.parse_single_example(example,
                                           features={
                                               "context_idxs": tf.FixedLenFeature([], tf.string),
                                               "ques_idxs": tf.FixedLenFeature([], tf.string),
                                               "context_char_idxs": tf.FixedLenFeature([], tf.string),
                                               "ques_char_idxs": tf.FixedLenFeature([], tf.string),
                                               "id": tf.FixedLenFeature([], tf.int64),
                                               "tag": tf.FixedLenFeature([], tf.string)
                                           })
        context_idxs = tf.reshape(tf.decode_raw(
            features["context_idxs"], tf.int32), [para_limit])
        ques_idxs = tf.reshape(tf.decode_raw(
            features["ques_idxs"], tf.int32), [ques_limit])
        context_char_idxs = tf.reshape(tf.decode_raw(
            features["context_char_idxs"], tf.int32), [para_limit, char_limit])
        ques_char_idxs = tf.reshape(tf.decode_raw(
            features["ques_char_idxs"], tf.int32), [ques_limit, char_limit])
        qa_id = features["id"]
        tag = tf.reshape(tf.decode_raw(
            features["tag"], tf.int32), [2])
        return context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, qa_id, tag
    
    return parse

In [3]:
flags.DEFINE_string('f', 'give up already', 'who cares lol')
config = flags.FLAGS

In [4]:
parser = get_binary_record_parser(config)
train_dataset = get_batch_dataset('data/binary_train.tf', parser, config)
dev_dataset = get_dataset('data/binary_test.tf', parser, config)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, train_dataset.output_types, train_dataset.output_shapes)
train_iterator = train_dataset.make_one_shot_iterator()
dev_iterator = dev_dataset.make_one_shot_iterator()

In [5]:
with open(config.word_emb_file, "r") as fh:
    word_mat = np.array(json.load(fh), dtype=np.float32)
with open(config.char_emb_file, "r") as fh:
    char_mat = np.array(json.load(fh), dtype=np.float32)

In [6]:
model = BinaryModel(config, iterator, word_mat, char_mat)

Instructions for updating:
Use the retry module or similar alternatives.


In [7]:
sess_config = tf.ConfigProto(allow_soft_placement=True)
sess_config.gpu_options.allow_growth = True

sess = tf.Session(config=sess_config)

In [8]:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(var_list=[v for v in tf.global_variables() if not v.name.startswith('binary/') and '/Adam' not in v.name and 'beta' not in v.name])
print('Restoring', tf.train.latest_checkpoint(config.save_dir))
saver.restore(sess, tf.train.latest_checkpoint(config.save_dir))

Restoring log/model/model_60000.ckpt
INFO:tensorflow:Restoring parameters from log/model/model_60000.ckpt


In [9]:
saver2 = tf.train.Saver()
best_acc = 0

In [10]:
train_handle = sess.run(train_iterator.string_handle())
dev_handle = sess.run(dev_iterator.string_handle())

In [11]:
for _ in range(1, 5000):
    loss, train_op = sess.run([model.loss, model.train_op], feed_dict={
                              handle: train_handle})
    
    if _ % 10 == 0:
        print('After', _, 'iterations:')
        print('Batch Loss:', np.mean(loss))
    
    if _ % 100 == 0:
        acc = 0

        for __ in range(22):
            pred, target = sess.run([model.prediction, model.y_target], feed_dict={
                                      handle: dev_handle})

            acc += np.mean(pred.argmax(1) == target.argmax(1))

        acc /= float(22)
        print('Dev Accuracy:', acc)
        
        if acc > best_acc:
            best_acc = acc
            saver2.save(sess, 'log/binary_model/badptr-savepoint', global_step=_)
    
    if _ % 10 == 0:
        print()

After 10 iterations:
Batch Loss: 0.69930375

After 20 iterations:
Batch Loss: 0.57432306

After 30 iterations:
Batch Loss: 0.56632173

After 40 iterations:
Batch Loss: 0.56910896

After 50 iterations:
Batch Loss: 0.5215808

After 60 iterations:
Batch Loss: 0.5983915

After 70 iterations:
Batch Loss: 0.39526314

After 80 iterations:
Batch Loss: 0.6396444

After 90 iterations:
Batch Loss: 0.6262356

After 100 iterations:
Batch Loss: 0.55743444
Dev Accuracy: 0.7144886363636364

After 110 iterations:
Batch Loss: 0.6496694

After 120 iterations:
Batch Loss: 0.50289553

After 130 iterations:
Batch Loss: 0.6797002

After 140 iterations:
Batch Loss: 0.50453573

After 150 iterations:
Batch Loss: 0.529658

After 160 iterations:
Batch Loss: 0.61568165

After 170 iterations:
Batch Loss: 0.43990687

After 180 iterations:
Batch Loss: 0.4012978

After 190 iterations:
Batch Loss: 0.38644642

After 200 iterations:
Batch Loss: 0.4575715
Dev Accuracy: 0.8394886363636364

After 210 iterations:
Batch Loss:

After 1660 iterations:
Batch Loss: 0.01968413

After 1670 iterations:
Batch Loss: 0.19608727

After 1680 iterations:
Batch Loss: 0.00017361026

After 1690 iterations:
Batch Loss: 0.006705144

After 1700 iterations:
Batch Loss: 0.0005890894
Dev Accuracy: 0.9346590909090909

After 1710 iterations:
Batch Loss: 0.061650183

After 1720 iterations:
Batch Loss: 0.077293135

After 1730 iterations:
Batch Loss: 0.059170477

After 1740 iterations:
Batch Loss: 0.00127084

After 1750 iterations:
Batch Loss: 0.014338209

After 1760 iterations:
Batch Loss: 0.009637867

After 1770 iterations:
Batch Loss: 0.00013823622

After 1780 iterations:
Batch Loss: 0.034501065

After 1790 iterations:
Batch Loss: 0.038956657

After 1800 iterations:
Batch Loss: 0.00013726926
Dev Accuracy: 0.9275568181818182

After 1810 iterations:
Batch Loss: 0.00017291743

After 1820 iterations:
Batch Loss: 0.105918825

After 1830 iterations:
Batch Loss: 7.014304e-05

After 1840 iterations:
Batch Loss: 0.00046054157

After 1850 it

After 3240 iterations:
Batch Loss: 0.0037023588

After 3250 iterations:
Batch Loss: 3.1847732e-05

After 3260 iterations:
Batch Loss: 0.1408601

After 3270 iterations:
Batch Loss: 2.898199e-06

After 3280 iterations:
Batch Loss: 2.174224e-05

After 3290 iterations:
Batch Loss: 7.1889735e-05

After 3300 iterations:
Batch Loss: 0.03168317
Dev Accuracy: 0.9360795454545454

After 3310 iterations:
Batch Loss: 1.4805324e-05

After 3320 iterations:
Batch Loss: 0.011880631

After 3330 iterations:
Batch Loss: 0.0020959254

After 3340 iterations:
Batch Loss: 0.05242681

After 3350 iterations:
Batch Loss: 0.10787136

After 3360 iterations:
Batch Loss: 0.0032008772

After 3370 iterations:
Batch Loss: 0.00011331692

After 3380 iterations:
Batch Loss: 2.000685e-05

After 3390 iterations:
Batch Loss: 0.24128948

After 3400 iterations:
Batch Loss: 5.4836266e-05
Dev Accuracy: 0.9289772727272727

After 3410 iterations:
Batch Loss: 0.044142112

After 3420 iterations:
Batch Loss: 0.00013812873

After 3430

After 4810 iterations:
Batch Loss: 7.040725e-07

After 4820 iterations:
Batch Loss: 0.017840568

After 4830 iterations:
Batch Loss: 2.190364e-06

After 4840 iterations:
Batch Loss: 9.5073365e-06

After 4850 iterations:
Batch Loss: 0.0046164035

After 4860 iterations:
Batch Loss: 0.0060909004

After 4870 iterations:
Batch Loss: 3.2555923e-05

After 4880 iterations:
Batch Loss: 1.1604226e-06

After 4890 iterations:
Batch Loss: 0.00040642644

After 4900 iterations:
Batch Loss: 3.4811055e-06
Dev Accuracy: 0.9339488636363636

After 4910 iterations:
Batch Loss: 3.909977e-05

After 4920 iterations:
Batch Loss: 0.0010017933

After 4930 iterations:
Batch Loss: 3.0440851e-05

After 4940 iterations:
Batch Loss: 1.1795339e-05

After 4950 iterations:
Batch Loss: 0.25513867

After 4960 iterations:
Batch Loss: 4.2054162e-05

After 4970 iterations:
Batch Loss: 6.9564057e-06

After 4980 iterations:
Batch Loss: 2.3524544e-06

After 4990 iterations:
Batch Loss: 2.7130005e-05



In [12]:
acc = 0

for __ in range(22):
    pred, target = sess.run([model.prediction, model.y_target], feed_dict={
                              handle: dev_handle})

    acc += np.mean(pred.argmax(1) == target.argmax(1))

acc /= float(22)
print('Dev Accuracy:', acc)

if acc > best_acc:
    best_acc = acc
    saver2.save(sess, 'log/binary_model/badptr-savepoint', global_step=_)

Dev Accuracy: 0.9339488636363636


In [None]:
acc = 0

for __ in range(22):
    pred, target = sess.run([model.prediction, model.y_target], feed_dict={
                              handle: dev_handle})

    acc += np.mean(pred.argmax(1) == target.argmax(1))

acc /= float(22)
print('Dev Accuracy:', acc)

if acc > best_acc:
    best_acc = acc
    saver2.save(sess, 'log/binary_model/badptr-savepoint', global_step=5000)

In [None]:
best_acc