* [1. Original Paper](https://arxiv.org/html/1706.03762v7)
* [2. Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer)

In [1]:
import random, numpy as np, matplotlib.pyplot as plt, torch, torch.nn as nn, torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
N, H, SEED    = 2, 8, 442
BATCH_SIZE    = 80
DROPOUT_RATE  = 0.1
EMBEDDING_DIM = 512
assert EMBEDDING_DIM % H == 0
random      .seed(SEED)
np   .random.seed(SEED)
torch.manual_seed(SEED)
None

### Data Generation

In [2]:
VOCAB = ['<BOS>', '<EOS>', '<PAD>', 
         'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 
         'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
VOCAB_SIZE = len(VOCAB)

def tokenizer(s, length_max):
    s = [VOCAB.index('<BOS>')] + [VOCAB.index(c) for c in s] + \
        [VOCAB.index('<EOS>')] + [VOCAB.index('<PAD>') for i in range(length_max - len(s))]
    return s

tokenizer('AAA', len('AAA')), tokenizer('AAA', len('AAAAAA'))

def text_gen(length_min, length_max, batch_size = BATCH_SIZE):
    ''' Throw out random length of uppercase letters, faking sentences
    Throw out corresponding sentences in lowercase letter but double its length'''
    stn_len = np.random.randint(length_min, length_max+1, size=batch_size)
    # 'A':65, 'Z':90, 'a': 97, 'z':122
    upper_case = [np.random.randint(65, 91, size=i)
                    for i in stn_len] 
    lower_case = [''.join([chr(i+32)
                    for i in np.array([
                        (j, j) for j in k]).flatten()]) # HERE
                                  for k in upper_case]
    upper_case = [''.join([chr(i) for i in j]) for j in upper_case]
    return upper_case, lower_case

print('', [chr(i) for i in range(65, 91)], '\n', [chr(i) for i in range(97, 123)], '\n', ord('A'), ord('Z'), ord('a'), ord('z'))
text_gen(3, 13, 3)

 ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] 
 ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 
 65 90 97 122


(['NLOCO', 'SBIIMTD', 'BAXQUFLXZGSN'],
 ['nnllooccoo', 'ssbbiiiimmttdd', 'bbaaxxqquuffllxxzzggssnn'])

In [3]:
class Batch:
    def __init__(self, UPPER, lower = None):
        self.UPPER = UPPER
        UPPER_length_max = max([len(s) for s in UPPER])
        self.src = torch.LongTensor([tokenizer(s, UPPER_length_max) for s in UPPER])

        if lower:
            self.lower = lower
            lower_length_max = max([len(s) for s in lower])
            self.tgt = torch.LongTensor([tokenizer(s, lower_length_max) for s in lower])
            self.tgt_in, self.tgt_out = self.tgt[:, :-1], self.tgt[:, 1:]
            self.tgt_mask = self.mask(lower_length_max+1)
            self.ntokens = (self.tgt_out != VOCAB.index('<PAD>')).data.sum() # count tokens

        # src, tgt, tgt_in, tgt_out, tgt_mask, ntokens

    def mask(self, size):
        return torch.triu(torch.ones(size, size, dtype = int), diagonal = 1) == 1

def data_gen(length_min = 10, length_max = 15, nbatch=10):
    for i in range(nbatch):
        UPPER, lower = text_gen(length_min, length_max)
        yield Batch(UPPER, lower)

for i, batch in enumerate(data_gen(3, 7, 2)):
    break
    print(batch.UPPER, batch.lower)
    print(batch.src.shape)
    print([[VOCAB[i] for i in j] for j in batch.src][0])
    print(batch.tgt.shape)
    print([[VOCAB[i] for i in j] for j in batch.tgt][0])
    print(batch.tgt_in.shape)
    print([[VOCAB[i] for i in j] for j in batch.tgt_in][0])
    print(batch.tgt_out.shape)
    print('        ', [[VOCAB[i] for i in j] for j in batch.tgt_out][0])
    print(batch.tgt_mask.shape)
    print('---ntokens---', batch.ntokens)

def decipher(tnsr):
    [print('\t'.join([VOCAB[i] for idx, i in enumerate(j)])) for j in tnsr]

print('src')
decipher(batch.src[:2])
print('tgt_in')
decipher(batch.tgt_in[:2])
print('tgt_out')
decipher(batch.tgt_out[:2])

src
<BOS>	R	K	S	W	T	<EOS>	<PAD>	<PAD>
<BOS>	O	B	X	H	H	E	<EOS>	<PAD>
tgt_in
<BOS>	r	r	k	k	s	s	w	w	t	t	<EOS>	<PAD>	<PAD>	<PAD>
<BOS>	o	o	b	b	x	x	h	h	h	h	e	e	<EOS>	<PAD>
tgt_out
r	r	k	k	s	s	w	w	t	t	<EOS>	<PAD>	<PAD>	<PAD>	<PAD>
o	o	b	b	x	x	h	h	h	h	e	e	<EOS>	<PAD>	<PAD>


### Components

In [4]:
class Embedder(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)

    def forward(self, x):
        x = self.embedding(x)
        x = x * np.sqrt(EMBEDDING_DIM)
        return x

embd = Embedder()
embd.forward(batch.src).shape, embd.forward(batch.tgt_in).shape # (8-2) * 3 + 2 = 20

(torch.Size([80, 9, 512]), torch.Size([80, 15, 512]))

$$PE_{(pos,2i)} = \sin(pos / 10000^{2i/d_{\text{model}}})$$
$$PE_{(pos,2i+1)} = \cos(pos / 10000^{2i/d_{\text{model}}})$$

In [5]:
class PositionalEncoder(nn.Module):
    def __init__(self, max_len=5000):
        super().__init__()
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, EMBEDDING_DIM)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(-np.log(10000.0) * torch.arange(0, EMBEDDING_DIM, 2) / EMBEDDING_DIM)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        x = F.dropout(x, p = DROPOUT_RATE)
        return x

