---
title: "Transformers"
jupytext:
  formats: md:myst
  text_representation:
    extension: .md
    format_name: myst
kernelspec:
  display_name: Python 3
  language: python
  name: python3
execute:
  enabled: true
filters:
  - marimo-team/marimo
---

# Transformers

::::{grid} 1
:class-container: spoiler-block

:::{grid-item-card} Spoiler
Transformers don't process sequences; they process relationships between every position simultaneously.
:::

::::

## The Mechanism

You've been taught to think of language models as sequential processors—reading left to right, one word triggering the next, like dominoes falling. This intuition comes from recurrent neural networks (RNNs), where information flows step by step, each word depending on the hidden state from the previous word. The transformer architecture throws this away entirely.

Instead of sequential processing, transformers operate through **parallel relationship mapping**. When you read "The cat sat on the mat because it was tired," you don't actually process word-by-word in isolation. Your brain simultaneously evaluates which words relate to which—"it" connects to "cat," "tired" explains "sat," "mat" anchors "on." Transformers formalize this intuition mathematically. Every position in the input sequence simultaneously computes its relationship to every other position. The mechanism is attention, and the result is a system where context flows in all directions at once, not just forward through time.

This parallelism is why transformers scaled when RNNs didn't. Recurrent architectures impose sequential computation—you can't process word 100 until you've processed word 99. Transformers eliminate this bottleneck. Every position can be computed in parallel, which means training time scales with sequence complexity, not sequence length. This architectural shift is what enabled GPT-3, GPT-4, and Claude to exist.

## The Architecture

Modern LLMs stack multiple **transformer blocks**—modular units that take a sequence of token vectors as input and output a transformed sequence of the same length. GPT-3 uses 96 of these blocks; GPT-4 likely uses more. Each block refines the representation, adding layers of contextual understanding.

```{figure} ../figs/transformer-overview.jpg
:name: transformer-overview
:alt: Transformer Overview
:width: 50%
:align: center

The basic architecture of the transformer-based LLMs.
```

These blocks come in two forms: **encoders** and **decoders**. The encoder processes the input sequence and builds a contextualized representation. The decoder generates the output sequence, attending to both its own previous outputs and the encoder's representation. For translation tasks ("I love you" → "Je t'aime"), the encoder processes English, the decoder generates French. For language modeling (GPT-style systems), only the decoder is used—it generates text autoregressively, predicting the next token based on all previous tokens.

```{figure} ../figs/transformer-encoder-decoder.jpg
:name: transformer-encoder-decoder
:alt: Transformer Encoder-Decoder
:width: 80%
:align: center

The encoder-decoder architecture. The encoder builds a representation of the input sequence; the decoder generates the output sequence while attending to the encoder's output.
```

Inside each block are three core components: **multi-head attention** (the relationship mapper), **layer normalization** (numerical stabilization), and **feed-forward networks** (nonlinear transformation). We'll build these components step by step.

```{figure} ../figs/transformer-component.jpg
:name: transformer-wired-components
:alt: Transformer Wired Components
:width: 80%
:align: center

Internal structure of encoder and decoder blocks.
```

## Attention: The Relationship Engine

**Self-attention**—the core of the transformer—computes how much each position in a sequence should "attend to" every other position. Unlike earlier attention mechanisms in seq2seq models, which attended from one sentence to another, self-attention operates within a single sequence. It answers the question: "Given this word, which other words matter most?"

```{figure} ../figs/transformer-attention.jpg
:name: transformer-attention
:alt: Attention Mechanism
:width: 80%
:align: center

The attention mechanism computes relationships between all positions simultaneously.
```

For each word, the attention mechanism creates three vectors: **query** ($Q$), **key** ($K$), and **value** ($V$). Think of these as a library search: the query is what you're looking for, the keys are book titles, and the values are the actual content. When you search for "machine learning" (your query), you match it against book titles (keys) to find relevant content (values).

Mathematically, each of these vectors is created by a learned linear transformation of the input word embedding. Given an input embedding $x$, we compute:

$$
Q = x W_Q, \quad K = x W_K, \quad V = x W_V
$$

