In [104]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import random



class EncoderParagraph(nn.Module):
    
    def __init__(self,word_size,word_dim, hidden_size, pretrained_word_embeds=None):
        super(EncoderParagraph, self).__init__()
        
        self.word_size = word_size
        self.word_dim = word_dim
        self.hidden_size = hidden_size
        self.pretrained_word_embeds = pretrained_word_embeds
        self.embedding = nn.Embedding(self.word_size,self.word_dim,padding_idx=0)
        self.lstm = nn.LSTM(self.word_dim,self.hidden_size,batch_first = True,bidirectional=True)
        
    def forward(self,x):
        out = self.embedding(x)
        out,hidden_cell = self.lstm(out)
        return out,hidden_cell

    def _init_weights(self):
        if PRE_TRAINED_EMBEDDING:
            self.embedding.weight.data.copy_(torch.from_numpy(self.pretrained_word_embeds))
            if NON_TRAINABLE:
                self.embedding.weight.requires_grad = False
            else:
                self.embedding.weight.requires_grad = True
        else:
            init.xavier_uniform_(self.embedding.weight.data)
            
class EncoderSentence(nn.Module):
    
    def __init__(self,word_size,word_dim, hidden_size, pretrained_word_embeds=None):
        super(EncoderSentence, self).__init__()
        
        self.word_size = word_size
        self.word_dim = word_dim
        self.hidden_size = hidden_size
        self.pretrained_word_embeds = pretrained_word_embeds
        self.embedding = nn.Embedding(self.word_size,self.word_dim,padding_idx=0)
        self.lstm = nn.LSTM(self.word_dim,self.hidden_size,batch_first = True,bidirectional=True)
        
    def forward(self,x):
        out = self.embedding(x)
        out,hidden_cell = self.lstm(out)
        return out,hidden_cell

    def _init_weights(self):
        if PRE_TRAINED_EMBEDDING:
            self.embedding.weight.data.copy_(torch.from_numpy(self.pretrained_word_embeds))
            if NON_TRAINABLE:
                self.embedding.weight.requires_grad = False
            else:
                self.embedding.weight.requires_grad = True
        else:
            init.xavier_uniform_(self.embedding.weight.data)

In [105]:
par_enc = EncoderParagraph(10,4,5)
sen_enc = EncoderSentence(10,4,5)

In [106]:
# Create an input tensor of random indices
test_input1 = torch.randint(0, 9, (4,3), dtype=torch.long)
test_input2 = torch.randint(0, 9, (4,6), dtype=torch.long)

In [107]:
a,b = par_enc(test_input1)
c,d = sen_enc(test_input2)

In [108]:
print a.size()
print c.size()

torch.Size([4, 3, 10])
torch.Size([4, 6, 10])


In [109]:
print b[0][0::2].size()
print d[0][0::2].size()

torch.Size([1, 4, 5])
torch.Size([1, 4, 5])


In [110]:
print b[1][0::2].size()
print d[1][0::2].size()

torch.Size([1, 4, 5])
torch.Size([1, 4, 5])


In [111]:
e = torch.cat((b[0][0::2],d[0][0::2]),2)
f = torch.cat((b[1][0::2],d[1][0::2]),2)

In [112]:
print e.size()
print f.size()

torch.Size([1, 4, 10])
torch.Size([1, 4, 10])


In [113]:
e

tensor([[[ 0.1076, -0.2935,  0.0755,  0.0771,  0.2181,  0.0943,  0.0197,
          -0.2198,  0.2790, -0.1438],
         [ 0.0639, -0.1106, -0.0201,  0.1325, -0.0875, -0.0698,  0.0299,
          -0.1304,  0.5169, -0.2540],
         [ 0.1161, -0.2714,  0.0652,  0.0895,  0.2234,  0.0622, -0.0138,
          -0.1512,  0.3889, -0.2338],
         [ 0.0059, -0.1488,  0.0567,  0.1260,  0.0924,  0.0996, -0.0005,
          -0.2134,  0.3472, -0.1267]]], grad_fn=<CatBackward>)

In [114]:
b[0][0::2]

