# Demystifying Transformers

Transformers are taking over machine learning. Since they were first described in 2017 (Vaswani et al, 2017), Transformers have come to define the entire NLP category, and are now spreading to image processing, reinforcement learning, and beyond. It's obvious that anyone who wants to contribute to the future of AI needs to learn the Transformer architecture inside and out. But this is easier said than done. While the Internet is full of papers, blog posts, and tutorials on Transformers, all of that material can be overwhelming.

One key challenge I found when first studying Transformers was distilling the abstract ideas in the research literature into concrete, actionable steps I could experiment with. I wanted to "see the code" so to speak. While it's easy to find open source Transformer implementations, I found they are often overloaded with configuration knobs and dials, conditionalized this way and that to support every Transformer variation ever imagined. All of this reusability likely adds value for the Transformer experts who wrote them, but it makes it difficult to see the big picture anymore.

My goal in this post is to help you get started with Transformers by walking through the core elements of the Transformer architecture step-by-step, connecting the abstract ideas from the literature to concrete lines of code without heaps of configuration logic piled on top.

# Background

Unlike many machine learning models that process one input at a time, Transformers are *sequence models*. This means they process multiple inputs represented as an ordered list or *sequence*. This is a big deal. Not only can sequence models learn about the content of each input, they also learn about relationships between inputs. This makes sequence models ideally suited for a wide range of tasks that involve context such as processing text, video, and timeseries.

Transformers are not the only sequence models out there. Transformers were introduced in 2017 as an improvement over earlier recurrent and convolutional approaches. These earlier approaches were already using "attention" mechanisms. The key innovation in the Transformer was realizing that the combination of attention and feedforward networks was powerful enough that the extra complexity of the RNN and CNN architectures was no longer needed. Not only did the Transformer perform better than previous approaches, the simpler neural network architecture proved to be faster and easier to train as well. (Vaswani et al, 2017)

# Setup

In [1]:
from math import sqrt
import warnings

import numpy as np
import torch
from torch.nn.functional import relu, softmax
import transformers

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Ignore all warnings
warnings.filterwarnings("ignore")

# Configure gpu
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

# Transformer Pipelines

The following diagram depicts a Transformer as a multi-stage pipeline. The center of the Transformer contains a stack of transformer layers that are responsible for mapping input embeddings to output embeddings. This is where most of the magic happens. The stages before and after the Transformer Stack provide extra machinery required first to transform raw data into input embeddings and then output embeddings into task-specific outputs.

<center><img src="img/transformer-pipeline.svg" width="500"></center> 

# Text Classifier

We'll use a basic text classification pipeline to illustrate the process. The following cells use Hugging Face's transformers package to create an end-to-end transformer pipeline that classifies text as either positive or negative. There is a lot happening in very few lines of code. Over the rest of this post, we'll break the pipeline down and walk through each step from "I love ice cream" to the POSITIVE classification.

In [3]:
# Create off-the-shelf text classification transformer
transformer = transformers.pipeline("text-classification", device=device)

No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).
Using a pipeline without specifying a model name and revision in production is not recommended.


In [4]:
transformer("I love ice cream")

[{'label': 'POSITIVE', 'score': 0.9998118281364441}]

In [5]:
transformer("I hate ice cream")

[{'label': 'NEGATIVE', 'score': 0.9974052309989929}]

In [6]:
# Save config params for later
config = transformer.model.distilbert.config
vocab_size = config.vocab_size
d_model = config.dim
n_heads = config.n_heads
d_head = int(d_model / n_heads)
n_layers = config.n_layers
max_sequence_length = config.max_position_embeddings
batch_size = 1

# Pre Process

The Pre Process stage of a transformer is responsible for transforming raw input data into a sequence of categorical values. For NLP tasks, this is accomplished using a `Tokenizer` that parses raw text into tokens and maps the tokens to integer-encoded categorical values using a fixed vocabulary. 

In [7]:
# Lookup tokenizer in transformer
tokenizer = transformer.tokenizer

