# Reading Comprehension (QA) - Pointer Network-based Model
Existing state-of-the-art solutions to the Cloze-form reading comprehension on the CNN dataset typically calculate a probability distribution over the entities in the vocabulary. Without using explicit information about entities (which basically provides an answer candidate list - unintended consequences of anonymising the entities), one way to solve the reading comprehension question-answering problem in an ideal case is to try to model the position/index of the answer in the input passage itself. Pointer Networks (Vinyals, 2015) is a solution that models attention directly over the input.

This is a pointer-networks inspired solution to RC, but there are various issues that come up with such an adaptation. One of the main problems is that the dataset provides the solution entity, but does not specify the exact context/sentence where this entity occurs - which means that the entity may appear any number of times in the passage, in different contexts, but we do not know which context is the most useful for predicting the answer to the query. This is problematic because it is necessary to know how many positions are being predicted during training.




In [None]:
import sys
import os
import re
import pickle
import numpy as np
from tqdm import tqdm
import corenlp
import torch
import time
import math
import corenlp
import torch.nn as nn
import torch.nn.init as I
import torch.nn.utils.rnn as R
import torch.nn.functional as F
from torch.autograd import Variable
import copy
import time
import math

### Data Preparation
The code expects the CNN dataset (downloaded from https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTTljRDVZMFJnVWM) and the Glove 6B file for 300 dimensions. This code extracts words and vectors from the glove.6B.300d.txt file (glove_file = path_to_Glove_file) and saves them as a dictionary.

