In [1]:
import torch
from d2l import torch as d2l
from torch import nn

In [2]:
class MTFraEng(d2l.DataModule):
    """The English-French dataset."""

    def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128):
        super(MTFraEng, self).__init__()
        self.save_hyperparameters()
        self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays(
            self._download())

    def _download(self):
        d2l.extract(d2l.download(
            d2l.DATA_URL + 'fra-eng.zip', self.root, '94646ad1522d915e7b0f9296181140edcf86a4f5'))
        with open(self.root + '/fra-eng/fra.txt', encoding='utf-8') as f:
            return f.read()

    def _preprocess(self, text):
        # Replace non-breaking space with space
        text = text.replace('\u202f', ' ').replace('\xa0', ' ')
        # Insert space between words and punctuation marks
        no_space = lambda char, prev_char: char in ',.!?' and prev_char != ' '
        out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
               for i, char in enumerate(text.lower())]
        return ''.join(out)

    def _tokenize(self, text, max_examples=None):
        src, tgt = [], []
        for i, line in enumerate(text.split('\n')):
            if max_examples and i > max_examples: break
            parts = line.split('\t')
            if len(parts) == 2:
                # Skip empty tokens
                src.append([t for t in f'{parts[0]} <eos>'.split(' ') if t])
                tgt.append([t for t in f'{parts[1]} <eos>'.split(' ') if t])
        return src, tgt

    def _build_arrays(self, raw_text, src_vocab=None, tgt_vocab=None):
        def _build_array(sentences, vocab, is_tgt=False):
            pad_or_trim = lambda seq, t: (
                seq[:t] if len(seq) > t else seq + ['<pad>'] * (t - len(seq)))
            sentences = [pad_or_trim(s, self.num_steps) for s in sentences]
            if is_tgt:
                sentences = [['<bos>'] + s for s in sentences]
            if vocab is None:
                vocab = d2l.Vocab(sentences, min_freq=2)  # discard the tokens with freq < 2
            array = torch.tensor([vocab[s] for s in sentences])
            valid_len = (array != vocab['<pad>']).type(torch.int32).sum(1)
            return array, vocab, valid_len

        src, tgt = self._tokenize(self._preprocess(raw_text),
                                  self.num_train + self.num_val)
        src_array, src_vocab, src_valid_len = _build_array(src, src_vocab)
        tgt_array, tgt_vocab, _ = _build_array(tgt, tgt_vocab, True)
        return ((src_array, tgt_array[:, :-1], src_valid_len, tgt_array[:, 1:]),
                src_vocab, tgt_vocab)

    def get_dataloader(self, train):
        idx = slice(0, self.num_train) if train else slice(self.num_train, None)
        return self.get_tensorloader(self.arrays, train, idx)

    def build(self, src_sentences, tgt_sentences):
        raw_text = '\n'.join([src + '\t' + tgt for src, tgt in zip(
            src_sentences, tgt_sentences)])
        arrays, _, _ = self._build_arrays(
            raw_text, self.src_vocab, self.tgt_vocab)
        return arrays

### Data preparation

1. download the data
2. preprocess: (a) Replace non-breaking space with space; (b) Insert space between words and punctuation marks
3. tokenization and generate source and target sequences. <span style="color:red">Here</span> we keep the punctuation marks.

In [3]:
torch.manual_seed(123)

batch_size = 10
num_steps = 9
data = MTFraEng(batch_size, num_steps)
raw_text = data._download()
print(f'raw_text[:75]: {raw_text[:75]}')

text = data._preprocess(raw_text)
print(f'text[:80]: {text[:80]}')

src, tgt = data._tokenize(text)
print(f'src[:6]: {src[:6]}')
print(f'tgt[:6]: {tgt[:6]}')

raw_text[:75]: Go.	Va !
Hi.	Salut !
Run!	Cours !
Run!	Courez !
Who?	Qui ?
Wow!	Ça alors !

