# Dynamic Coattention Network (DCN)

### Imports and Constants

In [1]:
import tensorflow as tf
import numpy as np
import tensorflow as tf
from tensorflow.contrib.layers import xavier_initializer
from tqdm import trange
import logging
import matplotlib as mpl
from collections import Counter
import matplotlib.pyplot as plt
import os, re, string

In [2]:
self_FLAGS = tf.flags.FLAGS
for name in list(self_FLAGS):
    delattr(self_FLAGS, name)

tf.app.flags.DEFINE_string('f', '', 'kernel')
tf.app.flags.DEFINE_string("model", "DCN", "Choose which model to use baseline/DCN")
tf.app.flags.DEFINE_float("learning_rate", 0.001, "Learning rate.")
tf.app.flags.DEFINE_integer("batch_size", 64, "Batch size to use during training.")
tf.app.flags.DEFINE_integer("epochs", 2, "Number of epochs to train.")
tf.app.flags.DEFINE_integer("rnn_state_size", 200, "Size of RNNs used in the model.")
tf.app.flags.DEFINE_string("figure_directory", "figs/", "Directory in which figures are stored.")
tf.app.flags.DEFINE_integer("word_vec_dim", 100, "Dimension of word vectors. Either 100 or 300")
tf.app.flags.DEFINE_float("dropout", 0.6, "1-Fraction of units randomly dropped.")
tf.app.flags.DEFINE_float("dropout_encoder", 0.7, "1-Fraction of units randomly dropped in the encoder.")
tf.app.flags.DEFINE_float("l2_lambda", 0.01, "Hyperparameter for l2 regularization.")
tf.app.flags.DEFINE_float("max_gradient_norm", 3.0, "Parameter for gradient clipping.")
tf.app.flags.DEFINE_string("batch_permutation", "random",
                           "Choose whether training data is shuffled ('random'), ordered by length ('by_length'), "
                           "or kept in initial order ('None') for each epoch")
tf.app.flags.DEFINE_integer("decrease_lr", 0, "Whether to decrease learning rate lr over time")
tf.app.flags.DEFINE_float("lr_d_base", 0.9997, "Base for the exponential decay of lr")
tf.app.flags.DEFINE_float("lr_divider", 2, "Due to exp. decay, lr can get as small as lr/lr_divider but not smaller")
# tf.app.flags.DEFINE_string("data_dir", "data/squad/", "SQuAD data directory")
# tf.app.flags.DEFINE_string("data_dir", "/content/DCN-Squad-Colab/data/squad_min/", "SQuAD data directory")
# tf.app.flags.DEFINE_string("glove_dir", "/content/DCN-Squad-Colab/data/glove/", "Glove and Vocab data directory")
# tf.app.flags.DEFINE_string("checkpoint_dir", "/gdrive/My Drive/Colab Notebooks/DCN/model/", "Tensorflow Chekpoints")

tf.app.flags.DEFINE_string("data_dir", "data/squad_min/", "SQuAD data directory")
tf.app.flags.DEFINE_string("glove_dir", "data/glove/", "Glove and Vocab data directory")
tf.app.flags.DEFINE_string("checkpoint_dir", "model/", "Tensorflow Chekpoints")
tf.app.flags.DEFINE_string("log_dir", "logs/", "Tensorboard Logs")

self_max_q_length = 30
self_max_c_length = 400

### Load and Preprocess

In [3]:
def span_to_y(y, max_length=None):
    """y is a numpy array, where each row consists of two ints start_id and end_id. 
    Do a one hot encoding of the start_id, and another one for the end_id.
    @max_length is the length of a context paragraph. Hence the one hot vectors have length @max_length"""
    if max_length is None:
        max_length = self_max_c_length
    start_ids, end_ids = y[:, 0], y[:, 1]
    S, E = [], []
    for i in range(len(start_ids)):
        labelS, labelE = np.zeros(max_length, dtype=np.int32), np.zeros(max_length, dtype=np.int32)
        if start_ids[i] < max_length and end_ids[i] < max_length:
            labelS[start_ids[i]], labelE[end_ids[i]] = 1, 1  # one hot encoding
        E.append(labelE)
        S.append(labelS)
    return np.array(S), np.array(E)

In [4]:
def read_and_pad(filename, length, pad_value):
    """filename is a file with words ids. Each row can have a different amount of ids.
    Read and pad each row with @pad_value such that it has length @length. 
    Additionally create a boolean mask for each row. An element is False iff the corresponding id is a pad_value
    Returns the padded id array, and the boolean mask."""
    with open(filename, 'r') as f:
        lines = f.readlines()
    lines = [line.split() for line in lines]
    line_array, mask_array = [], []
    for line in lines:
        line = line[:length]  # TODO: Add code to get rid of data for which len(line)>length.
        # Note: If the context for one line is not taken, question and answer_span should also not be taken and
        # vice versa for the question. Priority not too high, because having 0.1% garbage data is not too bad
        add_length = length - len(line)
        mask = [True] * len(line) + add_length * [False]
        line = line + add_length * [pad_value]
        line_array.append(line)
        mask_array.append(mask)
    return np.array(line_array, dtype=np.int32), np.array(mask_array)