Get Part-of-Speech tags for passage and query sentences (Using Stanford's CoreNLP parser taken from https://stanfordnlp.github.io/CoreNLP/index.html; also need to install python interface from https://github.com/stanfordnlp/python-stanford-corenlp.) Run the Stanford server e.g.: java -mx4g -cp "*" edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 9000 -timeout 15000 " and set environment variable  export CORENLP_HOME=~/path_to_stanford-corenlp

Extracts word vectors from Glove embeddings dictionary (random initialization for unknown, @entities and @placeholders) and generates two sets of files: one set with glove embeddings for passage and query, another set with glove embeddings + one-hot part-of-speech tags concatenated with glove embeddings. Also generates one-hot representation of answer indices in the same shape as the passage: zeros that correspond to the words in input passage, ones for every answer entity occurrence. Since there are multiple possible answer indices with ones, it is not _technically_ a one-hot representation, but will be referred to as such throughout this implementation.

In [None]:
def getGloveDict(glove_file):

    words={}
    vectors = []
    f = open(glove_file,'r',encoding='utf-8')
    lines = f.readlines()
    #lines = lines[0:100000]
    print('lines:',len(lines))
    for (n,i) in enumerate(tqdm(lines)):
    #print(n)
    #i=i.split()
        i = i.replace('\r','').replace('\n','').split(' ')
        #print(i)
        j = 1
        v = []
        while j < len(i):
            v.append(float(i[j]))
            j += 1
        words[i[0]]=v

    pickle.dump(words, open('gloveDict.pkl', 'wb'))

In [None]:
def prepare(passage, query, answer, glove_vectors, pos_dict, max_sentence_length, max_passage_length, dataOrigin):
    
    #index top-36 POS tags
    pos_dict = {'NNP':0, 'NN':1, 'CC':2, 'CD':3, 'DT':4, 'EX':5, 'FW':6, 'IN':7, 'JJ':8, 'JJR':9, 'JJS':10, 'LS':11, 'MD':12, 'NNS':13, 'NNPS':14, 'PDT':15, 'POS':16, 'PRP':17, 'PRP$':18, 'RB':19, 'RBR':20, 'RBS':21, 'RP':22, 'SYM':23, 'TO':24, 'UH':25, 'VB':26, 'VBD':27, 'VBG':28, 'VBN':29, 'VBP':30, 'VBZ':31, 'WDT':32, 'WP':33, 'WP$':34, 'WRB':35} 	
    
    #Run Stanford parser to tokenise + generate POS tags
    with corenlp.CoreNLPClient(annotators="tokenize ssplit pos".split()) as client:
        passagePos = client.annotate(passage)
        queryPos = client.annotate(query)
        passageSentences = passagePos.sentence #list of sentences!
        queryTokens = queryPos.sentence[0].token #list of tokens in only sentence

    #print(passageSentences[0])
    #print(queryTokens[0])

    if len(passageSentences) > max_passage_length:
        max_passage_length = len(passageSentences)

    passage_vector = []
    passage_pos_vector = []
    answer_vector = []
    entity_vectors = {}
    answer_index = 0

    for idx,sentence in enumerate(passageSentences):
        tokens = sentence.token
        sentence_vector = []
        sentence_pos_vector = []

        if len(tokens) > max_sentence_length:
            max_sentence_length = len(tokens)
        for eachtoken in tokens:
            token = eachtoken.word
            tokenpos_vector = torch.zeros(37)
            if re.match('@entity', token):
                try:
                    sentence_vector.append(entity_vectors[token])
                except:
                    entity_vectors[token] = torch.rand(300)
                    sentence_vector.append(entity_vectors[token])
                try:
                    tokenpos_vector[pos_dict[eachtoken.pos]] = 1
                except:
                    tokenpos_vector[-1] = 1
                sentence_pos_vector.append(torch.cat((entity_vectors[token],tokenpos_vector),0))

            else:
                try:
                    word_vec = glove_vectors[token.lower]
                    sentence_vector.append(word_vec)
                except:
                    word_vec = torch.rand(300)
                sentence_vector.append(word_vec)
                try:
        
                    tokenpos_vector[pos_dict[eachtoken.pos]] = 1
                except:
                    tokenpos_vector[-1] = 1
                
                sentence_pos_vector.append(torch.cat((word_vec,tokenpos_vector),0)) #default to Noun

                if token==answer:
                    answer_vector.append(1)
                else:
                    answer_vector.append(0)
            answer_index+=1


        assert len(sentence_vector) == len(sentence_pos_vector)
        passage_vector.extend(sentence_vector)
        passage_pos_vector.extend(sentence_pos_vector)

    query_vector = []
    query_pos_vector = []
    sentence_vector = []
    sentence_pos_vector = []
    if len(queryTokens) > max_sentence_length:
        max_sentence_length = len(queryTokens)
    for eachtoken in queryTokens:
        token = eachtoken.word
        tokenpos_vector = torch.zeros(37)
        if re.match('@placeholder', token):
            sentence_vector.append(torch.zeros(300))
            try:
                tokenpos_vector[pos_dict[eachtoken.pos]] = 1
            except:
                tokenpos_vector[-1] = 1
            sentence_pos_vector.append(torch.cat((torch.rand(300),tokenpos_vector),0))
        elif re.match('@entity', token):
            try:
                sentence_vector.append(entity_vectors[token])
            except:
                entity_vectors[token] = torch.rand(300)
                sentence_vector.append(entity_vectors[token])
                try:
                    tokenpos_vector[pos_dict[eachtoken.pos]] = 1
                except:
                    tokenpos_vector[-1] = 1
                sentence_pos_vector.append(torch.cat((entity_vectors[token],tokenpos_vector),0))
            
        else:
            try:
                word_vec = glove_vectors[token.lower]
                sentence_vector.append(word_vec)
            except:
                word_vec = torch.rand(300)
                sentence_vector.append(word_vec)

            try:
                tokenpos_vector[pos_dict[eachtoken.pos]] = 1
            except:
                tokenpos_vector[-1] = 1
            sentence_pos_vector.append(torch.cat((word_vec,tokenpos_vector),0))


    assert len(sentence_vector) == len(sentence_pos_vector)
    query_vector.extend(sentence_vector)
    query_pos_vector.extend(sentence_pos_vector)

    return passage_vector, passage_pos_vector, query_vector, query_pos_vector, answer_vector, max_passage_length, max_sentence_length


In [None]:
def normalizeString(s):
    s = s.lower().strip()
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z0-9.!?@]+", r" ", s) #keep @ which is pre-pended to entities and placeholders 
    return s

Change dataOrigin to "test" or "dev" to generate input files for each type. Assumes cnn dataset and Glove 300d vectors files are available in the same folder (change dataPath accordingly).

