In [None]:
from binary_model import BinaryModel
from util import convert_tokens, evaluate, get_batch_dataset, get_dataset
import tensorflow as tf
from config import flags
import numpy as np
import json

In [None]:
def get_binary_record_parser(config, is_test=True):
    def parse(example):
        para_limit = config.test_para_limit if is_test else config.para_limit
        ques_limit = config.test_ques_limit if is_test else config.ques_limit
        char_limit = config.char_limit
        features = tf.parse_single_example(example,
                                           features={
                                               "context_idxs": tf.FixedLenFeature([], tf.string),
                                               "ques_idxs": tf.FixedLenFeature([], tf.string),
                                               "context_char_idxs": tf.FixedLenFeature([], tf.string),
                                               "ques_char_idxs": tf.FixedLenFeature([], tf.string),
                                               "id": tf.FixedLenFeature([], tf.int64),
                                               "tag": tf.FixedLenFeature([], tf.string)
                                           })
        context_idxs = tf.reshape(tf.decode_raw(
            features["context_idxs"], tf.int32), [para_limit])
        ques_idxs = tf.reshape(tf.decode_raw(
            features["ques_idxs"], tf.int32), [ques_limit])
        context_char_idxs = tf.reshape(tf.decode_raw(
            features["context_char_idxs"], tf.int32), [para_limit, char_limit])
        ques_char_idxs = tf.reshape(tf.decode_raw(
            features["ques_char_idxs"], tf.int32), [ques_limit, char_limit])
        qa_id = features["id"]
        tag = tf.reshape(tf.decode_raw(
            features["tag"], tf.int32), [2])
        return context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, qa_id, tag
    
    return parse

In [None]:
flags.DEFINE_string('f', 'give up already', 'who cares lol')
config = flags.FLAGS

In [None]:
parser = get_binary_record_parser(config)
train_dataset = get_batch_dataset('data/binary_train.tf', parser, config)
dev_dataset = get_dataset('data/binary_test.tf', parser, config)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, train_dataset.output_types, train_dataset.output_shapes)
train_iterator = train_dataset.make_one_shot_iterator()
dev_iterator = dev_dataset.make_one_shot_iterator()

In [None]:
with open(config.word_emb_file, "r") as fh:
    word_mat = np.array(json.load(fh), dtype=np.float32)
with open(config.char_emb_file, "r") as fh:
    char_mat = np.array(json.load(fh), dtype=np.float32)

In [None]:
model = BinaryModel(config, iterator, word_mat, char_mat)

In [None]:
sess_config = tf.ConfigProto(allow_soft_placement=True)
sess_config.gpu_options.allow_growth = True

sess = tf.Session(config=sess_config)

In [None]:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(var_list=[v for v in tf.global_variables() if not v.name.startswith('binary/') and '/Adam' not in v.name and 'beta' not in v.name])
print('Restoring', tf.train.latest_checkpoint(config.save_dir))
saver.restore(sess, tf.train.latest_checkpoint(config.save_dir))

In [None]:
saver2 = tf.train.Saver()
best_acc = 0

In [None]:
train_handle = sess.run(train_iterator.string_handle())
dev_handle = sess.run(dev_iterator.string_handle())

In [None]:
for _ in range(1, 5000):
    loss, train_op = sess.run([model.loss, model.train_op], feed_dict={
                              handle: train_handle})
    
    if _ % 10 == 0:
        print('After', _, 'iterations:')
        print('Batch Loss:', np.mean(loss))
    
    if _ % 100 == 0:
        acc = 0

        for __ in range(22):
            pred, target = sess.run([model.prediction, model.y_target], feed_dict={
                                      handle: dev_handle})

            acc += np.mean(pred.argmax(1) == target.argmax(1))

        acc /= float(22)
        print('Dev Accuracy:', acc)
        
        if acc > best_acc:
            best_acc = acc
            saver2.save(sess, 'log/binary_model/badptr-savepoint', global_step=_)
    
    if _ % 10 == 0:
        print()

In [None]:
acc = 0

for __ in range(22):
    pred, target = sess.run([model.prediction, model.y_target], feed_dict={
                              handle: dev_handle})

    acc += np.mean(pred.argmax(1) == target.argmax(1))

acc /= float(22)
print('Dev Accuracy:', acc)

if acc > best_acc:
    best_acc = acc
    saver2.save(sess, 'log/binary_model/badptr-savepoint', global_step=_)