In [2]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [3]:
import layers
from util import collate_fn, SQuAD
import util
from util import masked_softmax

In [4]:
from models import BiDAF, RNet

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Load sample batch
* Batch size 64 (64 training examples)
* Each word is padded to 16 character

In [6]:
cw_idxs = torch.load('data_dev/cw_idxs.pt') # torch.Size([64, 293])
cc_idxs = torch.load('data_dev/cc_idxs.pt') # torch.Size([64, 293, 16])
qw_idxs = torch.load('data_dev/qw_idxs.pt') # torch.Size([64, 29])
qc_idxs = torch.load('data_dev/qc_idxs.pt') # torch.Size([64, 29, 16])
y1 = torch.load('data_dev/y1.pt') # torch.Size([64])
y2 = torch.load('data_dev/y2.pt') # torch.Size([64])
ids = torch.load('data_dev/ids.pt') # torch.Size([64])

In [7]:
c_mask = torch.zeros_like(cw_idxs) != cw_idxs
q_mask = torch.zeros_like(qw_idxs) != qw_idxs
c_len, q_len = c_mask.sum(-1), q_mask.sum(-1)

In [8]:
cc_idxs[0][0].size()
cc_idxs[0][2]

tensor([17, 18, 19, 10,  7, 12,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0])

In [49]:
cc_idxs[0]

tensor([[ 1,  1,  1,  ...,  1,  1,  1],
        [40,  6,  0,  ...,  0,  0,  0],
        [17, 18, 19,  ...,  0,  0,  0],
        ...,
        [ 0,  0,  0,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0]])

In [9]:
batch_size = cw_idxs.size()[0]
hidden_size = 100 

#### Load embeddings

In [10]:
word_vectors = util.torch_from_json("data/word_emb.json")

In [11]:
word_vectors.shape

torch.Size([88714, 300])

In [12]:
char_vectors = util.torch_from_json("data/char_emb.json")

In [13]:
char_vectors.shape

torch.Size([1376, 64])

### R-Net

#### 1. Embedding
R-net is built with four major components. The first step is to encode both passage P = {wtP }nt=1 and question Q = {wtQ}mt=1 with a bi-directional RNN, which transform word embeddings ({ePt }nt=1, {eQt }mt=1) and character
embeddings ({cPt }nt=1 , {cQt }mt=1 ) to new encoded representation ({uPt }nt=1 , {uQt }mt=1 ). The character and word embeddings are concatenated before entering the bi-directional RNN. Because of efficiency, the original paper chooses GRU cell instead of LSTM cell.

In [46]:
class CharEmbedding(nn.Module):
    def __init__(self, char_vectors, e_char, e_word):
        """
        The character-level embeddings are generated by taking the 
        final hidden states of a bi-directional recurrent neural 
        network (RNN) applied to embeddings of characters in the token.
        """
        super(CharEmbedding, self).__init__()
        
        self.e_char = e_char
        self.e_word = e_word
        
        self.embeddings = nn.Embedding.from_pretrained(char_vectors) 
        
        self.encoder = nn.LSTM(input_size=e_char, 
                               bidirectional=True, 
                               hidden_size=int(e_word / 2),
                               batch_first=True)

    def forward(self, input):
        """
        @param input (Tensor): indices of character batch (batch_size, sentence_length, max_word_length)
        @returns last_hidden (Tensor): Tensor of shape (batch_size, sentence_length, 2 * e_word)
        """
        x_emb = self.embeddings(input)
        batch_size, sentence_length, max_word_length, e_char = x_emb.shape
        
        # reshape to (batch_size * sentence_length, max_word_length, e_char)
        x_emb = x_emb.reshape(batch_size * sentence_length, max_word_length, e_char)
        
        # last_hidden: (2, batch_size * sentence_length, e_word)
        _, (last_hidden, _) = self.encoder(x_emb)
        
        # concate the two direction hidden states into shape (batch_size * sentence_length, 2 * e_word)
        last_hidden = torch.cat([last_hidden[0], last_hidden[1]], dim=1)
        
        # break apart dimension to (batch_size, sentence_length, 2 * e_word)
        last_hidden = last_hidden.reshape(batch_size, sentence_length, -1)
        return last_hidden

