In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Load data

Install `spacy` and download the raw data for the English and German Spacy tokenizers.  
NOTE: Administrator permission required. 
```bash
$ pip install spacy
$ python -m spacy download en
$ python -m spacy download de
```

In [2]:
from torchtext.datasets import Multi30k
from torchtext.data import Field

SRC = Field(tokenize = "spacy",
            tokenizer_language="de",
            init_token = '<sos>',
            eos_token = '<eos>',
            lower = True)

TRG = Field(tokenize = "spacy",
            tokenizer_language="en",
            init_token = '<sos>',
            eos_token = '<eos>',
            lower = True)

train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), 
                                                    fields = (SRC, TRG), 
                                                    root='data/')
print(train_data[0].src)
print(train_data[0].trg)

['zwei', 'junge', 'weiße', 'männer', 'sind', 'im', 'freien', 'in', 'der', 'nähe', 'vieler', 'büsche', '.']
['two', 'young', ',', 'white', 'males', 'are', 'outside', 'near', 'many', 'bushes', '.']


In [3]:
SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)

# A dict mapping word to index
print(list(SRC.vocab.stoi.keys())[:5])
print(list(SRC.vocab.stoi.values())[:5])
# A list mapping index to word
print(SRC.vocab.itos[:8])
print(TRG.vocab.itos[:8])

['<unk>', '<pad>', '<sos>', '<eos>', '.']
[0, 1, 2, 3, 4]
['<unk>', '<pad>', '<sos>', '<eos>', '.', 'ein', 'einem', 'in']
['<unk>', '<pad>', '<sos>', '<eos>', 'a', '.', 'in', 'the']


## `BucketIterator`: Iterate over the Datasets of Texts

In [4]:
from torch.utils.data import DataLoader
BATCH_SIZE = 4
# The default collate function checks if the batch contains tensors, numpy-arrays, ...
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, collate_fn=lambda x: x)

for i, batch in enumerate(train_loader):
    print(batch)
    break

[<torchtext.data.example.Example object at 0x000002AE89F12B08>, <torchtext.data.example.Example object at 0x000002AE89F12C88>, <torchtext.data.example.Example object at 0x000002AE88EF56C8>, <torchtext.data.example.Example object at 0x000002AE88EF5708>]


In [5]:
from torchtext.data import BucketIterator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 4

# The `BucketIterator` automatically transforms word sequences to tensors with paddings. 
train_iterator = BucketIterator(train_data, batch_size=BATCH_SIZE, device=device)
for i, batch in enumerate(train_iterator):
    print(batch.src.size())
    print(batch.trg.size())
    break

torch.Size([22, 4])
torch.Size([26, 4])


# Define Model

In [6]:
IN_DIM = len(SRC.vocab)
OUT_DIM = len(TRG.vocab)
# ENC_EMB_DIM = 256
# DEC_EMB_DIM = 256
# ENC_HID_DIM = 512
# DEC_HID_DIM = 512
# ATTN_DIM = 64
# ENC_DROPOUT = 0.5
# DEC_DROPOUT = 0.5

ENC_EMB_DIM = 32
DEC_EMB_DIM = 32
ENC_HID_DIM = 64
DEC_HID_DIM = 64
ATTN_DIM = 8
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

## The Encoder

In [7]:
from typing import Tuple

