# Reproduce Transformer from Attention is All You Need

## Preliminaries

In [24]:
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
from torch import nn
from dataset import Dataset
from tokenizer import get_tokenizer
from utils import NUM_PROC, DEVICE, free_memory
from model import TransformerModel
from transformer import Transformer


print("Number of processors: ", NUM_PROC)
print("Device: ", DEVICE)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Number of processors:  32
Device:  cuda


## Transformer Lite from Scratch

Using half the dimension as the base model: $d_{\rm model} = 256$, $d_{\rm ff} = 1024$. 

### Tokenizer

Byte-Pair Encoding with shared (English + German) vocabulary of 37000 tokens.

In [25]:
tokenizer = get_tokenizer(name="wmt14", language="de-en", vocab_size=37000)

Loaded tokenizer from ../tokenizer-wmt14-de-en.json


### Dataset

The dataset is downloaded at ~/.cache/huggingface/datasets/. I've turned off dataset caching to avoid disk explosion.

In [26]:
dataset = Dataset(name="wmt14", language="de-en", percentage=100)

In [27]:
dataset.tokenize(tokenizer)

Map (num_proc=32):   0%|          | 0/45088 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3000 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3003 [00:00<?, ? examples/s]

In [45]:
dataloader = {}
for split in ["train", "validation", "test"]:
    dataloader[split] = dataset.get_dataloader(split=split, batch_size=64, shuffle=True, min_len=1, max_len=128)


Filter:   0%|          | 0/45088 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/45025 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/2999 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3003 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3003 [00:00<?, ? examples/s]

### Train

In [52]:
# create the transformer model
model = TransformerModel(vocab_size=tokenizer.get_vocab_size(), d_model=256, dim_feedforward=1024).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=512**-0.5, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda nstep: min((nstep + 1) ** -0.5, (nstep + 1) * 4000 ** -1.5))
loss_fn = nn.CrossEntropyLoss() # could add label smoothing

In [30]:
# load model
# model.load_state_dict(torch.load("model_1.pth"))

In [53]:
# free_memory("model")
free_memory()
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 385929 KiB |  10642 MiB |  29718 GiB |  29717 GiB |
|       from large pool | 168398 KiB |  10381 MiB |  29301 GiB |  29301 GiB |
|       from small pool | 217531 KiB |    426 MiB |    416 GiB |    416 GiB |
|---------------------------------------------------------------------------|
| Active memory         | 385929 KiB |  10642 MiB |  29718 GiB |  29717 GiB |
|       from large pool | 168398 KiB |  10381 MiB |  29301 GiB |  29301 GiB |
|       from small pool | 217531 KiB |    426 MiB |    416 GiB |    416 GiB |
|---------------------------------------------------------------

In [54]:
# create the transformer wrapper
transformer = Transformer(model, tokenizer)

In [55]:
transformer.train(dataloader, model, loss_fn, optimizer, scheduler)

-------------------------------
Epoch 1/1
Accuracy: 0.0%, Avg loss: 160.877609  [    1/45025]  [0:00:00 < 1:22:21]
Accuracy: 0.0%, Avg loss: 47.863728  [  101/45025]  [0:00:05 < 0:38:07]
Accuracy: 0.0%, Avg loss: 31.169077  [  201/45025]  [0:00:10 < 0:37:33]
Accuracy: 3.6%, Avg loss: 32.713753  [  301/45025]  [0:00:15 < 0:37:20]
Accuracy: 0.0%, Avg loss: 30.198341  [  401/45025]  [0:00:23 < 0:44:23]


KeyboardInterrupt: 

### Evaluate

In [42]:
sample = dataset.dataset["test"]["translation"][101]
transformer.predict(sample["de"], sample["en"])