text[:80]: go .	va !
hi .	salut !
run !	cours !
run !	courez !
who ?	qui ?
wow !	ça alors !
src[:6]: [['go', '.', '<eos>'], ['hi', '.', '<eos>'], ['run', '!', '<eos>'], ['run', '!', '<eos>'], ['who', '?', '<eos>'], ['wow', '!', '<eos>']]
tgt[:6]: [['va', '!', '<eos>'], ['salut', '!', '<eos>'], ['cours', '!', '<eos>'], ['courez', '!', '<eos>'], ['qui', '?', '<eos>'], ['ça', 'alors', '!', '<eos>']]


### _build_arrays
```
def _build_array(sentences, vocab, is_tgt=False):
    pad_or_trim = lambda seq, t: (
        seq[:t] if len(seq) > t else seq + ['<pad>'] * (t - len(seq)))
    sentences = [pad_or_trim(s, self.num_steps) for s in sentences]
    if is_tgt:
        sentences = [['<bos>'] + s for s in sentences]
    if vocab is None:
        vocab = d2l.Vocab(sentences, min_freq=2)
    array = torch.tensor([vocab[s] for s in sentences])
    valid_len = (array != vocab['<pad>']).type(torch.int32).sum(1)
    return array, vocab, valid_len
```
1. pad_or_trim: To control sentence length, we truncate long sentences and pad short ones.
2. ```vocab = d2l.Vocab(sentences, min_freq=2)``` Discard the tokens with freq < 2.

In [4]:
src, tgt, src_valid_len, label = next(iter(data.train_dataloader()))

print(f'data.src_vocab.token_to_idx: {data.src_vocab.token_to_idx}')
print(f'data.tgt_vocab.token_to_idx: {data.tgt_vocab.token_to_idx}')

