# Reproduce Transformer from Attention is All You Need

## Preliminaries

In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
from torch import nn
from dataset import Dataset
from tokenizer import get_tokenizer
from utils import NUM_PROC, DEVICE, free_memory
from model import TransformerModel
from transformer import Transformer


print("Number of processors: ", NUM_PROC)
print("Device: ", DEVICE)

Number of processors:  32
Device:  cuda


## Transformer Lite from Scratch

Using half the dimension as the base model: $d_{\rm model} = 256$, $d_{\rm ff} = 1024$. 

### Tokenizer

Byte-Pair Encoding with shared (English + German) vocabulary of 37000 tokens.

In [2]:
tokenizer = get_tokenizer(name="wmt14", language="de-en", vocab_size=37000)

Loaded tokenizer from ../tokenizer-wmt14-de-en.json


### Dataset

The dataset is downloaded at ~/.cache/huggingface/datasets/. I've turned off dataset caching to avoid disk explosion.

In [56]:
dataset = Dataset(name="wmt14", language="de-en", percentage=10)

In [57]:
dataset.tokenize(tokenizer)

Map (num_proc=32):   0%|          | 0/450878 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3000 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3003 [00:00<?, ? examples/s]

In [58]:
dataloader = {}
for split in ["train", "validation", "test"]:
    dataloader[split] = dataset.get_dataloader(split=split, batch_size=64, shuffle=True, min_len=1, max_len=128)


Filter:   0%|          | 0/450878 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/450878 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3003 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3003 [00:00<?, ? examples/s]

### Train

In [63]:
# create the transformer model
model = TransformerModel(vocab_size=tokenizer.get_vocab_size(), d_model=256, dim_feedforward=1024).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=512**-0.5, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda nstep: min((nstep + 1) ** -0.5, (nstep + 1) * 4000 ** -1.5))
loss_fn = nn.CrossEntropyLoss() # could add label smoothing

In [7]:
# load model
# model.load_state_dict(torch.load("model_1.pth"))

In [64]:
# free_memory("model")
free_memory()
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |    913 MiB |   7976 MiB | 210928 GiB | 210927 GiB |
|       from large pool |    656 MiB |   7717 MiB | 205547 GiB | 205546 GiB |
|       from small pool |    256 MiB |    418 MiB |   5381 GiB |   5381 GiB |
|---------------------------------------------------------------------------|
| Active memory         |    913 MiB |   7976 MiB | 210928 GiB | 210927 GiB |
|       from large pool |    656 MiB |   7717 MiB | 205547 GiB | 205546 GiB |
|       from small pool |    256 MiB |    418 MiB |   5381 GiB |   5381 GiB |
|---------------------------------------------------------------

In [65]:
# create the transformer wrapper
transformer = Transformer(model, tokenizer)

In [66]:
transformer.train(dataloader, model, loss_fn, optimizer, scheduler)