In [5]:
# def load_and_preprocess_data(self):
"""Read in the Word embedding matrix as well as the question and context paragraphs and bring them into the 
desired numerical shape."""

logging.info("Data preparation. This can take some seconds...")
# load vocab
with open(self_FLAGS.glove_dir + "vocab.dat", "r") as f:
    self_vocab = f.readlines()
self_vocab = [x[:-1] for x in self_vocab]
# load word embedding
if self_FLAGS.word_vec_dim == 300:
    self_WordEmbeddingMatrix = np.load(self_FLAGS.glove_dir + "glove.trimmed.300.npz")['glove']
elif self_FLAGS.word_vec_dim == 100:
    self_WordEmbeddingMatrix = np.load(self_FLAGS.glove_dir + "glove.trimmed.100.npz")['glove']
else:
    raise ValueError("word_vec_dim can be either 100 or 300")
logging.debug("WordEmbeddingMatrix.shape={}".format(self_WordEmbeddingMatrix.shape))
null_wordvec_index = self_WordEmbeddingMatrix.shape[0]
# append a zero vector to WordEmbeddingMatrix, which shall be used as padding value
self_WordEmbeddingMatrix = np.vstack((self_WordEmbeddingMatrix, np.zeros(self_FLAGS.word_vec_dim)))
self_WordEmbeddingMatrix = self_WordEmbeddingMatrix.astype(np.float32)
logging.debug("WordEmbeddingMatrix.shape after appending zero vector={}".format(self_WordEmbeddingMatrix.shape))

# load contexts, questions and labels
self_yS, self_yE = span_to_y(np.loadtxt(self_FLAGS.data_dir + "train.span", dtype=np.int32))
self_yvalS, self_yvalE = span_to_y(np.loadtxt(self_FLAGS.data_dir + "val.span", dtype=np.int32))

self_X_c, self_X_c_mask = read_and_pad(self_FLAGS.data_dir + "train.ids.context", self_max_c_length,
                                            null_wordvec_index)
self_Xval_c, self_Xval_c_mask = read_and_pad(self_FLAGS.data_dir + "val.ids.context", self_max_c_length,
                                                  null_wordvec_index)
self_X_q, self_X_q_mask = read_and_pad(self_FLAGS.data_dir + "train.ids.question", self_max_q_length,
                                            null_wordvec_index)
self_Xval_q, self_Xval_q_mask = read_and_pad(self_FLAGS.data_dir + "val.ids.question", self_max_q_length,
                                                  null_wordvec_index)

logging.info("End data preparation.")

### MODEL: Placeholders

In [6]:
self_q_input_placeholder = tf.placeholder(tf.int32, (None, self_max_q_length), name="q_input_ph")
self_q_mask_placeholder = tf.placeholder(dtype=tf.bool, shape=(None, self_max_q_length),
                                         name="q_mask_placeholder")
self_c_input_placeholder = tf.placeholder(tf.int32, (None, self_max_c_length), name="c_input_ph")
self_c_mask_placeholder = tf.placeholder(dtype=tf.bool, shape=(None, self_max_c_length),
                                         name="c_mask_placeholder")
self_labels_placeholderS = tf.placeholder(tf.int32, (None, self_max_c_length), name="label_phS")
self_labels_placeholderE = tf.placeholder(tf.int32, (None, self_max_c_length), name="label_phE")

self_dropout_placeholder = tf.placeholder(tf.float32, name="dropout_ph")

### MODEL: Encode / Decode Network

