In [None]:
import ujson as json
import numpy as np
from tqdm import tqdm
import os
from torch import optim, nn
import time
import shutil
import random
import torch
from torch.autograd import Variable
import sys
from torch.nn import functional as F
import joblib
import re
from collections import Counter
import string
import pickle
import copy
import traceback
import math
from torch.nn import init
from torch.nn.utils import rnn

In [None]:
# Predefine variables

IGNORE_INDEX = -100
nll_sum = nn.CrossEntropyLoss(reduction = 'sum', ignore_index=IGNORE_INDEX)
nll_average = nn.CrossEntropyLoss(reduction = 'mean', ignore_index=IGNORE_INDEX)
nll_all = nn.CrossEntropyLoss(reduction = 'none', ignore_index=IGNORE_INDEX)

word_mat = joblib.load('word_emb.pkl')
char_mat = joblib.load('char_emb.pkl')
with open('eval_examples_dev.json', "r") as fh:
    dev_eval_file = json.load(fh)
idx2word_dict = joblib.load('idx2word.pkl')

batch_size = 64
para_limit = 1000
ques_limit = 80
char_limit = 16
sent_limit = 100

In [None]:
# Set seeds
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [None]:
# Function to divide data into batches

class DataIterator(object):
    def __init__(self, buckets, bsz, para_limit, ques_limit, char_limit, shuffle, sent_limit):
        self.buckets = buckets
        self.bsz = bsz
        
        # Fix class' para_limit and ques_limit
        if para_limit is not None and ques_limit is not None:
            self.para_limit = para_limit
            self.ques_limit = ques_limit
        else:
            para_limit, ques_limit = 0, 0
            for bucket in buckets:
                for dp in bucket:
                    para_limit = max(para_limit, dp['context_idxs'].size(0))
                    ques_limit = max(ques_limit, dp['ques_idxs'].size(0))
            self.para_limit, self.ques_limit = para_limit, ques_limit
        self.char_limit = char_limit
        self.sent_limit = sent_limit
        
        self.num_buckets = len(self.buckets)
        
        # Keep track of datapoints to choose from
        self.bkt_pool = [i for i in range(self.num_buckets) if len(self.buckets[i]) > 0]
        
        # Shuffle the datapoints if shuffle = True
        if shuffle:
            for i in range(self.num_buckets):
                random.shuffle(self.buckets[i])
        self.bkt_ptrs = [0 for i in range(self.num_buckets)]
        self.shuffle = shuffle
    
    def __iter__(self):
        
        # Create empty tensors for each batch
        context_idxs = torch.LongTensor(self.bsz, self.para_limit)
        ques_idxs = torch.LongTensor(self.bsz, self.ques_limit)
        context_char_idxs = torch.LongTensor(self.bsz, self.para_limit, self.char_limit)
        ques_char_idxs = torch.LongTensor(self.bsz, self.ques_limit, self.char_limit)
        y1 = torch.LongTensor(self.bsz)
        y2 = torch.LongTensor(self.bsz)
        q_type = torch.LongTensor(self.bsz)
        start_mapping = torch.Tensor(self.bsz, self.para_limit, self.sent_limit)
        end_mapping = torch.Tensor(self.bsz, self.para_limit, self.sent_limit)
        all_mapping = torch.Tensor(self.bsz, self.para_limit, self.sent_limit)
        is_support = torch.LongTensor(self.bsz, self.sent_limit)
        
        # Keep adding to batch till bucket pool has no elements
        while True:
            if len(self.bkt_pool) == 0: 
                break
            
            # Choose one data point from bucket pool
            bkt_id = random.choice(self.bkt_pool) if self.shuffle else self.bkt_pool[0]
            start_id = self.bkt_ptrs[bkt_id]
            cur_bucket = self.buckets[bkt_id]
            cur_bsz = min(self.bsz, len(cur_bucket) - start_id)

            ids = []
            
            # Define current batch, sort according to size of context
            cur_batch = cur_bucket[start_id: start_id + cur_bsz]
            cur_batch.sort(key=lambda x: (x['context_idxs'] > 0).long().sum(), reverse=True)
            
            max_sent_cnt = 0
            
            # Fill tensors with 0 or otherwise
            for mapping in [start_mapping, end_mapping, all_mapping]:
                mapping.zero_()
            is_support.fill_(IGNORE_INDEX)
            
            # Fill rest of the tensors by iterating over each datapoint in the batch
            for i in range(len(cur_batch)):
                context_idxs[i].copy_(cur_batch[i]['context_idxs'])
                ques_idxs[i].copy_(cur_batch[i]['ques_idxs'])
                context_char_idxs[i].copy_(cur_batch[i]['context_char_idxs'])
                ques_char_idxs[i].copy_(cur_batch[i]['ques_char_idxs'])
                
                # Keep track of question types (0,1,2,3) based on y1 and y2
                if cur_batch[i]['y1'] >= 0:
                    y1[i] = cur_batch[i]['y1']
                    y2[i] = cur_batch[i]['y2']
                    q_type[i] = 0
                elif cur_batch[i]['y1'] == -1:
                    y1[i] = IGNORE_INDEX
                    y2[i] = IGNORE_INDEX
                    q_type[i] = 1
                elif cur_batch[i]['y1'] == -2:
                    y1[i] = IGNORE_INDEX
                    y2[i] = IGNORE_INDEX
                    q_type[i] = 2
                elif cur_batch[i]['y1'] == -3:
                    y1[i] = IGNORE_INDEX
                    y2[i] = IGNORE_INDEX
                    q_type[i] = 3
                else:
                    assert False
                ids.append(cur_batch[i]['id'])
                
                # Fill start and end of each supporting facts, also if it is a supporting fact
                for j, cur_sp_dp in enumerate(cur_batch[i]['start_end_facts']):
                    if j >= self.sent_limit: break
                    if len(cur_sp_dp) == 3:
                        start, end, is_sp_flag = tuple(cur_sp_dp)
                        end = min(end, para_limit)
                    else:
                        start, end, is_sp_flag, is_gold = tuple(cur_sp_dp)
                        end = min(end, para_limit)
                    if start < end:
                        start_mapping[i, start, j] = 1
                        end_mapping[i, end-1, j] = 1
                        all_mapping[i, start:end, j] = 1
                        is_support[i, j] = int(is_sp_flag)
                max_sent_cnt = max(max_sent_cnt, len(cur_batch[i]['start_end_facts']))
            
            # Get max context length and question length to index the tensors
            input_lengths = (context_idxs[:cur_bsz] > 0).long().sum(dim=1)
            max_c_len = int(input_lengths.max())
            max_q_len = int((ques_idxs[:cur_bsz] > 0).long().sum(dim=1).max())
            
            # Keep track of size of batch
            self.bkt_ptrs[bkt_id] += cur_bsz
            if self.bkt_ptrs[bkt_id] >= len(cur_bucket):
                self.bkt_pool.remove(bkt_id)
            
            yield {'context_idxs': context_idxs[:cur_bsz, :max_c_len].contiguous(),
                'ques_idxs': ques_idxs[:cur_bsz, :max_q_len].contiguous(),
                'context_char_idxs': context_char_idxs[:cur_bsz, :max_c_len].contiguous(),
                'ques_char_idxs': ques_char_idxs[:cur_bsz, :max_q_len].contiguous(),
                'context_lens': input_lengths,
                'y1': y1[:cur_bsz],
                'y2': y2[:cur_bsz],
                'ids': ids,
                'q_type': q_type[:cur_bsz],
                'is_support': is_support[:cur_bsz, :max_sent_cnt].contiguous(),
                'start_mapping': start_mapping[:cur_bsz, :max_c_len, :max_sent_cnt],
                'end_mapping': end_mapping[:cur_bsz, :max_c_len, :max_sent_cnt],
                'all_mapping': all_mapping[:cur_bsz, :max_c_len, :max_sent_cnt]}