In [None]:
dataOrigin = "train"

if dataOrigin=="train":
    dataPath = 'cnn/questions/training'
elif dataOrigin=="dev":
    dataPath = 'cnn/questions/validation'
elif dataOrigin=="test":
    dataPath = 'cnn/questions/test'

glove_file = 'glove.6B.300d.txt'

getGloveDict(glove_file)

max_sentence_length = 0
max_passage_length = 0
save_folder = dataOrigin+"_size_full" 
os.mkdir(save_folder)

glove_vectors = pickle.load(open('gloveDict.pkl', 'rb'))
passage_out = open(save_folder+"/passage_"+dataOrigin, 'ab')
query_out = open(save_folder+"/query_"+dataOrigin, 'ab')
passage_pos_out = open(save_folder+"/passage_pos_"+dataOrigin, 'ab')
query_pos_out = open(save_folder+"/query_pos_"+dataOrigin, 'ab')
ans_out = open(save_folder+"/ans_"+dataOrigin, 'ab')


In [None]:
passage_array = []
query_array = []
passage_pos_array = []
query_pos_array = []
answer_array = []

for i, eachfile in enumerate(os.listdir(dataPath)):

    with open(os.path.join(dataPath, eachfile), 'r') as fileint:
        text = fileint.read()
        inputs = text.split('\n\n')

        passage = normalizeString(inputs[1])
        query = normalizeString(inputs[2])
        answer = inputs[3]
       

        passage_vector, passage_pos_vector, query_vector, query_pos_vector, answer_vector, max_passage_length, max_sentence_length = prepare(passage, query, answer, glove_vectors, pos_dict, max_sentence_length, max_passage_length, dataOrigin)

        pickle.dump(passage_vector, passage_out)
        pickle.dump(query_vector, query_out)
        pickle.dump(passage_pos_vector, passage_pos_out)
        pickle.dump(query_pos_vector, query_pos_out)
        pickle.dump(answer_vector, ans_out)

        #print(i)
        

print("max_passage_length:", max_passage_length)
print("max_sentence_length:", max_sentence_length)



### Pointer Network-based Model
The original pointer network attention is $v(\tanh (W_1 E + W_2 D))$ where E is the encoder output and D is the decoder output, and W_1, W_2 and V are outputs from linear layers. Here, the addition is changed to a dot product, since similarity between the passage (E) and query (D) representations (here obtained through a GRU layer) is required and SELU is used instead of tanh as the non-linear activation function. The final output is probabilities for each index in the passage obtained from a softmax operation.

Uncommment/Comment relevant lines under pointer_attention function to use original implementation or ReLU as activation functions. (Number of layers were reduced and GRU used to save time and to enable faster word-wise dot products. Also remove softmax from final layer if decoder [input_passage * attention_weights] is being used - (but) less successful than using softmax over pointer attention weights directly)

Number of parameters: 12,717,312 (with POS-tags) 12,660,480 (without POS-tags)

