In [1]:
from combo_model import ComboModel
from util import convert_tokens, get_batch_dataset, get_dataset, get_record_parser
import tensorflow as tf
from config import flags
import numpy as np
import json

In [2]:
flags.DEFINE_string('f', 'give up already', 'who cares lol')
config = flags.FLAGS

In [3]:
test_eval_file = 'data/combo_v2_test_meta.json'
test_record_file = 'data/combo_v2_test.tf'

with open(test_eval_file, "r") as fh:
    eval_file = json.load(fh)

meta = {'total': 1000}

In [4]:
with open(config.word_emb_file, "r") as fh:
    word_mat = np.array(json.load(fh), dtype=np.float32)
with open(config.char_emb_file, "r") as fh:
    char_mat = np.array(json.load(fh), dtype=np.float32)

In [5]:
def get_binary_record_parser(config, is_test=True):
    def parse(example):
        para_limit = config.test_para_limit if is_test else config.para_limit
        ques_limit = config.test_ques_limit if is_test else config.ques_limit
        char_limit = config.char_limit
        features = tf.parse_single_example(example,
                                           features={
                                               "context_idxs": tf.FixedLenFeature([], tf.string),
                                               "ques_idxs": tf.FixedLenFeature([], tf.string),
                                               "context_char_idxs": tf.FixedLenFeature([], tf.string),
                                               "ques_char_idxs": tf.FixedLenFeature([], tf.string),
                                               "id": tf.FixedLenFeature([], tf.int64),
                                               "bad_y1": tf.FixedLenFeature([], tf.string),
                                               "bad_y2": tf.FixedLenFeature([], tf.string),
                                               "y1": tf.FixedLenFeature([], tf.string),
                                               "y2": tf.FixedLenFeature([], tf.string),
                                           })
        context_idxs = tf.reshape(tf.decode_raw(
            features["context_idxs"], tf.int32), [para_limit])
        ques_idxs = tf.reshape(tf.decode_raw(
            features["ques_idxs"], tf.int32), [ques_limit])
        context_char_idxs = tf.reshape(tf.decode_raw(
            features["context_char_idxs"], tf.int32), [para_limit, char_limit])
        ques_char_idxs = tf.reshape(tf.decode_raw(
            features["ques_char_idxs"], tf.int32), [ques_limit, char_limit])
        qa_id = features["id"]
        bad_y1 = tf.reshape(tf.decode_raw(
            features["bad_y1"], tf.float32), [para_limit])
        bad_y2 = tf.reshape(tf.decode_raw(
            features["bad_y2"], tf.float32), [para_limit])
        y1 = tf.reshape(tf.decode_raw(
            features["y1"], tf.float32), [para_limit])
        y2 = tf.reshape(tf.decode_raw(
            features["y2"], tf.float32), [para_limit])
        return context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, qa_id, bad_y1, bad_y2, y1, y2
    
    return parse

In [6]:
parser = get_binary_record_parser(config)
train_dataset = get_batch_dataset('data/combo_v2_train.tf', parser, config)
dev_dataset = get_dataset('data/combo_v2_test.tf', parser, config)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, train_dataset.output_types, train_dataset.output_shapes)
train_iterator = train_dataset.make_one_shot_iterator()
dev_iterator = dev_dataset.make_one_shot_iterator()

In [7]:
model = ComboModel(config, iterator, word_mat, char_mat)

Instructions for updating:
Use the retry module or similar alternatives.
RUN ASSIGN TRICK OPS (model.assign_trick_ops)!!


In [8]:
sess_config = tf.ConfigProto(allow_soft_placement=True)
sess_config.gpu_options.allow_growth = True

sess = tf.Session(config=sess_config)

In [14]:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(var_list=[v for v in tf.global_variables() if not v.name.startswith('encoding_1/') and not v.name.startswith('badptr') and '/Adam' not in v.name and 'beta' not in v.name])
print('Restoring', tf.train.latest_checkpoint(config.save_dir))
saver.restore(sess, tf.train.latest_checkpoint(config.save_dir))
sess.run(model.assign_trick_ops)
sess.run(tf.assign(model.is_train, tf.constant(True, dtype=tf.bool)))

Restoring log/model/model_60000.ckpt
INFO:tensorflow:Restoring parameters from log/model/model_60000.ckpt


True

In [15]:
encoding_1_vars = sorted([(v.name, v) for v in tf.global_variables() if v.name.startswith('encoding_1/Variable')])
encoding_vars = sorted([(v.name, v) for v in tf.global_variables() if v.name.startswith('encoding/Variable')])

In [16]:
for (n1, e1v), (n, ev) in zip(encoding_1_vars, encoding_vars):
    assert n1.split('/')[1] == n.split('/')[1]
    e1v_val, ev_val = sess.run([e1v, ev])
    assert np.linalg.norm(e1v_val - ev_val) < 1e-10

In [17]:
best_acc = 0
saver2 = tf.train.Saver()
train_handle = sess.run(train_iterator.string_handle())
dev_handle = sess.run(dev_iterator.string_handle())

In [18]:
for _ in range(1, 5001):
    loss, train_op = sess.run([model.loss, model.train_op], feed_dict={
                              handle: train_handle})
    
    if _ % 10 == 0:
        print('After', _, 'iterations:')
        print('Batch Loss:', np.mean(loss))
    
    if _ % 100 == 0:
        acc = 0

        for __ in range(22):
            pred1, target1, pred2, target2 = sess.run([tf.argmax(model.bad_y1, 1), model.bad_yp1,
                                                       tf.argmax(model.bad_y2, 1), model.bad_yp2], feed_dict={handle: dev_handle})
            acc += np.mean(np.logical_and(pred1 == target1, pred2 == target2))

        acc /= float(22)
        print('Dev Accuracy:', acc)
        
        if acc > best_acc:
            best_acc = acc
            saver2.save(sess, 'log/combo_model/badptr-savepoint', global_step=_)
    
    if _ % 10 == 0:
        print()

