# Simple implementation of Transformer, GPT and BERT architectures

This is an implementation of the original transformer paper [Attention is All You Need](https://arxiv.org/abs/1706.03762) from scratch.
All section are provided with references to chapters of the paper where corresponding logic is described

Additionally, **GPT** architecture is implemented for comparison.

## Links

* [2017 Attention is All You Need](https://arxiv.org/abs/1706.03762)
* [2018 Improving Language Understanding by Generative Pre-Training](https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import glob
import re

import torch
import numpy as np
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import seaborn as sns


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

EPOCHS = 1
MAX_STEPS = 1500
BATCH_SIZE = 32

POS_EMB_MAX_PERIOD = 10000
BASE_MAX_SEQ_LEN = 256
VOCAB_SIZE = 40000
DATA_DIR = 'data/'

DATASET = "wmt14"
LANGUAGE_PAIR = 'ru-en'


RESUME_TRAIN = False
TRAIN_TRANSFORMER = True
TRAIN_GPT = True

try:
    # to run in colab link your Google Drive to the colab node and set up the cache directory.
    from google import colab
    DATA_DIR = '/content/drive/MyDrive/ML/transformer'
except:
    pass



In [None]:
!pip install -q datasets
!pip install -q sentencepiece
!pip install -q wandb

In [None]:
import wandb as wandb

# Dataset

### Parsing WMT data with Sentence Piece

Sentence Piece is a Google's implementation of BPE (byte-pair encoding) tokenizer.

![bpe](transformer/bpe.png)

In [None]:
from datasets import load_dataset

dataset = load_dataset(DATASET, LANGUAGE_PAIR, cache_dir=f'{DATA_DIR}/{DATASET}')

In [None]:
dataset

In [None]:
import os


if not os.path.exists(f'{DATA_DIR}/{DATASET}_train.txt'):
    !mkdir -p data
    with open(f'{DATA_DIR}/{DATASET}_train.txt', 'w+') as f:
        for step, data in enumerate(tqdm(dataset['train'], miniters=10000)):
            f.write((data['translation']['en'] + ' ' + data['translation']['ru'] + "\n").lower())

In [None]:
import sentencepiece as spm
if not os.path.exists(f'{DATA_DIR}/{DATASET}.model'):
    spm.SentencePieceTrainer.train(input=f'{DATA_DIR}/{DATASET}_train.txt', model_prefix=f'{DATA_DIR}/{DATASET}', vocab_size=VOCAB_SIZE)

In [None]:
sentence_piece = spm.SentencePieceProcessor(model_file=f'{DATA_DIR}/{DATASET}.model')

In [None]:
codes = sentence_piece.encode('this is a test', out_type=str, add_bos=True, add_eos=True)
codes

In [None]:
codes = sentence_piece.encode('this is a test', add_bos=True, add_eos=True)
codes

In [None]:
sentence_piece.decode(codes)

## Dataset class

Generate tokens along with attention and padding masks for encoding, decoding and target sequences.

In [None]:
import pickle
from torch.utils.data import Dataset, DataLoader


class TransformerDataset(Dataset):
    """
    Generates pairs of (encoded_input_sentence, encoded_translation)
    """

    def __init__(self, wmt_data, sentence_piece, cache_file):
        super().__init__()
        self.wmt_data, self.sentence_piece, self.cache_file = wmt_data, sentence_piece, cache_file
        self.data = None


    def preprocess_wmt(self, wmt_data, sp, fname):
        if not os.path.exists(self.cache_file):
            data = []
            for r in tqdm(wmt_data, miniters=1000):
                en_code = sp.encode(r['translation']['en'].lower().strip())
                ru_code = sp.encode(r['translation']['ru'].lower().strip())
                data.append([
                    en_code,
                    ru_code,
                    max(len(en_code), len(ru_code))
                ])
            self.data = sorted(data, key=lambda d: d[2])
            with open(fname, 'w+b') as f:
                pickle.dump(self.data, f)
        else:
            with open(self.cache_file, 'rb') as f:
                self.data = pickle.load(f)

    def __getitem__(self, i):
        if self.data is None:
            self.preprocess_wmt(self.wmt_data, self.sentence_piece, self.cache_file)
        return self.data[i][0], self.data[i][1]

    def __len__(self):
        if self.data is None:
            self.preprocess_wmt(self.wmt_data, self.sentence_piece, self.cache_file)
        return len(self.data)


In [None]:
ds_train = TransformerDataset(dataset['train'], sentence_piece, f'{DATA_DIR}/ds_wmt_train.bin')
ds_validation = TransformerDataset(dataset['validation'], sentence_piece, f'{DATA_DIR}/ds_wmt_validation.bin')
ds_test = TransformerDataset(dataset['test'], sentence_piece, f'{DATA_DIR}/ds_wmt_test.bin')

In [None]:
def wmt_collate(input, max_seq_len, bos, eos):
    encoder_input = [i[0] for i in input]
    decoder_input = [i[1] for i in input]
    # bos, eos
    encoder_input = [[bos] + i[:max_seq_len - 2] + [eos] for i in encoder_input]
    decoder_input = [[bos] + i[:max_seq_len - 1] for i in decoder_input]
    # masks
    dec_batch_seq_len = np.max([len(i1) for i1 in decoder_input])
    dec_mask = torch.triu(torch.ones((dec_batch_seq_len, dec_batch_seq_len)) * float('-inf'), diagonal=1)
    enc_padding_mask = torch.nn.utils.rnn.pad_sequence([torch.zeros(len(enc)) for enc in encoder_input], batch_first=True,
                                            padding_value=float('-inf'))
    dec_padding_mask = torch.nn.utils.rnn.pad_sequence([torch.zeros(len(dec)) for dec in decoder_input], batch_first=True,
                                                   padding_value=float('-inf'))
    # pad input sequences
    target = torch.nn.utils.rnn.pad_sequence(
        [torch.tensor(d[1:] + [eos]) for d in decoder_input],
        batch_first=True,
        padding_value=0
    )
    encoder_input = torch.nn.utils.rnn.pad_sequence(
        [torch.tensor(t) for t in encoder_input],
        batch_first=True,
        padding_value=0
    )
    decoder_input = torch.nn.utils.rnn.pad_sequence(
        [torch.tensor(d) for d in decoder_input],
        batch_first=True,
        padding_value=0
    )
    return (encoder_input, decoder_input), target, (enc_padding_mask, dec_padding_mask, dec_mask)


dl_train = DataLoader(ds_train, shuffle=True,
                      batch_size=BATCH_SIZE,
                      collate_fn=lambda d: wmt_collate(d, max_seq_len=BASE_MAX_SEQ_LEN,
                                                       bos=sentence_piece.bos_id(),
                                                       eos=sentence_piece.eos_id())
                      )


generator = torch.Generator()
generator.manual_seed(24)

dl_validate = DataLoader(
    ds_validation,
    shuffle=True,
    generator=generator,
    batch_size=BATCH_SIZE,
    collate_fn=lambda d: wmt_collate(d,
                                     max_seq_len=BASE_MAX_SEQ_LEN,
                                     bos=sentence_piece.bos_id(),
                                     eos=sentence_piece.eos_id())
)

for (enc, dec), target, (enc_padding_mask, dec_padding_mask, dec_mask) in dl_validate:
    print('inputs')
    print(enc)
    print(dec)
    print('target')
    print(target)
    print('masks')
    print(enc_padding_mask)
    print(dec_padding_mask)
    print(dec_mask)

    plt.figure()
    ax = sns.heatmap(torch.broadcast_to(enc_padding_mask[:1], (68, 68)), cmap='Reds')
    ax.set_title('Encoder padding mask example')

    plt.figure()
    ax = sns.heatmap(torch.broadcast_to(dec_padding_mask[:1], (68, 68)), cmap='Reds')
    ax.set_title('Decoder padding mask example')


    plt.figure()
    ax = sns.heatmap(dec_mask, cmap='Reds')
    ax.set_title('Attention mask example')
    break

In [None]:
batches = tqdm(dl_validate, miniters=100)
for i, ((enc, dec), target, (enc_padding_mask, dec_padding_mask, dec_mask)) in enumerate(batches):
    batches.set_description(f'X:{len(enc)}')
    if i > 1000:
        break

# Transformer (simplementation of "Attention is All You Need" paper)

Transformer architecture implemented in this section

![transformer](transformer/transformer.png)

## Positional Encodings

![pos-enc](transformer/positional_encoding.png)

In [None]:

from matplotlib import pyplot as plt
import pandas as pd


def positional_encoding(positions, dmodel):
    embeddings = []
    for i in range(dmodel // 2):
        embeddings.append(torch.sin(torch.tensor(range(positions)) / (POS_EMB_MAX_PERIOD ** (2 * i / dmodel))))
        embeddings.append(torch.cos(torch.tensor(range(positions)) / (POS_EMB_MAX_PERIOD ** (2 * i / dmodel))))
    return torch.stack(embeddings).transpose(1, 0)


pd.DataFrame(positional_encoding(64, 4).numpy()).plot()

In [None]:
import seaborn as sns
plt.figure(figsize=(20, 20))
ax = sns.heatmap(positional_encoding(128, 128))
ax.set_aspect(1.)

## Embeddings

![emb](transformer/emb.png)

In [None]:
import torch

class PEEmbedding(torch.nn.Module):

    pos_encoding: torch.Tensor

    def __init__(self, max_tokens, dmodel, max_seq_len, dropout_rate, verbose=False):
        super().__init__()
        self.verbose = verbose
        self.emb = torch.nn.Embedding(max_tokens, dmodel)
        self.pos_encoding = torch.nn.parameter.Parameter(positional_encoding(max_seq_len, dmodel).unsqueeze(0), requires_grad=False)
        self.dropout = torch.nn.Dropout(p=dropout_rate)

    def forward(self, tokens):  # B * Seq
        embs = self.emb(tokens)  # B * Seq * dmodel
        pe = self.pos_encoding[:, :tokens.shape[-1], :]  # 1 * Seq * dmodel
        if self.verbose:
            print('PEEmbeding::forward', 'embs:', embs.shape, 'pe:', pe.shape)
        return self.dropout(embs  + pe)


emb = PEEmbedding(10, 20, 20, 0.1, verbose=True)
emb.eval()

emb(torch.tensor([[2, 1, 0, 3, 0]]))

In [None]:
sns.heatmap(emb(torch.tensor([[0, 2, 1, 3, 4, 2, 1, 0, 3, 0, 2, 1, 0, 3, 0, 2, 1, 0, 3, 0]])).detach().numpy()[0])

## Attention

Scaled dot-product attention:
![abc](transformer/scaled-attention.png)
![abc](transformer/scaled-attention-desc.png)


In [None]:
# test: q * k == attention

torch.einsum('bSd,bsd->bSs',
             # Q
             torch.tensor([[
                 [1, 1],
                 [0, -1],
             ]]),
             # K
             torch.tensor([[
                 [1, 0],
                 [0, 1],
                 [-1, 0],
             ]]))

In [None]:
# test: multiplying attention matrix on value matirx
torch.einsum(
    'bSs,bsd->bSd',
    torch.tensor([[
        [0, 0, 1, 1],
        [0, 0, 0, 1],
        [1, 0, 0, 0],
        [1, 1, 0, 0]
    ], [
        [0, 0, 1, 0],
        [0, 0, 0, 1],
        [1, 0, 0, 0],
        [0, 1, 0, 0]
    ]]),
    torch.tensor([[
        [1, 1],
        [2, 2],
        [3, 3],
        [4, 4]
    ]] * 2)
)

In [None]:
v = torch.tensor([[[i] * 2 for i in range(4)]], dtype=torch.float)
v

In [None]:
k = torch.tensor([[[1. if i in [2] else -1e6] * 2 for i in range(4)]] * 1)
k

In [None]:
q = torch.ones((1, 3, 2))
q

In [None]:
torch.ones(2, 3).unsqueeze(1)

In [None]:
import math


class Attention(torch.nn.Module):

    def __init__(self, dmodel, verbose=False):
        super().__init__()
        self.dmodel = dmodel
        self.verbose = verbose

    def forward(self, q, k, v, k_padding_mask=None, attention_mask=None):
        if self.verbose:
            print('q, k, v')
            print(q.shape, k.shape, v.shape)
        attention = torch.einsum('bSd,bsd->bSs', q, k)
        scaled_attention = attention / math.sqrt(self.dmodel)  # B * Q_seq * K_seq
        if k_padding_mask is not None:
            scaled_attention += k_padding_mask.unsqueeze(1) # B * 1 * K_seq

        if attention_mask is not None:    # Q_seq * K_seq
            scaled_attention += attention_mask.unsqueeze(0)  # 1 * Q_seq * K_seq

        if self.verbose:
            print('Scaled attention')
            print(scaled_attention)
        softmaxed_attention = torch.softmax(scaled_attention, 2)

        if self.verbose:
            print('softmax attention')
            print(softmaxed_attention)
        return torch.einsum('bSs,bsd->bSd', softmaxed_attention, v)


sa = Attention(2, verbose=True)
sa(q, k, v, k_padding_mask=torch.zeros(1, 4), attention_mask=torch.zeros(3, 4))

## MultiHeadSelfAttention

![multihead](transformer/multihead.png)


In [None]:
class MultiHeadSelfAttention(torch.nn.Module):
    def __init__(self, heads, dmodel, verbose=False):
        super().__init__()
        self.heads = heads

        # no trainable parameters, can get away with a single instance
        self.self_attention = Attention(dmodel // heads, verbose=verbose)

        self.w_q = torch.nn.ModuleList([
            torch.nn.Linear(dmodel, dmodel // heads, bias=False)
            for _ in range(heads)
        ])
        self.w_k = torch.nn.ModuleList([
            torch.nn.Linear(dmodel, dmodel // heads, bias=False)
            for _ in range(heads)
        ])
        self.w_v = torch.nn.ModuleList([
            torch.nn.Linear(dmodel, dmodel // heads, bias=False)
            for _ in range(heads)
        ])
        self.w_o = torch.nn.Linear(dmodel, dmodel, bias=False)

    def forward(self, v, k, q, k_padding_mask=None, attention_mask=None):
        v_out = []
        for head in range(self.heads):
            v_out.append(self.self_attention(self.w_q[head](q), self.w_k[head](k), self.w_v[head](v), k_padding_mask, attention_mask))
        cat = torch.concat(v_out, dim=2)
        o = self.w_o(cat)
        return o


q = torch.ones((2, 3, 2))
k = torch.tensor([[[1. if i in [2] else -1e6] * 2 for i in range(4)]] * 2)
v = torch.tensor([[[i] * 2 for i in range(4)]] * 2, dtype=torch.float)

print(q.shape, k.shape, v.shape)
mhsa = MultiHeadSelfAttention(2, 2, False)
mhsa(v, k, q, k_padding_mask=torch.zeros(2, 4), attention_mask=torch.zeros(3, 4))

## Encoder transformer block

![encoder](transformer/encoder.png)

In [None]:
class EncoderTransformerBlock(torch.nn.Module):

    def __init__(self, dmodel, heads, inner_layer_dim, dropout_rate, verbose=False):
        super().__init__()
        self.attention = MultiHeadSelfAttention(heads, dmodel, verbose=verbose)
        self.dropout1 = torch.nn.Dropout(p=dropout_rate)
        self.layer_norm1 = torch.nn.LayerNorm(dmodel)

        self.ffn = torch.nn.Sequential(
            torch.nn.Linear(dmodel, inner_layer_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(inner_layer_dim, dmodel)
        )
        self.dropout2 = torch.nn.Dropout(p=dropout_rate)
        self.layer_norm2 = torch.nn.LayerNorm(dmodel)

    def forward(self, x, padding_mask=None):
        x = self.layer_norm1(x + self.dropout1(self.attention(x, x, x, k_padding_mask=padding_mask)))
        x = self.layer_norm2(x + self.dropout2(self.ffn(x)))
        return x


etb = EncoderTransformerBlock(4, 2, 8, 0.1)
etb.eval()
etb(torch.ones(2, 8, 4), padding_mask=torch.zeros(2, 8))

## Decoder transformer block

![decoder](transformer/decoder.png)

In [None]:
class DecoderTransformerBlock(torch.nn.Module):

    def __init__(self, dmodel, heads, inner_layer_dim, dropout_rate, verbose=False):
        super().__init__()
        # sublayer 1, self attention
        self.attention1 = MultiHeadSelfAttention(heads, dmodel, verbose=verbose)
        self.dropout1 = torch.nn.Dropout(p=dropout_rate)
        self.layer_norm1 = torch.nn.LayerNorm(dmodel)

        # sublayer 2, cross attention
        self.attention2 = MultiHeadSelfAttention(heads, dmodel, verbose=verbose)
        self.dropout2 = torch.nn.Dropout(p=dropout_rate)
        self.layer_norm2 = torch.nn.LayerNorm(dmodel)

        # sublayer 3
        self.ffn = torch.nn.Sequential(
            torch.nn.Linear(dmodel, inner_layer_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(inner_layer_dim, dmodel)
        )
        self.dropout3 = torch.nn.Dropout(p=dropout_rate)
        self.layer_norm3 = torch.nn.LayerNorm(dmodel)

    def forward(self, x_enc, x, enc_padding_mask, dec_padding_mask, dec_attention_mask):
        x = self.layer_norm1(x + self.dropout1(self.attention1(x, x, x, k_padding_mask=dec_padding_mask, attention_mask=dec_attention_mask)))
        x = self.layer_norm2(x + self.dropout2(self.attention2(x_enc, x_enc, x, k_padding_mask=enc_padding_mask)))
        x = self.layer_norm3(x + self.dropout3(self.ffn(x)))
        return x


dtb = DecoderTransformerBlock(4, 2, 8, 0.1)
dtb.eval()
dtb(
    torch.ones((1, 8, 4)),
    torch.ones((1, 8, 4)),
    torch.zeros(1, 8),
    torch.zeros(1, 8),
    torch.zeros(8, 8)
)

## Transformer

Main transformer model

![transformer](transformer/transformer.png)

In [None]:
# test: selecting best fitting token for a sequence of embeddings
b = torch.tensor([
    [1, 1],
    [-1, -1],
])
torch.einsum('bsd,td->bst', torch.tensor([[
    [1, 1],
    [0, 1],
    [-1, 0],
    [-1, -1],
]]), b).argmax(dim=2)

In [None]:
class Transformer(torch.nn.Module):

    def __init__(self, n_blocks, dmodel, heads, inner_layer_dim, dropout_rate, max_tokens, max_seq_len, bos=1, eos=2, verbose=False):
        super().__init__()
        self.bos = bos
        self.eos = eos
        self.verbose = verbose
        self.max_seq_len = max_seq_len
        self.emb = PEEmbedding(max_tokens, dmodel, max_seq_len, dropout_rate, verbose=verbose)
        self.encoder_blocks = torch.nn.ModuleList([
            EncoderTransformerBlock(dmodel, heads, inner_layer_dim, dropout_rate, verbose=verbose)
            for _ in range(n_blocks)
        ])
        self.decoder_blocks = torch.nn.ModuleList([
            DecoderTransformerBlock(dmodel, heads, inner_layer_dim, dropout_rate, verbose=verbose)
            for _ in range(n_blocks)
        ])



    def forward(self, encoder_input, decoder_input, encoder_padding_mask=None, decoder_padding_mask=None, dec_mask=None):
        # encoder stack
        encoder_input = self.emb(encoder_input)
        for b in self.encoder_blocks:
            encoder_input = b(encoder_input, encoder_padding_mask)
        encoder_output = encoder_input

        # decoder stack
        decoder_input = self.emb(decoder_input)
        for b in self.decoder_blocks:
            decoder_input = b(encoder_output, decoder_input, encoder_padding_mask, decoder_padding_mask, dec_mask)  # batch x seq x dmodel
        w = self.emb.emb.weight  # tokens*dmodel
        y = torch.einsum('bsd,td->bst', decoder_input, w)
        return y


test_transformer = Transformer(
    n_blocks=2,
    dmodel=8,
    heads=2,
    inner_layer_dim=16,
    dropout_rate=0.1,
    max_tokens=10,
    max_seq_len=7,
    verbose=True
)


test_transformer(torch.randint(0, 10, (2, 3)), torch.randint(0, 10, (2, 2)))

## Training loop (translation task)

In [None]:
def masked_loss_reduced(loss, pred, target, target_padding_mask):
    target_mask = torch.where(target_padding_mask == 0, 1, 0)
    losses = loss(pred.view(-1, pred.shape[-1]), target.view(-1)).view(target.shape) * target_mask
    return losses.sum() / target_mask.sum()


In [None]:
def get_lr(step_num, dmodel=512, warmup_steps=4000):
    return dmodel ** (-0.5) * min(step_num ** (-0.8), step_num * warmup_steps ** (-1.5))


# schedule described in the paper
def plot_scheduler():
    steps = np.arange(1, 100000)
    lr = [get_lr(dmodel=512, step_num=s, warmup_steps=4000) for s in steps]
    plt.plot(steps, lr)

plot_scheduler()


In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()
step = 0

# "base" model from "Attention is All You Need"
transformer = Transformer(
    n_blocks=6,
    dmodel=512,
    heads=8,
    inner_layer_dim=2048,
    dropout_rate=0.1,
    max_tokens=VOCAB_SIZE,
    max_seq_len=BASE_MAX_SEQ_LEN,  # not mentioned in the paper
    verbose=False,
)
transformer.to(DEVICE)
print(sum(p.numel() for p in transformer.parameters() if p.requires_grad))
transformer.train()

transformer

In [None]:
if RESUME_TRAIN:
    model_path = glob.glob(f'{DATA_DIR}/transformer_*.model')
    if len(model_path) > 0:
        step = int(re.findall('_(\d+).model', model_path[0])[0])
        transformer.load_state_dict(torch.load(model_path[0]))
        print('Resuming training from', model_path[0], 'step:', step)


In [None]:
optimizer = torch.optim.Adam(transformer.parameters(), betas=(0.9, 0.98), eps=1e-9)
loss = torch.nn.CrossEntropyLoss(reduction='none', label_smoothing=0.1)


In [None]:
if TRAIN_TRANSFORMER:
    wandb_run = wandb.init(project="transformer-simplementation")
    wandb_run.display()

    with wandb_run:
        samples = wandb.Table(columns=['prediction', 'decoder_input', 'prediction_codes', 'encoder_input'])
        for epoch in tqdm(range(EPOCHS), desc='Epoch'):
            steps = tqdm(dl_train)
            for (enc_input, dec_input), target, (enc_padding_mask, dec_padding_mask, dec_mask) in steps:
                if step > MAX_STEPS:
                    break
                new_lr = get_lr(step + 1)
                for pg in optimizer.param_groups:
                    pg['lr'] = new_lr

                pred = transformer(enc_input.to(DEVICE), dec_input.to(DEVICE), enc_padding_mask.to(DEVICE), dec_padding_mask.to(DEVICE), dec_mask.to(DEVICE))
                loss_value = masked_loss_reduced(loss, pred, target.to(DEVICE), dec_padding_mask.to(DEVICE))
                optimizer.zero_grad()
                loss_value.backward()

                optimizer.step()
                steps.set_description(f'loss:{loss_value.item():.02f} lr:{new_lr:.05f}')

                if step % 1000 == 0:
                    transformer.eval()
                    pred = transformer(enc_input.to(DEVICE), dec_input.to(DEVICE), enc_padding_mask.to(DEVICE), dec_padding_mask.to(DEVICE), dec_mask.to(DEVICE))
                    print('Step', step)
                    encoder_input = sentence_piece.decode(enc_input[0].detach().cpu().numpy().tolist())
                    decoder_input = sentence_piece.decode(dec_input[0].detach().cpu().numpy().tolist())
                    print(decoder_input)
                    prediction_codes = pred[0].argmax(dim=1).detach().cpu().numpy().tolist()
                    prediction = sentence_piece.decode(prediction_codes)
                    print(prediction)
                    samples.add_data(prediction, decoder_input, prediction_codes, encoder_input)
                    transformer.train()
                    !rm -f {DATA_DIR}/transformer_*.model
                    torch.save(transformer.state_dict(), f'{DATA_DIR}/transformer_{step}.model')

                wandb_run.log({'loss': loss_value.item(), 'lr': new_lr, 'samples': samples})

                step += 1


## Evaluation

In [None]:
transformer.eval()
loss_values = []
with torch.no_grad():
    for step, ((encoder_input, decoder_input), target, (enc_padding_mask, dec_padding_mask, dec_mask)) in enumerate(tqdm(dl_validate)):
        pred = transformer(
            encoder_input.to(DEVICE),
            decoder_input.to(DEVICE),
            enc_padding_mask.to(DEVICE),
            dec_padding_mask.to(DEVICE),
            dec_mask.to(DEVICE)
        )
        loss_value = masked_loss_reduced(loss, pred, target.to(DEVICE), dec_padding_mask.to(DEVICE))
        loss_values.append(loss_value.item())
        if step > MAX_STEPS:
            break

'Evaluation loss', np.mean(loss_values)

## Prediction

In [None]:
def predict(transformer: Transformer, input, sentence_piece, max_seq_len=32, verbose=False):
    # iterative prediction token by token
    enc_input = torch.tensor(sentence_piece.encode(input, add_bos=True, add_eos=True)).to(DEVICE)
    output = torch.tensor([[sentence_piece.bos_id()]]).to(DEVICE)
    for _ in tqdm(range(max_seq_len)):
        pred = transformer.forward(
            enc_input,
            output,
            dec_mask=torch.triu(torch.ones(output.shape[-1], output.shape[-1]) * float('-inf'), diagonal=1).to(DEVICE)
        )
        pred = torch.argmax(pred, dim=2)[:, -1:].detach()
        output = torch.cat([output, pred], dim=1).to(torch.int)
        if verbose:
            print(output.cpu())
        if output[0][-1].item() == sentence_piece.eos_id():
            break
    return sentence_piece.decode(output.cpu().numpy().tolist())

predict(transformer, 'translate this', sentence_piece, verbose=True)

In [None]:
import gc
del transformer
gc.collect()
torch.cuda.empty_cache()

# GPT (simplementation of "Improving Language Understanding by Generative Pre-Training")

GPT is a simplified Transformer model with only a decoder block and no cross-attention sublayer.

This is a simple implementation of GPT architecture which is very similar to transformer decoder.
Certain minor simplifications are applied to reduce code complexity:
- GELU -> ReLU
- No L2 regularization
- No learned positional embeddings
- SentencePiece instead of spaCy tokenizer


The following screenshots are from "Improving Language Understanding by Generative Pre-Training" paper:
![gpt](transformer/gpt.png)
![gpt description](transformer/gpt-desc.png)

## GPT blocks

In [None]:
class GPTBlock(torch.nn.Module):

    def __init__(self, dmodel, heads, inner_layer_dim, dropout_rate, verbose=False):
        super().__init__()
        # sublayer 1, self attention
        self.attention1 = MultiHeadSelfAttention(heads, dmodel, verbose=verbose)
        self.dropout1 = torch.nn.Dropout(p=dropout_rate)
        self.layer_norm1 = torch.nn.LayerNorm(dmodel)

        # sublayer 3
        self.ffn = torch.nn.Sequential(
            torch.nn.Linear(dmodel, inner_layer_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(inner_layer_dim, dmodel)
        )
        self.dropout3 = torch.nn.Dropout(p=dropout_rate)
        self.layer_norm3 = torch.nn.LayerNorm(dmodel)

    def forward(self, x, padding_mask=None, attention_mask=None):
        x = self.layer_norm1(x + self.dropout1(self.attention1(x, x, x, k_padding_mask=padding_mask, attention_mask=attention_mask)))
        x = self.layer_norm3(x + self.dropout3(self.ffn(x)))
        return x


dtb = GPTBlock(4, 2, 8, 0.1)
dtb.eval()
dtb(
    torch.ones((1, 8, 4)),
    torch.zeros(1, 8),
    torch.zeros(8, 8)
)

In [None]:
class GPT(torch.nn.Module):
    def __init__(self, nblocks, max_tokens, dmodel, max_seq_len, dropout_rate, heads, inner_layer_dim, verbose=False):
        super().__init__()
        self.emb = PEEmbedding(max_tokens, dmodel, max_seq_len, dropout_rate, verbose)
        self.blocks = torch.nn.ModuleList(
            [GPTBlock(dmodel, heads, inner_layer_dim, dropout_rate, verbose)
             for _ in range(nblocks)]
        )
        self.verbose = verbose

    def forward(self, x, padding_mask=None, att_mask=None):
        if self.verbose:
            print('GPT::forward att_mask=', att_mask)
        x = self.emb(x)
        for b in self.blocks:
            x = b(x, attention_mask=att_mask, padding_mask=padding_mask)
        y = torch.einsum('bsd,td->bst', x, self.emb.emb.weight)
        return y

gpt = GPT(2, 10, 8, 10, 0.1, 2, 20, True)
gpt(torch.randint(0, 5, (1, 2)))

## Training loop

In [None]:
# Configuration described in GPT paper
gpt = GPT(
    nblocks=12,
    max_tokens=VOCAB_SIZE,
    dmodel=768,
    max_seq_len=BASE_MAX_SEQ_LEN,
    dropout_rate=0.1,
    heads=12,
    inner_layer_dim=3072,
    verbose=False
)
gpt.to(DEVICE)
gpt.train()
step = 0

optimizer = torch.optim.Adam(gpt.parameters(), betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=2.5e-4, pct_start=0.1, epochs=EPOCHS, steps_per_epoch=len(dl_train))
loss = torch.nn.CrossEntropyLoss(reduction='none')

In [None]:
if TRAIN_GPT:
    wrun = wandb.init(project='gpt-simplementation', reinit=True)
    wrun.display()
    samples = []
    with wrun:
        for epoch in tqdm(range(EPOCHS), desc='Epoch'):
            batches = tqdm(dl_train)
            for (_, dec), target, (_, dec_mask, att_mask) in batches:
                if step > MAX_STEPS:
                    break
                optimizer.zero_grad()
                pred = gpt(dec.to(DEVICE), padding_mask=dec_mask.to(DEVICE), att_mask=att_mask.to(DEVICE))
                loss_value = masked_loss_reduced(loss, pred, target.to(DEVICE), dec_mask.to(DEVICE))
                loss_value.backward()
                optimizer.step()
                scheduler.step()
                batches.set_description(f'Step:{step} Loss:{loss_value.item():.02f}')
                if step % 1000 == 0:
                    gpt.eval()
                    pred = gpt(dec.to(DEVICE), padding_mask=dec_mask.to(DEVICE))
                    dec1 = dec.numpy().tolist()[0]
                    dec2 = pred.detach().argmax(dim=2).cpu().numpy().tolist()[0]
                    dec1 = sentence_piece.decode(dec1)
                    dec2 = sentence_piece.decode(dec2)
                    print('Step', step)
                    print(dec1)
                    print(dec2)
                    samples.append([step, dec1, dec2])
                    !rm -f {DATA_DIR} / gpt_ *.model
                    torch.save(gpt.state_dict(), f'{DATA_DIR}/gpt_{step}.model')
                    gpt.train()
                step += 1
                wrun.log({'learning_rate': float(scheduler.get_last_lr()[0]), 'loss': loss_value.item(),
                          'samples': wandb.Table(columns=['step', 'gpt_target', 'gpt_prediction'], data=samples)})

In [80]:
if TRAIN_GPT:
    wrun = wandb.init(project='gpt-simplementation', reinit=True)
    wrun.display()
    samples = []
    with wrun:
        for epoch in tqdm(range(EPOCHS), desc='Epoch'):
            batches = tqdm(dl_train)
            for (_, dec), target, (_, dec_mask, att_mask) in batches:
                if step > MAX_STEPS:
                    break
                optimizer.zero_grad()
                pred = gpt(dec.to(DEVICE), padding_mask=dec_mask.to(DEVICE), att_mask=att_mask.to(DEVICE))
                loss_value = masked_loss_reduced(loss, pred, target.to(DEVICE), dec_mask.to(DEVICE))
                loss_value.backward()
                optimizer.step()
                scheduler.step()
                batches.set_description(f'Step:{step} Loss:{loss_value.item():.02f}')
                if step % 1000 == 0:
                    gpt.eval()
                    pred = gpt(dec.to(DEVICE), padding_mask=dec_mask.to(DEVICE))
                    dec1 = dec.numpy().tolist()[0]
                    dec2 = pred.detach().argmax(dim=2).cpu().numpy().tolist()[0]
                    dec1 = sentence_piece.decode(dec1)
                    dec2 = sentence_piece.decode(dec2)
                    print('Step', step)
                    print(dec1)
                    print(dec2)
                    samples.append([step, dec1, dec2])
                    !rm -f {DATA_DIR} / gpt_ *.model
                    torch.save(gpt.state_dict(), f'{DATA_DIR}/gpt_{step}.model')
                    gpt.train()
                step += 1
                wrun.log({'learning_rate': float(scheduler.get_last_lr()[0]), 'loss': loss_value.item(),
                          'samples': wandb.Table(columns=['step', 'gpt_target', 'gpt_prediction'], data=samples)})

[34m[1mwandb[0m: Currently logged in as: [33mvslaykovsky[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/46468 [00:00<?, ?it/s]

Step 0
мы переводим документацию и пользовательские интерфейсы, адаптируем программное обеспечение и тестируем его. ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
мы переводим документацию и пользовательские интерфейсы, адаптируем программное обеспечение и тестируем его. ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
rm: cannot remove '/content/drive/MyDrive/ML/transformer': Is a directory
rm: cannot remove '/': Is a directory
Step 1000
многие страны из группы hipcs тратят на городские больницы и высшее образование больше, чем на элемен

VBox(children=(Label(value='4.407 MB of 4.407 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
learning_rate,▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▇▇▇██
loss,█▆▅▄▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▂▂▁▁▁▁

0,1
learning_rate,7e-05
loss,28.25859


## Generation

In [81]:
def gpt_generate(gpt: GPT, input, sentence_piece, max_seq_len=32, verbose=False):
    # iterative prediction token by token
    enc_input = torch.tensor([sentence_piece.encode(input, add_bos=True, add_eos=False)]).to(DEVICE)
    for _ in tqdm(range(max_seq_len)):
        pred = gpt.forward(
            enc_input,
            att_mask=torch.triu(torch.ones(enc_input.shape[-1], enc_input.shape[-1]) * float('-inf'), diagonal=1).to(DEVICE)
        )
        pred = torch.argmax(pred, dim=2)[:, -1:].detach()
        enc_input = torch.cat([enc_input, pred], dim=1).to(torch.int)
        if verbose:
            print(enc_input.cpu())
        if enc_input[0][-1].item() == sentence_piece.eos_id():
            break
    return sentence_piece.decode(enc_input.cpu().numpy().tolist())

gpt_generate(gpt, 'entail this', sentence_piece, verbose=True)

  0%|          | 0/32 [00:00<?, ?it/s]

tensor([[    1, 26387,    43,     3]], dtype=torch.int32)
tensor([[    1, 26387,    43,     3, 36356]], dtype=torch.int32)
tensor([[    1, 26387,    43,     3, 36356,  9741]], dtype=torch.int32)
tensor([[    1, 26387,    43,     3, 36356,  9741,   163]], dtype=torch.int32)
tensor([[    1, 26387,    43,     3, 36356,  9741,   163, 33040]],
       dtype=torch.int32)
tensor([[    1, 26387,    43,     3, 36356,  9741,   163, 33040, 33040]],
       dtype=torch.int32)
tensor([[    1, 26387,    43,     3, 36356,  9741,   163, 33040, 33040,     3]],
       dtype=torch.int32)
tensor([[    1, 26387,    43,     3, 36356,  9741,   163, 33040, 33040,     3,
            48]], dtype=torch.int32)
tensor([[    1, 26387,    43,     3, 36356,  9741,   163, 33040, 33040,     3,
            48,    40]], dtype=torch.int32)
tensor([[    1, 26387,    43,     3, 36356,  9741,   163, 33040, 33040,     3,
            48,    40,    40]], dtype=torch.int32)
tensor([[    1, 26387,    43,     3, 36356,  9741,   163,

['entail this, бейкер снижения егоmichelmichel, что вы вы вы ослаблен доволен аккуратно аккуратноанья.']