In [1]:
from badptr_model import BadptrModel
from util import convert_tokens, get_batch_dataset, get_dataset, get_record_parser
import tensorflow as tf
from config import flags
import numpy as np
import json

In [2]:
flags.DEFINE_string('f', 'give up already', 'who cares lol')
config = flags.FLAGS

In [16]:
test_eval_file = 'data/badptr_test_meta.json'
test_record_file = 'data/badptr_test.tf'

with open(test_eval_file, "r") as fh:
    eval_file = json.load(fh)

meta = {'total': 1382}

In [4]:
with open(config.word_emb_file, "r") as fh:
    word_mat = np.array(json.load(fh), dtype=np.float32)
with open(config.char_emb_file, "r") as fh:
    char_mat = np.array(json.load(fh), dtype=np.float32)

In [5]:
def get_binary_record_parser(config, is_test=True):
    def parse(example):
        para_limit = config.test_para_limit if is_test else config.para_limit
        ques_limit = config.test_ques_limit if is_test else config.ques_limit
        char_limit = config.char_limit
        features = tf.parse_single_example(example,
                                           features={
                                               "context_idxs": tf.FixedLenFeature([], tf.string),
                                               "ques_idxs": tf.FixedLenFeature([], tf.string),
                                               "context_char_idxs": tf.FixedLenFeature([], tf.string),
                                               "ques_char_idxs": tf.FixedLenFeature([], tf.string),
                                               "id": tf.FixedLenFeature([], tf.int64),
                                               "y1": tf.FixedLenFeature([], tf.string),
                                               "y2": tf.FixedLenFeature([], tf.string),
                                           })
        context_idxs = tf.reshape(tf.decode_raw(
            features["context_idxs"], tf.int32), [para_limit])
        ques_idxs = tf.reshape(tf.decode_raw(
            features["ques_idxs"], tf.int32), [ques_limit])
        context_char_idxs = tf.reshape(tf.decode_raw(
            features["context_char_idxs"], tf.int32), [para_limit, char_limit])
        ques_char_idxs = tf.reshape(tf.decode_raw(
            features["ques_char_idxs"], tf.int32), [ques_limit, char_limit])
        qa_id = features["id"]
        y1 = tf.reshape(tf.decode_raw(
            features["y1"], tf.float32), [para_limit])
        y2 = tf.reshape(tf.decode_raw(
            features["y2"], tf.float32), [para_limit])
        return context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, qa_id, y1, y2
    
    return parse

In [6]:
test_batch = get_dataset(test_record_file, get_binary_record_parser(
        config, is_test=True), config).make_one_shot_iterator()

In [7]:
model = BadptrModel(config, test_batch, word_mat, char_mat, trainable=False)

Instructions for updating:
Use the retry module or similar alternatives.


In [8]:
sess_config = tf.ConfigProto(allow_soft_placement=True)
sess_config.gpu_options.allow_growth = True

sess = tf.Session(config=sess_config)

In [None]:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint(config.save_dir + '/../badptr_model'))
sess.run(tf.assign(model.is_train, tf.constant(False, dtype=tf.bool)))

In [22]:
total = meta['total']
print(total)

1382


In [23]:
answer_dict = {}
remapped_dict = {}

for step in range(total // config.batch_size + 1):
    qa_id, loss, yp1, yp2 = sess.run([model.qa_id, model.loss, model.yp1, model.yp2])
    answer_dict_, remapped_dict_ = convert_tokens(eval_file, qa_id.tolist(), yp1.tolist(), yp2.tolist())
    answer_dict.update(answer_dict_)
    remapped_dict.update(remapped_dict_)

In [24]:
from util import metric_max_over_ground_truths, exact_match_score, f1_score

In [25]:
def evaluate(eval_file, answer_dict, only=None):
    f1 = exact_match = total = 0
    for key, value in answer_dict.items():
        if only == 'adv' and len(eval_file[key]['uuid'].split('-')) == 1:
            continue
        if only == 'orig' and len(eval_file[key]['uuid'].split('-')) > 1:
            continue
        total += 1
        ground_truths = eval_file[key]["answers"]
        prediction = value
        exact_match += metric_max_over_ground_truths(
            exact_match_score, prediction, ground_truths)
        f1 += metric_max_over_ground_truths(f1_score,
                                            prediction, ground_truths)
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total
    return {'exact_match': exact_match, 'f1': f1}

In [26]:
metrics = evaluate(eval_file, answer_dict, only='orig')
print("Unmutated data")
print("Exact Match: {}, F1: {}".format(metrics['exact_match'], metrics['f1']))

Unmutated data
Exact Match: 90.7103825136612, F1: 67.55399427530575


In [27]:
metrics = evaluate(eval_file, answer_dict, only='adv')
print("Mutated data")
print("Exact Match: {}, F1: {}".format(metrics['exact_match'], metrics['f1']))

Mutated data
Exact Match: 92.42125984251969, F1: 94.39447681441005


In [28]:
metrics = evaluate(eval_file, answer_dict)
print("Overall")
print("Exact Match: {}, F1: {}".format(metrics['exact_match'], metrics['f1']))

Overall
Exact Match: 91.96816208393632, F1: 87.28621588147794