In [None]:
# LockedDropout masks a certain proportion of the input given a dropout probability ONLY during training
class LockedDropout(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = dropout

    def forward(self, x):
        dropout = self.dropout
        if not self.training:
            return x
        m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - dropout)
        mask = Variable(m.div_(1 - dropout), requires_grad=False)
        mask = mask.expand_as(x)
        return mask * x

# Encoder RNN class
class EncoderRNN(nn.Module):
    # Initial input size, number of nodes in each layer, number of layers, Concat outputs bool, Bidirectional
    # Encoder bool, dropout prob, returning for final datapoint bool
    def __init__(self, input_size, num_units, nlayers, concat, bidir, dropout, return_last):
        super().__init__()
        self.rnns = []
        for i in range(nlayers):
            if i == 0: # First layer
                input_size_ = input_size
                output_size_ = num_units
            else: # Other layers
                input_size_ = num_units if not bidir else num_units * 2
                output_size_ = num_units
            self.rnns.append(nn.GRU(input_size_, output_size_, 1, bidirectional=bidir, batch_first=True))
        self.rnns = nn.ModuleList(self.rnns)
        
        # Keep track of weights
        self.init_hidden = nn.ParameterList([nn.Parameter(torch.Tensor(2 if bidir else 1, 1, num_units).zero_()) for _ in range(nlayers)])
        self.dropout = LockedDropout(dropout)
        self.concat = concat
        self.nlayers = nlayers
        self.return_last = return_last

    def reset_parameters(self):
        for rnn in self.rnns:
            for name, p in rnn.named_parameters():
                if 'weight' in name:
                    p.data.normal_(std=0.1)
                else:
                    p.data.zero_()

    def get_init(self, bsz, i):
        # Expand parameters to entire batch
        return self.init_hidden[i].expand(-1, bsz, -1).contiguous()

    def forward(self, input, input_lengths=None):
        # Batch size and sequence length
        bsz, slen = input.size(0), input.size(1)
        output = input
        outputs = []
        if input_lengths is not None:
            lens = input_lengths.data.cpu().numpy()
        for i in range(self.nlayers):
            # Get initial weights
            hidden = self.get_init(bsz, i)
            
            # Apply dropout or mask a small % of output (in this case input)
            output = self.dropout(output)
            if input_lengths is not None:
                # Pad input to max length
                output = rnn.pack_padded_sequence(output, lens, batch_first=True)
            # Apply the GRU to this unit
            output, hidden = self.rnns[i](output, hidden)
            if input_lengths is not None:
                output, _ = rnn.pad_packed_sequence(output, batch_first=True)
                if output.size(1) < slen: # used for parallel
                    padding = Variable(output.data.new(1, 1, 1).zero_())
                    output = torch.cat([output, padding.expand(output.size(0), slen-output.size(1), output.size(2))], dim=1)
            if self.return_last:
                outputs.append(hidden.permute(1, 0, 2).contiguous().view(bsz, -1))
            else:
                outputs.append(output)
        if self.concat:
            return torch.cat(outputs, dim=2)
        # Return final output
        return outputs[-1]

