# Reproduce Transformer from Attention is All You Need

## Preliminaries

In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
from torch import nn
from dataset import Dataset
from tokenizer import get_tokenizer
from utils import NUM_PROC, DEVICE, free_memory, analyze_params, compare_params
from model import *
from transformer import Transformer

print("Number of processors: ", NUM_PROC)
print("Device: ", DEVICE)

Number of processors:  32
Device:  cuda


## Transformer from Scratch

Using the same hyperparameters as the base model in the paper. 

### Tokenizer

Byte-Pair Encoding with shared (English + German) vocabulary of 37000 tokens.

In [2]:
tokenizer = get_tokenizer(name="wmt14", language="de-en", vocab_size=37000)

Loaded tokenizer from ../tokenizer-wmt14-de-en.json


### Dataset

The dataset is downloaded at ~/.cache/huggingface/datasets/. I've turned off dataset caching to avoid disk explosion.

In [3]:
dataset = Dataset(name="wmt14", language="de-en", percentage=100)

In [4]:
# about 1 minute
dataset.tokenize(tokenizer)

Map (num_proc=32):   0%|          | 0/4508785 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3000 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3003 [00:00<?, ? examples/s]

In [5]:
# about 5 minutes
dataloader = {}
for split in ["train", "validation", "test"]:
    dataloader[split] = dataset.get_dataloader(split=split, batch_size=64, shuffle=True, min_len=1, max_len=128)


Filter:   0%|          | 0/4508785 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/4496706 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/2999 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3003 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3003 [00:00<?, ? examples/s]

### Train

In [6]:
# create the transformer model
model = TransformerModel(vocab_size=tokenizer.get_vocab_size()).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=512**-0.5, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda nstep: min((nstep + 1) ** -0.5, (nstep + 1) * 4000 ** -1.5))
loss_fn = nn.CrossEntropyLoss() # could add label smoothing

In [7]:
# load model
load_model = "base_100%_e10.pth"
epoch = int(load_model.split(".pth")[0].split("_e")[1])
model.load_state_dict(torch.load(load_model))
num_steps_trained = int(4508785 / 64 * epoch)
for _ in range(num_steps_trained):
    scheduler.step()
print(f"Starting from step {num_steps_trained} with learning rate {scheduler.get_last_lr()[0]:f}")



Starting from step 704497 with learning rate 0.000053


In [27]:
# free_memory("model", "transformer")
free_memory()
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   2382 MiB |   8390 MiB |  12058 TiB |  12058 TiB |
|       from large pool |   2013 MiB |   8165 MiB |  11930 TiB |  11930 TiB |
|       from small pool |    369 MiB |    441 MiB |    127 TiB |    127 TiB |
|---------------------------------------------------------------------------|
| Active memory         |   2382 MiB |   8390 MiB |  12058 TiB |  12058 TiB |
|       from large pool |   2013 MiB |   8165 MiB |  11930 TiB |  11930 TiB |
|       from small pool |    369 MiB |    441 MiB |    127 TiB |    127 TiB |
|---------------------------------------------------------------

In [8]:
# create the transformer wrapper
transformer = Transformer(model, tokenizer)

In [9]:
# transformer.save("base_100%_e00.pth")
for i in range(10, 20):
    transformer.train(dataloader, loss_fn, optimizer, scheduler)
    transformer.save(f"base_100%_e{i+1:02d}.pth")

-------------------------------
Epoch 1/1
Accuracy: 48.7%, Avg loss:   3.362488, Lr:   0.000053  [     64/4496706]  [0:00:01 < 36:25:12]
Accuracy: 50.0%, Avg loss:   3.347984, Lr:   0.000053  [   6464/4496706]  [0:00:17 < 3:23:28]
Accuracy: 48.9%, Avg loss:   3.434219, Lr:   0.000053  [  12864/4496706]  [0:00:32 < 3:09:17]
Accuracy: 48.0%, Avg loss:   3.577359, Lr:   0.000053  [  19264/4496706]  [0:00:47 < 3:05:03]
Accuracy: 48.8%, Avg loss:   3.431506, Lr:   0.000053  [  25664/4496706]  [0:01:02 < 3:02:05]
Accuracy: 47.3%, Avg loss:   3.524631, Lr:   0.000053  [  32064/4496706]  [0:01:18 < 3:01:18]
Accuracy: 44.3%, Avg loss:   3.816612, Lr:   0.000053  [  38464/4496706]  [0:01:36 < 3:06:02]
Accuracy: 48.4%, Avg loss:   3.501348, Lr:   0.000053  [  44864/4496706]  [0:01:53 < 3:07:37]
Accuracy: 48.0%, Avg loss:   3.508458, Lr:   0.000053  [  51264/4496706]  [0:02:10 < 3:08:02]
Accuracy: 44.8%, Avg loss:   3.772749, Lr:   0.000053  [  57664/4496706]  [0:02:25 < 3:06:41]
Accuracy: 43.3%, 

In [114]:
transformer.validate(dataloader["validation"], loss_fn)

Validation Error: 
 Accuracy: 42.4%, Avg loss: 4.099176 