where $W_Q$, $W_K$, and $W_V$ are learned weight matrices. The attention mechanism then computes which keys are most relevant to each query using the dot product, which measures vector similarity. The dot product $QK^T$ produces a matrix of attention scores—large values indicate strong relationships, small values indicate weak ones.

These raw scores are scaled by $\sqrt{d_k}$ (the square root of the key dimension) to prevent extreme values, then normalized using softmax to produce a probability distribution. Finally, these normalized attention weights are used to compute a weighted sum of the value vectors. The complete operation is:

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

where $Q \in \mathbb{R}^{n \times d_k}$, $K \in \mathbb{R}^{n \times d_k}$, and $V \in \mathbb{R}^{n \times d_v}$ represent matrices containing $n$ query, key, and value vectors respectively.

The interactive visualization below demonstrates how learned Query and Key transformations produce different attention patterns. Adjust the transformation parameters to see how different $W_Q$ and $W_K$ matrices change which words attend to which:

<div>
<marimo-iframe data-height="700px" data-show-code="false">

```python {marimo}
import marimo as mo
import numpy as np
import pandas as pd
import altair as alt
```

```python {marimo}
attention_words = ["bank", "money", "loan", "river", "shore"]
attention_embeddings = np.array([
    [0.0, 0.0],  # bank (center)
    [-0.8, -0.3],  # money
    [-0.7, -0.6],  # loan
    [0.7, -0.5],  # river
    [0.6, -0.7],  # shore
]) * 2

# Query controls
q_scale_x = mo.ui.slider(-2, 2, 0.1, value=1.0, label="Q Scale X")
q_scale_y = mo.ui.slider(-2, 2, 0.1, value=1.0, label="Q Scale Y")
q_rotate = mo.ui.slider(-180, 180, 5, value=0, label="Q Rotate (deg)")

# Key controls
k_scale_x = mo.ui.slider(-2, 2, 0.1, value=1.0, label="K Scale X")
k_scale_y = mo.ui.slider(-2, 2, 0.1, value=1.0, label="K Scale Y")
k_rotate = mo.ui.slider(-180, 180, 5, value=0, label="K Rotate (deg)")

q_controls = mo.vstack([mo.md("**Query Transformation**"), q_scale_x, q_scale_y, q_rotate])
k_controls = mo.vstack([mo.md("**Key Transformation**"), k_scale_x, k_scale_y, k_rotate])
```

```python {marimo}
def _transform_embeddings(emb, scale_x, scale_y, rotate_deg):
    theta = np.radians(rotate_deg)
    rot_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
    scale_matrix = np.diag([scale_x, scale_y])
    W = rot_matrix @ scale_matrix
    return emb @ W.T

Q = _transform_embeddings(attention_embeddings, q_scale_x.value, q_scale_y.value, q_rotate.value)
K = _transform_embeddings(attention_embeddings, k_scale_x.value, k_scale_y.value, k_rotate.value)

# Compute attention scores
_scores = Q @ K.T
_exp_scores = np.exp(_scores - np.max(_scores, axis=1, keepdims=True))
attention_weights = _exp_scores / np.sum(_exp_scores, axis=1, keepdims=True)

# Create visualizations
_df_q = pd.DataFrame({"word": attention_words, "x": Q[:, 0], "y": Q[:, 1]})
_df_k = pd.DataFrame({"word": attention_words, "x": K[:, 0], "y": K[:, 1]})

_chart_q = alt.Chart(_df_q).mark_circle(size=100).encode(
    x=alt.X('x:Q', scale=alt.Scale(domain=[-4, 4]), title='Q1'),
    y=alt.Y('y:Q', scale=alt.Scale(domain=[-4, 4]), title='Q2'),
    tooltip=['word:N']
).properties(width=200, height=200, title="Query (Q)")
_text_q = _chart_q.mark_text(dy=-12, fontSize=10, fontWeight='bold').encode(text='word:N')

_chart_k = alt.Chart(_df_k).mark_circle(size=100).encode(
    x=alt.X('x:Q', scale=alt.Scale(domain=[-4, 4]), title='K1'),
    y=alt.Y('y:Q', scale=alt.Scale(domain=[-4, 4]), title='K2'),
    tooltip=['word:N']
).properties(width=200, height=200, title="Key (K)")
_text_k = _chart_k.mark_text(dy=-12, fontSize=10, fontWeight='bold').encode(text='word:N')

# Heatmap
_heatmap_data = []
for i, word_i in enumerate(attention_words):
    for j, word_j in enumerate(attention_words):
        _heatmap_data.append({"Query": word_i, "Key": word_j, "Weight": attention_weights[i, j]})
_df_heatmap = pd.DataFrame(_heatmap_data)

_heatmap = alt.Chart(_df_heatmap).mark_rect().encode(
    x=alt.X('Key:N', title='Key Word'),
    y=alt.Y('Query:N', title='Query Word'),
    color=alt.Color('Weight:Q', scale=alt.Scale(scheme='blues'), title='Attention'),
    tooltip=['Query:N', 'Key:N', alt.Tooltip('Weight:Q', format='.3f')]
).properties(width=250, height=250, title="Attention Weights (Softmax)")

mo.vstack([
    mo.hstack([q_controls, k_controls], align="center"),
    mo.hstack([_chart_q + _text_q, _chart_k + _text_k, _heatmap], align="center")
])
```