In [47]:
# charater based embedding
char_vectors = util.torch_from_json("data/char_emb.json")
char_emb = CharEmbedding(char_vectors, e_char=64, e_word=200)

cc_mask = torch.zeros_like(cc_idxs) != cc_idxs
qc_mask = torch.zeros_like(qc_idxs) != qc_idxs
cc_len, qc_len = cc_mask.sum(-1), qc_mask.sum(-1)
        
qc = char_emb.forward(qc_idxs)
# cc = char_emb.forward(cc_idxs, cc_len)
# qc.shape

cc_len.shape
cc_len.shape

torch.Size([64, 293])

In [43]:
# lookup word embedding
embed = nn.Embedding.from_pretrained(word_vectors) 

qw = embed(qw_idxs)
cw = embed(cw_idxs)
qw.shape

torch.Size([64, 29, 300])

In [75]:
# concatenate the two embedding
q = torch.cat((qc, qw), dim=2)
c = torch.cat((cc, cw), dim=2)
q.shape

torch.Size([64, 29, 500])

In [76]:
class Encoder(nn.Module):
    def __init__(self, input_size, h_size, dropout=0):
        super(Encoder, self).__init__()
        """
        This encoder takes concatenated embeddings of context (or question)
        and output a single hidden state representing the context (or question)
        
        @param input_size (int): input size is the sum of word embedding size 
            and character-based word embedding size (output of CharEmbedding module)
            this number is kept as the standard reference in all subsequent layers
        @param h_size (int): size of hidden state, also the output size
        """
        self.input_size = input_size
        self.h_size = h_size
        self.out_size = 2 * h_size
        self.gru = nn.GRU(input_size=input_size, 
            bidirectional=True, 
            hidden_size=h_size,
            batch_first=True)
        
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, input, lengths):
        """
        @param input (Tensor): concatenated embedding of character & word embeddings,
            shape (batch_size, sentence_length, embedding_size)
        @return last_hidden (Tensor): hidden state of every word of text, 
            shape (batch_size, sentence_length, 2 * h_size)
        """
        batch, sentence_length, embedding_size = input.size()
        
        # Sort by length and pack sequence for RNN
        lengths, sort_idx = lengths.sort(0, descending=True)
        input = input[sort_idx]     # (batch_size, seq_len, input_size)
        input = pack_padded_sequence(input, lengths, batch_first=True)
        
        # last_hidden: (batch_size, sentence_length, 2 * h_size)
        hiddens, last_hidden = self.gru(input) 
        
        # Unpack and reverse sort
        hiddens, _ = pad_packed_sequence(hiddens, batch_first=True, total_length=sentence_length)
        _, unsort_idx = sort_idx.sort(0)
        hiddens = hiddens[unsort_idx]   # (batch_size, seq_len, 2 * hidden_size)
        
        # apply dropout
        return self.dropout(hiddens)

In [77]:
c.shape

torch.Size([64, 293, 500])

In [78]:
encoder = Encoder(input_size=500, h_size=200)
uc = encoder.forward(c, c_len)
uq = encoder.forward(q, q_len)
uc.shape

torch.Size([64, 293, 400])

In [79]:
uq.shape

torch.Size([64, 29, 400])

#### 2. Gated Attention Based Recurrent Networks
The second component, Gated Attention-Based Recurrent Networks, modifies the representation of each passage word to become aware of the question ({vtP }nt=1). For passage word at step t, attention ct is aggregated over the entire question uQ. A signmoid input gate is added to attenuate the cell state ct, in order to capture the relation between current context word uPt and the entire question.

