A re-implementation of Callum McDougall's "Transformer from Scratch" using JAX and Flax.
This repository contains a clean, modular implementation of the Transformer architecture using JAX and Flax. The implementation follows the original architecture described in the "Attention Is All You Need" paper, with a focus on readability and educational value.
config.py
- Configuration dataclass with model hyperparametersmodules.py
- Core transformer components:LayerNorm
- Layer normalisationEmbed
- Token embedding layerPosEmbed
- Positional embedding layerAttention
- Multi-head self-attention mechanismMLP
- Multi-layer perceptron with GELU activationTransformerBlock
- Complete transformer block with attention, MLP and residual connectionsUnembed
- Final projection layer to vocabulary
transformer.py
- Full transformer modeltests/
- Test suite for all modulestransformer.ipynb
- Jupyter notebook for experimentation
The implementation includes a standard transformer architecture with:
- Token embeddings + positional embeddings
- Multiple transformer blocks, each containing:
- Multi-head self-attention with causal masking
- Layer normalisation (applied before attention and MLP)
- MLP with GELU activation
- Residual connections
- Final layer normalisation and projection to vocabulary
The default model configuration is:
- Hidden dimension (
d_model
): 768 - Attention heads (
n_heads
): 12 - Layers (
n_layers
): 12 - Context length (
n_ctx
): 1024 - MLP dimension (
d_mlp
): 3072 - Vocabulary size (
d_vocab
): 50257
- JAX
- Flax
- Jaxtyping
- Einops
- Clone the repository
- Install dependencies:
pip install jax flax jaxtyping einops
- Run tests:
python -m tests.test_modules
- Experiment with the model in the Jupyter notebook:
jupyter notebook transformer.ipynb
import jax.numpy as jnp
import jax.random as jr
from config import Config
from transformer import Transformer
# Initialize model with default config
cfg = Config()
model = Transformer(cfg=cfg)
# Generate random input tokens
tokens = jnp.ones((2, 4), jnp.int32) # Batch size 2, sequence length 4
rng = jr.PRNGKey(0)
# Initialize parameters
variables = model.init(rng, tokens)
# Run forward pass
logits = model.apply(variables, tokens)
This implementation is based on Callum McDougall's "Transformer from Scratch" and adapted to use JAX and Flax. It aims to provide a clear, educational implementation of the transformer architecture.
This work was supported by UK Research and Innovation [grant number EP/S023356/1], in the UKRI Centre for Doctoral Training in Safe and Trusted Artificial Intelligence (safeandtrustedai.org).