</marimo-iframe>
</div>

The output is a **contextualized vector** for each word—a representation that changes based on surrounding context. The word "bank" produces different vectors in "river bank" versus "financial bank" because the attention mechanism incorporates information from neighboring words.

To see this in action, consider how we might contextualize the word "bank" by mixing it with surrounding words. The visualization below shows static word embeddings—notice how "bank" sits neutrally between financial terms (money, loan) and geographical terms (river, shore).

<div>
<marimo-iframe data-height="400px" data-show-code="false">

```python {marimo}
static_words = ["bank", "money", "loan", "river", "shore"]
static_embeddings = np.array([
    [0.0, 0.0],  # bank (center)
    [-0.8, -0.3],  # money
    [-0.7, -0.6],  # loan
    [0.7, -0.5],  # river
    [0.6, -0.7],  # shore
]) * 2

_df_static = pd.DataFrame({"word": static_words, "x": static_embeddings[:, 0], "y": static_embeddings[:, 1]})

_chart_static = alt.Chart(_df_static).mark_circle(size=200).encode(
    x=alt.X('x:Q', scale=alt.Scale(domain=[-2, 2]), title='Dimension 1'),
    y=alt.Y('y:Q', scale=alt.Scale(domain=[-2, 2]), title='Dimension 2'),
    text='word:N',
    tooltip=['word:N', 'x:Q', 'y:Q']
).properties(width=300, height=300, title="Static Word Embeddings")

_text_static = _chart_static.mark_text(dy=-15, fontSize=14, fontWeight='bold').encode(text='word:N')

_chart_static + _text_static
```

</marimo-iframe>
</div>

Now, try adjusting the weights below to create a contextualized version of "bank." If the sentence is "Money in bank," adjust the weights to shift "bank" toward "money." If the sentence is "River bank," shift it toward "river."

<div>
<marimo-iframe data-height="500px" data-show-code="false">

```python {marimo}
context_words = ["bank", "money", "loan", "river", "shore"]
context_embeddings = np.array([
    [0.0, 0.0],  # bank (center)
    [-0.8, -0.3],  # money
    [-0.7, -0.6],  # loan
    [0.7, -0.5],  # river
    [0.6, -0.7],  # shore
]) * 2

slider_bank = mo.ui.slider(0, 1, 0.01, value=1.0, label="Bank Weight")
slider_money = mo.ui.slider(0, 1, 0.01, value=0, label="Money Weight")
slider_loan = mo.ui.slider(0, 1, 0.01, value=0, label="Loan Weight")
slider_river = mo.ui.slider(0, 1, 0.01, value=0, label="River Weight")
slider_shore = mo.ui.slider(0, 1, 0.01, value=0, label="Shore Weight")

context_sliders = mo.vstack([slider_bank, slider_money, slider_loan, slider_river, slider_shore])
```