class Encoder(nn.Module):
    def __init__(self,  in_dim: int,  emb_dim: int, 
                 enc_hid_dim: int,  dec_hid_dim: int,  dropout: float):
        super().__init__()

        self.emb = nn.Embedding(in_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(enc_hid_dim*2, dec_hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, enc_ins: torch.Tensor) -> Tuple[torch.Tensor]:
        # enc_ins: (batch_size, src_len)
        embedded = self.dropout(self.emb(enc_ins))
        # outs: (batch_size, src_len, enc_hid_dim*2)
        # hidden: (2, batch_size, enc_hid_dim)
        outs, hidden = self.rnn(embedded)
        # Concatenate the last hidden states in two directions. 
        # hidden: (batch_size, enc_hid_dim*2)
        hidden = torch.cat([hidden[0], hidden[1]], dim=-1)
        # hidden: (batch_size, dec_hid_dim)
        hidden = torch.tanh(self.fc(hidden))
        return outs, hidden

In [8]:
encoder = Encoder(IN_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
# No initial hidden state provided, default to be zeros. 
enc_outs, dec_hidden = encoder(batch.src.T)
print(enc_outs.size())
print(dec_hidden.size())

torch.Size([4, 22, 128])
torch.Size([4, 64])


## The Attention

In [9]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim: int, dec_hid_dim: int, attn_dim: int):
        super().__init__()

        self.attn_in = enc_hid_dim*2 + dec_hid_dim
        self.attn = nn.Linear(self.attn_in, attn_dim)

    def forward(self, dec_hidden: torch.Tensor, enc_outs: torch.Tensor) -> torch.Tensor:
        src_len = enc_outs.size(1)
        # repeated_dec_hidden: (batch_size, src_len, dec_hid_dim)
        repeated_dec_hidden = dec_hidden.unsqueeze(1).repeat(1, src_len, 1)
        # enc_outs: (batch_size, src_len, enc_hid_dim*2)
        # energy: (batch_size, src_len, enc_hid_dim*2 + dec_hid_dim)
        energy = torch.tanh(torch.cat([repeated_dec_hidden, enc_outs], dim=-1))
        # attn: (batch_size, src_len)
        attn = energy.sum(dim=-1)
        return F.softmax(attn, dim=-1)

In [10]:
attention = Attention(ENC_HID_DIM, DEC_HID_DIM, ATTN_DIM)
attn = attention(dec_hidden, enc_outs)
print(attn.size())

torch.Size([4, 22])


## The Decoder

In [11]:
class Decoder(nn.Module):
    def __init__(self, out_dim: int, emb_dim: int, enc_hid_dim: int,  dec_hid_dim: int,  
                 dropout: float, attention: Attention):
        super().__init__()

        self.attention = attention
        self.emb = nn.Embedding(out_dim, emb_dim)
        # Single-directional
        self.rnn = nn.GRU(enc_hid_dim*2 + emb_dim, dec_hid_dim, batch_first=True)
        self.fc = nn.Linear(enc_hid_dim*2 + dec_hid_dim + emb_dim, out_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, dec_ins: torch.Tensor, dec_hidden: torch.Tensor, 
                enc_outs: torch.Tensor) -> Tuple[torch.Tensor]:
        """
        One-step forward. 
        """
        # dec_ins: (batch_size, 1)
        # embedded: (batch_size, 1, dec_emb_dim)
        embedded = self.dropout(self.emb(dec_ins))
        
        # attn: (batch_size, src_len)
        attn = self.attention(dec_hidden, enc_outs)
        # enc_outs: (batch_size, src_len, enc_hid_dim*2)
        # wtd_enc_rep: (batch_size, 1, enc_hid_dim*2)
        wtd_enc_rep = attn.unsqueeze(1).bmm(enc_outs)
        # rnn_ins: (batch_size, 1, enc_hid_dim*2 + dec_emb_dim)
        rnn_ins = torch.cat([embedded, wtd_enc_rep], dim=-1)
        # outs: (batch_size, 1, dec_hid_dim)
        outs, dec_hidden = self.rnn(rnn_ins, dec_hidden.unsqueeze(0))
        # outs: (batch_size, 1, trg_voc_size)
        outs = self.fc(torch.cat([outs, wtd_enc_rep, embedded], dim=-1))
        return outs, dec_hidden.squeeze(0)

In [12]:
decoder = Decoder(OUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attention)

print(dec_hidden.size())

dec_ins_0 = batch.trg[0].unsqueeze(1)
dec_outs_0, dec_hidden = decoder(dec_ins_0, dec_hidden, enc_outs)
print(dec_hidden.size())
print(dec_outs_0.size())

torch.Size([4, 64])
torch.Size([4, 64])
torch.Size([4, 1, 5893])


## The Seq2Seq Model

In [13]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, device: torch.device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, enc_ins: torch.Tensor, dec_ins: torch.Tensor, 
                teacher_forcing_ratio: float=0.5) -> torch.Tensor:
        # enc_ins: (batch_size, src_len)
        # dec_ins: (batch_size, trg_len)
        # No initial hidden state provided, default to be zeros. 
        enc_outs, dec_hidden = self.encoder(enc_ins)

        dec_outs = []
        # The first input to the decoder is the <sos> token. 
        # dec_ins_t: (batch_size, 1)
        dec_ins_t = dec_ins[:, 0].unsqueeze(1)
        for t in range(1, dec_ins.size(1)):
            # dec_outs_t: (batch_size, 1, trg_voc_size)
            dec_outs_t, dec_hidden = decoder(dec_ins_t, dec_hidden, enc_outs)
            top1 = dec_outs_t.max(dim=-1)[1]
            if np.random.rand() < teacher_forcing_ratio:
                dec_ins_t = dec_ins[:, t].unsqueeze(1)
            else:
                dec_ins_t = top1
            dec_outs.append(dec_outs_t)
        return torch.cat(dec_outs, dim=1)

In [14]:
model = Seq2Seq(encoder, decoder, device).to(device)
dec_outs = model(batch.src.T, batch.trg.T, 0.5)
dec_outs.size()

torch.Size([4, 25, 5893])

# Train Model

In [17]:
def init_weights(m: nn.Module):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [18]:
encoder = Encoder(IN_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
attention = Attention(ENC_HID_DIM, DEC_HID_DIM, ATTN_DIM)
decoder = Decoder(OUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attention)
model = Seq2Seq(encoder, decoder, device).to(device)

model.apply(init_weights)
print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 1,856,685 trainable parameters


In [19]:
# NOTE: when scoring the performance of a language translation model in particular, we have to tell the `nn.CrossEntropyLoss` function to ignore the indices where the target is simply padding. 
PAD_IDX = TRG.vocab.stoi['<pad>']

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = optim.Adam(model.parameters())

In [20]:
# The iterators work like `DataLoader`.
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_sizes=(BATCH_SIZE, BATCH_SIZE*5, BATCH_SIZE*5), 
    device=device)

for i, batch in enumerate(train_iterator):
    print(batch.src.size())
    print(batch.trg.size())
    break

torch.Size([22, 4])
torch.Size([26, 4])
