# **Transformer Architecture**

### **Encoder-Decoder Architecture**

It is used for **sequence-to-sequence** tasks like machine translation.

This is the architecture proposed by the original transformer paper [Attention is All You Need](https://arxiv.org/abs/1706.03762) by Vaswani et al. in 2017.

<img src="assets/encoder_decoder.png" alt="Encoder-Decoder Architecture" style="background-color:white;" height="500" />

### **Decoder-Only Architecture**

It is used for **sequence generation** tasks like language modeling.

This simpler architecture will be implemented in this notebook from scratch.

<img src="assets/decoder_only.png" alt="Decoder-Only Architecture" style="background-color:white;" height="500" />

In [1]:
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    if not Path("transformer-notebook").exists():
        !git clone https://github.com/segusantos/transformer-notebook.git
    %cd transformer-notebook

cuda


## **Scaled Dot-Product Attention**

Let $Q \in \mathbb{R}^{m \times d_k}$, $K \in \mathbb{R}^{n \times d_k}$ y $V \in \mathbb{R}^{n \times d_v}$ be the query, key and value matrices, respectively. The scaled dot-product attention is defined as:

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \in \mathbb{R}^{m \times d_v}
$$

<img src="assets/self_attention.png" alt="Decoder-Only Architecture" style="background-color:white;" height="500" />

In [2]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self,
                 sequenceLength: int,
                 embeddingSize: int,
                 headSize: int,
                 dropout: float) -> None:
        super().__init__()
        self.key = nn.Linear(embeddingSize, headSize, bias=False)
        self.query = nn.Linear(embeddingSize, headSize, bias=False)
        self.value = nn.Linear(embeddingSize, headSize, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(sequenceLength, sequenceLength)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batchSize, sequenceLength, embeddingSize = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        weights = q @ k.transpose(-2, -1) / (embeddingSize ** 0.5)
        weights = weights.masked_fill(self.tril[:sequenceLength, :sequenceLength] == 0, float("-inf"))
        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)
        x = weights @ v
        return x

## **Multi-Head Attention**

Let $h$ be the number of heads. The multi-head attention is defined as:

$$
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O \in \mathbb{R}^{m \times d_v}
$$

where $\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$ and $W_i^Q \in \mathbb{R}^{d_k \times d_k}$, $W_i^K \in \mathbb{R}^{d_k \times d_k}$, $W_i^V \in \mathbb{R}^{d_v \times d_v}$ and $W^O \in \mathbb{R}^{hd_v \times d_v}$ are the learnable parameters.

<img src="assets/multi_head_attention.png" alt="Decoder-Only Architecture" style="background-color:white;" height="500" />

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self,
                 nHeads: int,
                 sequenceLength: int,
                 embeddingSize: int,
                 dropout: float) -> None:
        super().__init__()
        self.heads = nn.ModuleList([ScaledDotProductAttention(sequenceLength,
                                                              embeddingSize,
                                                              embeddingSize // nHeads,
                                                              dropout) for _ in range(nHeads)])
        self.projection = nn.Linear(embeddingSize, embeddingSize)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        x = self.projection(x)
        return x

## **Position-wise Feed-Forward Networks**

The position-wise feed-forward networks are defined as:

$$
\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2
$$

where $W_1 \in \mathbb{R}^{d_{ff} \times d_{model}}$, $b_1 \in \mathbb{R}^{d_{ff}}$, $W_2 \in \mathbb{R}^{d_{model} \times d_{ff}}$ and $b_2 \in \mathbb{R}^{d_{model}}$ are the learnable parameters.

In [4]:
class FeedForward(nn.Module):
    def __init__(self,
                 embeddingSize: int,
                 dropout: float) -> None:
        super().__init__()
        self.feedForward = nn.Sequential(
            nn.Linear(embeddingSize, 4 * embeddingSize),
            nn.ReLU(),
            nn.Linear(4 * embeddingSize, embeddingSize),
            nn.Dropout(dropout)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.feedForward(x)

## **Layer Normalization**

The layer normalization is defined as:

$$
\text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sigma} + \beta
$$

where $\gamma \in \mathbb{R}^{d_{model}}$ and $\beta \in \mathbb{R}^{d_{model}}$ are the learnable parameters, and $\mu$ and $\sigma$ are the mean and standard deviation of $x$, respectively.

In the original paper, the layer normalization is applied after the residual connections, whereas in more modern implementations, it is applied before in what is known as **pre-norm formulation**.

```python
# Original paper
x = LayerNorm(x + MultiHeadAttention(x))
x = LayerNorm(x + FeedForward(x))

# Pre-norm formulation
x = x + MultiHeadAttention(LayerNorm(x))
x = x + FeedForward(LayerNorm(x))
```