In [7]:
def encode(apply_dropout=False):
    """Coattention context encoder as introduced in https://arxiv.org/abs/1611.01604 
    Uses GRUs instead of LSTMs. """

    # Each word is represented by a glove word vector (https://nlp.stanford.edu/projects/glove/)
    self_WEM = tf.get_variable(name="WordEmbeddingMatrix", initializer=tf.constant(self_WordEmbeddingMatrix),
                               trainable=False)

    # map word index (integer) to word vector (100 dimensional float vector)
    self_embedded_q = tf.nn.embedding_lookup(params=self_WEM, ids=self_q_input_placeholder)
    self_embedded_c = tf.nn.embedding_lookup(params=self_WEM, ids=self_c_input_placeholder)

    rnn_size = self_FLAGS.rnn_state_size
    with tf.variable_scope("rnn", reuse=None):
        cell = tf.contrib.rnn.GRUCell(rnn_size)
        if apply_dropout:
            # TODO add separate dropout placeholder for encoding and decoding. Right now the maximum sets
            # enc_keep_prob to 1 during prediction.
            enc_keep_prob = tf.maximum(tf.constant(self_FLAGS.dropout_encoder), self_dropout_placeholder)
            cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=enc_keep_prob)
        q_sequence_length = tf.reduce_sum(tf.cast(self_q_mask_placeholder, tf.int32), axis=1)
        q_sequence_length = tf.reshape(q_sequence_length, [-1, ])

        q_outputs, q_final_state = tf.nn.dynamic_rnn(cell=cell, inputs=self_embedded_q,
                                                     sequence_length=q_sequence_length, dtype=tf.float32,
                                                     time_major=False)

    Qprime = q_outputs
    q_senti = tf.get_variable("q_senti0", (rnn_size,), dtype=tf.float32)
    q_senti = tf.tile(q_senti, tf.shape(Qprime)[0:1])
    q_senti = tf.reshape(q_senti, (-1, 1, tf.shape(Qprime)[2]))
    Qprime = tf.concat([Qprime, q_senti], axis=1)
    Qprime = tf.transpose(Qprime, [0, 2, 1], name="Qprime")

    # add tanh layer to go from Qprime to Q
    WQ = tf.get_variable("WQ", (self_max_q_length + 1, self_max_q_length + 1),
                         initializer=tf.contrib.layers.xavier_initializer())
    bQ = tf.get_variable("bQ_Bias", shape=(rnn_size, self_max_q_length + 1),
                         initializer=tf.contrib.layers.xavier_initializer())
    Q = tf.einsum('ijk,kl->ijl', Qprime, WQ)
    Q = tf.nn.tanh(Q + bQ, name="Q")

    with tf.variable_scope("rnn", reuse=True):
        c_sequence_length = tf.reduce_sum(tf.cast(self_c_mask_placeholder, tf.int32), axis=1)
        c_sequence_length = tf.reshape(c_sequence_length, [-1, ])
        # use the same RNN cell as for the question input
        c_outputs, c_final_state = tf.nn.dynamic_rnn(cell=cell, inputs=self_embedded_c,
                                                     sequence_length=c_sequence_length,
                                                     dtype=tf.float32,
                                                     time_major=False)

    D = c_outputs
    c_senti = tf.get_variable("c_senti0", (rnn_size,), dtype=tf.float32)
    c_senti = tf.tile(c_senti, tf.shape(D)[0:1])
    c_senti = tf.reshape(c_senti, (-1, 1, tf.shape(D)[2]))
    D = tf.concat([D, c_senti], axis=1)
    D = tf.transpose(D, [0, 2, 1])
    L = tf.einsum('ijk,ijl->ikl', D, Q)
    AQ = tf.nn.softmax(L)
    AD = tf.nn.softmax(tf.transpose(L, [0, 2, 1]))
    CQ = tf.matmul(D, AQ)
    CD1 = tf.matmul(Q, AD)
    CD2 = tf.matmul(CQ, AD)
    CD = tf.concat([CD1, CD2], axis=1)
    CDprime = tf.concat([CD, D], axis=1)
    CDprime = tf.transpose(CDprime, [0, 2, 1])

    with tf.variable_scope("u_rnn", reuse=False):
        cell_fw = tf.contrib.rnn.GRUCell(rnn_size)
        cell_bw = tf.contrib.rnn.GRUCell(rnn_size)
        if apply_dropout:
            cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw, input_keep_prob=enc_keep_prob)
            cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw, input_keep_prob=enc_keep_prob)

        (cc_fw, cc_bw), _ = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs=CDprime,
                                                            sequence_length=c_sequence_length,
                                                            dtype=tf.float32)

    U = tf.concat([cc_fw, cc_bw], axis=2)
    logging.debug("U={}".format(U))
    return U