```python {marimo}
_weights = np.array([slider_bank.value, slider_money.value, slider_loan.value, slider_river.value, slider_shore.value])
_total = _weights.sum()
if _total > 0:
    _weights = _weights / _total
    _new_vec = context_embeddings.T @ _weights
else:
    _new_vec = np.zeros(2)

_df_orig = pd.DataFrame({"word": context_words, "x": context_embeddings[:, 0], "y": context_embeddings[:, 1], "type": ["Original"] * 5})
_df_new = pd.DataFrame({"word": ["Contextualized Bank"], "x": [_new_vec[0]], "y": [_new_vec[1]], "type": ["Contextualized"]})
_df_combined = pd.concat([_df_orig, _df_new])

_chart_context = alt.Chart(_df_combined).mark_circle(size=200).encode(
    x=alt.X('x:Q', scale=alt.Scale(domain=[-2, 2]), title='Dimension 1'),
    y=alt.Y('y:Q', scale=alt.Scale(domain=[-2, 2]), title='Dimension 2'),
    color=alt.Color('type:N', scale=alt.Scale(domain=['Original', 'Contextualized'], range=['#dadada', '#ff7f0e'])),
    tooltip=['word:N', 'x:Q', 'y:Q']
).properties(width=350, height=350, title="Contextualized Bank")

_text_context = _chart_context.mark_text(dy=-15, fontSize=14, fontWeight='bold').encode(text='word:N', color=alt.value('black'))

mo.hstack([context_sliders, _chart_context + _text_context], align="center")
```

</marimo-iframe>
</div>

This manual weighting captures the intuition, but how do we learn which words to attend to? This is where queries and keys come in.

### Multi-Head Attention: Multiple Perspectives

A single attention mechanism captures one type of relationship. **Multi-head attention** runs multiple attention operations in parallel, each with different learned parameters. Each head can specialize—one might focus on syntactic dependencies (subject-verb relationships), another on semantic similarity (synonyms and antonyms), another on positional proximity (nearby words).

```{figure} ../figs/transformer-multihead-attention.jpg
:name: transformer-multihead-attention
:alt: Multi-Head Attention
:width: 50%
:align: center

Multi-head attention runs multiple attention operations in parallel, each capturing different relationships.
```

The outputs from all heads are concatenated and passed through a final linear transformation to produce the multi-head attention output. In the original transformer paper {footcite:p}`vaswani2017attention`, the authors used $h=8$ attention heads, with each head using dimension $d_k = d_v = d/h = 64$, where $d=512$ is the model dimension.

## Layer Normalization: Numerical Stability

Deep networks suffer from numerical instability—activations can grow explosively large or vanish to zero as they propagate through layers. **Layer normalization** stabilizes training by rescaling activations to have zero mean and unit variance.

```{figure} https://miro.medium.com/v2/resize:fit:1400/0*Agdt1zYwfUxXMJGJ
:name: transformer-layer-normalization
:alt: Layer Normalization
:width: 80%
:align: center

Layer normalization computes mean and standard deviation across all features for each sample, then normalizes.
```

For each input vector $x$, layer normalization computes:

$$
\text{LayerNorm}(x) = \gamma \frac{x - \mu}{\sigma} + \beta
$$

where $\mu$ and $\sigma$ are the mean and standard deviation of $x$, and $\gamma$ and $\beta$ are learned scaling and shifting parameters (initialized to 1 and 0 respectively). This ensures that no matter how the input distribution shifts during training, each layer receives inputs in a stable numerical range.

## The Encoder Block

Now we wire the components together. The **encoder block** processes the input sequence through four stages:

1. **Multi-head self-attention** computes contextualized representations
2. **Residual connection + normalization** stabilizes training
3. **Feed-forward network** applies nonlinear transformation
4. **Residual connection + normalization** again

```{figure} ../figs/transformer-encoder.jpg
:name: transformer-block
:alt: Transformer Block
:width: 50%
:align: center

Information flows through multi-head attention, normalization, feed-forward networks, and final normalization.
```

The feed-forward network is a simple two-layer MLP applied independently to each position:

$$
\text{FFN}(x) = \text{ReLU}(x W_1 + b_1) W_2 + b_2
$$

