# Reproduce Transformer from Attention is All You Need

## Preliminaries

In [2]:
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
from torch import nn
from dataset import Dataset
from tokenizer import get_tokenizer
from utils import NUM_PROC, DEVICE, free_memory, analyze_params, compare_params
from model import *
from transformer import Transformer

print("Number of processors: ", NUM_PROC)
print("Device: ", DEVICE)

Number of processors:  32
Device:  cuda


## Transformer from Scratch

Using the same hyperparameters as the base model in the paper. 

### Tokenizer

Byte-Pair Encoding with shared (English + German) vocabulary of 37000 tokens.

In [41]:
tokenizer = get_tokenizer(name="wmt14", language="de-en", vocab_size=37000)

Loaded tokenizer from ../tokenizer-wmt14-de-en.json


### Dataset

The dataset is downloaded at ~/.cache/huggingface/datasets/. I've turned off dataset caching to avoid disk explosion.

In [65]:
dataset = Dataset(name="wmt14", language="de-en", percentage=1)

In [66]:
# about 1 minute
dataset.tokenize(tokenizer)

Map (num_proc=32):   0%|          | 0/45088 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3000 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3003 [00:00<?, ? examples/s]

In [67]:
# about 5 minutes
dataloader = {}
for split in ["train", "validation", "test"]:
    dataset.dataset[split] = dataset.dataset[split].sort("src_len")
    dataloader[split] = dataset.get_dataloader(split=split, batch_size=64, shuffle=True, max_len=128)


Filter:   0%|          | 0/45088 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/45088 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3003 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/3003 [00:00<?, ? examples/s]

### Train

In [None]:
# create the transformer model
model = TransformerModel(vocab_size=tokenizer.get_vocab_size()).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=512**-0.5, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda nstep: min((nstep + 1) ** -0.5, (nstep + 1) * 4000 ** -1.5))
loss_fn = nn.CrossEntropyLoss() # could add label smoothing

In [58]:
# load model
load_model = "old_model/base_100%_e11.pth"
epoch = int(load_model.split(".pth")[0].split("_e")[1])
model.load_state_dict(torch.load(load_model))
num_steps_trained = int(4508785 / 64 * epoch)
for _ in range(num_steps_trained):
    scheduler.step()
print(f"Starting from step {num_steps_trained} with learning rate {scheduler.get_last_lr()[0]:f}")

Starting from step 774947 with learning rate 0.000050


In [59]:
# free_memory("model", "transformer")
free_memory()
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   1798 MiB |  10366 MiB |  92957 GiB |  92956 GiB |
|       from large pool |   1435 MiB |   9854 MiB |  91883 GiB |  91882 GiB |
|       from small pool |    362 MiB |    727 MiB |   1074 GiB |   1073 GiB |
|---------------------------------------------------------------------------|
| Active memory         |   1798 MiB |  10366 MiB |  92957 GiB |  92956 GiB |
|       from large pool |   1435 MiB |   9854 MiB |  91883 GiB |  91882 GiB |
|       from small pool |    362 MiB |    727 MiB |   1074 GiB |   1073 GiB |
|---------------------------------------------------------------

In [60]:
# create the transformer wrapper
transformer = Transformer(model, tokenizer)

In [None]:
# continue training on loaded model
# transformer.save("base_100%_e00.pth")
for i in range(10, 20):
    transformer.train(dataloader["train"], loss_fn, optimizer, scheduler, log_file="train.log")
    transformer.save(f"base_100%_e{i+1:02d}.pth")

In [54]:
transformer.validate(dataloader["validation"], loss_fn)

Validation: 
Accuracy: 48.0%, Avg loss: 3.466424


### Analyze

Initialization

In [10]:
module = TransformerModel(37000)
analyze_params(module)

