In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import sys
sys.path.append('..')
from d2l_helpers import *

In [2]:
class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super().__init__(**kwargs)
        self.attention = AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)
    
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        outputs, hidden_state = enc_outputs
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
    
    def forward(self, X, state):
        enc_outputs, hidden_state, env_valid_lens = state
        X = self.embedding(X).permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        for x in X:
            query = torch.unsqueeze(hidden_state[-1], dim=1)
            context = self.attention(query, enc_outputs, enc_outputs, env_valid_lens)
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
        outputs = self.dense(torch.cat(outputs, dim=0))
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, env_valid_lens]

In [3]:
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()
X = torch.zeros((4, 7), dtype=torch.long)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape

(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

In [4]:
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, get_device()

train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size, num_steps)
encoder = Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = EncoderDecoder(encoder, decoder)
train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

training on mps
epoch 10, loss 0.202, 4059.3 tokens/sec on mps
epoch 20, loss 0.138, 4094.0 tokens/sec on mps
epoch 30, loss 0.100, 4130.6 tokens/sec on mps
epoch 40, loss 0.076, 4152.3 tokens/sec on mps
epoch 50, loss 0.059, 4132.7 tokens/sec on mps
epoch 60, loss 0.049, 4101.9 tokens/sec on mps
epoch 70, loss 0.041, 4133.2 tokens/sec on mps
epoch 80, loss 0.036, 4145.1 tokens/sec on mps
epoch 90, loss 0.032, 4054.5 tokens/sec on mps
epoch 100, loss 0.028, 4080.7 tokens/sec on mps
epoch 110, loss 0.027, 4035.5 tokens/sec on mps
epoch 120, loss 0.025, 3477.9 tokens/sec on mps
epoch 130, loss 0.024, 3940.2 tokens/sec on mps
epoch 140, loss 0.023, 4052.5 tokens/sec on mps
epoch 150, loss 0.023, 4075.1 tokens/sec on mps
epoch 160, loss 0.022, 4103.2 tokens/sec on mps
epoch 170, loss 0.021, 4112.0 tokens/sec on mps
epoch 180, loss 0.021, 3502.8 tokens/sec on mps
epoch 190, loss 0.021, 4044.8 tokens/sec on mps
epoch 200, loss 0.021, 3677.5 tokens/sec on mps
epoch 210, loss 0.019, 4086.8 tok

In [5]:
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation, attention_weight_seq = predict_seq2seq(
        net, eng, src_vocab, tgt_vocab, num_steps, device)
    print(f'{eng} => {translation}, bleu {bleu(translation, fra, k=2):.3f}')

go . => va !, bleu 1.000
i lost . => j'ai perdu ., bleu 1.000
he's calm . => il est riche ., bleu 0.658
i'm home . => je suis chez moi ., bleu 1.000