In [8]:
# Convert "I love ice cream" into input values
batch = tokenizer("I love ice cream", return_tensors="pt").to(device)
batch

{'input_ids': tensor([[ 101, 1045, 2293, 3256, 6949,  102]], device='mps:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]], device='mps:0')}

In [9]:
# Decode input values so we can see what they represent 
[tokenizer.decode(input_id) for input_id in batch["input_ids"][0]]

['[CLS]', 'i', 'love', 'ice', 'cream', '[SEP]']

Here we can see the tokenizer parsed "I love ice cream" into 6 input values 101, 1045, 2293, 3256, 6949, 102. There is one value for each word as well [CLS] and [SEP] markers to flag the beginning and end of the original sequence.

# Embeddings

The Embeddings stage of a transformer transforms the input sequence into embeddings. This is typically done using two lookup tables. The first table maps the input values to embeddings. The second table maps the input positions to embeddings. The value and position embeddings are then added together to create position-encoded input embeddings.

<center><img src="img/embeddings.svg" width="700"></center> 

Both the value and position embeddings are initialized randomly and then learned during training. This means the embeddings are transformer-specific. You can't take embeddings from one transformer and use them in another without retraining.

In [10]:
# Lookup embeddings in transformer
value_embeddings = transformer.model.distilbert.embeddings.word_embeddings
position_embeddings = transformer.model.distilbert.embeddings.position_embeddings
normalize = transformer.model.distilbert.embeddings.LayerNorm

The input value embedding layer maps the integer-encoded categorical input values to unique embedding vectors. In the following cell, you can see `value_embeddings` has 30,522 768-element embedding vectors. 30,522 is the transformer's vocabulary size (`vocab_size`). 768 is the transformer's model dimensions (`d_model`).

In [11]:
value_embeddings

Embedding(30522, 768, padding_idx=0)

The input position embedding layer maps the positions of each input value to a set of embeddings. In the following cell, you can see `position_embeddings` has 512 768-element embedding vectors. 512 is the maximum sequence length the transformer can process (`max_sequence_length`).

In [12]:
position_embeddings

Embedding(512, 768)

Next, let's take convert the input_ids from the Pre Process stage into position-encoded input embeddings.

In [13]:
input_ids = batch["input_ids"]
input_ids

tensor([[ 101, 1045, 2293, 3256, 6949,  102]], device='mps:0')

In [14]:
# Lookup value embeddings for each input
v = value_embeddings(input_ids)
v.shape

torch.Size([1, 6, 768])

In [15]:
v

tensor([[[ 3.9925e-02, -1.0171e-02, -2.0390e-02,  ...,  6.1588e-02,
           2.1959e-02,  2.2732e-02],
         [-1.2794e-02,  4.9879e-03, -2.6270e-02,  ..., -7.2300e-05,
           5.3657e-03,  1.1908e-02],
         [ 5.9359e-02, -2.3563e-02, -2.0560e-03,  ..., -1.0420e-02,
           1.4846e-02, -1.2815e-02],
         [-2.4101e-02, -2.4911e-02, -2.2601e-02,  ..., -2.5139e-02,
           1.1392e-02,  3.2655e-02],
         [-8.5466e-02, -5.9276e-02, -5.6659e-02,  ..., -1.7192e-02,
          -8.6179e-02, -4.5105e-02],
         [-2.1060e-02, -6.4941e-03, -1.0682e-02,  ..., -2.3401e-02,
           6.1463e-03, -6.4845e-03]]], device='mps:0',
       grad_fn=<EmbeddingBackward0>)

In [16]:
position_ids = torch.arange(input_ids.size(1)).expand(1, -1).to(device)
position_ids

tensor([[0, 1, 2, 3, 4, 5]], device='mps:0')

In [17]:
# Lookup position embeddings for each input
p = position_embeddings(position_ids)
p.shape

torch.Size([1, 6, 768])

In [18]:
p

tensor([[[ 1.8007e-02, -2.3798e-02, -3.5982e-02,  ...,  4.5726e-04,
           5.1363e-05,  1.5002e-02],
         [ 7.8592e-03,  4.8144e-03, -1.6093e-02,  ...,  2.9312e-02,
           2.7634e-02, -8.5431e-03],
         [-1.1663e-02, -3.1590e-03, -9.4000e-03,  ...,  1.4870e-02,
           2.1609e-02, -7.4069e-03],
         [-4.0848e-03, -1.1123e-02, -2.1704e-02,  ...,  1.8962e-02,
           4.6763e-03, -1.0220e-03],
         [-8.2666e-03, -4.1641e-03, -7.5136e-03,  ...,  1.9757e-02,
          -2.2192e-03,  3.8681e-03],
         [ 4.6293e-04, -1.8499e-02, -1.9709e-02,  ...,  5.4042e-03,
           1.8076e-02,  2.9490e-03]]], device='mps:0',
       grad_fn=<EmbeddingBackward0>)

In [19]:
# Combine value and position embeddings and normalize
input_embeddings = normalize(v + p)
input_embeddings.shape

torch.Size([1, 6, 768])

In [20]:
input_embeddings

tensor([[[ 0.3549, -0.1386, -0.2253,  ...,  0.1536,  0.0748,  0.1310],
         [ 0.2282,  0.5511, -0.5092,  ...,  0.6421,  0.9541,  0.3192],
         [ 1.4511, -0.0794,  0.2168,  ...,  0.2851,  1.0723, -0.0919],
         [-0.0564, -0.1761, -0.2870,  ...,  0.1442,  0.6767,  1.0396],
         [-1.1349, -0.5135, -0.4714,  ...,  0.3874, -1.0348, -0.2812],
         [-0.2980, -0.3332, -0.3742,  ..., -0.3392,  0.3764, -0.1298]]],
       device='mps:0', grad_fn=<NativeLayerNormBackward0>)

At this point, `input_embeddings` contains a position-encoded input embedding vector for each value in the input sequence.

# Transformer Stack

The Transformer Stack is where the "magic" happens. The stack is responsible for transforming position-encoded input embeddings into contextualized output embeddings through layers of attention and feed-forward networks. The difference between most transformer architectures such as BERT, BART, GPT is in the number of layers, the size of each layer, and the type of attention used.

<center><img src="img/transformer-layer.svg" width="500"></center>

Let's take a look at one of the layers in our text classification transformer.

In [44]:
layer = transformer.model.distilbert.transformer.layer[0]

# Extract building blocks from layer
query_projection = layer.attention.q_lin
key_projection = layer.attention.k_lin
value_projection = layer.attention.v_lin
output_projection = layer.attention.out_lin
normalize_attention = layer.sa_layer_norm
ffn = layer.ffn
normalize_ffn = layer.output_layer_norm

layer

TransformerBlock(
  (attention): MultiHeadSelfAttention(
    (dropout): Dropout(p=0.1, inplace=False)
    (q_lin): Linear(in_features=768, out_features=768, bias=True)
    (k_lin): Linear(in_features=768, out_features=768, bias=True)
    (v_lin): Linear(in_features=768, out_features=768, bias=True)
    (out_lin): Linear(in_features=768, out_features=768, bias=True)
  )
  (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (ffn): FFN(
    (dropout): Dropout(p=0.1, inplace=False)
    (lin1): Linear(in_features=768, out_features=3072, bias=True)
    (lin2): Linear(in_features=3072, out_features=768, bias=True)
    (activation): GELUActivation()
  )
  (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)

## Multi-Head Self Attention

You can see in the cell above that transformer is using Multi-Head Self Attention. This is the standard attention algorithm described in the original Attention is All You Need paper (Vaswani et al, 2017). While many transformers still use the original algorithm, others are experimenting with new forms of attention (Dubey et al. 2024).

In [22]:
def split_heads(x):
    return x.view(batch_size, -1, n_heads, d_head).transpose(1, 2)

def combine_heads(x):
    return x.transpose(1, 2).contiguous().view(batch_size, -1, int(n_heads * d_head))

In [23]:
# Project input_embeddings into "query space"
q = query_projection(input_embeddings)

# Project input_embeddings into "key space"
k = key_projection(input_embeddings)

# Project input_embeddings into "value space"
v = value_projection(input_embeddings)

q.shape, k.shape, v.shape

(torch.Size([1, 6, 768]), torch.Size([1, 6, 768]), torch.Size([1, 6, 768]))

In [24]:
# Split q, k, v into multiple heads
q = split_heads(q)
k = split_heads(k)
v = split_heads(v)

q.shape, k.shape, v.shape

(torch.Size([1, 12, 6, 64]),
 torch.Size([1, 12, 6, 64]),
 torch.Size([1, 12, 6, 64]))

Next, we calculate attention using standard formula: $Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$

In [25]:
# Calculate attention for all heads in parallel
attention = softmax(q @ k.transpose(2, 3) / sqrt(d_head), dim=-1) @ v
attention.shape

torch.Size([1, 12, 6, 64])

In [26]:
# Combine heads back together
attention = combine_heads(attention)
attention.shape

torch.Size([1, 6, 768])

In [27]:
# Project attention back into "model space"
attention_embeddings = output_projection(attention)

In [28]:
# Combine attention and input embeddings
attention_embeddings = normalize_attention(attention_embeddings + input_embeddings)
attention_embeddings.shape

torch.Size([1, 6, 768])

## FNN

In [29]:
# Transform attention embeddings
ffn_embeddings = ffn(attention_embeddings)

In [30]:
# Combine ffn and attention embeddings
ffn_embeddings = normalize_ffn(ffn_embeddings + attention_embeddings)
ffn_embeddings.shape

torch.Size([1, 6, 768])

In [31]:
ffn_embeddings

tensor([[[ 0.2824, -0.0438, -0.1085,  ...,  0.0478, -0.0888, -0.1025],
         [ 0.5897,  0.7285,  0.0869,  ...,  0.1732,  0.5214,  0.4234],
         [ 1.7403,  0.1464,  0.4697,  ...,  0.2500,  0.8521, -0.2792],
         [-0.1238, -0.3677,  0.1768,  ...,  0.8630,  0.8072,  0.3339],
         [-0.3699, -0.2425, -0.5013,  ...,  0.7443, -0.5795, -0.6643],
         [-0.0757, -0.0300, -0.0669,  ..., -0.1338,  0.1781,  0.0099]]],
       device='mps:0', grad_fn=<NativeLayerNormBackward0>)

## N-Layers

Now that we've walked through a single layer step-by-step, let's put the pieces together and apply the entire stack.

In [32]:
embeddings = input_embeddings

for i in range(n_layers):
    layer = transformer.model.distilbert.transformer.layer[i]

    # Extract building blocks from layer
    query_projection = layer.attention.q_lin
    key_projection = layer.attention.k_lin
    value_projection = layer.attention.v_lin
    output_projection = layer.attention.out_lin
    normalize_attention = layer.sa_layer_norm
    ffn = layer.ffn
    normalize_ffn = layer.output_layer_norm

    # Project embeddings into "query space"
    q = query_projection(embeddings)
    
    # Project embeddings into "key space"
    k = key_projection(embeddings)
    
    # Project embeddings into "value space"
    v = value_projection(embeddings)

    # Split q, k, v into multiple heads
    q = split_heads(q)
    k = split_heads(k)
    v = split_heads(v)

    # Calculate attention for all heads in parallel
    attention = softmax(q @ k.transpose(2, 3) / sqrt(d_head), dim=-1) @ v

    # Combine heads back together
    attention = combine_heads(attention)
    
    # Project attention back into "model space"
    attention_embeddings = output_projection(attention)
    
    # Combine attention and embeddings
    attention_embeddings = normalize_attention(attention_embeddings + embeddings)
    
    # Transform attention embeddings
    ffn_embeddings = ffn(attention_embeddings)
    
    # Combine ffn and attention embeddings
    ffn_embeddings = normalize_ffn(ffn_embeddings + attention_embeddings)

    # Rinse and repeat
    embeddings = ffn_embeddings
    
output_embeddings = embeddings
output_embeddings.shape

torch.Size([1, 6, 768])

In [33]:
output_embeddings

tensor([[[ 3.6173e-01, -1.3168e-01,  3.5342e-02,  ...,  4.4015e-01,
           1.0666e+00, -1.9293e-01],
         [ 7.3341e-01,  4.9823e-02, -1.7590e-02,  ...,  5.0063e-01,
           1.1480e+00, -1.2997e-01],
         [ 1.1230e+00,  2.7603e-01,  3.2096e-01,  ...,  1.8820e-01,
           1.0586e+00, -1.2496e-01],
         [ 4.8728e-01,  1.4863e-02,  4.2930e-01,  ...,  4.8993e-01,
           7.9436e-01,  1.2331e-01],
         [ 1.0596e-03, -1.4508e-01,  2.8892e-01,  ...,  5.5342e-01,
           7.9370e-01, -9.0898e-02],
         [ 1.1021e+00,  8.6115e-02,  5.7461e-01,  ...,  6.8800e-01,
           5.6345e-01, -6.6278e-01]]], device='mps:0',
       grad_fn=<NativeLayerNormBackward0>)

# Head

The Head stage of a transformer takes the contextualized embeddings and applies a task-specific transformation to reach the desired output. In our case, we're running a binary text classifier so the goal of the Head stage is to convert contextualized embeddings into a binary prediction.

In [34]:
# Lookup head classifier in transformer
pre_classifier = transformer.model.pre_classifier
classifier = transformer.model.classifier

The head classifier is a standard FNN. But what do we pass into it? The FNN expects one set of features to predict on and we have a sequence of feature vectors.

Since the output embeddings are "contextualized", it's common practice to use the first embedding to represent the entire sequence and drop the rest. Note that you couldn't do that with the input embeddings because they represented each input value independently. It's only after the transformer has added context from the rest of the sequence that the first embedding can meaningfully represent the rest.

In [35]:
# Represent sequence with contextualized embedding for start marker [CLS]
embedding = output_embeddings[:, 0]
embedding.shape

torch.Size([1, 768])

In [36]:
# Transform embedding
embedding = relu(pre_classifier(embedding))
embedding.shape

torch.Size([1, 768])

In [37]:
# Classify embedding
logits = classifier(embedding)
logits.shape

torch.Size([1, 2])

In [38]:
logits

tensor([[-4.1625,  4.4154]], device='mps:0', grad_fn=<LinearBackward0>)

# Post Process

The final stage in the Transformer is to convert the predicted logits into probabilities. In the cells below, you can see we get the same prediction that we started with when running the high level pipeline API. This should give you confidence that the steps we walked through covered everything from raw text to positive predicition.

In [39]:
# Convert logits into probabilities
scores = softmax(logits[0])
scores

tensor([1.8818e-04, 9.9981e-01], device='mps:0', grad_fn=<SoftmaxBackward0>)

In [40]:
# Move scores back to cpu
scores = scores.detach().to("cpu").float().numpy()

In [41]:
# Map scores to labels
{config.id2label[scores.argmax()]: scores[scores.argmax()]}

{'POSITIVE': np.float32(0.9998118)}

# Summary

# References

Dubey, Abhimanyu, Abhinav Jauhri, Abhinav Pandey, Abhishek Kadian, Ahmad Al-Dahle, Aiesha Letman, Akhil Mathur, et al. “The Llama 3 Herd of Models.” arXiv, July 31, 2024. https://doi.org/10.48550/arXiv.2407.21783.

Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. “Attention Is All You Need.” arXiv, 2017. https://doi.org/10.48550/arXiv.1706.03762.