posEnc = PositionalEncoder()
posEnc.forward(embd.forward(batch.src)).shape

torch.Size([80, 9, 512])

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.Q_linear = nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM)
        self.V_linear = nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM)
        self.K_linear = nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM)
        self.out      = nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM)

    def forward(self, Q, K, V, mask = None): #
        bs = Q.size(0)
        sl = Q.size(1) # sequence_length
        dk = EMBEDDING_DIM // H # embed_dim / head_count
        Q  = self.Q_linear(Q).view(bs, sl, H, dk).transpose(1, 2)
        sl = K.size(1) # sequence_length
        KT = self.K_linear(K).view(bs, sl, H, dk).transpose(1, 2).transpose(2, 3)
        V  = self.V_linear(V).view(bs, sl, H, dk).transpose(1, 2)

        scores = torch.matmul(Q, KT) / np.sqrt(dk)
        if mask is not None:
            scores = scores.masked_fill(mask, -1e9)
        scores = scores.softmax(dim = -1)
        scores = F.dropout(scores, p = DROPOUT_RATE)

        scores = torch.matmul(scores, V)
        scores = scores.transpose(1,2).contiguous().view(bs, -1, EMBEDDING_DIM) # Concatenation
        scores = self.out(scores)
        scores = F.dropout(scores, p = DROPOUT_RATE)

        return scores

mha = MultiHeadAttention()
x = posEnc.forward(embd.forward(batch.src))
m = mha(Q = x, K = x, V = x)
print(m.shape)
x = posEnc.forward(embd.forward(batch.tgt_in))
print(mha(Q = x, K = m, V = m).shape)

torch.Size([80, 9, 512])
torch.Size([80, 15, 512])


$$\mathrm{LN}(x)=  \frac{x-\mu}{\sigma} \cdot \gamma+\beta$$

In [7]:
class Norm(nn.Module):
    def __init__(self, eps = 1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones (EMBEDDING_DIM)) # weight
        self.beta  = nn.Parameter(torch.zeros(EMBEDDING_DIM)) # bias
        self.eps   = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim = True)
        std  = x.std (-1, keepdim = True) # standard deviation
        x    = (x - mean) / (std + self.eps) * self.gamma + self.beta
        return x

class FeedForward(nn.Module):
    def __init__(self, feedforward_dim=2048, dropout_rate = 0.1):
        super().__init__()
        self.linear_1 = nn.Linear(EMBEDDING_DIM, feedforward_dim)
        self.linear_2 = nn.Linear(feedforward_dim, EMBEDDING_DIM)

    def forward(self, x):
        x = self.linear_1(x)
        x = F.relu(x)
        x = F.dropout(x, p=DROPOUT_RATE)
        x = self.linear_2(x)
        return x