The **residual connections** (also called skip connections) are critical for training deep networks. Instead of learning a direct mapping $f(x)$, we learn the residual:

$$
x_{\text{out}} = x_{\text{in}} + f(x_{\text{in}})
$$

This simple addition has profound consequences for gradient flow. During backpropagation, the gradient of the loss $\mathcal{L}$ with respect to layer $l$ is:

$$
\frac{\partial \mathcal{L}}{\partial x_l} = \frac{\partial \mathcal{L}}{\partial x_{l+1}} \left(1 + \frac{\partial f_l}{\partial x_l}\right)
$$

Notice the "+1" term—this provides a direct gradient path from the output back to the input. Without residual connections, gradients must pass through the chain:

$$
\frac{\partial f_L}{\partial f_{L-1}} \cdot \frac{\partial f_{L-1}}{\partial f_{L-2}} \cdot \ldots \cdot \frac{\partial f_1}{\partial x}
$$

If any term is less than 1, the gradient shrinks exponentially—this is the **vanishing gradient problem**. With residual connections, the gradient expansion becomes:

$$
1 + O_1 + O_2 + O_3 + \ldots
$$

where $O_1$ contains first-order terms, $O_2$ contains second-order products, etc. The constant "1" ensures gradients can flow even when the learned components $f_i$ produce small derivatives. This architectural innovation, originally developed for computer vision {footcite:p}`he2015deep`, is what allows transformers to scale to hundreds of layers.

## The Decoder Block

The **decoder block** extends the encoder with two modifications: **masked self-attention** and **cross-attention**.

```{figure} ../figs/transformer-decoder.jpg
:name: transformer-decoder
:alt: Transformer Decoder
:width: 50%
:align: center

The decoder adds masked self-attention (to prevent future peeking) and cross-attention (to access encoder outputs).
```

### Masked Self-Attention: Preventing Future Leakage

During training, we know the entire target sequence. For translation ("I love you" → "Je t'aime"), we have both input and output. A naive decoder could "cheat" by looking at future words in the target sequence. Masked self-attention prevents this by zeroing out attention to future positions.

The mask is implemented by setting attention scores to $-\infty$ before the softmax:

$$
\text{MaskedAttention}(Q, K, V) = \text{softmax}\left(\frac{QK^T + M}{\sqrt{d_k}}\right)V
$$

where $M$ is a matrix with $-\infty$ at positions $(i,j)$ where $j > i$ (future positions) and 0 elsewhere. After softmax, these $-\infty$ values become zero, eliminating information flow from future tokens.

```{figure} https://substackcdn.com/image/fetch/f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe1317a05-3542-4158-94bf-085109a5793a_1220x702.png
:name: transformer-masked-attention
:alt: Masked Attention
:width: 80%
:align: center

Masked attention zeros out future positions, allowing parallel training without information leakage.
```

This enables **parallel training**. Instead of generating "Je", then "t'aime", then the final token sequentially, we can train all positions simultaneously—each with access only to its causal past. During inference, masking happens naturally because future tokens don't exist yet.

### Cross-Attention: Connecting Encoder and Decoder

The second attention layer in the decoder uses **cross-attention** to access the encoder's output. The queries ($Q$) come from the decoder's previous layer, while the keys ($K$) and values ($V$) come from the encoder's output:

$$
\text{CrossAttention}(Q_{\text{decoder}}, K_{\text{encoder}}, V_{\text{encoder}}) = \text{softmax}\left(\frac{Q_{\text{decoder}}K_{\text{encoder}}^T}{\sqrt{d_k}}\right)V_{\text{encoder}}
$$

```{figure} ../figs/transformer-cross-attention.jpg
:name: transformer-cross-attention
:alt: Cross-Attention
:width: 60%
:align: center

Cross-attention allows the decoder to query the encoder's representation.
```

This is how translation works: when generating "Je", the decoder attends to "I"; when generating "t'aime", it attends to "love". The attention mechanism learns these alignments automatically from data, without explicit supervision.

## Position Embedding: Encoding Order

Attention is **permutation invariant**—it produces the same output regardless of input order. "The cat sat on the mat" and "mat the on sat cat the" yield identical attention outputs because the dot product doesn't encode position. We need to inject positional information.