In [None]:
class qapointernet(nn.Module):
    def __init__(self, hidden_size, dimensions, max_passage_length):
        super(qapointernet, self).__init__()
        self.hidden_size = hidden_size
        
        #self.passage_gru_layer = nn.GRU(dimensions, hidden_size, num_layers=4, dropout=0.2, batch_first=True,bidirectional=True)
        self.passage_gru_layer = nn.GRU(dimensions, hidden_size, batch_first=True,bidirectional=True)
        
        #self.query_gru_layer = nn.GRU(dimensions, hidden_size, num_layers=4, dropout=0.02, batch_first=True,bidirectional=True)
        self.query_gru_layer = nn.GRU(dimensions, hidden_size, batch_first=True,bidirectional=True)
        
        #self.linear_layer = nn.Linear(hidden_size, max_passage_length)
        self.SELU = nn.SELU()
        
        self.W1 = nn.Linear(2*hidden_size, 2*hidden_size, bias=False)
        self.W2 = nn.Linear(2*hidden_size, 2*hidden_size, bias=False)
        self.V = nn.Linear(max_passage_length, max_passage_length, bias=False)
        

    def passage_forward(self, input_passage, h_init, c_init):
        
        embedded  = input_passage
        h_seq, h_final = self.passage_gru_layer(embedded, h_init)
        #print(h_final, c_final)
        return h_seq, h_final

    def query_forward(self, input_query, h_init, c_init):
        
        embedded = input_query
        h_seq, h_final = self.query_gru_layer(embedded, h_init)
        return h_seq, h_final

    def pointer_attention(self, left_input, right_input):
        W_left = self.W1(left_input)
        W_right = self.W2(right_input)

        #print(W_left.size(), W_right.size())
        
        #mix_weights = self.V(F.relu(torch.einsum('ijk,ik->ij', (W_left,W_right)))) #relu as non-linear
        mix_weights = self.V(self.SELU(torch.einsum('ijk,ik->ij', (W_left,W_right)))) 
        #mix_weights = self.V(self.SELU(torch.matmul(W_left,W_right.t()))) #word-wise dot product
        #mix_weights = self.V(torch.tanh(torch.add(W_left,W_right))) #original pointer network eqn implemention (change V input dims to 2*hidden size and pass h_final states for p and q)
        
        mix_weights = F.softmax(mix_weights, dim=1) #remove softmax if training with decoder
        return mix_weights



In [None]:
def add_padding(max_passage_length, input_vector, dimensions=300):
    
    difference = max_passage_length - len(input_vector)
    padding = torch.zeros(difference, dimensions)
    input_vector.extend(padding)
    
    assert len(input_vector) == max_passage_length
    return input_vector