In [8]:
def dp_decode_HMN(U, pool_size=4, apply_dropout=True, cumulative_loss=True, apply_l2_reg=False):
    """ input: coattention_context U. tensor of shape (batch_size, context_length, arbitrary)
    Implementation of dynamic pointer decoder proposed by Xiong et al. ( https://arxiv.org/abs/1611.01604).

    Some of the implementation details such as the way us is obained from U via tf.gather_nd() are explored on toy 
    data in Experimentation_Notebooks/toy_data_examples_for_tile_map_fn_gather_nd_etc.ipynb"""

    def HMN_func(dim, ps):  # ps=pool size, HMN = highway maxout network
        def func(ut, h, us, ue):
            h_us_ue = tf.concat([h, us, ue], axis=1)
            WD = tf.get_variable(name="WD", shape=(5 * dim, dim), dtype='float32',
                                 initializer=xavier_initializer())
            r = tf.nn.tanh(tf.matmul(h_us_ue, WD))
            ut_r = tf.concat([ut, r], axis=1)
            if apply_dropout:
                ut_r = tf.nn.dropout(ut_r, keep_prob=self_dropout_placeholder)
            W1 = tf.get_variable(name="W1", shape=(3 * dim, dim, ps), dtype='float32',
                                 initializer=xavier_initializer())
            b1 = tf.get_variable(name="b1_Bias", shape=(dim, ps), dtype='float32',
                                 initializer=tf.zeros_initializer())
            mt1 = tf.einsum('bt,top->bop', ut_r, W1) + b1
            mt1 = tf.reduce_max(mt1, axis=2)
            if apply_dropout:
                mt1 = tf.nn.dropout(mt1, self_dropout_placeholder)
            W2 = tf.get_variable(name="W2", shape=(dim, dim, ps), dtype='float32',
                                 initializer=xavier_initializer())
            b2 = tf.get_variable(name="b2_Bias", shape=(dim, ps), dtype='float32',
                                 initializer=tf.zeros_initializer())
            mt2 = tf.einsum('bi,ijp->bjp', mt1, W2) + b2
            mt2 = tf.reduce_max(mt2, axis=2)
            mt12 = tf.concat([mt1, mt2], axis=1)
            if apply_dropout:
                mt12 = tf.nn.dropout(mt12, keep_prob=self_dropout_placeholder)
            W3 = tf.get_variable(name="W3", shape=(2 * dim, 1, ps), dtype='float32',
                                 initializer=xavier_initializer())
            b3 = tf.get_variable(name="b3_Bias", shape=(1, ps), dtype='float32', initializer=tf.zeros_initializer())
            hmn = tf.einsum('bi,ijp->bjp', mt12, W3) + b3
            hmn = tf.reduce_max(hmn, axis=2)
            hmn = tf.reshape(hmn, [-1])
            return hmn

        return func

    float_mask = tf.cast(self_c_mask_placeholder, dtype=tf.float32)
    neg = tf.constant([0], dtype=tf.float32)
    neg = tf.tile(neg, [tf.shape(float_mask)[0]])
    neg = tf.reshape(neg, (tf.shape(float_mask)[0], 1))
    float_mask = tf.concat([float_mask, neg], axis=1)
    labels_S = tf.concat([self_labels_placeholderS, tf.cast(neg, tf.int32)], axis=1)
    labels_E = tf.concat([self_labels_placeholderE, tf.cast(neg, tf.int32)], axis=1)
    dim = self_FLAGS.rnn_state_size

    # initialize us and ue as first word in context
    i_start = tf.zeros(shape=(tf.shape(U)[0],), dtype='int32')
    i_end = tf.zeros(shape=(tf.shape(U)[0],), dtype='int32')
    idx = tf.range(0, tf.shape(U)[0], 1)
    s_idx = tf.stack([idx, i_start], axis=1)
    e_idx = tf.stack([idx, i_end], axis=1)
    us = tf.gather_nd(U, s_idx)
    ue = tf.gather_nd(U, e_idx)

    HMN_alpha = HMN_func(dim, pool_size)
    HMN_beta = HMN_func(dim, pool_size)

    alphas, betas = [], []
    h = tf.zeros(shape=(tf.shape(U)[0], dim), dtype='float32', name="h_dpd")  # initial hidden state of RNN
    U_transpose = tf.transpose(U, [1, 0, 2])

    with tf.variable_scope("dpd_RNN"):
        cell = tf.contrib.rnn.GRUCell(dim)
        for time_step in range(3):  # number of time steps can be considered as a hyper parameter
            if time_step >= 1:
                tf.get_variable_scope().reuse_variables()

            us_ue = tf.concat([us, ue], axis=1)
            _, h = cell(inputs=us_ue, state=h)

            with tf.variable_scope("alpha_HMN"):
                if time_step >= 1:
                    tf.get_variable_scope().reuse_variables()
                alpha = tf.map_fn(lambda ut: HMN_alpha(ut, h, us, ue), U_transpose, dtype=tf.float32)
                alpha = tf.transpose(alpha, [1, 0]) * float_mask

            i_start = tf.argmax(alpha, 1)
            idx = tf.range(0, tf.shape(U)[0], 1)
            s_idx = tf.stack([idx, tf.cast(i_start, 'int32')], axis=1)
            us = tf.gather_nd(U, s_idx)

            with tf.variable_scope("beta_HMN"):
                if time_step >= 1:
                    tf.get_variable_scope().reuse_variables()
                beta = tf.map_fn(lambda ut: HMN_beta(ut, h, us, ue), U_transpose, dtype=tf.float32)
                beta = tf.transpose(beta, [1, 0]) * float_mask

            i_end = tf.argmax(beta, 1)
            e_idx = tf.stack([idx, tf.cast(i_end, 'int32')], axis=1)
            ue = tf.gather_nd(U, e_idx)

            alphas.append(alpha)
            betas.append(beta)

    if cumulative_loss:
        losses_alpha = [tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_S, logits=a) for a in
                        alphas]
        losses_alpha = [tf.reduce_mean(x) for x in losses_alpha]
        losses_beta = [tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_E, logits=b) for b in
                       betas]
        losses_beta = [tf.reduce_mean(x) for x in losses_beta]

        loss = tf.reduce_sum([losses_alpha, losses_beta])
    else:
        cross_entropy_start = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_S, logits=alpha,
                                                                      name="cross_entropy_start")
        cross_entropy_end = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_E, logits=beta,
                                                                    name="cross_entropy_end")
        loss = tf.reduce_mean(cross_entropy_start) + tf.reduce_mean(cross_entropy_end)

    if apply_l2_reg:
        loss_l2 = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if "Bias" not in v.name])
        loss += loss_l2 * self_FLAGS.l2_lambda

    return i_start, i_end, loss

In [9]:
coattention_context = encode(apply_dropout=True)
self_predictionS, self_predictionE, self_loss = dp_decode_HMN(coattention_context, apply_dropout=True, apply_l2_reg=False)

