# Implementing a Tiny GPT Language Model in Jax

In this tutorial, we'll guide you through building a simplified version of the GPT (Generative Pre-trained Transformer) model using Jax. This compact model will serve as a practical introduction to the world of language models and the Jax framework. While inspired by the original GPT-2 architecture presented in the GPT-2 paper, our model will be significantly smaller to ensure a manageable and efficient learning experience.


You can find more details about GPT models here: [GPT](https://www.semanticscholar.org/paper/Improving-Language-Understanding-by-Generative-Radford-Narasimhan/cd18800a0fe0b668a1cc19f2ec95b5003d0a5035), [GPT-2](https://www.semanticscholar.org/paper/Language-Models-are-Unsupervised-Multitask-Learners-Radford-Wu/9405cc0d6169988371b2755e573cc28650d14dfe),[GPT-3](https://arxiv.org/abs/2005.14165).
We also used [nanoGPT](https://github.com/karpathy/nanoGPT) as an inspirational example.

#1. Setup and Imports

First, let's install the necessary libraries:

In [None]:
# !pip install git+https://github.com/google/flax.
!pip install tiktoken



And import the necessary libraries:

In [None]:
from dataclasses import dataclass
import functools
from typing import Any
from flax import nnx
from flax.nnx.nnx.nn import initializers
import jax
import jax.numpy as jnp
import optax

#2. Model Architecture

We'll define the fundamental building blocks of the GPT-2 model:
* SelfAttention
* MlpBlock
* CausalAttentionBlock


Then stack these blocks to form our tiny GPT model:
* GPTModel



##2.0 Model Config

Before we start with building blocks, let's define config , so rest of the component can use common config.

In [None]:
@dataclass
class Config:
  """Configuration for a demo language model."""

  # The dimensionality of the embeddings (i.e., the size of the vector representing each word/token).
  num_embd: int

  # The number of attention heads in each multi-head attention layer.
  num_heads: int

  # The number of transformer blocks (layers) in the model.
  num_layers: int

  # The size of the vocabulary (i.e., the number of unique words/tokens the model knows).
  vocab_size: int

  # The maximum context length (i.e., the number of previous tokens the model considers when predicting the next one).
  block_size: int

  # The dropout probability for the attention layer.
  attn_pdrop: float

  # The dropout probability for the residual connections.
  resid_pdrop: float

  # The dropout probability for the embedding layer.
  embd_pdrop: float

##2.1 Attention layer

In [None]:
class SelfAttention(nnx.Module):
  """Implements a causal self-attention layer with a projecttion at the end.

  It is possible to use nnx.MultiHeadAttention instead.
  """

  # Configurable parameters
  num_heads: int
  head_dim: int

  def __init__(self, config: Config, rngs: nnx.Rngs):
    super().__init__()

    self.num_heads = config.num_embd
    self.head_dim = config.num_embd // config.num_heads

    # key, query, value projections for all heads.
    out_features = (config.num_heads, self.head_dim)
    self.q_attn = nnx.LinearGeneral(
        in_features=config.num_embd, out_features=out_features, rngs=rngs
    )
    self.k_attn = nnx.LinearGeneral(
        in_features=config.num_embd, out_features=out_features, rngs=rngs
    )
    self.v_attn = nnx.LinearGeneral(
        in_features=config.num_embd, out_features=out_features, rngs=rngs
    )

    # output projection
    self.out = nnx.LinearGeneral(
        in_features=out_features,
        out_features=config.num_embd,
        axis=(-2, -1),
        rngs=rngs,
    )

    # regularization
    self.attn_dropout = nnx.Dropout(config.attn_pdrop, rngs=rngs)
    self.resid_dropout = nnx.Dropout(config.resid_pdrop, rngs=rngs)

  def __call__(self, x, mask, train: bool = True):
    # batch, sequence length, embedding dimensionality
    B, T, C = x.shape
    dtype = x.dtype

    # calculate query, key, values for all heads in batch and move head forward to be the batch dim
    k = self.k_attn(x)
    q = self.q_attn(x)
    v = self.v_attn(x)

    # Scaled dot-product attention
    q = q / jnp.sqrt(self.head_dim).astype(dtype)
    # attn weight shape is (batch..., num_heads, q_length, kv_length)
    attn_weights = jnp.einsum('...qhd,...khd->...hqk', q, k)

    # apply mask
    big_neg = jnp.finfo(dtype).min
    attn_weights = jnp.where(mask, attn_weights, big_neg)

    # normalize the attention weights
    attn_weights = nnx.softmax(attn_weights, axis=-1)
    # apply the dropout mask
    attn_weights = self.attn_dropout(attn_weights, deterministic=not train)

    # Attention output
    y = jnp.einsum('...hqk,...khd->...qhd', attn_weights, v)
    # reshape back to batch, sequence length, embedding dimensionality and apply dropout
    y = self.resid_dropout(self.out(y), deterministic=not train)

    return y

##2.2 Feed-Forward Network

In [None]:
class MlpBlock(nnx.Module):

  def __init__(self, config: Config, rngs: nnx.Rngs):
    super().__init__()
    self.c_fc = nnx.Linear(config.num_embd, 4 * config.num_embd, rngs=rngs)
    self.c_proj = nnx.Linear(4 * config.num_embd, config.num_embd, rngs=rngs)
    self.dropout = nnx.Dropout(config.resid_pdrop, rngs=rngs)
    self.act = nnx.gelu

  def __call__(self, x, train: bool = True) -> Any:
    x = self.c_fc(x)
    x = self.act(x)
    x = self.dropout(x, deterministic=not train)
    x = self.c_proj(x)

    return x

##2.3 Transformer Block

In [None]:
class AttentionBlock(nnx.Module):

  def __init__(self, config: Config, rngs: nnx.Rngs):
    super().__init__()
    self.ln_1 = nnx.LayerNorm(num_features=config.num_embd, rngs=rngs)
    self.attn = SelfAttention(config, rngs=rngs)
    self.ln_2 = nnx.LayerNorm(num_features=config.num_embd, rngs=rngs)
    self.mlp = MlpBlock(config, rngs=rngs)

  def __call__(self, x, mask, train: bool = True):
    ln_x = self.ln_1(x)
    attn_x = self.attn(ln_x, mask, train)
    x = x + attn_x

    ln2_x = self.ln_2(x)
    mlp_x = self.mlp(ln2_x, train)
    x = x + mlp_x

    return x

##2.4 GPTModel

In [None]:
def causal_attention_mask(batch_size, n_dest, n_src, dtype=jnp.float32):
  """Auxilary function to create a causal attention mask.

  Mask the upper half of the dot product matrix in self attention.
  This prevents flow of information from future tokens to current token.
  1's in the lower triangle, counting from the lower right corner.
  """
  # [B, 1, SRC, DST]
  return jnp.tril(jnp.ones((batch_size, 1, n_dest, n_src), dtype=dtype))


class GPTModel(nnx.Module):

  def __init__(self, config):
    super().__init__()

    rngs = nnx.Rngs(0)
    self.wte = nnx.Embed(config.vocab_size, config.num_embd, rngs=rngs)
    self.wpe = nnx.Embed(config.block_size, config.num_embd, rngs=rngs)
    self.dropout = nnx.Dropout(config.embd_pdrop, rngs=rngs)
    # list of attention blocks
    self.h = [
        AttentionBlock(config, rngs=rngs) for _ in range(config.num_layers)
    ]
    # layer norm before output logits
    self.ln_f = nnx.LayerNorm(num_features=config.num_embd, rngs=rngs)
    # predicted logits
    self.ln_logits = nnx.Linear(config.num_embd, config.vocab_size, rngs=rngs)
    # keep reference on config
    self.config = config

  def __call__(self, idx, targets=None, train: bool = True):
    # batch, sequence length
    B, T = idx.shape
    tok_emb = self.wte(idx)  # (B,T,C)
    pos_emb = self.wpe(jnp.arange(T))  # (T,C)
    x = self.dropout(tok_emb + pos_emb, deterministic=not train)
    casual_mask = causal_attention_mask(B, T, T, dtype=jnp.float32)

    for block in self.h:
      x = block(x, casual_mask, train)

    x = self.ln_f(x)
    logits = self.ln_logits(x)

    # if we are given some desired targets also calculate the loss
    loss = None
    if targets is not None:
      loss = optax.losses.softmax_cross_entropy_with_integer_labels(
          logits=logits, labels=targets
      ).mean()

    return logits, loss

  def generate(self, prompt, max_length=100, temperature=1.0, top_k=0):
    input_ids = jnp.array(prompt, dtype=jnp.int32)[None, ...]
    # pad idx up to block_size
    idx_size = input_ids.shape[1]
    pad_len = self.config.block_size - idx_size
    if pad_len > 0:
      input_ids = jnp.concatenate(
          (input_ids, jnp.zeros((1, pad_len), dtype=jnp.int32)), axis=1
      )
    # idx is (B, T) array of indices in the current context
    predict_idx = min(idx_size - 1, self.config.block_size - 1)
    for _ in range(max_length):
      # crop idx to the last block_size tokens
      idx_cond = input_ids[:, -self.config.block_size :]
      # get the predictions
      logits, loss = self(idx_cond, train=False)
      # focus only on the last time step
      logits = logits[0, predict_idx, :] / temperature
      if top_k > 0:
        top_k_logits, top_k_indices = jax.lax.top_k(logits, top_k)
        probs = nnx.softmax(top_k_logits)
        next_token = jax.random.categorical(jax.random.PRNGKey(0), probs)
        next_token = top_k_indices[next_token]
      else:
        probs = nnx.softmax(logits)
        next_token = jax.random.categorical(jax.random.PRNGKey(0), logits)

      if (
          next_token == self.config.vocab_size - 1
          or next_token == tiktoken.get_encoding("gpt2").eot_token
      ):
        break
      # append sampled index to the running sequence
      if predict_idx == self.config.block_size - 1:
        input_ids = jnp.concatenate((input_ids, next_token[None, None]), axis=1)
      else:
        input_ids = input_ids.at[0, predict_idx + 1].set(next_token)
      predict_idx = min(predict_idx + 1, self.config.block_size - 1)

    return input_ids

#3. Initializing and Training

Let's initialize our GPT model, set up the training state, and define the loss function and training step.

# 3. Dataset

Now, let's prepare data. In this tutorial, we'll be using:
1. tiktoken tokenizer to process our text data. Tiktoken is a fast byte-pair encoding (BPE) tokenizer designed for use with OpenAI models. While we've chosen tiktoken for its efficiency, it's worth noting that other tokenizers, such as SentencePiece, could also be used depending on your specific requirements.
2. [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) data as a toy dataset.

## Preparing Train and Eval Datasets

Before feeding our data to the model, we'll divide it into two sets: a training set and an evaluation set.

**Key Point**: Transformer models expect numerical input, so tokenization is a crucial preprocessing step. It's the bridge between human-readable text and the numerical representations that the model operates on.

In [None]:
# Original source: https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare/prepare.py
import os
import numpy as np
import requests
import tiktoken

# download the tiny shakespeare dataset
input_file_dir = '/tmp'
input_file_path = os.path.join(input_file_dir, 'input.txt')
if not os.path.exists(input_file_path):
  data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
  with open(input_file_path, 'w', encoding='utf-8') as f:
    f.write(requests.get(data_url).text)

with open(input_file_path, 'r', encoding='utf-8') as f:
  data = f.read()
n = len(data)
train_data = data[: int(n * 0.9)]
val_data = data[int(n * 0.9) :]

# encode with tiktoken gpt2 bpe
enc = tiktoken.get_encoding('gpt2')
train_ids = enc.encode_ordinary(train_data)
val_ids = enc.encode_ordinary(val_data)
print(f'train has {len(train_ids):,} tokens')
print(f'val has {len(val_ids):,} tokens')

# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(input_file_dir, 'train.bin'))
val_ids.tofile(os.path.join(input_file_dir, 'val.bin'))

# train.bin has 301,966 tokens
# val.bin has 36,059 tokens


def get_batch(split, data_dir, block_size, batch_size):
  # generate a small batch of data of inputs x and targets y
  data = train_data if split == 'train' else val_data
  # We recreate np.memmap every batch to avoid a memory leak, as per
  # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
  if split == 'train':
    data = np.memmap(
        os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r'
    )
  else:
    data = np.memmap(
        os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r'
    )
  ix = np.random.randint(0, len(data) - block_size, (batch_size,))
  x = jnp.stack(
      [jnp.asarray((data[i : i + block_size]).astype(np.int64)) for i in ix]
  )
  y = jnp.stack([
      jnp.asarray((data[i + 1 : i + 1 + block_size]).astype(np.int64))
      for i in ix
  ])

  return x, y

train has 301,966 tokens
val has 36,059 tokens


#4. Initializing and Training

Let's initialize our GPT-2 model, define the loss function and training step.

In [None]:
# define config, using reduced number of heads and layers to speed up tutorial.
config = Config(
    num_embd=768,
    num_heads=8,
    num_layers=8,
    vocab_size=tiktoken.get_encoding("gpt2").n_vocab,
    block_size=256,
    attn_pdrop=0.1,
    resid_pdrop=0.1,
    embd_pdrop=0.1,
)

model = GPTModel(config)


def loss_fn(model, x, y):
  logits, loss = model(x, y)  # call methods directly
  return loss, logits


@nnx.jit  # automatic state management
def train_step(model, optimizer, x, y, metrics):
  (loss, logits), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model, x, y)
  optimizer.update(grads)  # inplace updates
  metrics.update(loss=loss, logits=logits, labels=y)
  return loss, logits


@nnx.jit
def eval_step(model, x, y, metrics):
  loss, logits = loss_fn(x, y)  # call methods directly
  metrics.update(loss=loss, logits=logits, labels=y)

Now, define optimizer, define training loop and start training the model.

In [None]:
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate=1e-4))
metrics = nnx.MultiMetric(
    accuracy=nnx.metrics.Accuracy(),
    loss=nnx.metrics.Average('loss'),
)

