In [None]:
!pip install d2l

In [None]:
import collections
import math
import torch 
import torch.nn as nn
from d2l import torch as d2l

# Encoder

In [None]:
class Seq2SeqEncoder(nn.Module):
  """The RNN encoder for sequence to sequence learning"""
  def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
    super(Seq2SeqEncoder, self).__init__(**kwargs)
    # Embedding layer
    self.embedding = nn.Embedding(vocab_size, embed_size)
    self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)

  def forward(self, X, *args):
    # Input X shape: (batch_size, num_steps)
    X = self.embedding(X) # Shape of X after embedding: (batch_size, num_steps, embed_size)

    # In RNN models the first axis corresponds to time steps
    X = X.permute(1,0,2)    # shape: (num_steps, batch_size, embed_size)

    # When state is not mentioned, it defaults to zeros
    output, state = self.rnn(X)

    # output shape: (num_steps, batch_size, num_hiddens)
    # state shape: (num_layers, batch_size, num_hiddens)
    return output, state

In [None]:
# Testig above implementation
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()
batch_size = 4
num_steps = 7
X = torch.zeros((batch_size,num_steps), dtype=torch.long)
output, state = encoder(X)
print(output.shape)
print(state.shape)

torch.Size([7, 4, 16])
torch.Size([2, 4, 16])


# Decoder

In [None]:
class Seq2SeqDecoder(nn.Module):
  def __init__(self, vocab_size, num_hiddens, embed_size, num_layers, dropout=0, **kwargs):
    super(Seq2SeqDecoder, self).__init__(**kwargs)
    self.embedding = nn.Embedding(vocab_size, embed_size)
    self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
    self.dense = nn.Linear(num_hiddens, vocab_size)

  def init_state(self, enc_outputs, *args):
    return enc_outputs[1]

  def forward(self, X, state):
    # Input shape of X: (batch_size,num_steps)
    X = self.embedding(X) # (batch_size, num_steps, embed_size)
    X = X.permute(1,0,2)  # (num_steps, batch_size, embed_size)

    # Broadcast context
    context = state[-1].repeat(X.shape[0], 1, 1)  # (num_steps, batch_size, num_hiddens)
    X_and_context = torch.cat((X,context), 2)   # (num_steps, batch_size, embed_size + num_hiddens)

    output, state = self.rnn(X_and_context, state) 
    output = self.dense(output).permute(1,0,2)

    # output shape: (batch_size, num_steps, vocab_size)
    # state shape: (num_layers, batch_size, num_hiddens)

    return output, state

In [None]:
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16,
                         num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
output, state = decoder(X, state)
output.shape, state.shape

(torch.Size([4, 7, 10]), torch.Size([2, 4, 16]))

Encoder Decoder

In [None]:
class EncoderDecoder(nn.Module):
  """The base class for the encoder-decoder architecture."""
  def __init__(self, encoder, decoder, **kwargs):
    super(EncoderDecoder, self).__init__(**kwargs)
    self.encoder = encoder
    self.decoder = decoder

  def forward(self, enc_X, dec_X, *args):
    enc_outputs = self.encoder(enc_X, *args)
    dec_state = self.decoder.init_state(enc_outputs, *args)
    return self.decoder(dec_X, dec_state)

In [None]:
def sequence_mask(X, valid_len, value=0):
  """Mask irrelevant entries in sequences."""
  maxlen = X.size(1)
  mask = torch.arange((maxlen), dtype=torch.float32,
                      device=X.device)[None, :] < valid_len[:, None]
  X[~mask] = value
  return X

X = torch.tensor([[1, 2, 3], [4, 5, 6]])
sequence_mask(X, torch.tensor([1, 2]))

tensor([[1, 0, 0],
        [4, 5, 0]])

In [None]:
X = torch.ones(2, 3, 4)
sequence_mask(X, torch.tensor([1, 2]), value=-1)

tensor([[[ 1.,  1.,  1.,  1.],
         [-1., -1., -1., -1.],
         [-1., -1., -1., -1.]],

        [[ 1.,  1.,  1.,  1.],
         [ 1.,  1.,  1.,  1.],
         [-1., -1., -1., -1.]]])

In [None]:
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
  """The softmax cross-entropy loss with masks."""

  # `pred` shape: (`batch_size`, `num_steps`, `vocab_size`)
  # `label` shape: (`batch_size`, `num_steps`)
  # `valid_len` shape: (`batch_size`,)
  def forward(self, pred, label, valid_len):
    weights = torch.ones_like(label)
    weights = sequence_mask(weights, valid_len)
    self.reduction = 'none'
    unweighted_loss = super(MaskedSoftmaxCELoss,
                            self).forward(pred.permute(0, 2, 1), label)
    weighted_loss = (unweighted_loss * weights).mean(dim=1)
    return weighted_loss

In [None]:
loss = MaskedSoftmaxCELoss()
loss(torch.ones(3, 4, 10), torch.ones((3, 4), dtype=torch.long),
     torch.tensor([4, 2, 0]))

tensor([2.3026, 1.1513, 0.0000])

# Training 

In [None]:
def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):
  def xavier_init_weights(m):
    if type(m) == nn.Linear:
      nn.init.xavier_uniform_(m.weight)
    if type(m) == nn.GRU:
      for param in m._flat_weights_names:
        if "weight" in param:
          nn.init.xavier_uniform_(m._parameters[param])
  
  net.apply(xavier_init_weights)
  net.to(device)
  optimizer = torch.optim.Adam(net.parameters(), lr=lr)
  loss = MaskedSoftmaxCELoss()
  net.train()
  for epoch in range(num_epochs):
    metric = d2l.Accumulator(2)
    for batch in data_iter:
      optimizer.zero_grad()
      X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
      # Shape of X: (batch_size, num_steps)
      # Shape of Y: (batch_size, num_steps)
      bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0], device=device).reshape(-1, 1)
      dec_input = torch.cat([bos, Y[:, :-1]], 1)  # Teacher forcing
      Y_hat, _ = net(X, dec_input, X_valid_len)
      l = loss(Y_hat, Y, Y_valid_len)
      l.sum().backward()  # Make the loss scalar for `backward`
      d2l.grad_clipping(net, 1)
      num_tokens = Y_valid_len.sum()
      optimizer.step()

In [None]:
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 300, d2l.try_gpu()

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers,
                         dropout)
decoder = Seq2SeqDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers,
                         dropout)
net = EncoderDecoder(encoder, decoder)
train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)