In [1]:
import torch
import collections
import math
from torch import nn

import sys
sys.path.append('..')
from d2l_helpers import *

In [2]:
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()
X = torch.zeros((4, 7), dtype=torch.long)
output, state = encoder(X)
output.shape, state.shape

(torch.Size([7, 4, 16]), torch.Size([2, 4, 16]))

In [4]:
class Seq2SeqDecoder(Decoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super().__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)
    
    def init_state(self, enc_outputs, *args):
        return enc_outputs[1]
    
    def forward(self, X, state):
        X = self.embedding(X).permute(1, 0, 2)
        context = state[-1].repeat(X.shape[0], 1, 1)
        X_and_context = torch.cat((X, context), 2)
        output, state = self.rnn(X_and_context, state)
        output = self.dense(output).permute(1, 0, 2)
        return output, state

In [5]:
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
output, state = decoder(X, state)
output.shape, state.shape

(torch.Size([4, 7, 10]), torch.Size([2, 4, 16]))

In [6]:
loss = MaskedSoftmaxCELoss()
loss(torch.ones(3, 4, 10), torch.ones((3, 4), dtype=torch.long), torch.tensor([4, 2, 0]))

tensor([2.3026, 1.1513, 0.0000])

In [7]:
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 300, get_device()

train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size, num_steps)
encoder = Seq2SeqEncoder(
    len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqDecoder(
    len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = EncoderDecoder(encoder, decoder)
train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

training on mps
epoch 10, loss 0.206, 7093.8 tokens/sec on mps
epoch 20, loss 0.147, 7087.4 tokens/sec on mps
epoch 30, loss 0.110, 7622.8 tokens/sec on mps
epoch 40, loss 0.085, 7372.0 tokens/sec on mps
epoch 50, loss 0.070, 7203.6 tokens/sec on mps
epoch 60, loss 0.057, 7606.1 tokens/sec on mps
epoch 70, loss 0.048, 7626.9 tokens/sec on mps
epoch 80, loss 0.042, 7628.9 tokens/sec on mps
epoch 90, loss 0.036, 7298.2 tokens/sec on mps
epoch 100, loss 0.034, 7334.5 tokens/sec on mps
epoch 110, loss 0.031, 7507.8 tokens/sec on mps
epoch 120, loss 0.029, 7602.2 tokens/sec on mps
epoch 130, loss 0.028, 7637.1 tokens/sec on mps
epoch 140, loss 0.025, 7448.1 tokens/sec on mps
epoch 150, loss 0.024, 7531.0 tokens/sec on mps
epoch 160, loss 0.025, 7585.1 tokens/sec on mps
epoch 170, loss 0.023, 7535.5 tokens/sec on mps
epoch 180, loss 0.022, 7629.4 tokens/sec on mps
epoch 190, loss 0.022, 7500.4 tokens/sec on mps
epoch 200, loss 0.021, 7635.0 tokens/sec on mps
epoch 210, loss 0.021, 7590.8 tok

In [8]:
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation, attention_weight_seq = predict_seq2seq(
        net, eng, src_vocab, tgt_vocab, num_steps, device)
    print(f'{eng} => {translation}, bleu {bleu(translation, fra, k=2):.3f}')

go . => va !, bleu 1.000
i lost . => j'ai perdu ., bleu 1.000
he's calm . => il est paresseux ., bleu 0.658
i'm home . => je suis chez chez moi qui ai la paresseux ., bleu 0.537