Total number of parameters: 63082496
[32membedding.weight[39m
	(37000, 512)         torch.float32	param =   -0.0003240 +/-   0.9999905	grad = None
[32mencoder.layers.0.multi_head_attention.q_linear.weight[39m
	(512, 512)           torch.float32	param =   -0.0000179 +/-   0.0254906	grad = None
[32mencoder.layers.0.multi_head_attention.q_linear.bias[39m
	(512,)               torch.float32	param =    0.0004377 +/-   0.0257921	grad = None
[32mencoder.layers.0.multi_head_attention.k_linear.weight[39m
	(512, 512)           torch.float32	param =    0.0000482 +/-   0.0254886	grad = None
[32mencoder.layers.0.multi_head_attention.k_linear.bias[39m
	(512,)               torch.float32	param =   -0.0018594 +/-   0.0252353	grad = None
[32mencoder.layers.0.multi_head_attention.v_linear.weight[39m
	(512, 512)           torch.float32	param =    0.0000349 +/-   0.0254972	grad = None
[32mencoder.layers.0.multi_head_attention.v_linear.bias[39m
	(512,)               torch.float32	param =   -0

We observe that PyTorch initializes its layers with

-   Embedding:  $0\pm 1$

-   Linear: $0\pm 1 / \sqrt{3 d_{\rm in}}$

-   LayerNorm: $\gamma = 1,\ \beta = 0$

Gradient Behaviors

In [11]:
analyze_params(model)

Total number of parameters: 63082496
[32membedding.weight[39m
	(37000, 512)         torch.float32	param =   -0.0002504 +/-   1.0052953	grad = None
[32mencoder.layers.0.multi_head_attention.q_linear.weight[39m
	(512, 512)           torch.float32	param =   -0.0000068 +/-   0.0819763	grad = None
[32mencoder.layers.0.multi_head_attention.q_linear.bias[39m
	(512,)               torch.float32	param =   -0.0054485 +/-   0.0992457	grad = None
[32mencoder.layers.0.multi_head_attention.k_linear.weight[39m
	(512, 512)           torch.float32	param =   -0.0000100 +/-   0.0823582	grad = None
[32mencoder.layers.0.multi_head_attention.k_linear.bias[39m
	(512,)               torch.float32	param =   -0.0004820 +/-   0.0260727	grad = None
[32mencoder.layers.0.multi_head_attention.v_linear.weight[39m
	(512, 512)           torch.float32	param =   -0.0000113 +/-   0.0239310	grad = None
[32mencoder.layers.0.multi_head_attention.v_linear.bias[39m
	(512,)               torch.float32	param =   -0

Parameter Shift over Training

In [12]:
module1 = TransformerModel(vocab_size=tokenizer.get_vocab_size(), d_model=512)
module1.load_state_dict(torch.load("base_100%_e03.pth"))
module2 = TransformerModel(vocab_size=tokenizer.get_vocab_size(), d_model=512)
module2.load_state_dict(torch.load("base_100%_e00.pth"))
compare_params(module1, module2)

[32membedding.weight[39m
(37000, 512)        	param1 =   -0.0002504 +/-   1.0052953	param2 =   -0.0002411 +/-   0.9999496	diff(rms) =   0.1056992	diff(max) =   0.8115359
[32mencoder.layers.0.multi_head_attention.q_linear.weight[39m
(512, 512)          	param1 =   -0.0000068 +/-   0.0819763	param2 =    0.0000490 +/-   0.0255153	diff(rms) =   0.0778650	diff(max) =   0.3792256
[32mencoder.layers.0.multi_head_attention.q_linear.bias[39m
(512,)              	param1 =   -0.0054485 +/-   0.0992457	param2 =   -0.0018342 +/-   0.0252954	diff(rms) =   0.0979231	diff(max) =   0.2868303
[32mencoder.layers.0.multi_head_attention.k_linear.weight[39m
(512, 512)          	param1 =   -0.0000100 +/-   0.0823582	param2 =   -0.0000158 +/-   0.0255332	diff(rms) =   0.0782432	diff(max) =   0.3995680
[32mencoder.layers.0.multi_head_attention.k_linear.bias[39m
(512,)              	param1 =   -0.0004820 +/-   0.0260727	param2 =   -0.0005781 +/-   0.0258791	diff(rms) =   0.0031696	diff(max) =   0.0114

### Evaluate

In [27]:
rand_idx = np.random.randint(len(dataset.dataset["test"]["translation"]))
sample = dataset.dataset["test"]["translation"][rand_idx]
transformer.predict(sample["de"], sample["en"])

Accuracy: 66.7%
[31mIn[39m
This [32mis[39m
This is [31mthe[39m
This is with [32mregard[39m
This is with regard [32mto[39m
This is with regard to [32mthe[39m
This is with regard to the [32mquality[39m
This is with regard to the quality [32mof[39m
This is with regard to the quality of [32mthe[39m
This is with regard to the quality of the [32mproducts[39m
This is with regard to the quality of the products [31m,[39m
This is with regard to the quality of the products that [32mare[39m
This is with regard to the quality of the products that are [32moffered[39m
This is with regard to the quality of the products that are offered [32mhere[39m
This is with regard to the quality of the products that are offered here [32m,[39m
This is with regard to the quality of the products that are offered here , [32mas[39m
This is with regard to the quality of the products that are offered here , as [32mwell[39m
This is with regard to the quality of the products that are offer

In [28]:
print(transformer.translate("""Während die einzelnen Sprachen und Dialekte der germanischen Völker eigene Namen trugen – „Fränkisch“, „Gotisch“ usw. –, gab es daneben für den Gegensatz zwischen Latein und Volkssprache das Wort *þeudisk, das aber vom Anfang (786) bis ins Jahr 1000 nur in der mittellateinischen Form theodiscus überliefert wurde. Der Ursprung dieses Wortes liegt, wie Ähnlichkeiten in der Lautform zeigen, mit großer Wahrscheinlichkeit im westfränkischen (bzw. altniederländischen) Gebiet des Fränkischen Reichs.[3] Die Franken nannten ihre Sprache anfangs „frenkisk“ und die romanischen Sprachen wurden gemeinsam als *walhisk bezeichnet. Als aber im Verlauf des Frühmittelalters im zweisprachigen Westfrankenreich der politische und der sprachliche Begriff „fränkisch“ sich nicht mehr deckten, weil auch die romanischsprachige Bevölkerung sich als „fränkisch“ (vgl. neufranzösisch: français) bezeichnete, setzte sich hier das Wort *þeudisk für den sprachlichen Gegensatz zu *walhisk durch und fand ein Bedeutungswandel statt, wobei die Bedeutung sich von „Volkssprache“ in „germanisch statt romanisch“ änderte. Da im ostfränkischen Reich (dem späteren Deutschland) kein Anlass zu einem Bezeichnungswandel bestand, stellte sich dieser hier erst später ein, vielleicht nach westfränkischem Vorbild. Ganz allmählich wandelte sich damit bei theodiscus/*þeudisk die Bedeutung von „volkssprachlich“ über „germanisch“ und, viele Jahrhunderte später, letztendlich zu „Deutsch“.""", realtime=True))

The word “ Fr isc us ” – the word “ Fr än k ” – is in the form of the word “ Fr än ” – was first used in the Latin - Dutch language , and the Latin word “ Fr än ” – was only used in the form of the Latin word “ Fr än k ” ( 86 3 ) and the Latin word “ Fr än ” ( 86 3 ) in the form of the Latin word ) between the beginning of the beginning of the beginning of the year and the beginning of the year of the year , the year of the year , the year of the year , the year of the year , the 
The word “ Fr isc us ” – the word “ Fr än k ” – is in the form of the word “ Fr än ” – was first used in the Latin - Dutch language , and the Latin word “ Fr än ” – was only used in the form of the Latin word “ Fr än k ” ( 86 3 ) and the Latin word “ Fr än ” ( 86 3 ) in the form of the Latin word ) between the beginning of the beginning of the beginning of the year and the beginning of the year of the year , the year of the year , the year of the year , the year of the year , the


In [29]:
for i in range(3):
    samples = dataset.dataset["test"]["translation"]
    idx = np.random.randint(len(samples))
    sample = samples[idx]
    print(f"#{i+1}")
    print(f"Source: {sample['de']}")
    print(f"Target: {sample['en']}")
    print(f"Prediction: {transformer.translate(sample['de'])}")
    print()

#1
Source: Das sind etwa neun Milliarden Maiskörner.
Target: That's about 9 billion individual kernels of corn.
Prediction: These are approximately nine billion Ma isk ör ner .

#2
Source: Feuerwehr zur Rettung eines Hündchens gerufen, das 15 Meter über dem Boden auf einem gefährlichen Felsvorsprung in einem Steinbruch festsaß
Target: Fire crews called to rescue lost puppy after she got stuck 50ft above the ground on precarious ledge in a quarry
Prediction: The fire of a stone on the ground , which was founded in 15 meters above the rock , is a dangerous fire in front of a stone ' s throw ing in front of a stone .

#3
Source: Die Untersuchungsbeamten des Sheriffs von Lowndes County kamen zu dem Schluss, dass Johnson bei einem tragischen Unfall starb, was die Familie des 17-Jährigen jedoch anzweifelt.
Target: Lowndes County sheriff's investigators concluded Johnson died in a freak accident, but the 17-year-old's family disputes that.
Prediction: The family of the tragic accident of the 

## DEBUG

In [8]:
# dataset corpus length analysis
for name in ["src_len", "tgt_len"]:
    len_list = dataset.dataset["train"][name]
    tot = sum(len_list)
    count = 0
    for num in len_list:
        if num <= 64:
            count += num
    print(f"count: {count}, tot: {tot}, percentage: {count/tot*100:.2f}%")

count: 123676808, tot: 141799856, percentage: 87.22%
count: 122343081, tot: 139480626, percentage: 87.71%


In [13]:
# Total number of parameters
total = 0
for par in model.parameters():
    total += par.numel()
print(total)

63082496


In [14]:
# parameter distributions over the model
for par in model.parameters():
    print(f"{100 * par.numel() / total:.2f}% {par.shape}")

30.03% torch.Size([37000, 512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.00% torch.Size([512])
0.00% torch.Size([512])
1.66% torch.Size([2048, 512])
0.00% torch.Size([2048])
1.66% torch.Size([512, 2048])
0.00% torch.Size([512])
0.00% torch.Size([512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.00% torch.Size([512])
0.00% torch.Size([512])
1.66% torch.Size([2048, 512])
0.00% torch.Size([2048])
1.66% torch.Size([512, 2048])
0.00% torch.Size([512])
0.00% torch.Size([512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])
0.00% torch.Size([512])
0.42% torch.Size([512, 512])

### BLEU Score

In [40]:
result, ref, sys = transformer.evaluate_bleu(dataloader["validation"])
print(result)

  0%|          | 0/47 [00:00<?, ?it/s]

100%|██████████| 47/47 [01:50<00:00,  2.35s/it]
That's 100 lines that end in a tokenized period ('.')
It looks like you forgot to detokenize your test data, which may hurt your score.
If you insist your data is detokenized, or don't care, you can suppress this message with the `force` parameter.


BLEU = 12.15 50.6/18.2/7.6/3.4 (BP = 0.977 ratio = 0.977 hyp_len = 72510 ref_len = 74181)


In [179]:
# check the reference sentences and the predicted sentences
for i in range(5):
    print(i, ref[i])
    print(i, sys[i])

0 The free mar kete ers at the Re ason Foundation are also fond of having drivers pay per mile .
0 Also the idea of the free road to Re ason Foundation is to return to Re mark ter Foundation .
1 There were large quantities of wood and bal es of stra w stored inside .
1 It also made a lot of timber and all the timber .
2 " We need to have a better system ," he said .
2 “ We need a better system .”
3 The film never sli ps into pr ur ience or sens ational ism - and that ' s the problem .
3 The problem is its problem – never its film is in the way of the film and the sit t ings .
4 As ked if he would return to the post of prime minister , Mr Blair was quoted by London ' s Even ing Standard as saying : " Yes , sure , but it ' s not likely to happen is it , so ..."
4 The question is whether it is unlikely that the Prime Minister of London would return from the words of Prime Minister Blair , that is , but that is the standard of the “ standard ” that would return from London ...