class BiAttention(nn.Module):
    def __init__(self, input_size, dropout):
        super().__init__()
        self.dropout = LockedDropout(dropout)
        
        # Linear transformations to input and memory
        self.input_linear = nn.Linear(input_size, 1, bias=False)
        self.memory_linear = nn.Linear(input_size, 1, bias=False)
        
        # Uniform attention initialized
        self.dot_scale = nn.Parameter(torch.Tensor(input_size).uniform_(1.0 / (input_size ** 0.5)))

    def forward(self, input, memory, mask):
        bsz, input_len, memory_len = input.size(0), input.size(1), memory.size(1)

        input = self.dropout(input)
        memory = self.dropout(memory)
        
        # apply linear transformations
        input_dot = self.input_linear(input)
        memory_dot = self.memory_linear(memory).view(bsz, 1, memory_len)
        
        # Batch multiplication of inputs with memory without linear transformations
        cross_dot = torch.bmm(input * self.dot_scale, memory.permute(0, 2, 1).contiguous())
        
        # Calculate attention as sum of matrix multiplications and linear transformations
        att = input_dot + memory_dot + cross_dot
        
        # Don't pay attention to all tokens or padding
        att = att - 1e30 * (1 - mask[:,None])
        
        # Attentions should sum to 1
        weight_one = F.softmax(att, dim=-1)
        output_one = torch.bmm(weight_one, memory)
        weight_two = F.softmax(att.max(dim=-1)[0], dim=-1).view(bsz, 1, input_len)
        output_two = torch.bmm(weight_two, input)

        return torch.cat([input, output_one, input*output_one, output_two*output_one], dim=-1)

