# Reproduce Transformer from Attention is All You Need

## Preliminaries

In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
from torch import nn
from dataset import Dataset
from tokenizer import get_tokenizer
from utils import NUM_PROC, DEVICE, free_memory
from model import TransformerModel
from transformer import Transformer


print("Number of processors: ", NUM_PROC)
print("Device: ", DEVICE)

Number of processors:  32
Device:  cuda


## Transformer Lite from Scratch

Using half the dimension as the base model: $d_{\rm model} = 256$, $d_{\rm ff} = 1024$. 

### Tokenizer

Byte-Pair Encoding with shared (English + German) vocabulary of 37000 tokens.

In [2]:
tokenizer = get_tokenizer(name="wmt14", language="de-en", vocab_size=37000)

Loaded tokenizer from ../tokenizer-wmt14-de-en.json


### Dataset

The dataset is downloaded at ~/.cache/huggingface/datasets/. I've turned off dataset caching to avoid disk explosion.

In [3]:
dataset = Dataset(name="wmt14", language="de-en", percentage=10)

In [4]:
dataset.tokenize(tokenizer)

Map (num_proc=32):   0%|          | 0/450878 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3000 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3003 [00:00<?, ? examples/s]

In [5]:
dataloader = {}
for split in ["train", "validation", "test"]:
    dataloader[split] = dataset.get_dataloader(split=split, batch_size=64, shuffle=True, min_len=1, max_len=128)


Filter:   0%|          | 0/450878 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/450222 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/2999 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3003 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3003 [00:00<?, ? examples/s]

### Train

In [35]:
# create the transformer model
model = TransformerModel(vocab_size=tokenizer.get_vocab_size(), d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=512**-0.5, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda nstep: min((nstep + 1) ** -0.5, (nstep + 1) * 4000 ** -1.5))
loss_fn = nn.CrossEntropyLoss() # could add label smoothing

In [36]:
# load model
model.load_state_dict(torch.load("model_1.pth"))

<All keys matched successfully>

In [40]:
# free_memory("model")
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 257542 KiB |   6042 MiB | 270349 GiB | 270349 GiB |
|       from large pool | 127640 KiB |   5913 MiB | 267709 GiB | 267708 GiB |
|       from small pool | 129902 KiB |    253 MiB |   2640 GiB |   2640 GiB |
|---------------------------------------------------------------------------|
| Active memory         | 257542 KiB |   6042 MiB | 270349 GiB | 270349 GiB |
|       from large pool | 127640 KiB |   5913 MiB | 267709 GiB | 267708 GiB |
|       from small pool | 129902 KiB |    253 MiB |   2640 GiB |   2640 GiB |
|---------------------------------------------------------------

In [37]:
# create the transformer wrapper
transformer = Transformer(model, tokenizer)

In [27]:
transformer.train(dataloader, model, loss_fn, optimizer, scheduler)

-------------------------------
Epoch 1/1
Accuracy: 36.1%, Avg loss: 5.538639  [   64/450222]  [0:00:00 < 0:12:27]
Accuracy: 35.9%, Avg loss: 5.541759  [ 6464/450222]  [0:00:10 < 0:11:40]
Accuracy: 38.0%, Avg loss: 5.389147  [12864/450222]  [0:00:19 < 0:10:55]
Accuracy: 38.2%, Avg loss: 5.546345  [19264/450222]  [0:00:29 < 0:10:53]
Accuracy: 35.8%, Avg loss: 5.519670  [25664/450222]  [0:00:38 < 0:10:30]
Accuracy: 37.1%, Avg loss: 5.471238  [32064/450222]  [0:00:47 < 0:10:15]
Accuracy: 38.1%, Avg loss: 5.393694  [38464/450222]  [0:00:56 < 0:10:01]
Accuracy: 35.1%, Avg loss: 5.619608  [44864/450222]  [0:01:05 < 0:09:50]
Accuracy: 40.9%, Avg loss: 5.103611  [51264/450222]  [0:01:14 < 0:09:38]
Accuracy: 39.8%, Avg loss: 5.204819  [57664/450222]  [0:01:23 < 0:09:27]
Accuracy: 38.3%, Avg loss: 5.322990  [64064/450222]  [0:01:32 < 0:09:17]
Accuracy: 37.3%, Avg loss: 5.219100  [70464/450222]  [0:01:41 < 0:09:07]
Accuracy: 38.4%, Avg loss: 5.299012  [76864/450222]  [0:01:50 < 0:08:57]
Accuracy:

### Evaluate

In [38]:
sample = dataset.dataset["test"]["translation"][10]
transformer.predict(sample["de"], sample["en"])

Accuracy: 20.5%
[31mThe[39m
" [31mThe[39m
" According [32mto[39m
" According to [31mthe[39m
" According to current [31m"[39m
" According to current measurements [31m"[39m
" According to current measurements , [31mthe[39m
" According to current measurements , around [32m12[39m
" According to current measurements , around 12 [31m%[39m
" According to current measurements , around 12 , [31m12[39m
" According to current measurements , around 12 , 000 [31mpeople[39m
" According to current measurements , around 12 , 000 vehicles [31mare[39m
" According to current measurements , around 12 , 000 vehicles travel [31min[39m
" According to current measurements , around 12 , 000 vehicles travel through [32mthe[39m
" According to current measurements , around 12 , 000 vehicles travel through the [31m'[39m
" According to current measurements , around 12 , 000 vehicles travel through the town [32mof[39m
" According to current measurements , around 12 , 000 vehicles trav

In [32]:
print(transformer.translate("Ich bin ein Berliner."))

I am a large foss .


In [39]:
for i in range(5):
    samples = dataset.dataset["test"]["translation"]
    idx = np.random.randint(len(samples))
    sample = samples[i]
    print(f"#{i+1}")
    print(f"Source: {sample['de']}")
    print(f"Target: {sample['en']}")
    print(f"Prediction: {transformer.translate(sample['de'])}")
    print()

#1
Source: Gutach: Noch mehr Sicherheit für Fußgänger
Target: Gutach: Increased safety for pedestrians
Prediction: The safety of the Aires is a matter for more security :

#2
Source: Sie stehen keine 100 Meter voneinander entfernt: Am Dienstag ist in Gutach die neue B 33-Fußgängerampel am Dorfparkplatz in Betrieb genommen worden - in Sichtweite der älteren Rathausampel.
Target: They are not even 100 metres apart: On Tuesday, the new B 33 pedestrian lights in Dorfparkplatz in Gutach became operational - within view of the existing Town Hall traffic lights.
Prediction: You have no doubt that the new anschluss entsprechender Quant in the lebt Like the 況 of the Festivals - up of the Kanada - in the lebt Like - in the - hand , in the new BAN - has been in the case of the Kanada .

#3
Source: Zwei Anlagen so nah beieinander: Absicht oder Schildbürgerstreich?
Target: Two sets of lights so close to one another: intentional or just a silly error?
Prediction: Two - thirds of the lebt reinigung o

## DEBUG

In [None]:
for name in ["src_len", "tgt_len"]:
    len_list = dataset.dataset["train"][name]
    tot = sum(len_list)
    count = 0
    for num in len_list:
        if num <= 256:
            count += num
    print(f"count: {count}, tot: {tot}, percentage: {count/tot*100:.2f}%")

count: 141283864, tot: 141799856, percentage: 99.64%
count: 138895459, tot: 139480626, percentage: 99.58%
