In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class WordEmbedding(nn.Module):
    '''
    In : (N, sentence_len)
    Out: (N, sentence_len, embd_size)
    '''
    def __init__(self, args, is_train_embd=False):
        super(WordEmbedding, self).__init__()
        self.embedding = nn.Embedding(args.vocab_size_w, args.w_embd_size)
        if args.pre_embd_w is not None:
            self.embedding.weight = nn.Parameter(args.pre_embd_w, requires_grad=is_train_embd)
        #a = True
    def forward(self, x):
        return self.embedding(x)
    


In [None]:
class BiDAF_PrNet(nn.Module):
    def __init__(self, args):
        super(BiDAF_PrNet, self).__init__()
        self.word_embd_size = args.word_embd_size
        self.input = self.word_embd_size
        self.usePointer = args.pointer
        if self.usePointer:
            self.answer_len = 2
            self.hidden_size = 4*self.input_size
            self.pointer_weight_size = 2*self.input_size
            self.pointer_embd_size = 2*self.input_size
        
        
        self.word_embd_layer = WordEmbedding(args)
        self.context_layer = nn.GRU(self.input, self.input, bidirectional=True, dropout=0.2, batch_first=True)
        self.W = nn.Linear(6*self.input, 1, bias=False)
        self.modeling_layer = nn.GRU(8*self.input, self.input, num_layers=2, bidirectional=True, dropout=0.2, batch_first=True)
        if self.usePointer:
            self.encode_layer = nn.GRU(self.pointer_emb_size, self.hidden_size, batch_first=True)
            self.decode_layer = nn.GRUCell(self.pointer_emb_size, self.hidden_size) 
            self.W1 = nn.Linear(self.hidden_size, self.pointer_weight_size, bias=False) 
            self.W2 = nn.Linear(self.hidden_size, self.pointer_weight_size, bias=False) 
            self.vt = nn.Linear(self.weight_size, 1, bias=False) 
    def to_variable(self, x):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x)
    
    def build_word_embd(self, x_context, x_word):
        word_embd = self.word_embd_layer(x_word)
        output, _vector = self.context_layer(x_word)
        return output
    
    def forward(self, x_word, x_context, x_query, x_query_context):
        batch_size = x_word.size(0)
        T = x_word.size(1)   
        L = x_query.size(1) 
        embd_context = self.build_word_embd(x_word, x_context)     
        embd_query = self.build_word_embd(x_query, x_query_context) 
        #Attention Layer
        shape = (batch_size, T, L, 2*self.input)            
        embd_context_expend = embd_context.unsqueeze(2).expand(shape) 
        embd_query_expend = embd_query.unsqueeze(1).expand(shape)            
        State = self.W(torch.cat((embd_context_ex, embd_query_ex, torch.mul(embd_context_expend, embd_query_expend) ), 3)).view(batch_size, T, L) 
        c2q = torch.bmm(F.softmax(State, dim=-1), embd_query)  
        q2c = torch.bmm(F.softmax(torch.max(State, 2)[0], dim=-1).unsqueeze(1), embd_context).repeat(1, T, 1) 

        H = torch.cat((embd_context, c2q, embd_context.mul(c2q), embd_context.mul(q2c)), 2)
        M, _vector = self.modeling_layer(H) 
        #Output Layer:
        if self.usePointer:
            encoder_states, _vector = self.encode_layer(M) 
            encoder_states = encoder_states.transpose(1, 0) 
            decoder_input = self.to_variable(torch.zeros(batch_size, self.pointer_embd_size))
            hidden = to_var(torch.zeros([batch_size, self.hidden_size]))   
            cell_state = encoder_states[-1]                            
            probs = []
            for i in range(self.answer_seq_len): 
                hidden = self.decode_layer(decoder_input, hidden) 
                blend1 = self.W1(encoder_states)          
                blend2 = self.W2(hidden)                  
                blend_sum = F.tanh(blend1 + blend2)    
                out = self.vt(blend_sum).squeeze()        
                out = F.log_softmax(out.t().contiguous()) 
                probs.append(out)
            probs = torch.stack(probs, dim=1) 
        return probs        
            
        # 6. Output Layer
        #G_M = torch.cat((G, M), 2) # (N, T, 10d)
        #p1 = F.softmax(self.p1_layer(G_M).squeeze(), dim=-1) # (N, T)
        #self.p1_layer(M)
        #M2, _ = self.p2_lstm_layer(M) # (N, T, 2d)
        
        #Output = torch.cat((M, M2),2)
        #G_M2 = torch.cat((G, M2), 2) # (N, T, 10d)
        #p2 = F.softmax(self.p2_layer(G_M2).squeeze(), dim=-1) # (N, T)
        #print(M)
        #input = self.emb(M) # (bs, L, embd_size)
            
        # Encoding
        
        