tensor([[[ 0.1076, -0.2935,  0.0755,  0.0771,  0.2181],
         [ 0.0639, -0.1106, -0.0201,  0.1325, -0.0875],
         [ 0.1161, -0.2714,  0.0652,  0.0895,  0.2234],
         [ 0.0059, -0.1488,  0.0567,  0.1260,  0.0924]]],
       grad_fn=<SliceBackward>)

In [115]:
class AttnDecoderLSTM(nn.Module):
    def __init__(self, word_size,word_dim, hidden_size,max_length,pretrained_word_embeds=None):
        super(AttnDecoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.word_size = word_size
        self.word_dim = word_dim
        self.encoder_hidden_dim = hidden_size
        self.max_length = max_length
        
        self.embedding = nn.Embedding(self.word_size,self.word_dim,padding_idx=0)
        self.attn = nn.Linear(self.word_dim+self.encoder_hidden_dim, self.max_length)
        self.attn_combine = nn.Linear(self.word_dim+self.encoder_hidden_dim, self.word_dim)
        #self.dropout = nn.Dropout(self.dropout_p)
        self.lstm = nn.LSTM(self.word_dim,self.hidden_size,batch_first = True)
        self.out = nn.Linear(self.hidden_size, self.word_size)

    def forward(self, input, hidden, encoder_output1):
        embedded = self.embedding(input)
        #print embedded.squeeze(1).size(),hidden[0].squeeze(0).size()
        attn_weights = F.softmax(
            self.attn(torch.cat((embedded.squeeze(1),hidden[0].squeeze(0)),1)), dim=1)
        
        attn_weights = attn_weights.unsqueeze(1)
        
        # Apply Attention weights
        #print attn_weights.size(),encoder_output1.size()
        attn_applied = torch.bmm(attn_weights, encoder_output1)
        attn_applied = attn_applied.squeeze(1)
        
        # Prepare LSTM input tensor
        attn_combined = torch.cat((embedded.squeeze(1), attn_applied), 1)

        lstm_input = F.relu(self.attn_combine(attn_combined))
        lstm_input = lstm_input.unsqueeze(1)
        output, hidden = self.lstm(lstm_input, hidden)
        output = F.log_softmax(self.out(output[:,0,:]), dim=1)

        return output, hidden, attn_weights

In [116]:
decoder = AttnDecoderLSTM(10,4,10,3)

In [117]:
test_input3 = torch.randint(0, 9, (4,1), dtype=torch.long)

In [118]:
p,q,r = decoder(test_input3,(e,f),a)

In [119]:
print p.size()
print q[0].size()
print r.size()

torch.Size([4, 10])
torch.Size([1, 4, 10])
torch.Size([4, 1, 3])


In [120]:
class QuestionGeneration(nn.Module):
    def __init__(self, para_encoder,sent_encoder, decoder):
        super(QuestionGeneration, self).__init__()
        
        self.encoder1 = para_encoder
        self.encoder2 = sent_encoder
        self.decoder = decoder 
        
        
    def forward(self, para_src,sent_src, trg, teacher_forcing_ratio=0.5):
        
        #src = [batch size, sent len]
        #trg = [batch size, sent len]
        #teacher_forcing_ratio is probability to use teacher forcing
        #e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time
        
        batch_size = trg.shape[0]
        max_len = trg.shape[1]
        trg_vocab_size = self.decoder.word_size
        
        #tensor to store decoder outputs
        outputs = torch.zeros(batch_size,max_len, trg_vocab_size)
        
        #last hidden state of the encoder is used as the initial hidden state of the decoder
        out1, hidden_cell1 = self.encoder1(para_src)
        out2, hidden_cell2 = self.encoder2(sent_src)
        hidden = torch.cat((hidden_cell1[0][0::2],hidden_cell2[0][0::2]),dim=2)
        cell = torch.cat((hidden_cell1[1][0::2],hidden_cell2[1][0::2]),dim=2)
        
        #first input to the decoder is the <sos> tokens
        input = trg[:,0]
        
        for t in range(1, max_len):
            input = input.unsqueeze(1)
            output, (hidden, cell) = self.decoder(input, (hidden, cell),out2)
            outputs[:,t,:] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.max(1)[1]
            input = (trg[:,t] if teacher_force else top1)
        
        return outputs