### Analyze

Initialization

In [10]:
module = TransformerModel(37000)
analyze_params(module)

Total number of parameters: 63082496
[32membedding.weight[39m
	(37000, 512)         torch.float32	param =   -0.0003240 +/-   0.9999905	grad = None
[32mencoder.layers.0.multi_head_attention.q_linear.weight[39m
	(512, 512)           torch.float32	param =   -0.0000179 +/-   0.0254906	grad = None
[32mencoder.layers.0.multi_head_attention.q_linear.bias[39m
	(512,)               torch.float32	param =    0.0004377 +/-   0.0257921	grad = None
[32mencoder.layers.0.multi_head_attention.k_linear.weight[39m
	(512, 512)           torch.float32	param =    0.0000482 +/-   0.0254886	grad = None
[32mencoder.layers.0.multi_head_attention.k_linear.bias[39m
	(512,)               torch.float32	param =   -0.0018594 +/-   0.0252353	grad = None
[32mencoder.layers.0.multi_head_attention.v_linear.weight[39m
	(512, 512)           torch.float32	param =    0.0000349 +/-   0.0254972	grad = None
[32mencoder.layers.0.multi_head_attention.v_linear.bias[39m
	(512,)               torch.float32	param =   -0

We observe that PyTorch initializes its layers with

-   Embedding:  $0\pm 1$

-   Linear: $0\pm 1 / \sqrt{3 d_{\rm in}}$

-   LayerNorm: $\gamma = 1,\ \beta = 0$

Gradient Behaviors

In [11]:
analyze_params(model)

Total number of parameters: 63082496
[32membedding.weight[39m
	(37000, 512)         torch.float32	param =   -0.0002504 +/-   1.0052953	grad = None
[32mencoder.layers.0.multi_head_attention.q_linear.weight[39m
	(512, 512)           torch.float32	param =   -0.0000068 +/-   0.0819763	grad = None
[32mencoder.layers.0.multi_head_attention.q_linear.bias[39m
	(512,)               torch.float32	param =   -0.0054485 +/-   0.0992457	grad = None
[32mencoder.layers.0.multi_head_attention.k_linear.weight[39m
	(512, 512)           torch.float32	param =   -0.0000100 +/-   0.0823582	grad = None
[32mencoder.layers.0.multi_head_attention.k_linear.bias[39m
	(512,)               torch.float32	param =   -0.0004820 +/-   0.0260727	grad = None
[32mencoder.layers.0.multi_head_attention.v_linear.weight[39m
	(512, 512)           torch.float32	param =   -0.0000113 +/-   0.0239310	grad = None
[32mencoder.layers.0.multi_head_attention.v_linear.bias[39m
	(512,)               torch.float32	param =   -0

Parameter Shift over Training

In [12]:
module1 = TransformerModel(vocab_size=tokenizer.get_vocab_size(), d_model=512)
module1.load_state_dict(torch.load("base_100%_e03.pth"))
module2 = TransformerModel(vocab_size=tokenizer.get_vocab_size(), d_model=512)
module2.load_state_dict(torch.load("base_100%_e00.pth"))
compare_params(module1, module2)

[32membedding.weight[39m
(37000, 512)        	param1 =   -0.0002504 +/-   1.0052953	param2 =   -0.0002411 +/-   0.9999496	diff(rms) =   0.1056992	diff(max) =   0.8115359
[32mencoder.layers.0.multi_head_attention.q_linear.weight[39m
(512, 512)          	param1 =   -0.0000068 +/-   0.0819763	param2 =    0.0000490 +/-   0.0255153	diff(rms) =   0.0778650	diff(max) =   0.3792256
[32mencoder.layers.0.multi_head_attention.q_linear.bias[39m
(512,)              	param1 =   -0.0054485 +/-   0.0992457	param2 =   -0.0018342 +/-   0.0252954	diff(rms) =   0.0979231	diff(max) =   0.2868303
[32mencoder.layers.0.multi_head_attention.k_linear.weight[39m
(512, 512)          	param1 =   -0.0000100 +/-   0.0823582	param2 =   -0.0000158 +/-   0.0255332	diff(rms) =   0.0782432	diff(max) =   0.3995680
[32mencoder.layers.0.multi_head_attention.k_linear.bias[39m
(512,)              	param1 =   -0.0004820 +/-   0.0260727	param2 =   -0.0005781 +/-   0.0258791	diff(rms) =   0.0031696	diff(max) =   0.0114

### Evaluate

In [14]:
rand_idx = np.random.randint(len(dataset.dataset["test"]["translation"]))
sample = dataset.dataset["test"]["translation"][rand_idx]
transformer.predict(sample["de"], sample["en"])

Accuracy: 48.0%
[31mThomas[39m
Mayor [32mThomas[39m
Mayor Thomas [31mHir[39m
Mayor Thomas Ha [32mas[39m
Mayor Thomas Ha as [31mspoke[39m
Mayor Thomas Ha as ret [31mired[39m
Mayor Thomas Ha as ret orted [32m:[39m
Mayor Thomas Ha as ret orted : [31m"[39m
Mayor Thomas Ha as ret orted : The [32m"[39m
Mayor Thomas Ha as ret orted : The " [31mlong[39m
Mayor Thomas Ha as ret orted : The " Hir [32mschen[39m
Mayor Thomas Ha as ret orted : The " Hir schen [32m"[39m
Mayor Thomas Ha as ret orted : The " Hir schen " [31mfor[39m
Mayor Thomas Ha as ret orted : The " Hir schen " railway [31mtransport[39m
Mayor Thomas Ha as ret orted : The " Hir schen " railway crossing [31mthe[39m
Mayor Thomas Ha as ret orted : The " Hir schen " railway crossing is [31mregularly[39m
Mayor Thomas Ha as ret orted : The " Hir schen " railway crossing is used [32mregularly[39m
Mayor Thomas Ha as ret orted : The " Hir schen " railway crossing is used regularly [32mfor[39m
Mayor Thomas Ha

In [15]:
print(transformer.translate("Englisch"))

English


In [20]:
for i in range(3):
    samples = dataset.dataset["test"]["translation"]
    idx = np.random.randint(len(samples))
    sample = samples[idx]
    print(f"#{i+1}")
    print(f"Source: {sample['de']}")
    print(f"Target: {sample['en']}")
    print(f"Prediction: {transformer.translate(sample['de'])}")
    print()

#1
Source: Herr Max Maier, bitte kommen Sie zu Gate 24.
Target: Mr. Max Maier, please make your way to Gate 24.
Prediction: Mr Max Ma ier , please come to Max Gate 24 .

#2
Source: Bombardier erklärte, es überprüfe die Planung für die Inbetriebnahme (EIS) und werde diese in den nächsten Monaten aktualisieren.
Target: Bombardier said it was evaluating the entry-into-service (EIS) schedule and will provide an update in the next few months.
Prediction: E IS declared that it would update the planning for the next few months ( and it will be over the next months ) and update it .

#3
Source: Diese Fahrer werden bald die Meilengebühren statt der Mineralölsteuer an den Bundesstaat zahlen.
Target: Those drivers will soon pay the mileage fees instead of gas taxes to the state.
Prediction: These drivers will soon pay tax fees for the mineral oil in the eastern part of the town .



## DEBUG

In [8]:
# dataset corpus length analysis
for name in ["src_len", "tgt_len"]:
    len_list = dataset.dataset["train"][name]
    tot = sum(len_list)
    count = 0
    for num in len_list:
        if num <= 64:
            count += num
    print(f"count: {count}, tot: {tot}, percentage: {count/tot*100:.2f}%")

count: 123676808, tot: 141799856, percentage: 87.22%
count: 122343081, tot: 139480626, percentage: 87.71%


In [13]:
# Total number of parameters
total = 0
for par in model.parameters():
    total += par.numel()
print(total)

63082496


In [14]:
# parameter distributions over the model
for par in model.parameters():
    print(f"{100 * par.numel() / total:.2f}% {par.shape}")

30.03% torch.Size([37000, 512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.00% torch.Size([512])
0.00% torch.Size([512])
1.66% torch.Size([2048, 512])
0.00% torch.Size([2048])
1.66% torch.Size([512, 2048])
0.00% torch.Size([512])
0.00% torch.Size([512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.00% torch.Size([512])
0.00% torch.Size([512])
1.66% torch.Size([2048, 512])
0.00% torch.Size([2048])
1.66% torch.Size([512, 2048])
0.00% torch.Size([512])
0.00% torch.Size([512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])

### BLEU Score

In [23]:
result, ref, sys = transformer.evaluate_bleu(dataloader["test"])
print(result)

  0%|          | 0/47 [00:00<?, ?it/s]

100%|██████████| 47/47 [02:52<00:00,  3.67s/it]
That's 100 lines that end in a tokenized period ('.')
It looks like you forgot to detokenize your test data, which may hurt your score.
If you insist your data is detokenized, or don't care, you can suppress this message with the `force` parameter.


BLEU = 8.38 47.0/14.2/5.0/1.9 (BP = 0.942 ratio = 0.943 hyp_len = 74447 ref_len = 78909)


In [179]:
# check the reference sentences and the predicted sentences
for i in range(5):
    print(i, ref[i])
    print(i, sys[i])

0 The free mar kete ers at the Re ason Foundation are also fond of having drivers pay per mile .
0 Also the idea of the free road to Re ason Foundation is to return to Re mark ter Foundation .
1 There were large quantities of wood and bal es of stra w stored inside .
1 It also made a lot of timber and all the timber .
2 " We need to have a better system ," he said .
2 “ We need a better system .”
3 The film never sli ps into pr ur ience or sens ational ism - and that ' s the problem .
3 The problem is its problem – never its film is in the way of the film and the sit t ings .
4 As ked if he would return to the post of prime minister , Mr Blair was quoted by London ' s Even ing Standard as saying : " Yes , sure , but it ' s not likely to happen is it , so ..."
4 The question is whether it is unlikely that the Prime Minister of London would return from the words of Prime Minister Blair , that is , but that is the standard of the “ standard ” that would return from London ...