# Keep results for plotting
metrics_history = {
    'train_loss': [],
    'train_accuracy': [],
}

# Feel free to experiment with the number of steps.
num_steps = 25000

for step in range(num_steps):
  # will optimize for bigger batch size by subsequent tutorials,
  # using batch size 32 for now.
  x, y = get_batch('train', input_file_dir, config.block_size, 32)

  train_step(model, optimizer, x, y, metrics)
  for metric, value in metrics.compute().items():  # compute metrics
    metrics_history[f'train_{metric}'].append(value)  # record metrics
    if step % 1000 == 0:
      print(f'[train] step[{step}]: {metric}: {value:.3f}')

[train] epoch[0]: accuracy: 0.000
[train] epoch[0]: loss: 11.221
[train] epoch[1000]: accuracy: 0.249
[train] epoch[1000]: loss: 4.808
[train] epoch[2000]: accuracy: 0.304
[train] epoch[2000]: loss: 4.110
[train] epoch[3000]: accuracy: 0.381
[train] epoch[3000]: loss: 3.456
[train] epoch[4000]: accuracy: 0.483
[train] epoch[4000]: loss: 2.825
[train] epoch[5000]: accuracy: 0.567
[train] epoch[5000]: loss: 2.348
[train] epoch[6000]: accuracy: 0.629
[train] epoch[6000]: loss: 2.005
[train] epoch[7000]: accuracy: 0.675
[train] epoch[7000]: loss: 1.750
[train] epoch[8000]: accuracy: 0.710
[train] epoch[8000]: loss: 1.554
[train] epoch[9000]: accuracy: 0.739
[train] epoch[9000]: loss: 1.398
[train] epoch[10000]: accuracy: 0.762
[train] epoch[10000]: loss: 1.272
[train] epoch[11000]: accuracy: 0.781
[train] epoch[11000]: loss: 1.167
[train] epoch[12000]: accuracy: 0.797
[train] epoch[12000]: loss: 1.079
[train] epoch[13000]: accuracy: 0.811
[train] epoch[13000]: loss: 1.004
[train] epoch[140

# 5. Generate text

Let's evaluate the capabilities of our compact model by assessing its capacity to generate text reminiscent of Shakespearean prose.

In [None]:
x, y = get_batch("val", input_file_dir, 125, 1)
print("=========    Prefix   =======")
print(enc.decode(x[0].tolist()[:25]))
print("=========    Generated using top_k=1 =======")
print(enc.decode(GPTModel.generate(model, x[0, :25], top_k=1)[0].tolist()))
print("=========    Generated not using top_k  =======")
print(enc.decode(GPTModel.generate(model, x[0, :25])[0].tolist()))

 power, I would
Have sunk the sea within the earth or ere
It should the good ship so have swallow'd
 power, I would
Have sunk the sea within the earth or ere
It should the good ship so have swallow'd all the kindred of the Capulets lie.
In the mean time, against thou shalt awake,
Shall Romeo by my letters know our drift,
And hither shall he come: and he and I
Will watch thy waking, and that very night
Shall Romeo bear thee hence to Mantua.
And this shall free thee from this present shame;
If no inconstant toy, nor womanish fear,
Abate thy valour in the acting it.

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
 power, I would
Have sunk the sea within the earth or ere
It should the good ship so have swallow'd all swords.

KING RICHARD II:
Why, uncle, what's the matter?

DUKE OF YORK:
O my liege,
Pardon me, if you please; if not, I, pleased
Not to be pardon'd, am content withal.
Seek you to seize and gri

# 6. Conclusion

This tutorial has provided a foundational exploration of the GPT-2 model architecture, implemented using Jax and Flax frameworks. While this implementation serves as an educational starting point, it's important to note that further development would be necessary to create a production-ready language model.
Key Takeaways:

* **Core Transformer Components:** We've successfully implemented the essential elements of a transformer architecture, including self-attention mechanisms and feed-forward networks.
* **Positional Information:** The model incorporates learned positional embeddings, crucial for sequence understanding in transformer models.
* **Training Loop:** A basic training loop has been established, utilizing cross-entropy loss and the Adam optimizer, demonstrating the fundamental steps in model training.
* **Data Processing:** The tutorial showcases the use of the tiktoken tokenizer, illustrating an approach to prepare textual data for input into the language model.

This implementation provides a solid foundation for further exploration and experimentation in the field of language modeling using Jax and Flax. Users are encouraged to build upon this framework, modify the architecture, and delve deeper into advanced concepts in natural language processing and transformer models.