class Model(nn.Module):
    def __init__(self, word_mat, char_mat):
        super().__init__()
        self.word_dim = 300
        
        # Convert to embeddings
        self.word_emb = nn.Embedding(len(word_mat), len(word_mat[0]), padding_idx=0)
        self.word_emb.weight.data.copy_(torch.from_numpy(word_mat))
        self.word_emb.weight.requires_grad = False
        self.char_emb = nn.Embedding(len(char_mat), len(char_mat[0]), padding_idx=0)
        self.char_emb.weight.data.copy_(torch.from_numpy(char_mat))
        
        # 1D convolution definition
        self.char_cnn = nn.Conv1d(8, 100, 5)
        self.char_hidden = 100
        self.hidden = 80
        
        # Building the architecture
        self.rnn = EncoderRNN(100+self.word_dim, 80, 1, True, True, 0.2, False)

        self.qc_att = BiAttention(80*2, 0.2)
        self.linear_1 = nn.Sequential(
                nn.Linear(80*8, 80),
                nn.ReLU()
            )

        self.rnn_2 = EncoderRNN(80, 80, 1, False, True, 0.2, False)
        self.self_att = BiAttention(80*2, 0.2)
        self.linear_2 = nn.Sequential(
                nn.Linear(80*8, 80),
                nn.ReLU()
            )

        self.rnn_sp = EncoderRNN(80, 80, 1, False, True, 0.2, False)
        self.linear_sp = nn.Linear(80*2, 1)

        self.rnn_start = EncoderRNN(80+1, 80, 1, False, True, 0.2, False)
        self.linear_start = nn.Linear(80*2, 1)

        self.rnn_end = EncoderRNN(80*3+1, 80, 1, False, True, 0.2, False)
        self.linear_end = nn.Linear(80*2, 1)

        self.rnn_type = EncoderRNN(80*3+1, 80, 1, False, True, 0.2, False)
        self.linear_type = nn.Linear(80*2, 3)

        self.cache_S = 0
    
    # To mask the output, not see future tokens, only see last but 15 tokens
    def get_output_mask(self, outer):
        S = outer.size(1)
        if S <= self.cache_S:
            return Variable(self.cache_mask[:S, :S], requires_grad=False)
        self.cache_S = S
        np_mask = np.tril(np.triu(np.ones((S, S)), 0), 15)
        self.cache_mask = outer.data.new(S, S).copy_(torch.from_numpy(np_mask))
        return Variable(self.cache_mask, requires_grad=False)
    
    # Forward pass on output of DataIterator
    def forward(self, context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False):
        para_size, ques_size, char_size, bsz = context_idxs.size(1), ques_idxs.size(1), context_char_idxs.size(2), context_idxs.size(0)
        
        # Number of datapoints to mask
        context_mask = (context_idxs > 0).float()
        ques_mask = (ques_idxs > 0).float()
        
        # Convert context and question to character embeddings
        context_ch = self.char_emb(context_char_idxs.contiguous().view(-1, char_size)).view(bsz * para_size, char_size, -1)
        ques_ch = self.char_emb(ques_char_idxs.contiguous().view(-1, char_size)).view(bsz * ques_size, char_size, -1)
        
        # CNN on character embeddings to get actual embeddings of characters from data and not the random input
        context_ch = self.char_cnn(context_ch.permute(0, 2, 1).contiguous()).max(dim=-1)[0].view(bsz, para_size, -1)
        ques_ch = self.char_cnn(ques_ch.permute(0, 2, 1).contiguous()).max(dim=-1)[0].view(bsz, ques_size, -1)
        
        # Word embeddings - not trained by model since we are already using glove
        context_word = self.word_emb(context_idxs)
        ques_word = self.word_emb(ques_idxs)
        
        # Concatenate word and character embeddings
        context_output = torch.cat([context_word, context_ch], dim=2)
        ques_output = torch.cat([ques_word, ques_ch], dim=2)
        
        # Run RNN on the embeddings
        context_output = self.rnn(context_output, context_lens)
        ques_output = self.rnn(ques_output)
        
        # Biattention model using RNN outputs
        output = self.qc_att(context_output, ques_output, ques_mask)
        output = self.linear_1(output)
        
        # RNN followed by self-attention model on context+ques
        output_t = self.rnn_2(output, context_lens)
        output_t = self.self_att(output_t, output_t, context_mask)
        output_t = self.linear_2(output_t)
        
        # Add biattention and self-attention outputs
        output = output + output_t
        
        # Predict if supporting fact using RNN
        sp_output = self.rnn_sp(output, context_lens)

        start_output = torch.matmul(start_mapping.permute(0, 2, 1).contiguous(), sp_output[:,:,self.hidden:])
        end_output = torch.matmul(end_mapping.permute(0, 2, 1).contiguous(), sp_output[:,:,:self.hidden])
        sp_output = torch.cat([start_output, end_output], dim=-1)
        sp_output = self.linear_sp(sp_output)
        sp_output_aux = Variable(sp_output.data.new(sp_output.size(0), sp_output.size(1), 1).zero_())
        predict_support = torch.cat([sp_output_aux, sp_output], dim=-1).contiguous()
        
        # Use supporting fact prediction to predict the answers
        sp_output = torch.matmul(all_mapping, sp_output)
        output = torch.cat([output, sp_output], dim=-1)
        
        # Get span start and span end predictions for answer, first predict start and use that + previous output
        # to predict end. Mask portions of outputs to avoid using future tokens for prediction
        output_start = self.rnn_start(output, context_lens)
        logit1 = self.linear_start(output_start).squeeze(2) - 1e30 * (1 - context_mask)
        output_end = torch.cat([output, output_start], dim=2)
        output_end = self.rnn_end(output_end, context_lens)
        logit2 = self.linear_end(output_end).squeeze(2) - 1e30 * (1 - context_mask)
        
        # Predict answer type
        output_type = torch.cat([output, output_end], dim=2)
        output_type = torch.max(self.rnn_type(output_type, context_lens), 1)[0]
        predict_type = self.linear_type(output_type)
        
        # Return the probabilities itself and not prediction
        if not return_yp: return logit1, logit2, predict_type, predict_support

        outer = logit1[:,:,None] + logit2[:,None]
        outer_mask = self.get_output_mask(outer)
        outer = outer - 1e30 * (1 - outer_mask[None].expand_as(outer))
        yp1 = outer.max(dim=2)[0].max(dim=1)[1]
        yp2 = outer.max(dim=1)[0].max(dim=1)[1]
        return logit1, logit2, predict_type, predict_support, yp1, yp2
    

