In [83]:
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils

from typing import Any, Dict, List
import pdb

In [272]:
class TimeDistributed(nn.Module):
    def __init__(self, module):
        super(TimeDistributed, self).__init__()
        self._module = module
        
    def forward(self, x):
        """Shape -> (batch_size, time_steps, *shapes)"""
        bs = x.shape[0]
        ts = x.shape[1]
        x = x.view(bs*ts, *x.shape[2:])
        x = self._module(x)
        x = x.view(bs, ts, *x.shape[1:])
        return x
    

class Pack2Pad(nn.Module):
    def __init__(self, module):
        super(Pack2Pad, self).__init__()
        self._module = module
        
    def forward(self, x, lens, ind=None):
        # sort according to lens
        if ind is None:
            _, ind = torch.sort(lens, 0, descending=True)

        x = rnn_utils.pack_padded_sequence(x[ind], 
                                           lens[ind], 
                                           batch_first=True)
        # only want the first output
        x = self._module(x)[0]
        # reorder
        x, _ = rnn_utils.pad_packed_sequence(x, 
                                             batch_first=True)
        x = x[ind]
        return x
        

class Highway(nn.Module):
    "From AllenNLP"
    def __init__(self,
                 in_dim: int,
                 num_layers: int = 1):
        super(Highway, self).__init__()
        self.in_dim = in_dim
        self.layers = torch.nn.ModuleList([nn.Linear(in_dim, in_dim * 2)
                                            for _ in range(num_layers)])
        self.activ = nn.ReLU()
        
        # make bias positive to carry forward
        for layer in self.layers:
            layer.bias[in_dim:].data.fill_(1)

    def forward(self, x):
        for layer in self.layers:
            f_x, gate = layer(x).chunk(2, dim=-1)
            f_x = self.activ(f_x)
            gate = torch.sigmoid(gate)
            x = gate * x + (1 - gate) * f_x
        return x


class BidafCharEmbedding(nn.Module):
    def __init__(self):
        super(BidafCharEmbedding, self).__init__()
        self.embed = nn.Embedding(262, 16, padding_idx=0)
        self.conv = TimeDistributed(nn.Conv1d(16, 100, kernel_size=(5,), stride=(1,)))
        self.activ = nn.ReLU()
        self.dropout = nn.Dropout(p=0.2)
        
    def forward(self, x):
        "Takes Packed Sequence as Input"
        # (pack_len, seq_len)
        x = self.embed(x)
        # (pack_len, seq_length, in_channels) 
        x = x.transpose(-2,-1)
        x = self.activ(self.conv(x))
        x = x.max(-1)[0]
        # (pack_len, out_channels) 
        return x


class BidafEmbedding(nn.Module):
    def __init__(self):
        super(BidafEmbedding, self).__init__()
        self.word_embed = nn.Embedding(97914, 100)
        self.char_embed = BidafCharEmbedding()
        
    def forward(self, word, char):
        w = self.word_embed(word)
        c = self.char_embed(char)
        x = torch.cat([c, w], dim=-1)
        return x


class Matrix_Attention(nn.Module):
    def __init__(self):
        super(Matrix_Attention, self).__init__()
        self.attention = nn.Linear(600, 1)

    def forward(self, x, y):
        # expand so that x,y have the same dimension
        x = x.unsqueeze(2).expand(x.shape[0], x.shape[1], y.shape[1], x.shape[2])
        y = y.unsqueeze(1).expand(y.shape[0], x.shape[1], y.shape[1], y.shape[2])
        
        # combine x and y
        xy = torch.cat([x, y, x *y], dim=-1)
        
        return self.attention(xy).squeeze(-1)


def len2mask(lens):
    max_len = lens.max()
    # uint8
    mask = torch.arange(max_len).expand(len(lens), max_len) < lens.unsqueeze(1)
    return mask

def masked_softmax(x, mask):
    mask = mask.float()
    x = torch.softmax(x * mask, dim=-1)
    x = x * mask
    x = x / (x.sum(-1, keepdim=True) + 1e-13)
    return x

def replace_masked_values(x, mask, value):
    select = mask.expand(*x.shape)
    x[select] = value
    
def sort_pack_seq(x):
    "Batch first"
    # get lengths of each seq
    lengths = torch.tensor([len(i) for i in x])
    # pad the seq
    x = rnn_utils.pad_sequence(x, batch_first=True)
    # sorting
    _, ind = torch.sort(lengths, 0, descending=True)
    
    # packed seq
    x = rnn_utils.pack_padded_sequence(x[ind], 
                                       lengths[ind], 
                                       batch_first=True)
    return x, ind