-------------------------------
Epoch 1/1
Accuracy: 0.0%, Avg loss: 10.600394  [    64/450878]  [0:00:00 < 1:07:33]
Accuracy: 6.1%, Avg loss: 9.508562  [  6464/450878]  [0:00:11 < 0:13:04]
Accuracy: 9.2%, Avg loss: 8.594700  [ 12864/450878]  [0:00:21 < 0:12:21]
Accuracy: 10.1%, Avg loss: 7.280375  [ 19264/450878]  [0:00:32 < 0:12:00]
Accuracy: 12.1%, Avg loss: 6.516609  [ 25664/450878]  [0:00:42 < 0:11:43]
Accuracy: 15.1%, Avg loss: 6.074619  [ 32064/450878]  [0:00:52 < 0:11:31]
Accuracy: 16.4%, Avg loss: 6.044272  [ 38464/450878]  [0:01:03 < 0:11:18]
Accuracy: 16.7%, Avg loss: 5.908104  [ 44864/450878]  [0:01:13 < 0:11:05]
Accuracy: 17.3%, Avg loss: 5.831823  [ 51264/450878]  [0:01:25 < 0:11:04]
Accuracy: 19.1%, Avg loss: 5.533202  [ 57664/450878]  [0:01:35 < 0:10:50]
Accuracy: 19.4%, Avg loss: 5.604733  [ 64064/450878]  [0:01:45 < 0:10:38]
Accuracy: 20.9%, Avg loss: 5.427184  [ 70464/450878]  [0:01:55 < 0:10:25]
Accuracy: 22.5%, Avg loss: 5.256050  [ 76864/450878]  [0:02:06 < 0:10:14

### Evaluate

In [72]:
sample = dataset.dataset["test"]["translation"][5]
transformer.predict(sample["de"], sample["en"])

Accuracy: 26.1%
[32mThe[39m
The [31msame[39m
The Kl [31ms[39m
The Kl user [31mand[39m
The Kl user lights [31mand[39m
The Kl user lights protect [31mthe[39m
The Kl user lights protect cycl [31mand[39m
The Kl user lights protect cycl ists [31mand[39m
The Kl user lights protect cycl ists , [31mand[39m
The Kl user lights protect cycl ists , as [32mwell[39m
The Kl user lights protect cycl ists , as well [32mas[39m
The Kl user lights protect cycl ists , as well as [31mthe[39m
The Kl user lights protect cycl ists , as well as those [31mwho[39m
The Kl user lights protect cycl ists , as well as those travelling [31mand[39m
The Kl user lights protect cycl ists , as well as those travelling by [31mthe[39m
The Kl user lights protect cycl ists , as well as those travelling by bus [31mh[39m
The Kl user lights protect cycl ists , as well as those travelling by bus and [32mthe[39m
The Kl user lights protect cycl ists , as well as those travelling by bus and the [31mo

In [73]:
print(transformer.translate("Ich bin ein Berliner."))

I have a pleasure to make a clear distinction .


In [74]:
for i in range(5):
    samples = dataset.dataset["test"]["translation"]
    idx = np.random.randint(len(samples))
    sample = samples[i]
    print(f"#{i+1}")
    print(f"Source: {sample['de']}")
    print(f"Target: {sample['en']}")
    print(f"Prediction: {transformer.translate(sample['de'])}")
    print()

#1
Source: Gutach: Noch mehr Sicherheit für Fußgänger
Target: Gutach: Increased safety for pedestrians
Prediction: The Council is also a more important point .

#2
Source: Sie stehen keine 100 Meter voneinander entfernt: Am Dienstag ist in Gutach die neue B 33-Fußgängerampel am Dorfparkplatz in Betrieb genommen worden - in Sichtweite der älteren Rathausampel.
Target: They are not even 100 metres apart: On Tuesday, the new B 33 pedestrian lights in Dorfparkplatz in Gutach became operational - within view of the existing Town Hall traffic lights.
Prediction: You are not going to have the same resources , which you know : ' You are not being used in the new form of the new waste - in the new Sea .

#3
Source: Zwei Anlagen so nah beieinander: Absicht oder Schildbürgerstreich?
Target: Two sets of lights so close to one another: intentional or just a silly error?
Prediction: Thirdly , what are the same kind of or whether or not they are they going to be a good example ?

#4
Source: Diese Fra

## DEBUG

In [14]:
for name in ["src_len", "tgt_len"]:
    len_list = dataset.dataset["train"][name]
    tot = sum(len_list)
    count = 0
    for num in len_list:
        if num <= 256:
            count += num
    print(f"count: {count}, tot: {tot}, percentage: {count/tot*100:.2f}%")

count: 14302156, tot: 14303264, percentage: 99.99%
count: 14339828, tot: 14340777, percentage: 99.99%


In [27]:
for batch in dataloader["train"]:
    x, x_mask, y, y_mask = batch.values()
    print(x.shape, x_mask.shape, y.shape, y_mask.shape)
    x = model.embedding(x.to(DEVICE))
    print(x.shape)
    break

torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 128])
torch.Size([64, 128, 256])