In [10]:
step_adam = tf.Variable(0, trainable=False)
lr = tf.constant(self_FLAGS.learning_rate)
if self_FLAGS.decrease_lr:
    # use adam optimizer with exponentially decaying learning rate
    rate_adam = tf.train.exponential_decay(lr, step_adam, 1, self_FLAGS.lr_d_base)
    # after one epoch: # 0.999**2500 = 0.5,  hence learning rate decays by a factor of 0.5 each epoch
    rate_adam = tf.maximum(rate_adam, tf.constant(self_FLAGS.learning_rate / self_FLAGS.lr_divider))
    # should not go down by more than a factor of 2
    self_optimizer = tf.train.AdamOptimizer(rate_adam)
else:
    self_optimizer = tf.train.AdamOptimizer(lr)

grads_and_vars = self_optimizer.compute_gradients(self_loss)
variables = [output[1] for output in grads_and_vars]
gradients = [output[0] for output in grads_and_vars]

gradients = tf.clip_by_global_norm(gradients, clip_norm=self_FLAGS.max_gradient_norm)[0]
self_global_grad_norm = tf.global_norm(gradients)
grads_and_vars = [(gradients[i], variables[i]) for i in range(len(gradients))]

train_op = self_optimizer.apply_gradients(grads_and_vars, global_step=step_adam)

In [11]:
def get_feed_dict(batch_xc, batch_xc_mask, batch_xq, batch_xq_mask, batch_yS, batch_yE, keep_prob):
    feed_dict = {self_c_input_placeholder: batch_xc,
                 self_c_mask_placeholder: batch_xc_mask,
                 self_q_input_placeholder: batch_xq,
                 self_q_mask_placeholder: batch_xq_mask,
                 self_labels_placeholderS: batch_yS,
                 self_labels_placeholderE: batch_yE,
                 self_dropout_placeholder: keep_prob}
    return feed_dict

In [12]:
def squad_normalize_answer(s):
    """ Lower text and remove punctuation, articles and extra whitespace.
    Method copied from the SQuAD Leaderboard: https://rajpurkar.github.io/SQuAD-explorer/  """

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def squad_f1_score(prediction, ground_truth):
    """Method copied from the SQuAD Leaderboard: https://rajpurkar.github.io/SQuAD-explorer/"""
    prediction_tokens = squad_normalize_answer(prediction).split()
    ground_truth_tokens = squad_normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def squad_exact_match_score(prediction, ground_truth):
    """Method copied from the SQuAD Leaderboard: https://rajpurkar.github.io/SQuAD-explorer/"""
    return (squad_normalize_answer(prediction) == squad_normalize_answer(ground_truth))

def get_f1(yS, yE, ypS, ypE):
    """My own, more strict f1 metric"""
    f1_tot = 0.0
    for i in range(len(yS)):
        y = np.zeros(self_max_c_length)
        s = np.argmax(yS[i])
        e = np.argmax(yE[i])
        y[s:e + 1] = 1

        yp = np.zeros_like(y)
        yp[ypS[i]:ypE[i] + 1] = 1
        yp[ypE[i]:ypS[i] + 1] = 1  # allow flipping between start and end

        n_true_pos = np.sum(y * yp)
        n_pred_pos = np.sum(yp)
        n_actual_pos = np.sum(y)
        if n_true_pos != 0:
            precision = 1.0 * n_true_pos / n_pred_pos
            recall = 1.0 * n_true_pos / n_actual_pos
            f1_tot += (2 * precision * recall) / (precision + recall)
    f1_tot /= len(yS)
    return f1_tot

def get_exact_match(yS, yE, ypS, ypE):
    """My own, more strict EM metric"""
    count = 0
    for i in range(len(yS)):
        s, e = np.argmax(yS[i]), np.argmax(yE[i])
        sp, ep = ypS[i], ypE[i]
        if sp > ep:
            sp, ep = ep, sp  # allow flipping between start and end
        if s == sp and e == ep:
            count += 1
    match_fraction = count / float(len(yS))
    return match_fraction

def index_list_to_string(index_list):
    """Helper function. Converts a list of word indices to a string of words"""
    res = [self_vocab[index] for index in index_list]
    return ' '.join(res)

def get_exact_match_from_tokens(yS, yE, ypS, ypE, batch_Xc):
    """This function doesn't compare the indices, but the tokens behind the indices. This is a bit more forgiving
    and it is the metric applied on the SQuAD leaderboard"""
    em = 0
    for i in range(len(yS)):
        s, e = np.argmax(yS[i]), np.argmax(yE[i])
        sp, ep = ypS[i], ypE[i]
        if sp > ep:
            sp, ep = ep, sp  # allow flipping between start and end
        ground_truth = index_list_to_string(batch_Xc[i][s:e + 1])
        prediction = index_list_to_string(batch_Xc[i][sp:ep + 1])
        em += squad_exact_match_score(prediction, ground_truth)
    return em / float(len(yS))