The naive approach is to add a position index: $x_t := x_t + \beta t$. This fails for two reasons:

1. **Unbounded**: Position indices grow arbitrarily large. Models trained on sequences of length 512 fail on sequences of length 1000 because they've never seen position 513.
2. **Discrete**: Positions 10 and 11 are no more similar than positions 10 and 100.

A better approach is **binary position encoding**. Represent position $t$ as a binary vector:

$$
\begin{align*}
  0: \ \ \ \ \texttt{0} \ \ \texttt{0} \ \ \texttt{0} \ \ \texttt{0} & \quad &
  8: \ \ \ \ \texttt{1} \ \ \texttt{0} \ \ \texttt{0} \ \ \texttt{0} \\
  1: \ \ \ \ \texttt{0} \ \ \texttt{0} \ \ \texttt{0} \ \ \texttt{1} & &
  9: \ \ \ \ \texttt{1} \ \ \texttt{0} \ \ \texttt{0} \ \ \texttt{1} \\
  2: \ \ \ \ \texttt{0} \ \ \texttt{0} \ \ \texttt{1} \ \ \texttt{0} & &
  10: \ \ \ \ \texttt{1} \ \ \texttt{0} \ \ \texttt{1} \ \ \texttt{0}
\end{align*}
$$

This is unbounded—you can represent arbitrarily large positions by adding bits—but still discrete. The transformer solution is **sinusoidal position embedding**, a continuous version of binary encoding:

$$
\text{Pos}(t, i) =
\begin{cases}
\sin\left(\dfrac{t}{10000^{2i/d}}\right), & \text{if } i \text{ is even} \\
\cos\left(\dfrac{t}{10000^{2i/d}}\right), & \text{if } i \text{ is odd}
\end{cases}
$$

where $t$ is the position index and $i$ is the dimension index. This encoding has three critical properties:

1. **Continuous**: Smooth interpolation between positions
2. **Bounded**: All values lie in $[-1, 1]$
3. **Relative distance preservation**: The dot product $\text{Pos}(t) \cdot \text{Pos}(t+k)$ depends only on the offset $k$, not the absolute position $t$

```{figure} https://kazemnejad.com/img/transformer_architecture_positional_encoding/positional_encoding.png
:name: transformer-position-embedding
:alt: Transformer Position Embedding
:width: 80%
:align: center

Sinusoidal position embeddings exhibit periodic patterns across dimensions. Image from [Amirhossein Kazemnejad](https://kazemnejad.com/blog/transformer_architecture_positional_encoding/).
```

Notice the alternating pattern—just like binary encoding, but continuous. Low-frequency dimensions (right) flip slowly across positions; high-frequency dimensions (left) flip rapidly. This creates a unique fingerprint for each position while preserving distance relationships.

```{figure} https://kazemnejad.com/img/transformer_architecture_positional_encoding/time-steps_dot_product.png
:name: transformer-position-embedding-similarity
:alt: Transformer Position Embedding Similarity
:width: 80%
:align: center

Dot product between position embeddings depends only on relative distance, not absolute position. Image from [Amirhossein Kazemnejad](https://kazemnejad.com/blog/transformer_architecture_positional_encoding/).
```

The position embedding is added directly to the token embedding: $x_{t,i} := x_{t,i} + \text{Pos}(t, i)$. Why addition instead of concatenation? Concatenation would increase the model dimension, adding parameters. Addition creates an interesting interaction in the attention mechanism—queries and keys now encode both content and position, allowing the model to attend based on both "what" (semantic similarity) and "where" (positional proximity).

## The Takeaway

Transformers replaced sequential computation with parallel relationship mapping. Every position simultaneously computes its context from every other position. This architectural shift—from recurrent bottlenecks to parallel attention—is what allowed language models to scale from millions to hundreds of billions of parameters. The mechanism is simple: query, key, value. The result is GPT-4.

```{footbibliography}
:style: unsrt
:filter: docname in docnames
```

<script src="https://cdn.jsdelivr.net/npm/@marimo-team/marimo-snippets@1"></script>