In [None]:
# Model evaluation functions

def normalize_answer(s): # Some cleaning before comparing actual and predicted answers

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def f1_score(prediction, ground_truth): # Total correct / total predicted and total correct / actual tokens
    # Harmonic mean of precision and recall = F1 score
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)

    ZERO_METRIC = (0, 0, 0)

    if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC
    if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC

    prediction_tokens = normalized_prediction.split()
    ground_truth_tokens = normalized_ground_truth.split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return ZERO_METRIC
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1, precision, recall

def exact_match_score(prediction, ground_truth): # Check if the two answers are exactly matching
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

def convert_tokens(eval_file, qa_id, pp1, pp2, p_type): 
    # Get the predicted answers for each question ID in dict format
    answer_dict = {}
    for qid, p1, p2, type in zip(qa_id, pp1, pp2, p_type):
        if type == 0:
            context = eval_file[str(qid)]["context"]
            spans = eval_file[str(qid)]["spans"]
            start_idx = spans[p1][0]
            end_idx = spans[p2][1]
            answer_dict[str(qid)] = context[start_idx: end_idx]
        elif type == 1:
            answer_dict[str(qid)] = 'yes'
        elif type == 2:
            answer_dict[str(qid)] = 'no'
        elif type == 3:
            answer_dict[str(qid)] = 'noanswer'
        else:
            assert False
    return answer_dict