In [516]:
class BidirectionalAttentionFlow(nn.Module):
    "From AllenNLP"
    def __init__(self):
        super(BidirectionalAttentionFlow, self).__init__()

        self.bidaf_embed = BidafEmbedding()
        self.highway = Highway(200, num_layers=2)
        
        self.phrase_layer = Pack2Pad(nn.LSTM(200, 100, batch_first=True,
                                          bidirectional=True))
    
        self.matrix_attention = Matrix_Attention()
        
        self.modeling_layer = Pack2Pad(nn.LSTM(800, 100, 
                                       num_layers=2, 
                                       batch_first=True, 
                                       dropout=0.2, 
                                       bidirectional=True))
        "Dense+Softmax"
        self.start_predictor = nn.Linear(1000, 1)
        
        "LSTM+Softmax"
        self.end_encoder = Pack2Pad(nn.LSTM(1400, 100, 
                                    batch_first=True,
                                    bidirectional=True))        
        self.end_predictor = nn.Linear(1000, 1)

    def forward(self, context, query):
        "Takes PackData as Input"
        
        "----Init Batch Calc----"
        # Faster if Seq lengths are variables
        c_mask = (context.words != 0)
        c_lens = c_mask.sum(1)
        c_ind = torch.sort(c_lens, 0, descending=True)[1]
        q_mask = (query.words != 0)
        q_lens = q_mask.sum(1)
        q_ind = torch.sort(q_lens, 0, descending=True)[1]

        
        "----Embedding Layer----"
        e_c = self.highway(
            self.bidaf_embed(context.words, context.chars))
        e_q = self.highway(
            self.bidaf_embed(query.words, query.chars))
        
        
        "----Phrase Layer----"
        e_c = self.phrase_layer(e_c, c_lens, c_ind)
        e_q = self.phrase_layer(e_q, q_lens, q_ind)
        
        
        "----Attention Layer----"
        # linear attention
        c2q_sim = self.matrix_attention(e_c , e_q)        
        
        # context to query attention
        c2q_att = masked_softmax(c2q_sim, q_mask.unsqueeze(1))
        c2q = c2q_att.bmm(e_q)
        
        # masked fill to value -1e7
        q2c_sim = c2q_sim.masked_fill((1 - q_mask.unsqueeze(1)), -1e7)
        q2c_sim = q2c_sim.max(dim=-1)[0]
        q2c_att = masked_softmax(q2c_sim, c_mask).unsqueeze(1)
        # Shape: c2q shape
        q2c = q2c_att.bmm(e_c).expand(*c2q.shape)
        
        att_out = torch.cat([e_c,
                             c2q,
                             e_c * c2q,
                             e_c * q2c],
                             dim=-1)
        
        
        "----Modeling Layer----"
        model_out = self.modeling_layer(att_out, c_lens, c_ind)
        
        
        "----Output Layer----"
        "----Start Layer-----"
        start = torch.cat([att_out, model_out], dim=-1)
        start_logits = self.start_predictor(start).squeeze(-1)
        start_probs = masked_softmax(start_logits, c_mask)
        
 
        "----End Layer-----"
        start_vector = start_probs.unsqueeze(1).bmm(model_out).expand(*c2q.shape)
        end_vector = torch.cat([att_out,
                                model_out,
                                start_vector,
                                model_out * start_vector],
                                dim=-1)

        # sort according to context
        end_out = self.end_encoder(end_vector, c_lens, c_ind)
        end_out = torch.cat([att_out, end_out], dim=-1)
        end_logits = self.end_predictor(end_out).squeeze(-1)

        # masked fill to refine the results
        start_logits = start_logits.masked_fill(1 - c_mask, -1e7)
        end_logits = end_logits.masked_fill(1 - c_mask, -1e7)
    
        "----Return----"
        return start_logits, end_logits


    def load_weights(self, load_table, weights_dict):
        for dst in load_table.keys():
            self.state_dict()[dst].copy_(weights_dict[load_table[dst]])
        
        # set padding weight to 0
        self.bidaf_embed.char_embed.embed.weight.data[0].fill_(0)

In [517]:
def read_table(file_name:str)->Dict[str, str]:
    with open(file_name, "r") as f:
        load_table = f.read()
        load_table = "{" + load_table + "}"
        load_table =  eval(load_table)
    return load_table

In [518]:
weights_dict = torch.load("weights.th", map_location='cpu')
load_table = read_table("bidaf_load.txt")
bidaf = BidirectionalAttentionFlow()
bidaf.load_weights(load_table, weights_dict)
bidaf.eval();

# Sanity Check

In [519]:
from allennlp.predictors.predictor import Predictor
from allennlp.models.archival import load_archive

if not "predictor" in vars():
    archive = load_archive("bidaf.tar.gz")
    predictor = Predictor.from_archive(archive)

In [520]:
class PackData():
    def __init__(self, data):
        self.words = data['tokens']
        self.chars = data['token_characters']

In [521]:
from allennlp.data.dataset import Batch

instance1 = predictor._dataset_reader.text_to_instance('good', 'This is not good!')
instance2 = predictor._dataset_reader.text_to_instance('it is bad', 'not bad')

dataset = Batch([instance1, instance2])
vocab = predictor._model.vocab
dataset.index_instances(vocab)
passage = dataset.as_tensor_dict()['passage']
question = dataset.as_tensor_dict()['question']

In [522]:
context = PackData(passage)
query = PackData(question)

## Embedding layer

In [523]:
lens = (context.words != 0).sum(1)

In [524]:
# bidaf(context, query)[1].sum()

## Phrase Layer

In [525]:
# e_c, e_q = bidaf(context, query)

## Matrix Attention

In [527]:
#bidaf(context, query)[0].sum()

## Modeling Layer

In [528]:
bidaf(context, query)[1].sum()

tensor(-30000048., grad_fn=<SumBackward0>)

## Ouput Layer

In [530]:
bidaf(context, query)

(tensor([[-4.1058e+00, -7.2027e+00, -5.9775e+00, -8.1707e+00, -6.1313e+00],
         [-2.6531e+00, -4.3184e+00, -1.0000e+07, -1.0000e+07, -1.0000e+07]],
        grad_fn=<MaskedFillBackward0>),
 tensor([[-7.6523e+00, -1.0996e+01, -8.5098e+00, -5.1956e+00, -3.5191e+00],
         [-7.9713e+00, -4.6289e+00, -1.0000e+07, -1.0000e+07, -1.0000e+07]],
        grad_fn=<MaskedFillBackward0>))

# PASS!!!

In [533]:
import inspect
lines = inspect.getsource(masked_softmax)
print(lines)

def masked_softmax(x, mask):
    mask = mask.float()
    x = torch.softmax(x * mask, dim=-1)
    x = x * mask
    x = x / (x.sum(-1, keepdim=True) + 1e-13)
    return x

