In [5]:
%run effective_approaches.ipynb

In [36]:
class Decoder_Coverage(nn.Module):

    def __init__(self, num_embeddings, embedding_dim, hidden_dim, num_layers, 
                 padding_dim=0, start_dim=1, Attention=Attention_General):
        
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_dim)
        self.decoder = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.word_predictor = nn.Linear(hidden_dim * 2, num_embeddings)
        self.start_dim = start_dim
        
        self.attention = Attention(hidden_dim)        
        self.fertility = nn.Linear(hidden_dim, 1)
        
    def forward(self, output, h, c, input=None, max_sen_len=20, source_mask=None):
        
        source_h = output
        num_sentences, num_words, hidden_dim = output.shape
        words_selected = torch.cuda.LongTensor([[self.start_dim] for _ in range(num_sentences)])

        source_h_importance = self.fertility(source_h.reshape(num_sentences * num_words, -1)).reshape(num_sentences, num_words)
        
        decoder_context = torch.zeros(num_sentences, hidden_dim).cuda()
        pred = []
        attention_weights = []
        
        if input is not None:
            max_sen_len = input.shape[1]
            teacher_words = input.t()
        
        for i in range(max_sen_len):
            embeddings = self.embedding(words_selected)
            h_t, (h, c) = self.decoder(embeddings, (h, c))
            orig_shape = h_t.shape
            
            ct, _attn_weights = self.attention(source_h, h[-1])
            h_t_ = torch.cat((ct, h[-1]), dim=1)
            attention_weights.append(_attn_weights)
            
            pred_t = self.word_predictor(h_t_.reshape(orig_shape[0]*orig_shape[1], 
                                                     -1)).reshape(*orig_shape[:-1], -1).squeeze(1)
            pred.append(pred_t)
            decoder_context = self.attention(source_h, h[-1])
            
            if input is None:
                words_selected = torch.max(pred_t, dim=1)[1].unsqueeze(1)
            else:
                words_selected = teacher_words[i].unsqueeze(1)


        pred = torch.stack(pred, dim=1)
        attention_weights = torch.stack(attention_weights, dim=-1)
        
        return torch.log_softmax(pred, dim=-1), (attention_weights, source_h_importance)

In [37]:
class Seq2Seq_Coverage(nn.Module):
    def __init__(self, source_vocab_len, target_vocab_len, embedding_dim, hidden_dim, num_layers=2,
                       padding_dim_source=0, start_dim_target=1, padding_dim_target=0, Attention=Attention_General):
        super().__init__()
        self.encoder = Encoder(source_vocab_len, embedding_dim, hidden_dim, num_layers, padding_dim_source)
        self.decoder = Decoder_Coverage(target_vocab_len, embedding_dim, hidden_dim, num_layers, 
                                        padding_dim_target, start_dim_target, Attention=Attention)
        
    def forward(self, source_input, target_input=None, source_mask=None):
        output, h, c = self.encoder(source_input)
        return self.decoder(output, h, c, target_input)

```
s = Seq2Seq_Coverage(10, 12, 9, 8, Attention=Attention_Concat)
s.cuda()
input = torch.cuda.LongTensor(np.random.randint(0, 10, (3, 6)))
target = torch.cuda.LongTensor(np.random.randint(0, 12, (3, 10)))

# without teacher forcing
pred, (attn_weights, fertility) = s(input)

# with teacher forcing
pred, (attn_weights, fertility) = s(input, target)
```

In [39]:
s = Seq2Seq_Coverage(10, 12, 9, 8, Attention=Attention_Concat)
s.cuda()
input = torch.cuda.LongTensor(np.random.randint(0, 10, (3, 6)))
target = torch.cuda.LongTensor(np.random.randint(0, 12, (3, 10)))

# without teacher forcing
pred, (attn_weights, fertility) = s(input)

# with teacher forcing
pred, (attn_weights, fertility) = s(input, target)

In [40]:
fertility

tensor([[-0.3678, -0.3860, -0.3827, -0.3883, -0.3942, -0.3919],
        [-0.3572, -0.3705, -0.3810, -0.3873, -0.3939, -0.3961],
        [-0.3397, -0.3767, -0.3804, -0.3896, -0.4011, -0.3916]],
       device='cuda:0', grad_fn=<AsStridedBackward>)

In [41]:
l = nn.MSELoss()

In [48]:
l(torch.sum(attn_weights, dim=-1), fertility)

tensor(4.1976, device='cuda:0', grad_fn=<MeanBackward1>)

In [47]:
torch.sum(attn_weights, dim=-1)

tensor([[1.6686, 1.6659, 1.6656, 1.6594, 1.6682, 1.6723],
        [1.6411, 1.6560, 1.6658, 1.6754, 1.6878, 1.6739],
        [1.6503, 1.6634, 1.6704, 1.6681, 1.6725, 1.6753]], device='cuda:0',
       grad_fn=<SumBackward2>)