def get_f1_from_tokens(yS, yE, ypS, ypE, batch_Xc):
    """This function doesn't compare the indices, but the tokens behind the indices. This is a bit more forgiving
    and it is the metric applied on the SQuAD leaderboard."""
    f1 = 0
    for i in range(len(yS)):
        s, e = np.argmax(yS[i]), np.argmax(yE[i])
        sp, ep = ypS[i], ypE[i]
        if sp > ep:
            sp, ep = ep, sp  # allow flipping between start and end
        ground_truth = index_list_to_string(batch_Xc[i][s:e + 1])
        prediction = index_list_to_string(batch_Xc[i][sp:ep + 1])
        f1 += squad_f1_score(prediction, ground_truth)
    return f1 / float(len(yS))

def plot_metrics(prefix, index_epoch, losses, val_losses, EMs, val_Ems, F1s, val_F1s, grad_norms):
    n_data_points = len(losses)
    epoch_axis = np.arange(n_data_points, dtype=np.float32) * index_epoch / float(n_data_points)
    epoch_axis_val = list(range(index_epoch + 1))

    plt.plot(epoch_axis, losses, label="training")
    plt.plot(epoch_axis_val, [losses[0]] + val_losses, label="validation", marker="x", ms=15)
    # initial value is from training set, just so that we have a nice value for epoch=0 in the plot
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.legend()
    plt.savefig(self_FLAGS.figure_directory + prefix + "losses_over_time.png")
    plt.close()

    plt.plot(epoch_axis, EMs, label="training")
    plt.plot(epoch_axis_val, [EMs[0]] + val_Ems, label="validation", marker="x", ms=15)
    plt.xlabel("epoch")
    plt.ylabel("EM")
    plt.legend()
    plt.savefig(self_FLAGS.figure_directory + prefix + "EMs_over_time.png")
    plt.close()

    plt.plot(epoch_axis, F1s, label="training")
    plt.plot(epoch_axis_val, [F1s[0]] + val_F1s, label="validation", marker="x", ms=15)
    plt.xlabel("epoch")
    plt.ylabel("F1")
    plt.legend()
    plt.savefig(self_FLAGS.figure_directory + prefix + "f1s_over_time.png")
    plt.close()

    plt.plot(epoch_axis, grad_norms)
    plt.xlabel("epoch")
    plt.ylabel("gradient_norm")
    plt.savefig(self_FLAGS.figure_directory + prefix + "training_grad_norms_over_time.png")
    plt.close()

In [13]:
class batch(object):
    def __init__(self, X_c, X_c_mask, X_q, X_q_mask, yS, yE, Xval_c, Xval_c_mask, Xval_q, Xval_q_mask, yvalS, yvalE):
        self.X_c = X_c
        self.X_c_mask = X_c_mask
        self.X_q = X_q 
        self.X_q_mask = X_q_mask 
        self.yS = yS 
        self.yE = yE 
        self.Xval_c = Xval_c 
        self.Xval_c_mask = Xval_c_mask 
        self.Xval_q = Xval_q 
        self.Xval_q_mask = Xval_q_mask 
        self.yvalS = yvalS 
        self.yvalE = yvalE
        self.batch_index = 0
        self.max_batch_index = -1
        self.batch_permutation = []
    
    def initialize_batch_processing(self, permutation='None', n_samples=None):
        self.batch_index = 0
        if n_samples is not None:
            self.max_batch_index = n_samples
        if permutation == 'by_length':
            # sum over True/False gives number of words in each sample
            length_of_each_context_paragraph = np.sum(self.X_c_mask, axis=1)
            # permutation of data is chosen, such that the algorithm sees short context_paragraphs first
            self.batch_permutation = np.argsort(length_of_each_context_paragraph)
        elif permutation == 'random':
            self.batch_permutation = np.random.permutation(self.max_batch_index)  # random initial permutation
        elif (permutation == 'None' or permutation is None):  # no permutation
            self.batch_permutation = np.arange(self.max_batch_index)  # initial permutation = identity
        else:
            raise ValueError("permutation must be 'by_length', 'random' or 'None'")

    def next_batch(self, batch_size, permutation_after_epoch='None', val=False):
        if self.batch_index >= self.max_batch_index:
            # we went through one epoch. reset batch_index and initialize batch_permutation
            self.initialize_batch_processing(permutation=permutation_after_epoch)

        start = self.batch_index
        end = self.batch_index + batch_size

        if not val:
            Xcres = self.X_c[self.batch_permutation[start:end]]
            Xcmaskres = self.X_c_mask[self.batch_permutation[start:end]]
            Xqres = self.X_q[self.batch_permutation[start:end]]
            Xqmaskres = self.X_q_mask[self.batch_permutation[start:end]]
            yresS = self.yS[self.batch_permutation[start:end]]
            yresE = self.yE[self.batch_permutation[start:end]]
        else:
            Xcres = self.Xval_c[self.batch_permutation[start:end]]
            Xcmaskres = self.Xval_c_mask[self.batch_permutation[start:end]]
            Xqres = self.Xval_q[self.batch_permutation[start:end]]
            Xqmaskres = self.Xval_q_mask[self.batch_permutation[start:end]]
            yresS = self.yvalS[self.batch_permutation[start:end]]
            yresE = self.yvalE[self.batch_permutation[start:end]]

        self.batch_index += batch_size
        return Xcres, Xcmaskres, Xqres, Xqmaskres, yresS, yresE