def repeat_query(max_passage_length, query_vector, dimensions=300):
    
    num_loops = (max_passage_length//len(query_vector)) + 1
    num_loops += num_loops + len(query_vector)
    paddingsize = max_passage_length % len(query_vector)
    padding = torch.zeros(paddingsize, dimensions)

    query_vector = query_vector.repeat(num_loops,1)
    query_vector = query_vector[:max_passage_length]

    assert len(query_vector) == max_passage_length
    return query_vector

Paths to output files from prepData function

In [None]:
training_passage_file = 'train_size_full/passage_train'
training_passage_posfile = 'train_size_full/passage_pos_train'
training_query_file = 'train_size_full/query_train'
training_query_posfile = 'train_size_full/query_pos_train'
training_answer_file = 'train_size_full/ans_train'
dev_passage_file = 'dev_size_full/passage_dev'
dev_passage_posfile = 'dev_size_full/passage_pos_dev'
dev_query_file = 'dev_size_full/query_dev'
dev_query_posfile = 'dev_size_full/query_pos_dev'
dev_answer_file = 'dev_size_full/ans_dev'
test_passage_file = 'test_size_full/passage_test'
test_passage_posfile = 'test_size_full/passage_pos_test'
test_query_file = 'test_size_full/query_test'
test_query_posfile = 'test_size_full/query_pos_test'
test_answer_file = 'test_size_full/ans_test'

Change dimensions and criterion by commenting/uncommenting accordingly. NLL/CrossEntropy are typically for classification so they are not particularly useful here. MSE and L1 losses cause gradients/weights to reduce to all zeros within 2-3 epochs, since the target output is a sparse matrix with mostly zeros of size max_passage_length. Binary Cross-Entropy loss has slightly better success.

In [None]:
max_passage_length = 2000
max_query_length = 2000
dimensions=337 #300 if running without Part of Speech 
hidden_size = 128
batch_size = 100
epochs = 10
learning_rate = 0.01
criterion = torch.nn.BCELoss()
#criterion = torch.nn.MSELoss() 
#criterion = torch.nn.NLLLoss() #CrossEntropyLoss()
#criterion = torch.nn.L1Loss()
device = torch.device("cuda") #("cpu")
start = time.time()
total_sample_size=0

net = qapointernet(hidden_size, dimensions, max_passage_length)
net = net.to(device)

### Training Strategies

Train directly with input or input_with_PartofSpeech (comment/uncomment accordingly). Four different strategies were tried to find a loss that works in this scenario (multiple, differing number of ones in answer one-hot matrix), for which accuracies on the development set (without POS) are reported. The accuracy is calculated by getting the index with the highest predicted probability, and if it matches _any_ index in the list of known answer indices, it is considered a match - since the referred entity would be the same no matter which index in the list is predicted. 

(1) Train directly with one-hot answer matrix, and do 
    (a) similarity between passage and query final representations: Unsuccessful (dev acc: __0.4%__)
    (b) dot-product for each word in the passage with query representation. Unsuccessful, but works better with (4). (dev acc: __1.86%__)
    (c) similarity between passage and query as encoder output, with answer-matrix as decoder output. Unsuccessful. (dev acc: __1.04%__)

(2) Train with query expanded to passage length. Essentially trying to match query-length context in the passage; since the correct entity context is just one sentence in the passage. Unsuccessful, but works better when used with (4). (dev acc: __2.37%__)

(3) Train with one-hot answer matrix for first 3 epochs or so; then reduce the number of ones predicted in the answer by picking the top K predicted already to train. Reduce K for each epoch until K=1. Also tried with initialising net weights to identity matrices before training. The idea is, if the passage words and query similarity provides a strong enough signal, the answer entity with the correct context is more likely to be preferred early, and with repeated filtering, the answer entity in the correct context might be in the top K and eventually top 1. Unsuccessful. (dev acc: __1.03%__)

(4) Sum over the probabilities of all the known answer/gold indices. Target sum is one - basically predict zero probability for all words except answer entities, but individual probabilities are not important. Collectively try to maximise all index probabilities, under the assumption that the correct context will provide maximum contribution. Two parts to this: 
    (a) Train with decoder input=input_passage. Maximise sum(softmax(input_passage * attention weights)): Unsuccessful. (dev acc: __3.628%__) 
    (b) Remove decoder. Directly maximise sum(softmax(pointer_attention_weights)): somewhat (relatively) better. (dev acc: __6.67%__)

(5) __Possible future work__: 
    (a) Represent each word in passage by similarity with bi-directional query-size-window context, not including the word itself. Essentially, because the entities/placeholders are randomly initialised and have no similarity with the query, the dependence is entirely on the surrounding context.
    (b) Model by stride: pick a reasonable stride size in the passage within which answer entity may not repeat. Learn attention weights for similarity with query with each stride then combine.


In [None]:
for epoch in range(epochs):
    with open(training_passage_file, 'rb') as trainpassageint, open(training_passage_posfile, 'rb') as trainpassageposint, open(training_query_file, 'rb') as trainqueryint, open(training_query_posfile, 'rb') as trainqueryposint, open(training_answer_file, 'rb') as trainansint, open("logfile", "a+") as logint:
        match = 0
        if epoch >=2:
            learning_rate = learning_rate/3
        
        optimizer=torch.optim.SGD(net.parameters(), lr=learning_rate ) #RMSprop/Adam stagnate too quickly
        running_loss = 0
        current_batch_size = 0
        current_sample_size = 1
        up_batch_passage = torch.Tensor()#unpadded; pad after shuffling
        up_batch_query = torch.Tensor()
        up_batch_answer = torch.FloatTensor()
        batch_input_indices = random.sample(range(batch_size),batch_size) #shuffle 
        batch_passage = torch.Tensor()#(max_passage_length, dimensions)
        batch_query = torch.Tensor()
        batch_answer = torch.FloatTensor()
        
        running_loss = 0

        passage_h = torch.zeros(2,batch_size, hidden_size)
        passage_c =torch.zeros(2,batch_size, hidden_size)
        query_h = torch.zeros(2,batch_size, hidden_size)
        query_c = torch.zeros(2,batch_size, hidden_size)

        passage_h = passage_h.to(device)
        passage_c = passage_c.to(device)
        query_h = query_h.to(device)
        query_c = query_c.to(device)
        
        #while trainpassageint: #without part-of-speech
        while trainpassageposint: #part-of-speech included

            optimizer.zero_grad()
            try:
                #passage = pickle.load(trainpassageint) #without part-of-speech
                passage = pickle.load(trainpassageposint) #part-of-speech included
                
                #query = pickle.load(trainqueryint) #without part-of-speech
                query = pickle.load(trainqueryposint) #part-of-speech included
                
                ans = pickle.load(trainansint)
                
                current_passage_length = len(passage)
                passage_vector = torch.stack(add_padding(max_passage_length, passage, dimensions))
              
                
                #Slightly more successful than expand_as
                query_vector = repeat_query(max_query_length, torch.stack(query), dimensions) #torch.stack(add_padding(max_query_length, query, dimensions)) #repeat_query(max_passage_length, torch.stack(query), dimensions)
              

                ans_pad = torch.zeros(max_passage_length-len(ans))
                ans.extend(ans_pad)
                answer_vector = torch.FloatTensor(ans)


                current_batch_size += 1
                current_sample_size += 1
                

                batch_passage = torch.cat((batch_passage,passage_vector),0)
                batch_query = torch.cat((batch_query,query_vector),0)
                #up_batch_passage = torch.cat((up_batch_passage,passage),0)
                #up_batch_query = torch.cat((up_batch_query,query),0)

                batch_answer = torch.cat((batch_answer,answer_vector),0)
                #up_batch_answer = torch.cat((up_batch_answer,ans),0)


                if current_batch_size == batch_size:
                    #do stuff
                    #at the end, reset current_batch_size = 0
                    #at the end, reset batch_x_vector = []
                    #batch_sort_indices = sorted(range(len(up_batch_passage)), key=len(up_batch_passage.__getitem__))
                    
                    total_sample_size += batch_size
                    passage_h=passage_h.detach()
                    passage_c=passage_c.detach()
                    query_h = query_h.detach()
                    query_c = query_c.detach()
                    passage_h=passage_h.requires_grad_()
                    passage_c=passage_c.requires_grad_()
                    query_h = query_h.requires_grad_()
                    query_c=query_c.requires_grad_()
                     

                    input_batch_passage = batch_passage.view(batch_size, max_passage_length,-1)
                    input_batch_query = batch_query.view(batch_size, max_query_length, -1)


                    input_batch_answer = batch_answer.view(batch_size, max_passage_length)



                    input_batch_passage=input_batch_passage.to(device)
                    input_batch_query=input_batch_query.to(device)
                    #input_batch_answer=input_batch_answer.to(device) #send mutated answer batch to device

                    passage_hseq, passage_hfinal = net.passage_forward(input_batch_passage, passage_h, passage_c)
                    query_hseq, query_hfinal = net.query_forward(input_batch_query, query_h, query_c)
                    query_hfinal_ =  query_hfinal.permute(1,0,2).contiguous().view(batch_size,-1)
                    

                    #Train with similarity between overall representations of passage and query instead of word by word. Unsuccessful.
                    #pointer_attention = net.pointer_attention(passage_hfinal.permute(1,0,2).contiguous().view(batch_size,-1), query_hfinal.permute(1,0,2).contiguous().view(batch_size,-1))
                    pointer_attention = net.pointer_attention(passage_hseq, query_hfinal_)
                    
                    
                   
                    batch_gold_indices = torch.nonzero(batch_answer) #get indices for answer_entity (ones) 

                    print(current_sample_size)

                    pred_probs = pointer_attention.view(batch_size*max_passage_length,-1)

                                       
                    samplewise_pred_probs = torch.empty(batch_size) #Variable(torch.FloatTensor(batch_size), requires_grad=True)
                    for i in range(0,batch_size):
                        ## Gather indices of answers. Use these to compute sum of possible answer indices.
                        ans_idx = torch.nonzero(batch_answer[max_passage_length*i:max_passage_length*i+max_passage_length])
                        

                        ##Decoder: multiplying input passage embedding with attention weights. Unsucessful.
                        #samplewise_pred_probs[i] = torch.sum(F.softmax(torch.einsum('ijk,ij->ij', (passage_hseq,pointer_attention)),dim=1).view(batch_size*max_passage_length,-1).gather(0,ans_idx))

                        ##Direct sum over pointer attention. Slightly more successful.
                        samplewise_pred_probs[i] = torch.sum(pointer_attention.view(batch_size*max_passage_length,-1).gather(0,ans_idx))

                    ### Trains on K answer indices set to ones. Trains on all for the first 3 epochs, then picks top K predicted indicesto train on further, until top 1 is reached. Unsuccessful.
                    '''K = initialize before epoch to len(batch_gold_indices)
                   
                    if K > 2 and epoch>3:
                        K = K//2
                    elif epoch>3:
                        K=1
                        
                    #print(K)
                    #(aval, aidx) = torch.topk(batch_answer,batch_size*K)
                    (aval, aidx) = torch.topk(short_list, K, dim=0)
                    topk_gold_indices = batch_gold_indices.gather(0,aidx) #batch_gold_indices = short_list indices                    
                    batch_gold_outputs = torch.zeros(batch_size*max_passage_length,1)
                    short_list = pred_probs.gather(0,batch_gold_indices)
                    short_list_sum = torch.sum(short_list)#.to(device)
                    batch_gold_outputs.scatter_(0,topk_gold_indices,1) #modified batch gold outputs with topk predictions=1
                    batch_gold_outputs=batch_gold_outputs.view(batch_size, max_passage_length)
                    #batch_gold_outputs=batch_gold_outputs.to(device)'''
                    
                    
                    ## Train directly on one-hot answer index matrix. Unsuccessful.
                    #loss = criterion(pointer_attention, input_batch_answer) 
                    
                    ##Train on sum of answer indices from softmaxed attention. Sum of all probs at the indices corresponding to answer entity from each passage must sum to 1. Collective loss: somewhat successful.
                    input_batch_answer_sum=torch.ones(batch_size) #Sum of probs of all answer_indices for each passage in the batch = 1
                    
                    loss = criterion(samplewise_pred_probs,input_batch_answer_sum)
                    #print(loss.item())
                    running_loss += loss.item()

                    #print("PRE")
                    #[print(p.size(),p[0][0]) for p in net.parameters()]
                    #a = list(net.parameters())[0].clone()
                    loss.backward()
                    optimizer.step()
                    #b = list(net.parameters())[0].clone()
                    #print("POST: ", torch.equal(a.data, b.data)) # check if gradients are changing
                    #print("PARAM GRAD:", list(net.parameters())[0].grad)
                    #print("POST")
                    #[print(p) for p in net.named_parameters()]
                    
                    #reset for next batch
                    current_batch_size = 0
                    batch_passage = torch.Tensor()#(max_passage_length, dimensions)
                    batch_query = torch.Tensor()
                    batch_answer = torch.FloatTensor()
                    
                    #Running accuracy. A prediction is correct if the index with highest probability is one of the answer indices. 
                    (pval, pidx) = torch.topk(pointer_attention, 1, dim=1) 
                    
                    for pred_idx in pidx.view(-1):
                        if pred_idx in batch_gold_indices:
                            match += 1
                

                
            except EOFError:
                break
            
            
        elapsed = time.time()-start
        total_loss = running_loss/batch_size
        print("epoch=", epoch, "\t time=", elapsed, "\t exp(loss)=", math.exp(total_loss), "\t train_acc=", match, match/total_sample_size)
        logint.write("epoch="+str(epoch)+"\t time="+str(elapsed)+"\t exp(loss)="+str(math.exp(total_loss))+"\t train_acc="+str(match)+" "+str( match/total_sample_size)+"bs: "+str(batch_size)+" size:"+size+"\n")
        torch.save({'epoch':epoch, 'model_state_dict':net.state_dict(), 'optimizer_state_dict':optimizer.state_dict(), 'loss': running_loss, 'tsize':size, 'bsize':batch_size, 'hsize':hidden_size},"model_file_"+str(epoch))

        check_accuracy(dev_passage_file, dev_passage_posfile, dev_query_file, dev_query_posfile, dev_answer_file)


In [None]:
def check_accuracy(passage_file, passage_posfile, query_file, query_posfile, answer_file):
    with open(passage_file, 'rb') as passageint, open(passage_posfile, 'rb') as passageposint, open(query_file, 'rb') as queryint, open(query_posfile, 'rb') as queryposint, open(answer_file, 'rb') as ansint, open("pred_logfile", "a+") as logint:
            match = 0

            current_batch_size = 0
            current_sample_size = 1
            batch_passage = torch.Tensor() #(max_passage_length, dimensions)
            batch_query = torch.Tensor()
            batch_answer = torch.FloatTensor()


            passage_h = torch.zeros(2,batch_size, hidden_size)
            passage_c =torch.zeros(2,batch_size, hidden_size)
            query_h = torch.zeros(2,batch_size, hidden_size)
            query_c = torch.zeros(2,batch_size, hidden_size)

            passage_h = passage_h.to(device)
            passage_c = passage_c.to(device)
            query_h = query_h.to(device)
            query_c = query_c.to(device)

            #while passageposint:
            while passageint:

                try:
                    #passage = pickle.load(trainpassageint)
                    passage = pickle.load(passageposint) #part-of-speech included
                    #query = pickle.load(trainqueryint)
                    query = pickle.load(trainqueryposint) #part-of-speech included
                    ans = pickle.load(trainansint)

                    current_passage_length = len(passage)
                    passage_vector = torch.stack(add_padding(max_passage_length, passage, dimensions))

                    query_vector = repeat_query(max_query_length, torch.stack(query), dimensions) #torch.stack(add_padding(max_query_length, query, dimensions)) #repeat_query(max_passage_length, torch.stack(query), dimensions)

                    ans_pad = torch.zeros(max_passage_length-len(ans))
                    ans.extend(ans_pad)
                    answer_vector = torch.FloatTensor(ans)


                    current_batch_size += 1
                    current_sample_size += 1

                    batch_passage = torch.cat((batch_passage,passage_vector),0)
                    batch_query = torch.cat((batch_query,query_vector),0)
                    batch_answer = torch.cat((batch_answer,answer_vector),0)



                    if current_batch_size == batch_size:
                        passage_h=passage_h.detach()
                        passage_c=passage_c.detach()
                        query_h = query_h.detach()
                        query_c = query_c.detach()
                        passage_h=passage_h.requires_grad_()
                        passage_c=passage_c.requires_grad_()
                        query_h = query_h.requires_grad_()
                        query_c=query_c.requires_grad_()


                        input_batch_passage = batch_passage.view(batch_size, max_passage_length,-1)
                        input_batch_query = batch_query.view(batch_size, max_query_length, -1)

                        input_batch_answer = batch_answer.view(batch_size, max_passage_length)

                        input_batch_passage=input_batch_passage.to(device)
                        input_batch_query=input_batch_query.to(device)

                        passage_hseq, passage_hfinal = net.passage_forward(input_batch_passage, passage_h, passage_c)
                        query_hseq, query_hfinal = net.query_forward(input_batch_query, query_h, query_c)
                        query_hfinal_ =  query_hfinal.permute(1,0,2).contiguous().view(batch_size,-1)

                        pointer_attention = net.pointer_attention(passage_hseq, query_hfinal_)


                        batch_gold_indices = torch.nonzero(batch_answer)

                        print(current_sample_size)

                        pred_probs = pointer_attention.view(batch_size*max_passage_length,-1)


                        current_batch_size = 0
                        batch_passage = torch.Tensor()#(max_passage_length, dimensions)
                        batch_query = torch.Tensor()
                        batch_answer = torch.FloatTensor()

                        (pval, pidx) = torch.topk(pointer_attention, 1, dim=1) #pick top (max prob) index

                        for pred_idx in pidx.view(-1): #if top index is in the list of gold_indices, consider it as correct
                            if pred_idx in batch_gold_indices:
                                match += 1



                except EOFError:
                    break

            print("accuracy: ", match/total_sample_size)

In [None]:
check_accuracy(test_passage_file, test_passage_posfile, test_query_file, test_query_posfile, test_answer_file)

### Analysis

__Final test data accuracies__: __6.44%__ (without POS) __5.62%__ (with POS) 

Possible reasons for poor performance:
- Target is a unique, sparse matrix of document size, the relevant answer entities are randomly initialised, and to predict indices, the design does not necessarily provide signals from the most useful context.
- Loss calculation is not entirely straightforward.
- Using a better contextual embedding or more layers/parameters may be helpful. 