def evaluate(eval_file, answer_dict): # Get exact match anf F1 score
    f1 = exact_match = total = 0
    for key, value in answer_dict.items():
        total += 1
        ground_truths = eval_file[key]["answer"]
        prediction = value
        assert len(ground_truths) == 1
        cur_EM = exact_match_score(prediction, ground_truths[0])
        cur_f1, _, _ = f1_score(prediction, ground_truths[0])
        exact_match += cur_EM
        f1 += cur_f1

    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total

    return {'exact_match': exact_match, 'f1': f1}

# Get evaluation metrics for a batch
def evaluate_batch(data_source, model, max_batches, eval_file):
    answer_dict = {}
    sp_dict = {}
    total_loss, step_cnt = 0, 0
    iter = data_source
    for step, data in enumerate(iter):
        # Check number of batches to calculate metrics for
        if step >= max_batches and max_batches > 0: break

        context_idxs = Variable(data['context_idxs'], volatile=True)
        ques_idxs = Variable(data['ques_idxs'], volatile=True)
        context_char_idxs = Variable(data['context_char_idxs'], volatile=True)
        ques_char_idxs = Variable(data['ques_char_idxs'], volatile=True)
        context_lens = Variable(data['context_lens'], volatile=True)
        y1 = Variable(data['y1'], volatile=True)
        y2 = Variable(data['y2'], volatile=True)
        q_type = Variable(data['q_type'], volatile=True)
        is_support = Variable(data['is_support'], volatile=True)
        start_mapping = Variable(data['start_mapping'], volatile=True)
        end_mapping = Variable(data['end_mapping'], volatile=True)
        all_mapping = Variable(data['all_mapping'], volatile=True)

        logit1, logit2, predict_type, predict_support, yp1, yp2 = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=True)
        loss = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) + nll_sum(logit2, y2)) / context_idxs.size(0) + config.sp_lambda * nll_average(predict_support.view(-1, 2), is_support.view(-1))
        answer_dict_ = convert_tokens(eval_file, data['ids'], yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), np.argmax(predict_type.data.cpu().numpy(), 1))
        answer_dict.update(answer_dict_)

        total_loss += loss.data[0]
        step_cnt += 1
    loss = total_loss / step_cnt
    metrics = evaluate(eval_file, answer_dict)
    metrics['loss'] = loss

    return metrics

# Make predictions and export them
def predict(data_source, model, eval_file, prediction_file):
    answer_dict = {}
    sp_dict = {}
    sp_th = 0.3
    for step, data in enumerate(tqdm(data_source)):
        context_idxs = Variable(data['context_idxs'], volatile=True)
        ques_idxs = Variable(data['ques_idxs'], volatile=True)
        context_char_idxs = Variable(data['context_char_idxs'], volatile=True)
        ques_char_idxs = Variable(data['ques_char_idxs'], volatile=True)
        context_lens = Variable(data['context_lens'], volatile=True)
        start_mapping = Variable(data['start_mapping'], volatile=True)
        end_mapping = Variable(data['end_mapping'], volatile=True)
        all_mapping = Variable(data['all_mapping'], volatile=True)

        logit1, logit2, predict_type, predict_support, yp1, yp2 = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=True)
        answer_dict_ = convert_tokens(eval_file, data['ids'], yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), np.argmax(predict_type.data.cpu().numpy(), 1))
        answer_dict.update(answer_dict_)
        
        # Append all the predicted supporting facts for each ID
        predict_support_np = torch.sigmoid(predict_support[:, :, 1]).data.cpu().numpy()
        for i in range(predict_support_np.shape[0]):
            cur_sp_pred = []
            cur_id = data['ids'][i]
            for j in range(predict_support_np.shape[1]):
                if j >= len(eval_file[cur_id]['sent2title_ids']): break
                if predict_support_np[i, j] > sp_th:
                    cur_sp_pred.append(eval_file[cur_id]['sent2title_ids'][j])
            sp_dict.update({cur_id: cur_sp_pred})
    
    # Export answer and supporting facts
    prediction = {'answer': answer_dict, 'sp': sp_dict}
    with open(prediction_file, 'w') as f:
        json.dump(prediction, f)

