In [None]:
from combo_model import ComboModel
from util import convert_tokens, get_batch_dataset, get_dataset, get_record_parser
import tensorflow as tf
from config import flags
import numpy as np
import json

In [None]:
flags.DEFINE_string('f', 'give up already', 'who cares lol')
config = flags.FLAGS

In [None]:
test_eval_file = 'data/combo_v2_test_meta.json'
test_record_file = 'data/combo_v2_test.tf'

with open(test_eval_file, "r") as fh:
    eval_file = json.load(fh)

meta = {'total': 1382}

In [None]:
with open(config.word_emb_file, "r") as fh:
    word_mat = np.array(json.load(fh), dtype=np.float32)
with open(config.char_emb_file, "r") as fh:
    char_mat = np.array(json.load(fh), dtype=np.float32)

In [None]:
def get_binary_record_parser(config, is_test=True):
    def parse(example):
        para_limit = config.test_para_limit if is_test else config.para_limit
        ques_limit = config.test_ques_limit if is_test else config.ques_limit
        char_limit = config.char_limit
        features = tf.parse_single_example(example,
                                           features={
                                               "context_idxs": tf.FixedLenFeature([], tf.string),
                                               "ques_idxs": tf.FixedLenFeature([], tf.string),
                                               "context_char_idxs": tf.FixedLenFeature([], tf.string),
                                               "ques_char_idxs": tf.FixedLenFeature([], tf.string),
                                               "id": tf.FixedLenFeature([], tf.int64),
                                               "bad_y1": tf.FixedLenFeature([], tf.string),
                                               "bad_y2": tf.FixedLenFeature([], tf.string),
                                               "y1": tf.FixedLenFeature([], tf.string),
                                               "y2": tf.FixedLenFeature([], tf.string),
                                           })
        context_idxs = tf.reshape(tf.decode_raw(
            features["context_idxs"], tf.int32), [para_limit])
        ques_idxs = tf.reshape(tf.decode_raw(
            features["ques_idxs"], tf.int32), [ques_limit])
        context_char_idxs = tf.reshape(tf.decode_raw(
            features["context_char_idxs"], tf.int32), [para_limit, char_limit])
        ques_char_idxs = tf.reshape(tf.decode_raw(
            features["ques_char_idxs"], tf.int32), [ques_limit, char_limit])
        qa_id = features["id"]
        bad_y1 = tf.reshape(tf.decode_raw(
            features["bad_y1"], tf.float32), [para_limit])
        bad_y2 = tf.reshape(tf.decode_raw(
            features["bad_y2"], tf.float32), [para_limit])
        y1 = tf.reshape(tf.decode_raw(
            features["y1"], tf.float32), [para_limit])
        y2 = tf.reshape(tf.decode_raw(
            features["y2"], tf.float32), [para_limit])
        return context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, qa_id, bad_y1, bad_y2, y1, y2
    
    return parse

In [None]:
test_batch = get_dataset('data/combo_v2_test.tf', get_binary_record_parser(
        config, is_test=True), config).make_one_shot_iterator()

In [None]:
model = ComboModel(config, test_batch, word_mat, char_mat, trainable=False)

In [None]:
sess_config = tf.ConfigProto(allow_soft_placement=True)
sess_config.gpu_options.allow_growth = True

sess = tf.Session(config=sess_config)

In [None]:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(var_list=[v for v in tf.global_variables() if not v.name.startswith('encoding_1/') and not v.name.startswith('badptr') and '/Adam' not in v.name and 'beta' not in v.name])
saver2 = tf.train.Saver()

#### Restore trained bad pointer model!!
if True:
    saver2.restore(sess, tf.train.latest_checkpoint('log/combo_model'))
else:
    saver.restore(sess, tf.train.latest_checkpoint(config.save_dir))
    sess.run(model.assign_trick_ops)

sess.run(tf.assign(model.is_train, tf.constant(False, dtype=tf.bool)))

In [None]:
encoding_1_vars = sorted([(v.name, v) for v in tf.global_variables() if v.name.startswith('encoding_1/Variable')])
encoding_vars = sorted([(v.name, v) for v in tf.global_variables() if v.name.startswith('encoding/Variable')])

In [None]:
for (n1, e1v), (n, ev) in zip(encoding_1_vars, encoding_vars):
    assert n1.split('/')[1] == n.split('/')[1]
    e1v_val, ev_val = sess.run([e1v, ev])
    assert np.linalg.norm(e1v_val - ev_val) < 1e-10

In [None]:
total = meta['total']
print(total)

In [None]:
def convert_tokens(eval_file, qa_id, pp1, pp2, bad_pp1, bad_pp2):
    answer_dict = {}
    remapped_dict = {}
    for qid, p1, p2, bad_p1, bad_p2 in zip(qa_id, pp1, pp2, bad_pp1, bad_pp2):
        context = eval_file[str(qid)]["context"]
        spans = [z for z in eval_file[str(qid)]["spans"]]
        uuid = eval_file[str(qid)]["uuid"]
        start_idx = spans[p1][0]
        end_idx = spans[p2][1]
        
        if bad_p1 != 0 or bad_p2 != 0:
            for q in range(bad_p1, bad_p2+1):
                 del spans[bad_p1]
        
        start_idx = spans[p1][0]
        end_idx = spans[p2][1]
        
        answer_dict[str(qid)] = context[start_idx: end_idx]
        remapped_dict[uuid] = context[start_idx: end_idx]
    return answer_dict, remapped_dict

In [None]:
answer_dict = {}
remapped_dict = {}

for step in range(total // config.batch_size + 1):
    qa_id, loss, yp1, yp2, bad_yp1, bad_yp2 = sess.run([model.qa_id, model.loss, model.yp1, model.yp2,
                                                        model.bad_yp1, model.bad_yp2])
    answer_dict_, remapped_dict_ = convert_tokens(eval_file, qa_id.tolist(), yp1.tolist(), yp2.tolist(),
                                                  bad_yp1.tolist(), bad_yp2.tolist())
    answer_dict.update(answer_dict_)
    remapped_dict.update(remapped_dict_)

In [None]:
from util import metric_max_over_ground_truths, exact_match_score, f1_score

In [None]:
def evaluate(eval_file, answer_dict, only=None):
    f1 = exact_match = total = 0
    for key, value in answer_dict.items():
        if only == 'adv' and len(eval_file[key]['uuid'].split('-')) == 1:
            continue
        if only == 'orig' and len(eval_file[key]['uuid'].split('-')) > 1:
            continue
        total += 1
        ground_truths = eval_file[key]["answers"]
        prediction = value
        exact_match += metric_max_over_ground_truths(
            exact_match_score, prediction, ground_truths)
        f1 += metric_max_over_ground_truths(f1_score,
                                            prediction, ground_truths)
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total
    return {'exact_match': exact_match, 'f1': f1}

In [None]:
metrics = evaluate(eval_file, answer_dict, only='orig')
print("Unmutated data")
print("Exact Match: {}, F1: {}".format(metrics['exact_match'], metrics['f1']))

In [None]:
metrics = evaluate(eval_file, answer_dict, only='adv')
print("Mutated data")
print("Exact Match: {}, F1: {}".format(metrics['exact_match'], metrics['f1']))

In [None]:
metrics = evaluate(eval_file, answer_dict)
print("Overall")
print("Exact Match: {}, F1: {}".format(metrics['exact_match'], metrics['f1']))