## Train Operation

In [14]:
with tf.name_scope('Performance'):
    tb_f1 = tf.placeholder(tf.float32, shape=None, name='f1_summary')
    tb_em = tf.placeholder(tf.float32, shape=None, name='em_summary')
    tb_loss = tf.placeholder(tf.float32, shape=None, name='loss_summary')

    tf.summary.scalar('F1', tb_f1)
    tf.summary.scalar('EM', tb_em)
    tf.summary.scalar('Loss', tb_loss)
    tb_step = tf.Variable(0, tf.int32)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

# sess = tf.Session()
# sess.run(tf.global_variables_initializer())

epochs = self_FLAGS.epochs
batch_size = self_FLAGS.batch_size
n_samples = len(self_yS)

ba = batch(self_X_c, self_X_c_mask, self_X_q, self_X_q_mask, self_yS, self_yE, self_Xval_c, self_Xval_c_mask, self_Xval_q, self_Xval_q_mask, self_yvalS, self_yvalE)

global_losses, global_EMs, global_f1s, global_grad_norms = [], [], [], []  # global means "over several epochs"
EMs_val, F1s_val, loss_val = [], [], []  # exact_match- and F1-metrics as well as loss on the validation data
SQ_global_EMs, SQ_global_f1s, SQ_EMs_val, SQ_F1s_val = [], [], [], []  # corresponding squad metrics


########### LOGIN IN TENSORBOARD #####################
# with tf.name_scope('Paper'):
# #     f1_summary = tf.summary.scalar("F1", np.mean(f1s))
# #     em_summary = tf.summary.scalar("EM", np.mean(ems))
#     loss_summary = tf.summary.scalar("loss", self_loss)

# # with tf.name_scope('Squad'):
# #     f1_summary_s = tf.summary.scalar("Squad_F1", np.mean(sq_f1s))
# #     em_summary_s = tf.summary.scalar("Squad_EM", np.mean(sq_ems))

writer = tf.summary.FileWriter(self_FLAGS.log_dir, sess.graph)
merged_op = tf.summary.merge_all()
tb_step_op = tb_step.assign_add(1)