In [5]:
class Layer(nn.Module):
    def __init__(self,
                 nHeads: int,
                 sequenceLength: int,
                 embeddingSize: int,
                 dropout: float) -> None:
        super().__init__()
        self.multiHeadAttention = MultiHeadAttention(nHeads,
                                                     sequenceLength,
                                                     embeddingSize,
                                                     dropout)
        self.feedForward = FeedForward(embeddingSize, dropout)
        self.layerNorm1 = nn.LayerNorm(embeddingSize)
        self.layerNorm2 = nn.LayerNorm(embeddingSize)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.multiHeadAttention(self.layerNorm1(x))
        x = x + self.feedForward(self.layerNorm2(x))
        return x

## **Language Model**

In [6]:
class GPTLanguageModel(nn.Module):
    def __init__ (self,
                  vocabularySize: int,
                  nLayers: int,
                  nHeads: int,
                  sequenceLength:int,
                  embeddingSize: int,
                  dropout: float) -> None:
        super().__init__()
        self.tokenEmbeddingTable = nn.Embedding(vocabularySize, embeddingSize)
        self.positionEmbeddingTable = nn.Embedding(sequenceLength, embeddingSize)
        self.layers = nn.Sequential(*[Layer(nHeads,
                                            sequenceLength,
                                            embeddingSize,
                                            dropout) for _ in range(nLayers)])
        self.layerNorm = nn.LayerNorm(embeddingSize)
        self.linearModelHead = nn.Linear(embeddingSize, vocabularySize)
        self.sequenceLength = sequenceLength
        self.apply(self.initWeights)

    def initWeights(self, module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x: torch.Tensor, targets: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
        batchSize, sequenceLength = x.shape
        tokenEmbeddings = self.tokenEmbeddingTable(x)
        positionEmbeddings = self.positionEmbeddingTable(torch.arange(sequenceLength, device=device))
        x = tokenEmbeddings + positionEmbeddings
        x = self.layers(x)
        x = self.layerNorm(x)
        logits = self.linearModelHead(x)
        if targets is None:
            loss = None
        else:
            batchSize, sequenceLength, embeddingSize = logits.shape
            logits = logits.view(batchSize * sequenceLength, embeddingSize)
            targets = targets.view(batchSize * sequenceLength)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, x: torch.Tensor, nTokens: int) -> torch.Tensor:
        self.eval()
        with torch.no_grad():
            for _ in range(nTokens):
                logits, loss = self(x[:, -self.sequenceLength:])
                probabilities = F.softmax(logits[:, -1, :], dim=-1)
                nextToken = torch.multinomial(probabilities, num_samples=1)
                x = torch.cat([x, nextToken], dim=1)
        self.train()
        return x

## **Text Generation**

In [None]:
dataset = "martin_fierro" # "shakespeare"
with open((Path("data") / dataset).with_suffix(".txt"), "r") as f:
    text = f.read()
print(f"Text length: {len(text)}")
print(text[:250])

Text length: 1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.



In [8]:
vocabulary = sorted(set(text))
vocabularySize = len(vocabulary)
print(f"Vocabulary size: {vocabularySize}")
print("".join(vocabulary))

Vocabulary size: 65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [9]:
stringToToken = {ch: i for i, ch in enumerate(vocabulary)}
tokenToString = {i: ch for i, ch in enumerate(vocabulary)}
encode = lambda string: [stringToToken[char] for char in string]
decode = lambda tokens: "".join([tokenToString[token] for token in tokens])
print(encode("Los hermanos sean unidos"))
print(decode(encode("Los hermanos sean unidos")))

[24, 53, 57, 1, 46, 43, 56, 51, 39, 52, 53, 57, 1, 57, 43, 39, 52, 1, 59, 52, 47, 42, 53, 57]
Los hermanos sean unidos


In [10]:
data = torch.tensor(encode(text), dtype=torch.long)
print(f"Data shape: {data.shape}")
print(f"Data type: {data.dtype}")
print(data[:100])

Data shape: torch.Size([1115394])
Data type: torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


In [11]:
trainValSplit = 0.8
trainSize = int(len(data) * trainValSplit)
trainData = data[:trainSize]
valData = data[trainSize:]
print(f"Train data shape: {trainData.shape}")
print(f"Validation data shape: {valData.shape}")

Train data shape: torch.Size([892315])
Validation data shape: torch.Size([223079])