data.src_vocab.token_to_idx: {'!': 0, ',': 1, '.': 2, '<eos>': 3, '<pad>': 4, '<unk>': 5, '?': 6, 'a': 7, 'agree': 8, 'ahead': 9, 'am': 10, 'ask': 11, 'attack': 12, 'away': 13, 'back': 14, 'bark': 15, 'be': 16, 'beats': 17, 'bed': 18, 'beg': 19, 'busy': 20, 'call': 21, 'calm': 22, 'came': 23, 'can': 24, 'catch': 25, 'cheers': 26, 'cold': 27, 'come': 28, 'cool': 29, 'cringed': 30, 'deaf': 31, 'did': 32, 'die': 33, 'died': 34, 'dogs': 35, "don't": 36, 'down': 37, 'dozed': 38, 'drive': 39, 'drop': 40, 'excuse': 41, 'fair': 42, 'fat': 43, 'feel': 44, 'fell': 45, 'find': 46, 'fine': 47, 'fire': 48, 'fix': 49, 'follow': 50, 'for': 51, 'forget': 52, 'free': 53, 'full': 54, 'fun': 55, 'game': 56, 'get': 57, 'give': 58, 'go': 59, 'good': 60, 'got': 61, 'grab': 62, 'had': 63, 'hang': 64, 'have': 65, 'he': 66, "he's": 67, 'hello': 68, 'help': 69, 'here': 70, 'hi': 71, 'him': 72, 'his': 73, 'hit': 74, 'hold': 75, 'home': 76, 'hop': 77, 'hot': 78, 'how': 79, "how's": 80, 'hug': 81, 'hurried': 82, '

In [5]:
print('source:', src.type(torch.int32))
print('tgt:', tgt.type(torch.int32))
print('source[0]:', data.src_vocab.to_tokens(src[0].type(torch.int32)))
print('target[0]:', data.tgt_vocab.to_tokens(tgt[0].type(torch.int32)))
print('source len excluding pad:', src_valid_len.type(torch.int32))
print('label:', label.type(torch.int32))

source: tensor([[ 86,  20,   2,   3,   4,   4,   4,   4,   4],
        [ 84,  32, 120,   2,   3,   4,   4,   4,   4],
        [ 86,   5,   2,   3,   4,   4,   4,   4,   4],
        [ 28, 150,   2,   3,   4,   4,   4,   4,   4],
        [ 16, 153,   2,   3,   4,   4,   4,   4,   4],
        [ 84, 172,   2,   3,   4,   4,   4,   4,   4],
        [  5,   2,   3,   4,   4,   4,   4,   4,   4],
        [ 84, 170,   2,   3,   4,   4,   4,   4,   4],
        [ 35,  15,   2,   3,   4,   4,   4,   4,   4],
        [ 86, 121,   2,   3,   4,   4,   4,   4,   4]], dtype=torch.int32)
tgt: tensor([[  3, 108, 183,   6,   2,   4,   5,   5,   5],
        [  3, 108, 122, 183,  30,   6,   2,   4,   5],
        [  3, 108, 183,   6,   2,   4,   5,   5,   5],
        [  3, 204,  31,   0,   4,   5,   5,   5,   5],
        [  3, 182,  39,   0,   4,   5,   5,   5,   5],
        [  3,   6,   2,   4,   5,   5,   5,   5,   5],
        [  3, 128,   0,   4,   5,   5,   5,   5,   5],
        [  3,   6,   2,   4,   5

### The Encoder-Decoder Seq2Seq Architecture

[Sequence to Sequence Learning with Neural Networks, 2014](https://arxiv.org/pdf/1409.3215.pdf)

input -> encoder -> state -> decoder (<- input) -> output

1. **encoder**: enc_input -> enc_outputs
2. **decoder**: dec_input $\times$ dec_state -> dec_outputs <br>
   **dec_init_state**: enc_outputs -> dec_state

In [6]:
class Encoder(nn.Module):
    """The base encoder interface for the encoder-decoder architecture."""

    def __init__(self):
        super().__init__()

    # Later there can be additional arguments (e.g., length excluding padding)
    def forward(self, X, *args):
        raise NotImplementedError

In [7]:
class Decoder(nn.Module):
    """The base decoder interface for the encoder-decoder architecture."""

    def __init__(self):
        super().__init__()

    # Later there can be additional arguments (e.g., length excluding padding)
    def init_state(self, enc_all_outputs, *args):
        raise NotImplementedError

    def forward(self, X, state):
        raise NotImplementedError

In [8]:
class EncoderDecoder(d2l.Classifier):
    """The base class for the encoder-decoder architecture."""

    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        enc_all_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_all_outputs, *args)
        # Return decoder output only
        return self.decoder(dec_X, dec_state)[0]

In [9]:
def init_seq2seq(module):
    """Initialize weights for Seq2Seq."""
    if type(module) == nn.Linear:
        nn.init.xavier_uniform_(module.weight)
    if type(module) == nn.GRU:
        for param in module._flat_weights_names:
            if "weight" in param:
                nn.init.xavier_uniform_(module._parameters[param])

In [10]:
class Seq2SeqEncoder(d2l.Encoder):
    """The RNN encoder for sequence to sequence learning."""

    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = d2l.GRU(embed_size, num_hiddens, num_layers, dropout)
        self.apply(init_seq2seq)

    def forward(self, X, *args):
        # X shape: (batch_size, num_steps)
        embs = self.embedding(X.t().type(torch.int64))
        # embs shape: (num_steps, batch_size, embed_size)
        outputs, state = self.rnn(embs)
        # outputs shape: (num_steps, batch_size, num_hiddens)
        # state shape: (num_layers, batch_size, num_hiddens)
        return outputs, state

In [11]:
class Seq2SeqDecoder(d2l.Decoder):
    """The RNN decoder for sequence to sequence learning."""

    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = d2l.GRU(embed_size + num_hiddens, num_hiddens,
                           num_layers, dropout)
        self.dense = nn.LazyLinear(vocab_size)
        self.apply(init_seq2seq)

    def init_state(self, enc_all_outputs, *args):
        """
        :param enc_all_outputs: (outputs, state)
        """
        return enc_all_outputs

    def forward(self, X, state):
        # X shape: (batch_size, num_steps)
        # embs shape: (num_steps, batch_size, embed_size)
        embs = self.embedding(X.t().type(torch.int32))
        enc_output, hidden_state = state
        # context shape: (batch_size, num_hiddens)
        context = enc_output[-1]  # enc_output at the final time step
        # Broadcast context to (num_steps, batch_size, num_hiddens)
        context = context.repeat(embs.shape[0], 1, 1)
        # Concat at the feature dimension
        embs_and_context = torch.cat((embs, context), -1)
        outputs, hidden_state = self.rnn(embs_and_context, hidden_state)
        outputs = self.dense(outputs).swapaxes(0, 1)
        # outputs shape: (batch_size, num_steps, vocab_size)
        # hidden_state shape: (num_layers, batch_size, num_hiddens)
        return outputs, [enc_output, hidden_state]

### init_seq2seq

1. nn.linear: in_features $\times$ out_features -> (\*, out_features) <br>
   input: (*, in_features) <br>
   output: (\*, out_features) <br>
   $y = x A^{T} + b$
2. by default, bias = True

In [12]:
in_feas, out_feas, num_input = 2, 3, 10
m = nn.Linear(in_feas, out_feas)
input = torch.randn(num_input, in_feas)
output = m(input)
print(f'output.shape: {output.shape}')
print(f'm.weight.shape, m.bias.shape: {m.weight.shape, m.bias.shape}')

output.shape: torch.Size([10, 3])
m.weight.shape, m.bias.shape: (torch.Size([3, 2]), torch.Size([3]))


### nn.Embedding

1. nn.Embedding: num_embeddings $\times$ embedding_dim, OR vocab_size $\times$ vector_size
2. num_embeddings: size of the dictionary of embeddings
3. embedding_dim: the size of each embedding vector
4. input: (*), IntTensor or LongTensor of arbitrary shape containing the indices to extract
5. output: (*, H), where * is the input shape and H = embedding_dim
6. vocab_size > any element in input

In [13]:
vocab_size, emb_size = 10, 3
embedding = nn.Embedding(vocab_size, emb_size)
# a batch of 2 samples of 4 indices each
input = torch.LongTensor([[1, 2, 4, 0], [4, 3, 2, 9]])
embedding(input)

tensor([[[-1.2203,  1.3139,  1.0533],
         [ 0.1388, -0.2044, -0.8036],
         [-0.7979,  0.1838,  1.6863],
         [ 1.6221, -1.4779,  1.1331]],

        [[-0.7979,  0.1838,  1.6863],
         [-0.2808,  0.7697, -0.6596],
         [ 0.1388, -0.2044, -0.8036],
         [-0.5951, -0.7112,  0.6230]]], grad_fn=<EmbeddingBackward0>)

Let us take the 1st batch in the input ```[1, 2, 4, 0]``` as an example. Its embedding matrix (4 by 3) is a re-representation of the raw data.nts 

### Encoder

In [14]:
vocab_size, embed_size, num_hiddens, num_layers = 10, 8, 16, 2
batch_size, num_steps = 4, 9
encoder = Seq2SeqEncoder(vocab_size, embed_size, num_hiddens, num_layers)
X = torch.zeros((batch_size, num_steps))
enc_outputs, enc_state = encoder(X)
print(f'enc_outputs.shape: {enc_outputs.shape}')
print(f'enc_state.shape: {enc_state.shape}')

enc_outputs.shape: torch.Size([9, 4, 16])
enc_state.shape: torch.Size([2, 4, 16])


### Decoder

In [15]:
decoder = Seq2SeqDecoder(vocab_size, embed_size, num_hiddens, num_layers)
state = decoder.init_state(encoder(X))
dec_outputs, state = decoder(X, state)