### Encoder and Decoder
see eq. 2: $ y_l = x_l+Module(Norm(x_l)) \ $ in [this paper](https://arxiv.org/pdf/2502.02732v1)

In [8]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.mhsa   = MultiHeadAttention()
        self.norm_1 = Norm()
        self.ff     = FeedForward()
        self.norm_2 = Norm()

    def forward(self, x):
        '''see eq. 2 in https://arxiv.org/pdf/2502.02732v1 for 
        Rre-LN (Norm before addition to the residue)'''
        # Norm
        y = self.norm_1(x)
        y = F.dropout(y, p=DROPOUT_RATE)
        # Add
        x = x + self.mhsa(y, y, y)
        # Norm
        y = self.norm_2(x)
        y = F.dropout(y, p=DROPOUT_RATE)
        # Add
        x = x + self.ff(y)
        return x

enc = Encoder()
x = posEnc.forward(embd.forward(batch.src))
m = enc.forward(x)
m.shape

torch.Size([80, 9, 512])

In [9]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.mmhsa  = MultiHeadAttention()
        self.norm_1 = Norm()
        self.mhca   = MultiHeadAttention()
        self.norm_2 = Norm()
        self.ff     = FeedForward()
        self.norm_3 = Norm()

    def forward(self, mem, tgt, mask):
        # Norm
        y = self.norm_1(tgt)
        y = F.dropout(y, p=DROPOUT_RATE)
        # Add
        tgt = tgt + self.mmhsa(y, y, y, mask)
        # Norm
        y = self.norm_2(tgt)
        y = F.dropout(y, p=DROPOUT_RATE)
        # Add
        tgt = tgt + self.mhca(y, mem, mem) # (Q, K, V, mask)
        # Norm
        y = self.norm_3(tgt)
        # Add
        tgt = tgt + self.ff(y)
        return tgt

dec = Decoder()
x = posEnc.forward(embd.forward(batch.tgt_in))
dec.forward(m, x, batch.tgt_mask).shape, m.shape, x.shape

(torch.Size([80, 15, 512]),
 torch.Size([80, 9, 512]),
 torch.Size([80, 15, 512]))

### Transformer

In [10]:
class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.src_embd = Embedder()
        self.tgt_embd = Embedder()
        self.posEnc   = PositionalEncoder()
        self.nEnc     = nn.ModuleList([Encoder() for i in range(N)])
        self.nDec     = nn.ModuleList([Decoder() for i in range(N)])
        self.proj     = nn.Linear(EMBEDDING_DIM, VOCAB_SIZE)

        [nn.init.xavier_uniform_(p) for p in self.parameters() if p.dim() > 1]

    def encode(self, src):
        mem = self.src_embd(src)
        mem = self.posEnc(mem)
        for enc in self.nEnc:
            mem = enc(mem)
        return mem

    def decode(self, mem, tgt, tgt_mask):
        tgt = self.tgt_embd(tgt)
        tgt = self.posEnc(tgt)
        for dec in self.nDec:
            tgt = dec(mem, tgt, tgt_mask)
        return tgt

    def gen(self, tgt):
        tgt = self.proj(tgt)
        tgt = F.log_softmax(tgt, dim=-1)
        return tgt

model = Transformer()

if 1:
    src_ = batch.src[:1, :]
    mem = model.encode(src_)
    print('src', batch.src.shape)
    print('mem', mem.shape)
    BOS = torch.empty(1, 1, dtype=int).fill_(VOCAB.index('<BOS>'))
    pred = BOS
    for i in range(10):
        pred = model.decode(mem, pred, None)
        pred = model.gen(pred)
        _, pred = torch.max(pred, dim = -1)
        pred = torch.cat([BOS, pred], dim = -1)
        print([' '.join([f'{idx}{VOCAB[i]}' for idx, i in enumerate(j)]) for j in pred])

src torch.Size([80, 9])
mem torch.Size([1, 9, 512])
['0<BOS> 1B']
['0<BOS> 1W 2h']
['0<BOS> 1t 2W 3h']
['0<BOS> 1t 2<EOS> 3W 4D']
['0<BOS> 1<EOS> 2W 3B 4W 5<EOS>']
['0<BOS> 1J 2X 3W 4W 5W 6X']
['0<BOS> 1<EOS> 2B 3g 4W 5W 6W 7<EOS>']
['0<BOS> 1J 2<EOS> 3W 4V 5W 6W 7W 8W']
['0<BOS> 1J 2<EOS> 3E 4D 5<EOS> 6W 7D 8D 9W']
['0<BOS> 1t 2B 3D 4W 5a 6<EOS> 7W 8O 9a 10g']


### Training

In [11]:
model = Transformer()
model.train() # train mode
loss_func    = nn.CrossEntropyLoss(ignore_index = VOCAB.index('<PAD>'), label_smoothing = 0.1)
optimizer    = torch.optim.Adam(model.parameters(), lr = 0.5, betas = (0.9, 0.98), eps = 1e-9)
rate         = lambda step: 0.0442 * min(step ** (-0.5), step * 0.000125) if step else 5e-06
lr_scheduler = LambdaLR( optimizer = optimizer, lr_lambda = rate)

epoch = 100
for _ in range(epoch):
    for i, batch in enumerate(data_gen(length_min = 10, length_max = 15, nbatch = 20)):
        mem = model.encode(batch.src)
        tgt_out_ = model.decode(mem, batch.tgt_in, batch.tgt_mask)
        tgt_out_ = model.gen(tgt_out_)
        loss = loss_func(tgt_out_.view(-1, tgt_out_.size(-1)), batch.tgt_out.reshape(-1)) 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
    print(loss)
    if loss < 0.74: break 

tensor(5.1024, grad_fn=<AddBackward0>)
tensor(3.8499, grad_fn=<AddBackward0>)
tensor(3.2130, grad_fn=<AddBackward0>)
tensor(2.8292, grad_fn=<AddBackward0>)
tensor(2.5104, grad_fn=<AddBackward0>)
tensor(2.2470, grad_fn=<AddBackward0>)
tensor(1.9842, grad_fn=<AddBackward0>)
tensor(1.8702, grad_fn=<AddBackward0>)
tensor(1.6249, grad_fn=<AddBackward0>)
tensor(1.5020, grad_fn=<AddBackward0>)
tensor(1.4005, grad_fn=<AddBackward0>)
tensor(1.3334, grad_fn=<AddBackward0>)
tensor(1.2849, grad_fn=<AddBackward0>)
tensor(1.2416, grad_fn=<AddBackward0>)
tensor(1.1860, grad_fn=<AddBackward0>)
tensor(1.0869, grad_fn=<AddBackward0>)
tensor(1.1262, grad_fn=<AddBackward0>)
tensor(1.0565, grad_fn=<AddBackward0>)
tensor(1.0188, grad_fn=<AddBackward0>)
tensor(1.0290, grad_fn=<AddBackward0>)
tensor(0.9996, grad_fn=<AddBackward0>)
tensor(0.9539, grad_fn=<AddBackward0>)
tensor(0.9362, grad_fn=<AddBackward0>)
tensor(0.9101, grad_fn=<AddBackward0>)
tensor(0.8885, grad_fn=<AddBackward0>)
tensor(0.8805, grad_fn=<A

### Inference 

In [12]:
model.eval() # train mode
def predict(s):
    b = Batch([s])
    decipher(b.src)
    mem = model.encode(b.src)
    BOS = torch.empty(1, 1, dtype=int).fill_(VOCAB.index('<BOS>'))
    ys = BOS
    for i in range(30):
        pred = model.decode(mem, ys, None)
        pred = model.gen(pred)
        _, pred = torch.max(pred, dim = -1)
        ys = torch.cat([BOS, pred], dim = -1)
    [print('\t'.join([VOCAB[i] for idx, i in enumerate(j)])) for j in pred]

predict('CXNGQ')
predict('EMBEDDINGDIM')

<BOS>	C	X	N	G	Q	<EOS>
c	c	x	x	n	n	g	g	q	q	<EOS>	x	n	n	g	x	g	x	x	q	g	q	n	q	<EOS>	c	n	c	x	g
<BOS>	E	M	B	E	D	D	I	N	G	D	I	M	<EOS>
i	e	m	m	b	b	e	e	d	d	d	d	i	i	n	n	g	d	d	d	i	i	m	m	<EOS>	d	d	d	h	t