In [116]:
class GatedAttn(nn.Module):
    def __init__(self, input_size, h_size, dropout=0):
        super(GatedAttn, self).__init__()
        
        hidden_size = h_size
        self.hidden_size = hidden_size * 2
        self.input_size = input_size
        self.out_size = 2 * h_size
        
        self.gru = nn.GRUCell(input_size=input_size * 2, hidden_size=h_size * 2)
        
        self.Wp = nn.Linear(in_features=input_size * 2, 
                            out_features=h_size * 2, 
                            bias=False)
        
        self.Wq = nn.Linear(in_features=input_size * 2, 
                            out_features=h_size * 2, 
                            bias=False)
        
        self.Wv = nn.Linear(in_features=h_size * 2, 
                            out_features=h_size * 2, 
                            bias=False)
        
        self.Wg = nn.Linear(in_features=input_size * 4, 
                            out_features=input_size * 4, 
                            bias=False)

        self.dropout = nn.Dropout(p=dropout)

    
    def forward(self, up, uq, c_len):
        up = up.permute(1, 0, 2) # [n, batch_size, 2 * h_size]
        uq = uq.permute(1, 0, 2) # [m, batch_size, 2 * h_size]
        (n, batch_size, _) = up.size()
        (m, _, _) = uq.size()

        Up = up
        Uq = uq
        
        vs = torch.zeros(n, batch_size, self.out_size).to(device)
        v = torch.randn(batch_size, self.hidden_size).to(device)
        V = torch.randn(batch_size, self.hidden_size, 1).to(device)
        
        # [64, 29, 400]
        Uq_ = Uq.permute([1, 0, 2])
        for i in range(n):
            Wup = self.Wp(Up[i]) # [64, 400] -> [64, h_size]
            Wuq = self.Wq(Uq) # [29, 64, 400] -> [29, 64, h_size]
            Wvv = self.Wv(v) # [batch_size, h_size] -> [batch_size, h_size]
            x = torch.tanh(Wup + Wuq + Wvv) # (29, 64, 400)
            
            x = x.permute([1, 0, 2]) # (64, 29, 400)
            s = torch.bmm(x, V).squeeze() # (64, 29)
            
            a = torch.softmax(s, dim=1).unsqueeze(1) # [64, 1, 29]
#             a = masked_softmax(s, q_len[i], dim=1).unsqueeze(1) # [64, 1, 29]

            c = torch.bmm(a, Uq_).squeeze() # (64, 400)
            r = torch.cat([Up[i], c], dim=1) # (64, 800)
            g = torch.sigmoid(self.Wg(r)) # (64, 800)
            r_ = torch.mul(g, r) # element-wise mult
            
            c_ = r_[:, self.input_size*2:]
            v = self.gru(c_, v)
            vs[i] = v
            del Wup, Wuq, Wvv, x, a, s, c, g, r, r_, c_
        del up, uq, Up, Uq, Uq_
        vs = self.dropout(vs)
        return vs

In [117]:
gatedAttn = GatedAttn(input_size=200, h_size=200)

In [119]:
vp = gatedAttn.forward(uc, uq, c_len)
vp.shape

torch.Size([293, 64, 400])

#### 3. Self Attention

In [35]:
class SelfAttn(nn.Module):
    def __init__(self, in_size, dropout=0):
        super(SelfAttn, self).__init__()
        self.hidden_size = in_size
        self.in_size = in_size
        
        self.gru = nn.GRUCell(input_size=in_size, hidden_size=self.hidden_size)
        
        self.Wp = nn.Linear(self.in_size, self.hidden_size, bias=False)
        self.Wp_ = nn.Linear(self.in_size, self.hidden_size, bias=False)
        
        self.out_size = self.hidden_size
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, v):
        (l, batch_size, _) = v.size()
        h = torch.randn(batch_size, self.hidden_size).to(device)
        V = torch.randn(batch_size, self.hidden_size, 1).to(device)
        hs = torch.zeros(l, batch_size, self.out_size).to(device)
        
        for i in range(l):
            Wpv = self.Wp(v[i])
            Wpv_ = self.Wp_(v)
            x = torch.tanh(Wpv + Wpv_)
            x = x.permute([1, 0, 2])
            s = torch.bmm(x, V)
            s = torch.squeeze(s, 2)
            a = torch.softmax(s, 1).unsqueeze(1)
            c = torch.bmm(a, v.permute([1, 0, 2])).squeeze()
            h = self.gru(c, h)
            hs[i] = h
            del Wpv, Wpv_, x, s, a, c
        hs = self.dropout(hs)
        del h, v
        return hs

In [36]:
selfAttn = SelfAttn(in_size=400)

In [37]:
selfAttn.out_size

400

In [38]:
c = selfAttn.forward(vp)
c.shape

torch.Size([293, 64, 400])

#### 4. Pointer

