In [1]:
import torch
from torch import nn
from d2l import torch as d2l

In [2]:
class AttentionDecoder(d2l.Decoder):  #@save
    """The base attention-based decoder interface."""
    def __init__(self):
        super().__init__()

    @property
    def attention_weights(self):
        raise NotImplementedError

In [35]:
class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0):
        super().__init__()
        self.attention = d2l.AdditiveAttention(num_hiddens, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.LazyLinear(vocab_size)
        self.apply(d2l.init_seq2seq)

    def init_state(self, enc_outputs, enc_valid_lens):
        outputs, hidden_state = enc_outputs
        # print(outputs)
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
    
    def forward(self, X, state):
        enc_outputs, hidden_state, enc_valid_lens = state
        X = self.embedding(X).permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        for x in X:
            query = torch.unsqueeze(hidden_state[-1], dim=1)
            context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
            
            print(context.shape)
            print(context)
            print("----------------------------------------------------------------")
            print(torch.unsqueeze(x, dim=1).shape)
            print(torch.unsqueeze(x, dim=1))

            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=1)

            out, hidden_state = self.rnn(x.permute(1, 0 ,2))
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
        outputs = self.dense(torch.cat(outputs, dim=0))
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
    
    @property
    def attention_weights(self):
        return self._attention_weights

In [36]:
vocab_size, embed_size, num_hiddens, num_layers = 10, 8, 16, 2
batch_size, num_steps = 4, 7
encoder = d2l.Seq2SeqEncoder(vocab_size, embed_size, num_hiddens, num_layers)
decoder = Seq2SeqAttentionDecoder(vocab_size, embed_size, num_hiddens,
                                  num_layers)
X = torch.zeros((batch_size, num_steps), dtype=torch.long)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
d2l.check_shape(output, (batch_size, num_steps, vocab_size))
d2l.check_shape(state[0], (batch_size, num_steps, num_hiddens))
d2l.check_shape(state[1][0], (batch_size, num_hiddens))

torch.Size([4, 1, 16])
tensor([[[ 0.0963, -0.2351,  0.1959, -0.1839, -0.1110, -0.1178, -0.3405,
           0.2456,  0.3421, -0.3238,  0.2807, -0.2502,  0.4350, -0.4681,
          -0.1753, -0.3188]],

        [[ 0.0963, -0.2351,  0.1959, -0.1839, -0.1110, -0.1178, -0.3405,
           0.2456,  0.3421, -0.3238,  0.2807, -0.2502,  0.4350, -0.4681,
          -0.1753, -0.3188]],

        [[ 0.0963, -0.2351,  0.1959, -0.1839, -0.1110, -0.1178, -0.3405,
           0.2456,  0.3421, -0.3238,  0.2807, -0.2502,  0.4350, -0.4681,
          -0.1753, -0.3188]],

        [[ 0.0963, -0.2351,  0.1959, -0.1839, -0.1110, -0.1178, -0.3405,
           0.2456,  0.3421, -0.3238,  0.2807, -0.2502,  0.4350, -0.4681,
          -0.1753, -0.3188]]], grad_fn=<BmmBackward0>)
----------------------------------------------------------------
torch.Size([4, 1, 8])
tensor([[[-0.4258, -1.1284, -0.9827,  1.3844, -1.3728,  1.7386, -0.5952,
           0.7795]],

        [[-0.4258, -1.1284, -0.9827,  1.3844, -1.3728,  1.7386,

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 16 but got size 8 for tensor number 1 in the list.