In [None]:
def train():
    train_buckets = [torch.load('train_record.pkl')]
    dev_buckets = [torch.load('dev_record.pkl')]
    
    # We create functions to create the batches since we don't want them in memory all at once 
    def build_train_iterator():
        return DataIterator(train_buckets, batch_size, para_limit, ques_limit, char_limit, True, sent_limit)

    def build_dev_iterator():
        return DataIterator(dev_buckets, batch_size, para_limit, ques_limit, char_limit, False, sent_limit)
    print('Building Iterators Done')
    
    # Load the model
    model = Model(np.array(word_mat), np.array(char_mat))
    
    # Parallelize and build model
    ori_model = model
    model = nn.DataParallel(ori_model)
    
    lr = 0.5 
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr = 0.5)
    cur_patience = 0
    total_loss = 0
    global_step = 0
    best_dev_F1 = None
    stop_train = False
    start_time = time.time()
    eval_start_time = time.time()
    model.train()
    print('Start training')
    for epoch in range(100):
        for data in build_train_iterator():
            context_idxs = Variable(data['context_idxs'])
            ques_idxs = Variable(data['ques_idxs'])
            context_char_idxs = Variable(data['context_char_idxs'])
            ques_char_idxs = Variable(data['ques_char_idxs'])
            context_lens = Variable(data['context_lens'])
            y1 = Variable(data['y1'])
            y2 = Variable(data['y2'])
            q_type = Variable(data['q_type'])
            is_support = Variable(data['is_support'])
            start_mapping = Variable(data['start_mapping'])
            end_mapping = Variable(data['end_mapping'])
            all_mapping = Variable(data['all_mapping'])
            
            # Get predictions and calculate negative log likelihood loss
            logit1, logit2, predict_type, predict_support = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False)
            loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) + nll_sum(logit2, y2)) / context_idxs.size(0)
            loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1))
            loss = loss_1 + 0 * loss_2

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.data[0]
            global_step += 1

            if global_step % 10 == 0:
                cur_loss = total_loss / config.period
                elapsed = time.time() - start_time
                print('| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f}'.format(epoch, global_step, lr, elapsed*1000/100, cur_loss))
                total_loss = 0
                start_time = time.time()

            if global_step % 1000 == 0:
                model.eval()
                # Predict on dev data and get evaluation metrics
                metrics = evaluate_batch(build_dev_iterator(), model, 0, dev_eval_file, config)
                model.train()

                print('| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'.format(global_step//1000,
                    epoch, time.time()-eval_start_time, metrics['loss'], metrics['exact_match'], metrics['f1']))

                eval_start_time = time.time()

                dev_F1 = metrics['f1']
                if best_dev_F1 is None or dev_F1 > best_dev_F1:
                    best_dev_F1 = dev_F1
                    torch.save(ori_model.state_dict(), os.path.join('HOTPOT', 'model.pt'))
                    cur_patience = 0
                else:
                    cur_patience += 1
                    if cur_patience >= 1:
                        lr /= 2.0
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        if lr < 0.5 * 1e-2:
                            stop_train = True
                            break
                        cur_patience = 0
        if stop_train: break
    print('best_dev_F1 {}'.format(best_dev_F1))

In [None]:
def test(data_split):
    word_mat = joblib.load('word_emb.pkl')
    char_mat = joblib.load('char_emb.pkl')
    if data_split == 'dev':
        with open('eval_examples_dev.json', "r") as fh:
            dev_eval_file = json.load(fh)
    else:
        with open('eval_examples_test.json', 'r') as fh:
            dev_eval_file = json.load(fh)
    idx2word_dict = joblib.load('idx2word.pkl')

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)

    if data_split == 'dev':
        dev_buckets = get_buckets('dev_record.pkl')
        para_limit = 1000
        ques_limit = 80
    elif data_split == 'test':
        para_limit = None
        ques_limit = None
        dev_buckets = get_buckets('test_record.pkl')

    def build_dev_iterator():
        return DataIterator(dev_buckets, 64, para_limit,
            ques_limit, 16, False, 100)

    model = Model(word_mat, char_mat)
    ori_model = model
    ori_model.load_state_dict(torch.load(os.path.join('HOTPOT', 'model.pt')))
    model = nn.DataParallel(ori_model)

    model.eval()
    predict(build_dev_iterator(), model, dev_eval_file, config, data_split + '_predictions.json')

In [None]:
train()