In [39]:
class Pointer(nn.Module):
    def __init__(self, in_size1, in_size2):
        super(Pointer, self).__init__()
        self.hidden_size = in_size2
        self.in_size1 = in_size1
        self.in_size2 = in_size2
        self.gru = nn.GRUCell(input_size=in_size1, hidden_size=self.hidden_size)
        # Wu uses bias. See formula (11). Maybe Vr is just a bias.
        self.Wu = nn.Linear(self.in_size2, self.hidden_size, bias=True)
        self.Wh = nn.Linear(self.in_size1, self.hidden_size, bias=False)
        self.Wha = nn.Linear(self.in_size2, self.hidden_size, bias=False)
        self.out_size = 1

    def forward(self, h, u):
        """
        self matching output, 
        Uq: [64, 29, 400]
        """
        (lp, batch_size, _) = h.size()
  
        u = u.permute(1, 0, 2)
        (lq, _, _) = u.size()
        v = torch.randn(batch_size, self.hidden_size, 1).to(device)
        u_ = u.permute([1,0,2]) # (m, batch_size, h_size * 2)
        h_ = h.permute([1,0,2]) # (m, batch_size, h_size * 2)
        x = torch.tanh(self.Wu(u)).permute([1, 0, 2])
        s = torch.bmm(x, v)
        s = torch.squeeze(s, 2)
        a = torch.softmax(s, 1).unsqueeze(1)
        r = torch.bmm(a, u_).squeeze()
        x = torch.tanh(self.Wh(h)+self.Wha(r)).permute([1, 0, 2])
        s = torch.bmm(x, v)
        s = torch.squeeze(s)
        p1 = torch.softmax(s, 1)
        c = torch.bmm(p1.unsqueeze(1), h_).squeeze()
        r = self.gru(c, r)
        x = torch.tanh(self.Wh(h) + self.Wha(r)).permute([1, 0, 2])
        s = torch.bmm(x, v)
        s = torch.squeeze(s)
        p2 = torch.softmax(s, 1)
        return p1, p2

In [40]:
uq.shape

torch.Size([64, 29, 400])

In [41]:
pointer = Pointer(in_size1=400, in_size2=400)

In [48]:
p1, p2 = pointer(c, uq)

In [49]:
p1.shape

torch.Size([64, 293])

In [50]:
p2.shape

torch.Size([64, 293])

#### Put together

In [63]:
class RNet(nn.Module):
    def __init__(self, word_vectors, char_vectors, drop_prob=0):
        super(RNet, self).__init__()
        self.word_emb = nn.Embedding.from_pretrained(word_vectors)
        self.char_emb = CharEmbedding(char_vectors, e_char=64, e_word=200)
        
        self.encoder = Encoder(input_size=500, h_size=200)
        self.gatedAttn = GatedAttn(input_size=200, h_size=200)
        self.selfAttn = SelfAttn(self.gatedAttn.out_size)
        self.pointer = Pointer(self.selfAttn.out_size, self.encoder.out_size)

    # wemb of P, cemb of P, w of Q, c of Q, Answer
    def forward(self, Pcw_idxs, qw_idxs, cc_idxs, qc_idxs):
        qc = self.char_emb.forward(qc_idxs)
        cc = self.char_emb.forward(cc_idxs)

        qw = self.word_emb(qw_idxs)
        cw = self.word_emb(cw_idxs)

        Q = torch.cat((qc, qw), dim=2)
        P = torch.cat((cc, cw), dim=2)

        Up = self.encoder.forward(P)
        Uq = self.encoder.forward(Q)
        
        v = self.gatedAttn.forward(Up, Uq)
        torch.cuda.empty_cache()
        h = self.selfAttn(v)
        p1, p2 = self.pointer(h, Uq)
        return p1, p2

In [64]:
rnet = RNet(word_vectors, char_vectors)

In [65]:
p1, p2 = rnet.forward(cw_idxs, qw_idxs, cc_idxs, qc_idxs)

torch.Size([64, 293, 400])
torch.Size([64, 29, 400])
torch.Size([293, 64, 400])
torch.Size([293, 64, 400])


In [66]:
p1.shape

torch.Size([64, 293])

In [67]:
p2.shape

torch.Size([64, 293])