After 10 iterations:
Batch Loss: 5.4161797

After 20 iterations:
Batch Loss: 4.7621365

After 30 iterations:
Batch Loss: 4.166677

After 40 iterations:
Batch Loss: 3.2321475

After 50 iterations:
Batch Loss: 3.3673484

After 60 iterations:
Batch Loss: 3.1178713

After 70 iterations:
Batch Loss: 2.7015762

After 80 iterations:
Batch Loss: 2.2041872

After 90 iterations:
Batch Loss: 1.7988122

After 100 iterations:
Batch Loss: 1.7977811
Dev Accuracy: 0.6981534090909091

After 110 iterations:
Batch Loss: 1.1279154

After 120 iterations:
Batch Loss: 1.5531837

After 130 iterations:
Batch Loss: 1.552281

After 140 iterations:
Batch Loss: 0.748489

After 150 iterations:
Batch Loss: 0.92454183

After 160 iterations:
Batch Loss: 1.0439746

After 170 iterations:
Batch Loss: 1.0651668

After 180 iterations:
Batch Loss: 0.39612645

After 190 iterations:
Batch Loss: 0.6495237

After 200 iterations:
Batch Loss: 1.0986636
Dev Accuracy: 0.7947443181818182

After 210 iterations:
Batch Loss: 0.43030694

After 1670 iterations:
Batch Loss: 0.007527422

After 1680 iterations:
Batch Loss: 0.014398581

After 1690 iterations:
Batch Loss: 0.06132517

After 1700 iterations:
Batch Loss: 0.09674517
Dev Accuracy: 0.8501420454545454

After 1710 iterations:
Batch Loss: 0.0043999436

After 1720 iterations:
Batch Loss: 0.096353814

After 1730 iterations:
Batch Loss: 0.13751595

After 1740 iterations:
Batch Loss: 0.066489205

After 1750 iterations:
Batch Loss: 0.003518252

After 1760 iterations:
Batch Loss: 0.019687088

After 1770 iterations:
Batch Loss: 0.027850691

After 1780 iterations:
Batch Loss: 0.004118139

After 1790 iterations:
Batch Loss: 0.11529322

After 1800 iterations:
Batch Loss: 0.0056443596
Dev Accuracy: 0.8544034090909091

After 1810 iterations:
Batch Loss: 0.0063931714

After 1820 iterations:
Batch Loss: 0.012171693

After 1830 iterations:
Batch Loss: 0.0020232403

After 1840 iterations:
Batch Loss: 0.014215855

After 1850 iterations:
Batch Loss: 0.0018027266

After 1860 iterations

After 3270 iterations:
Batch Loss: 0.0017322395

After 3280 iterations:
Batch Loss: 0.0044734683

After 3290 iterations:
Batch Loss: 0.022527684

After 3300 iterations:
Batch Loss: 0.009697869
Dev Accuracy: 0.8622159090909091

After 3310 iterations:
Batch Loss: 0.0018611788

After 3320 iterations:
Batch Loss: 0.002323562

After 3330 iterations:
Batch Loss: 0.01726541

After 3340 iterations:
Batch Loss: 0.005984817

After 3350 iterations:
Batch Loss: 0.004765146

After 3360 iterations:
Batch Loss: 0.0020371901

After 3370 iterations:
Batch Loss: 0.0045574983

After 3380 iterations:
Batch Loss: 0.0074246144

After 3390 iterations:
Batch Loss: 0.007856004

After 3400 iterations:
Batch Loss: 0.07888506
Dev Accuracy: 0.8494318181818182

After 3410 iterations:
Batch Loss: 0.019822482

After 3420 iterations:
Batch Loss: 0.030942213

After 3430 iterations:
Batch Loss: 0.011313289

After 3440 iterations:
Batch Loss: 0.0077142944

After 3450 iterations:
Batch Loss: 0.17415063

After 3460 iterati

After 4870 iterations:
Batch Loss: 0.0018712685

After 4880 iterations:
Batch Loss: 0.018692227

After 4890 iterations:
Batch Loss: 0.0019537294

After 4900 iterations:
Batch Loss: 0.009711854
Dev Accuracy: 0.8643465909090909

After 4910 iterations:
Batch Loss: 0.030263472

After 4920 iterations:
Batch Loss: 0.05143156

After 4930 iterations:
Batch Loss: 0.22204125

After 4940 iterations:
Batch Loss: 0.0031309137

After 4950 iterations:
Batch Loss: 0.14926158

After 4960 iterations:
Batch Loss: 0.0063579474

After 4970 iterations:
Batch Loss: 0.0025264667

After 4980 iterations:
Batch Loss: 0.010757651

After 4990 iterations:
Batch Loss: 0.028488707

After 5000 iterations:
Batch Loss: 0.007413002
Dev Accuracy: 0.8643465909090909



In [19]:
saver2.restore(sess, tf.train.latest_checkpoint('log/combo_model'))

INFO:tensorflow:Restoring parameters from log/combo_model/badptr-savepoint-4400


In [20]:
encoding_1_vars = sorted([(v.name, v) for v in tf.global_variables() if v.name.startswith('encoding_1/Variable')])
encoding_vars = sorted([(v.name, v) for v in tf.global_variables() if v.name.startswith('encoding/Variable')])

In [21]:
for (n1, e1v), (n, ev) in zip(encoding_1_vars, encoding_vars):
    assert n1.split('/')[1] == n.split('/')[1]
    e1v_val, ev_val = sess.run([e1v, ev])
    assert np.linalg.norm(e1v_val - ev_val) < 1e-10