In [1]:
from models.Attention import LuongAttention
from utils.data_reader import load_and_preprocess_data, load_word_embeddings
from utils.result_saver import ResultSaver
from os.path import join as pjoin
import numpy as np
import tensorflow as tf
from utils.eval import evaluate

In [2]:
tf.app.flags.DEFINE_float("learning_rate", 0.001, "Learning rate")
tf.app.flags.DEFINE_float("keep_prob", 0.8, "The probably that a node is kept after the affine transform")
tf.app.flags.DEFINE_float("max_grad_norm", 5.,
                          "The maximum grad norm during backpropagation, anything greater than max_grad_norm is truncated to be max_grad_norm")
tf.app.flags.DEFINE_integer("batch_size", 24, "Number of batches to be used per training batch")
tf.app.flags.DEFINE_integer("eval_num", 250, "Evaluate on validation set for every eval_num batches trained")
tf.app.flags.DEFINE_integer("embedding_size", 100, "Word embedding size")
tf.app.flags.DEFINE_integer("window_size", 3, "Window size for sampling during training")
tf.app.flags.DEFINE_integer("hidden_size", 100, "Hidden size of the RNNs")
tf.app.flags.DEFINE_integer("samples_used_for_evaluation", 500,
                            "Samples to be used at evaluation for every eval_num batches trained")
tf.app.flags.DEFINE_integer("num_epochs", 10, "Number of Epochs")
tf.app.flags.DEFINE_integer("max_context_length", None, "Maximum length for the context")
tf.app.flags.DEFINE_integer("max_question_length", None, "Maximum length for the question")
tf.app.flags.DEFINE_string("data_dir", "data/squad", "Data directory")
tf.app.flags.DEFINE_string("train_dir", "", "Saved training parameters directory")
tf.app.flags.DEFINE_string("retrain_embeddings", False, "Whether or not to retrain the embeddings")
tf.app.flags.DEFINE_string("share_encoder_weights", False, "Whether or not to share the encoder weights")
tf.app.flags.DEFINE_string("learning_rate_annealing", False, "Whether or not to anneal the learning rate")
tf.app.flags.DEFINE_string("ema_for_weights", False, "Whether or not to use EMA for weights")
tf.app.flags.DEFINE_string("log", True, "Whether or not to log the metrics during training")
tf.app.flags.DEFINE_string("optimizer", "adam", "The optimizer to be used ")
tf.app.flags.DEFINE_string("model", "BiDAF", "Model type")
tf.app.flags.DEFINE_string("find_best_span", True, "Whether find the span with the highest probability")

FLAGS = tf.app.flags.FLAGS

## Loading the data

In [3]:
# load the data
train, val = load_and_preprocess_data(FLAGS.data_dir)

# load the word matrix
embeddings = load_word_embeddings(FLAGS.data_dir)

# vocab map file
vocabs = []
with open(pjoin(FLAGS.data_dir, "vocab.dat")) as f:
    for line in f:
        vocabs.append(line.strip("\n"))

In [4]:
# load the result save (isn't used but is needed to initialize the model should be refactored out to be a 
# Singleton class)
result_saver = ResultSaver(FLAGS.train_dir)

## Initializing the model

In [5]:
# model = LuongAttention(result_saver, embeddings, FLAGS)
model = LuongAttention(result_saver, embeddings, FLAGS)

INFO:root:('----------', 'ENCODING ', '----------')
INFO:root:('----------', ' DECODING ', '----------')
INFO:root:answer_span_start_one_hot.get_shape() <unknown>
INFO:root:answer_span_end_one_hot.get_shape() <unknown>


## Getting a random sample

In [22]:
# Load a random sample from the validation set
n_val_samples = len(val["context"])

index = np.random.choice(np.arange(n_val_samples))

sample_data = {}
for k, v in val.items():
    sample_data[k] = v[[index]]

### The context paragraph

In [23]:
context = " ".join([word for word in sample_data["word_context"]])
context

'The Srijana Contemporary Art Gallery , located inside the Bhrikutimandap Exhibition grounds , hosts the work of contemporary painters and sculptors , and regularly organizes exhibitions . It also runs morning and evening classes in the schools of art . Also of note is the Moti Azima Gallery , located in a three storied building in Bhimsenthan which contains an impressive collection of traditional utensils and handmade dolls and items typical of a medieval Newar house , giving an important insight into Nepali history . The J Art Gallery is also located in Kathmandu , near the Royal Palace in Durbarmarg , Kathmandu and displays the artwork of eminent , established Nepali painters . The Nepal Art Council Gallery , located in the Babar Mahal , on the way to Tribhuvan International Airport contains artwork of both national and international artists and extensive halls regularly used for art exhibitions .\n'

### The question to be answered

In [24]:
# See what the question is
question = " ".join([vocabs[word] for word in sample_data["question"][0]])
question

'What art gallery is located close to the Durbarmarg Royal Palace ?'

### The answer

In [25]:
answer = " ".join([word for word in sample_data["word_answer"]])
answer

'J\n'

### Predict the answer with the model

In [26]:
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, pjoin("README-files", "Attention-model", "BATCH-6039"))
    start_index, end_index = model.answer(sess, sample_data, FLAGS.find_best_span)
    pred, truth = model.get_sentences_from_indices(sample_data, start_index, end_index)

INFO:tensorflow:Restoring parameters from README-files/Attention-model/BATCH-6039


INFO:tensorflow:Restoring parameters from README-files/Attention-model/BATCH-6039


In [28]:
print("Prediction from model is: {}".format(" ".join(pred)))
print("Ground truth is: {}".format(" ".join(truth)))

Prediction from model is: J Art Gallery
Ground truth is: J