for index_epoch in range(1, epochs + 1):
    progbar = trange(int(n_samples / batch_size))
    losses, ems, f1s, grad_norms = [], [], [], []
    sq_ems, sq_f1s = [], []
    ba.initialize_batch_processing(permutation=self_FLAGS.batch_permutation, n_samples=n_samples)

    ############### train for one epoch ###############
    for _ in progbar:
        batch_xc, batch_xc_mask, batch_xq, batch_xq_mask, batch_yS, batch_yE = ba.next_batch(
            batch_size=batch_size, permutation_after_epoch=self_FLAGS.batch_permutation)
        feed_dict = get_feed_dict(batch_xc, batch_xc_mask, batch_xq, batch_xq_mask, batch_yS, batch_yE,
                                       self_FLAGS.dropout)
        _, current_loss, predictionS, predictionE, grad_norm, curr_lr = sess.run(
            [train_op, self_loss, self_predictionS, self_predictionE, self_global_grad_norm,
             self_optimizer._lr],
            feed_dict=feed_dict)
        ems.append(get_exact_match(batch_yS, batch_yE, predictionS, predictionE))
        f1s.append(get_f1(batch_yS, batch_yE, predictionS, predictionE))
        sq_ems.append(get_exact_match_from_tokens(batch_yS, batch_yE, predictionS, predictionE, batch_xc))
        sq_f1s.append(get_f1_from_tokens(batch_yS, batch_yE, predictionS, predictionE, batch_xc))
        losses.append(current_loss)
        grad_norms.append(grad_norm)

        if len(losses) >= 20:
            progbar.set_postfix({'loss': np.mean(losses), 'EM': np.mean(ems), 'SQ_EM': np.mean(sq_ems), 'F1':
                np.mean(f1s), 'SQ_F1': np.mean(sq_f1s), 'grad_norm': np.mean(grad_norms), 'lr': curr_lr})
            global_losses.append(np.mean(losses))
            global_EMs.append(np.mean(ems))
            global_f1s.append(np.mean(f1s))
            SQ_global_EMs.append(np.mean(sq_ems))
            SQ_global_f1s.append(np.mean(sq_f1s))
            global_grad_norms.append(np.mean(grad_norms))
            
            summary_train = sess.run(merged_op, feed_dict={tb_f1: np.mean(f1s), 
                                                           tb_em: np.mean(ems),
                                                           tb_loss: np.mean(losses)})
            counter = sess.run(tb_step_op)
            writer.add_summary(summary_train, counter)

            losses, ems, f1s, grad_norms = [], [], [], []
            sq_ems, sq_f1s = [], []



    ############## SAVING CHECKPOINT MODEL ###################################
    saver = tf.train.Saver()
    save_path = saver.save(sess, self_FLAGS.checkpoint_dir+"model.ckpt")
    logging.info("Model saved in path: %s" % save_path)


    ############### After an epoch: evaluate on validation set ###############
    logging.info("Epoch {} finished. Doing evaluation on validation set...".format(index_epoch))
    ba.initialize_batch_processing(permutation=self_FLAGS.batch_permutation,
                                     n_samples=len(self_yvalE))
    val_batch_size = batch_size  # can be a multiple of batch_size, but be sure to not run out of memory
    losses, ems, f1s = [], [], []
    sq_ems, sq_f1s = [], []
    for _ in range(int(len(self_yvalE) / val_batch_size)):
        batch_xc, batch_xc_mask, batch_xq, batch_xq_mask, batch_yS, batch_yE = ba.next_batch(
            batch_size=val_batch_size, permutation_after_epoch=self_FLAGS.batch_permutation, val=True)
        feed_dict = get_feed_dict(batch_xc, batch_xc_mask, batch_xq, batch_xq_mask, batch_yS,
                                       batch_yE, keep_prob=1)
        current_loss, predictionS, predictionE = sess.run([self_loss, self_predictionS, self_predictionE],
                                                          feed_dict=feed_dict)
        ems.append(get_exact_match(batch_yS, batch_yE, predictionS, predictionE))
        f1s.append(get_f1(batch_yS, batch_yE, predictionS, predictionE))
        sq_ems.append(get_exact_match_from_tokens(batch_yS, batch_yE, predictionS, predictionE, batch_xc))
        sq_f1s.append(get_f1_from_tokens(batch_yS, batch_yE, predictionS, predictionE, batch_xc))
        losses.append(current_loss)

    loss_on_validation, EM_val, F1_val = np.mean(losses), np.mean(ems), np.mean(f1s)
    SQ_F1_val, SQ_EM_val = np.mean(sq_f1s), np.mean(sq_ems)
    logging.info("loss_val={}".format(loss_on_validation))
    logging.info("EM_val={}".format(EM_val))
    logging.info("F1_val={}".format(F1_val))
    logging.info("SQ_EM_val={}".format(SQ_EM_val))
    logging.info("SQ_F1_val={}".format(SQ_F1_val))
    EMs_val.append(EM_val)
    F1s_val.append(F1_val)
    SQ_EMs_val.append(SQ_EM_val)
    SQ_F1s_val.append(SQ_F1_val)
    loss_val.append(loss_on_validation)

    ############### do some plotting ###############
    # if index_epoch > 1:
    plot_metrics("strict_", index_epoch, global_losses, loss_val, global_EMs, EMs_val, global_f1s, F1s_val,
                      global_grad_norms)

    plot_metrics("SQuAD_", index_epoch, global_losses, loss_val, SQ_global_EMs, SQ_EMs_val, SQ_global_f1s,
                      SQ_F1s_val, global_grad_norms)

 45%|████▌     | 14/31 [03:47<04:39, 16.43s/it, loss=29.5, EM=0, SQ_EM=0, F1=0.028, SQ_F1=0.0447, grad_norm=3, lr=0.001]     

KeyboardInterrupt: 

### Helpers

In [None]:
q_id = 1
context = ' '.join([self_vocab[i] for i in batch_xc[q_id] if i < len(self_vocab)])
print(context)

In [None]:
question = ' '.join([self_vocab[i] for i in batch_xq[q_id] if i < len(self_vocab)])
print(question)

In [None]:
start = np.argmax(batch_yS[q_id])
end = np.argmax(batch_yE[q_id])

In [None]:
answer = [batch_xc[q_id,i] for i in range(start,end+1)]
answer = ' '.join([self_vocab[i] for i in answer if i < len(self_vocab)])
print(answer)

### Predictions

In [None]:
q_id = 1
feed_dict = get_feed_dict(batch_xc[q_id].reshape(1,-1), 
                          batch_xc_mask[q_id].reshape(1,-1), 
                          batch_xq[q_id].reshape(1,-1), 
                          batch_xq_mask[q_id].reshape(1,-1), 
                          batch_yS[q_id].reshape(1,-1), 
                          batch_yE[q_id].reshape(1,-1)
                          , keep_prob=1)
predictionS, predictionE = sess.run([self_predictionS, self_predictionE], feed_dict=feed_dict)

print('Context')
context = ' '.join([self_vocab[i] for i in batch_xc[q_id] if i < len(self_vocab)])
print(context)

print('\nQuestion')
question = ' '.join([self_vocab[i] for i in batch_xq[q_id] if i < len(self_vocab)])
print(question)

print('\nPredicted')
answer = [batch_xc[q_id,i] for i in range(predictionS[0], predictionE[0]+1)]
answer = ' '.join([self_vocab[i] for i in answer if i < len(self_vocab)])
print(answer)

print('\nCorrected')
start = np.argmax(batch_yS[q_id])
end = np.argmax(batch_yE[q_id])
answer = [batch_xc[q_id,i] for i in range(start, end+1)]
answer = ' '.join([self_vocab[i] for i in answer if i < len(self_vocab)])
print(answer)