In [12]:
def getBatch(data: torch.Tensor,
             batchSize: int,
             sequenceLength: int) -> tuple[torch.Tensor, torch.Tensor]:
    ix = torch.randint(0, data.size(0) - sequenceLength, (batchSize,))
    x = torch.stack([data[i:i+sequenceLength] for i in ix])
    y = torch.stack([data[i+1:i+sequenceLength+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


batchSize = 4
sequenceLength = 8
xBatch, yBatch = getBatch(trainData, batchSize, sequenceLength)
print(f"Context batch shape: {xBatch.shape}")
print(xBatch)
print("Target batch shape:", yBatch.shape)
print(yBatch)
for batch in range(batchSize):
    for token in range(sequenceLength):
        context = xBatch[batch, :token+1].tolist()
        target = yBatch[batch, token].item()
        print(f"Context: {context} -> Target: {target}")

Context batch shape: torch.Size([4, 8])
tensor([[47, 57, 10,  1, 39, 52, 42,  1],
        [59, 56,  1, 46, 43, 39, 56, 58],
        [32, 46, 39, 58,  1, 39, 50, 61],
        [26, 53, 58, 46, 47, 52, 45,  1]], device='cuda:0')
Target batch shape: torch.Size([4, 8])
tensor([[57, 10,  1, 39, 52, 42,  1, 50],
        [56,  1, 46, 43, 39, 56, 58, 57],
        [46, 39, 58,  1, 39, 50, 61, 39],
        [53, 58, 46, 47, 52, 45,  1, 40]], device='cuda:0')
Context: [47] -> Target: 57
Context: [47, 57] -> Target: 10
Context: [47, 57, 10] -> Target: 1
Context: [47, 57, 10, 1] -> Target: 39
Context: [47, 57, 10, 1, 39] -> Target: 52
Context: [47, 57, 10, 1, 39, 52] -> Target: 42
Context: [47, 57, 10, 1, 39, 52, 42] -> Target: 1
Context: [47, 57, 10, 1, 39, 52, 42, 1] -> Target: 50
Context: [59] -> Target: 56
Context: [59, 56] -> Target: 1
Context: [59, 56, 1] -> Target: 46
Context: [59, 56, 1, 46] -> Target: 43
Context: [59, 56, 1, 46, 43] -> Target: 39
Context: [59, 56, 1, 46, 43, 39] -> Target: 5

In [13]:
nLayers = 6
nHeads = 6
sequenceLength = 256
embeddingSize = 384
dropout = 0.2
model = GPTLanguageModel(vocabularySize,
                         nLayers,
                         nHeads,
                         sequenceLength,
                         embeddingSize,
                         dropout).to(device)
print(f"{sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters")

10.79M parameters


In [14]:
@torch.no_grad()
def estimateLoss(model: nn.Module,
                 data: torch.Tensor,
                 evalIter: int,
                 batchSize: int,
                 sequenceLength: int) -> float:
    model.eval()
    losses = torch.zeros(evalIter)
    for i in range(evalIter):
        xBatch, yBatch = getBatch(data, batchSize, sequenceLength)
        logits, loss = model(xBatch, yBatch)
        losses[i] = loss.item()
    model.train()
    return losses.mean().item()


# Load trained model
# model.load_state_dict(torch.load((Path("models") / dataset).with_suffix(".pt"), weights_only=False))

# Train model from scratch
batchSize = 64
learningRate = 3e-4
maxIter = 800
evalInterval = 100
evalIter = 100
optimizer = optim.Adam(model.parameters(), lr=learningRate)
for iter in range(maxIter + 1):
    if iter % evalInterval == 0 or iter == maxIter - 1:
        trainLoss = estimateLoss(model,
                                 trainData,
                                 evalIter,
                                 batchSize,
                                 sequenceLength)
        valLoss = estimateLoss(model, valData, evalIter, batchSize, sequenceLength)
        print(f"Iter: {iter}, Train loss: {trainLoss}, Val loss: {valLoss}")
    xBatch, yBatch = getBatch(trainData, batchSize, sequenceLength)
    logits, loss = model(xBatch, yBatch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
torch.save(model.state_dict(), (Path("models") / dataset).with_suffix(".pt"))

Iter: 0, Train loss: 4.284997940063477, Val loss: 4.282075881958008
Iter: 500, Train loss: 1.85334050655365, Val loss: 2.0040788650512695


KeyboardInterrupt: 

In [None]:
context = "Los hermanos sean unidos" # "To be, or not to be"
context = torch.tensor(encode(context), dtype=torch.long).unsqueeze(0).to(device)
generatedText = decode(model.generate(context, nTokens=10000)[0].tolist())
print(generatedText)
with open((Path("outputs") / dataset).with_suffix(".txt"), "w") as f:
    f.write(generatedText)

Los hermanos sean unidos.

Los astimos las juridos
que enel roron sustumbures;
y en caminias en otan risa,
debe aqueles pampas mi foro,
cuando, da el encaque en supal
eran de boliadas lo sol.

Cuanto el desapulto
nos cretestaban las campas
viendo a al indio y levanto;
pero a un hombre tenía
como al tenerro campañó
si lo puse lo ec"
levaran letó el trastoún yito.
yo ya me hale alma suelo viles
con las los yos, copletos
y sin perros y hablaridos
y indao las infiecias.

Con laste cuitó con las tílas
habías é rastó los de día
émpre una: venos ¡Para!
que el trabios que había
con el arrogallé seguillo!
Y Na tía estaba me rancal
era ganaba las canas arganas
de un punto cantón
es la vistas sen traillan
de uno que al nochina
en al indio a entre asiona.

Una prece esa comentenda
como la enperranzan
pero pan en dar estrasquiandas
pue.Ña..
no hizos tras en las pare¿
Y algunas de hombre al áen potro;
me he dicé que afligarse el entreo
en medio del cancelanto;

que he dé por cuitabullao
en monitaré 