Accuracy: 0.0%
[31m[39m
An [31mAn[39m
An extra [31mextra[39m
An extra one [31mone[39m
An extra one - [31m-[39m
An extra one - time [31mtime[39m
An extra one - time or [31mor[39m
An extra one - time or annual [31mannual[39m
An extra one - time or annual lev [31mlev[39m
An extra one - time or annual lev y [31my[39m
An extra one - time or annual lev y could [31mcould[39m
An extra one - time or annual lev y could be [31mbe[39m
An extra one - time or annual lev y could be imposed [31mimposed[39m
An extra one - time or annual lev y could be imposed on [31mon[39m
An extra one - time or annual lev y could be imposed on drivers [31mdrivers[39m
An extra one - time or annual lev y could be imposed on drivers of [31mof[39m
An extra one - time or annual lev y could be imposed on drivers of hy [31mhy[39m
An extra one - time or annual lev y could be imposed on drivers of hy br [31mbr[39m
An extra one - time or annual lev y could be imposed on drivers of hy br ids 

In [43]:
print(transformer.translate("Ich bin ein Berliner."))




In [44]:
for i in range(5):
    samples = dataset.dataset["test"]["translation"]
    idx = np.random.randint(len(samples))
    sample = samples[i]
    print(f"#{i+1}")
    print(f"Source: {sample['de']}")
    print(f"Target: {sample['en']}")
    print(f"Prediction: {transformer.translate(sample['de'])}")
    print()

#1
Source: Gutach: Noch mehr Sicherheit für Fußgänger
Target: Gutach: Increased safety for pedestrians
Prediction: 

#2
Source: Sie stehen keine 100 Meter voneinander entfernt: Am Dienstag ist in Gutach die neue B 33-Fußgängerampel am Dorfparkplatz in Betrieb genommen worden - in Sichtweite der älteren Rathausampel.
Target: They are not even 100 metres apart: On Tuesday, the new B 33 pedestrian lights in Dorfparkplatz in Gutach became operational - within view of the existing Town Hall traffic lights.
Prediction: 

#3
Source: Zwei Anlagen so nah beieinander: Absicht oder Schildbürgerstreich?
Target: Two sets of lights so close to one another: intentional or just a silly error?
Prediction: 

#4
Source: Diese Frage hat Gutachs Bürgermeister gestern klar beantwortet.
Target: Yesterday, Gutacht's Mayor gave a clear answer to this question.
Prediction: 

#5
Source: "Die Rathausampel ist damals installiert worden, weil diese den Schulweg sichert", erläuterte Eckert gestern.
Target: "At the t

## DEBUG

In [None]:
for name in ["src_len", "tgt_len"]:
    len_list = dataset.dataset["train"][name]
    tot = sum(len_list)
    count = 0
    for num in len_list:
        if num <= 256:
            count += num
    print(f"count: {count}, tot: {tot}, percentage: {count/tot*100:.2f}%")

count: 14302156, tot: 14303264, percentage: 99.99%
count: 14339828, tot: 14340777, percentage: 99.99%


In [None]:
for batch in dataloader["train"]:
    x, x_mask, y, y_mask = batch.values()
    print(x.shape, x_mask.shape, y.shape, y_mask.shape)
    x = model.embedding(x.to(DEVICE))
    print(x.shape)
    break

torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 128])
torch.Size([64, 128, 256])


In [None]:
for batch in dataloader["train"]:
    x, x_mask, y, y_mask = batch.values()
    z = x.masked_fill(x_mask == 0, 5)
    print(z[0])
    break

tensor([    1,  4126,  6476,  4263,  8684,  3956,  3767,  7128,  3873,  3807,
         7137, 25293,    16,  3807, 33842,  6294,  3983,    16,  3807, 25969,
           16,  3807, 21098,  7897, 35666, 13318,  3800,  3807,  9048, 15157,
           18,     2,     5,     5,     5,     5,     5,     5,     5,     5,
            5,     5,     5,     5,     5,     5,     5,     5,     5,     5,
            5,     5,     5,     5,     5,     5,     5,     5,     5,     5,
            5,     5,     5,     5,     5,     5,     5,     5,     5,     5,
            5,     5,     5,     5,     5,     5,     5,     5,     5,     5,
            5,     5,     5,     5,     5,     5,     5,     5,     5,     5,
            5,     5,     5,     5,     5,     5,     5,     5,     5,     5,
            5,     5,     5,     5,     5,     5,     5,     5,     5,     5,
            5,     5,     5,     5,     5,     5,     5,     5,     5,     5,
            5,     5,     5,     5,     5,     5,     5,     5])

In [49]:
for batch in dataloader["validation"]:
    x, x_mask, y, y_mask = batch.values()
    x, x_mask, y, y_mask = (
        x.to(DEVICE),
        x_mask.to(DEVICE),
        y.to(DEVICE),
        y_mask.to(DEVICE),
    )
    print("src =", x[0])
    print("tgt =", y[0])
    pred = model(x, x_mask, y, y_mask)
    print("pred =", pred.argmax(-1)[0])
    pred = pred[:, :-1, :]  # (batch_size, seq_len, vocab_size)
    label = y[:, 1:]  # (batch_size, seq_len)
    label_mask = y_mask[:, 1:] == 1  # (batch_size, seq_len)
    loss = loss_fn(pred[label_mask], label[label_mask])
    correct = (pred.argmax(-1) == label)[label_mask].float().sum().item() / label[label_mask].numel()
    print("loss =", loss.item())
    print("correct =", correct)
    break

src = tensor([    1,  3974, 19360,  3767,    16,  8428,  3782,  6364,  4060, 21877,
           16, 29774,  9518,  3803,  3956,    16, 10129,  3766,  3938,  3942,
         3946,  4986,  9913,  5286,  9797,    16,  3985,  3766,  3938,  3816,
         9609, 11372,  3873,  4155,  6301, 27076,    16,  3784,  3804,  3859,
         7013,  5426, 11281,  9275,  3862,  4155, 13125, 20643,  5556,  4054,
         3866,  5814,  5431, 18319,  5129,    18,     2,     3,     3,     3,
            3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
            3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
            3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
            3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
            3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
            3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
            3,     3,     3,     3,     3,     3,     3,  