This is a direct translation of the [Text generation with a miniature GPT](https://keras.io/examples/generative/text_generation_with_miniature_gpt/) tutorial from Keras to JAX. It aims to teach developers who are familiar with Keras/Tensorflow to pick up JAX/Flax quickly.

This notebook demonstrates how to use [Flax NNX](https://flax.readthedocs.io/en/latest/nnx/index.html) to implement an autoregressive language model using a miniaturized version of the GPT model. The model uses only a single transformer block and is easy to understand.

It is assumed that Colab T4 is used to run this notebook. Adjust the batch size if another hardware is used.

## Setup

Install JAX and Flax first.

In [None]:
!pip install jax-ai-stack
!pip install -U "jax[cuda12]"

Collecting jax-ai-stack
  Downloading jax_ai_stack-2024.10.1-py3-none-any.whl.metadata (16 kB)
Collecting flax==0.9.0 (from jax-ai-stack)
  Downloading flax-0.9.0-py3-none-any.whl.metadata (11 kB)
Collecting ml-dtypes==0.4.0 (from jax-ai-stack)
  Downloading ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting orbax-export==0.0.5 (from jax-ai-stack)
  Downloading orbax_export-0.0.5-py3-none-any.whl.metadata (1.9 kB)
Collecting dataclasses-json (from orbax-export==0.0.5->jax-ai-stack)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting jaxtyping (from orbax-export==0.0.5->jax-ai-stack)
  Downloading jaxtyping-0.2.34-py3-none-any.whl.metadata (6.4 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json->orbax-export==0.0.5->jax-ai-stack)
  Downloading marshmallow-3.22.0-py3-none-any.whl.metadata (7.2 kB)
Collecting typing-inspect<1,>=0.4.0 (from dataclasses-json->orbax-export==0.0.5->jax-ai-stack)
  

Grab the IMDB review data as the training data.

In [None]:
!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xf aclImdb_v1.tar.gz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 80.2M  100 80.2M    0     0  1447k      0  0:00:56  0:00:56 --:--:-- 1242k


Take care of the imports.

In [None]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax
from typing import Any
import os
import string
import random
from collections import Counter

## Build the model

Next, defne the model architecture, which is a decoder-only transformer model. The model is similar to the GPT model series but it's smaller in size with only one transformer block, which is why we are calling it miniGPT. The model has several key components stacked up together, so let's go over the them one by one.

The key component is the `TransformerBlock`, which uses the multi-head attention mechanism as described in the famous [Attention Is All You Need](https://arxiv.org/abs/1706.03762) paper. Please get familiar with the paper if you are not already because we are going to implement some of the details below.

The model is auto-regressive, so it can only attend to previous tokens. So we use [`jax.numpy.tril`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.tril.html) to create the attention mask, and pass it in the `nnx.MultiHeadAttention` layer. The other layers follow the practice of the decoder layer in the paper.

All layers (except `Dropout`) has a `rngs` parameter, which is the [random generator key](https://jax.readthedocs.io/en/latest/jax.random.html#prng-keys) that can help you reproduce results and debug issues.

In [None]:
def causal_attention_mask(seq_len):
    return jnp.tril(jnp.ones((seq_len, seq_len)))

class TransformerBlock(nnx.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, *, rngs: nnx.Rngs, rate: float = 0.1):
        self.mha = nnx.MultiHeadAttention(num_heads=num_heads, in_features=embed_dim, rngs=rngs)
        self.dropout1 = nnx.Dropout(rate=rate)
        self.layer_norm1 = nnx.LayerNorm(epsilon=1e-6, num_features=embed_dim, rngs=rngs)
        self.linear1 = nnx.Linear(in_features=embed_dim, out_features=ff_dim, rngs=rngs)
        self.linear2 = nnx.Linear(in_features=ff_dim, out_features=embed_dim, rngs=rngs)
        self.dropout2 = nnx.Dropout(rate=rate)
        self.layer_norm2 = nnx.LayerNorm(epsilon=1e-6, num_features=embed_dim, rngs=rngs)


    def __call__(self, inputs, training: bool = False):
        input_shape = inputs.shape
        batch_size, seq_len, _ = input_shape

        # Create causal mask
        mask = causal_attention_mask(seq_len)

        # Apply MultiHeadAttention with causal mask
        attention_output = self.mha(
            inputs_q=inputs,
            mask=mask,
            decode=False
        )
        attention_output = self.dropout1(attention_output, deterministic=not training)
        out1 = self.layer_norm1(inputs + attention_output)

        # Feed-forward network
        ffn_output = self.linear1(out1)
        ffn_output = nnx.relu(ffn_output)
        ffn_output = self.linear2(ffn_output)
        ffn_output = self.dropout2(ffn_output, deterministic=not training)

        return self.layer_norm2(out1 + ffn_output)

Since the model input is just text tokens, we need to convert them into embeddings. We use two kinds of embeddings: token embedding and position embeddings, both of which are learned by the model and are added up. Note that this is slightly different from the origianl paper, which uses static, instead of learned, positional embeddings.

In [None]:
class TokenAndPositionEmbedding(nnx.Module):

    def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, *, rngs: nnx.Rngs):
        self.token_emb = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)
        self.pos_emb = nnx.Embed(num_embeddings=maxlen, features=embed_dim, rngs=rngs)

    def __call__(self, x):
        positions = jnp.arange(0, x.shape[1])[None, :]
        position_embedding = self.pos_emb(positions)
        token_embedding = self.token_emb(x)
        return token_embedding + position_embedding

Now we can put everything together to build our miniGPT model. We convert the tokens into embeddings, add a single `TransformerBlock` and finally use a linear projection layer for output.

In [None]:
class MiniGPT(nnx.Module):
    def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, num_heads: int, feed_forward_dim: int, *, rngs: nnx.Rngs):
        self.embedding_layer = TokenAndPositionEmbedding(
                    maxlen, vocab_size, embed_dim, rngs=rngs
                )
        self.transformer_block = TransformerBlock(
            embed_dim, num_heads, feed_forward_dim, rngs=rngs
        )
        self.output_layer = nnx.Linear(in_features=embed_dim, out_features=vocab_size, rngs=rngs)

    def __call__(self, inputs, training: bool = False):
        x = self.embedding_layer(inputs)
        x = self.transformer_block(x, training=training)
        outputs = self.output_layer(x)
        return outputs

def create_model(rngs):
    return MiniGPT(maxlen, vocab_size, embed_dim, num_heads, feed_forward_dim, rngs=rngs)

Set some hyperparameters.

In [None]:
vocab_size = 20000
maxlen = 80
embed_dim = 256
num_heads = 2
feed_forward_dim = 256
batch_size = 512 # for Colab T4 GPU

## Prepare data

Data loading and preprocessing. To map the words and symbols to indices, we need to tokenize them first. For simplicity, we are using a vey simple tokenization scheme:
* The `custom_standardization` function does some preprocessing by removing undesirable symbols and adding space before punctuations, so that punctuations can be treated as tokens like words
* The `build_vocab` function builds our own vocaulary according to the `vocab_size` defined above
* The `tokenize` function does the tokenization
* We also batch the data


In [None]:
# Data loading and preprocessing
filenames = []
directories = [
    "./aclImdb/train/pos",
    "./aclImdb/train/neg",
    "./aclImdb/test/pos",
    "./aclImdb/test/neg",
]
for dir in directories:
    for f in os.listdir(dir):
        filenames.append(os.path.join(dir, f))

print(f"{len(filenames)} files")

random.shuffle(filenames)

# Custom text processing: add space before and after punctuations for tokenization
def custom_standardization(input_string):
    lowercased = input_string.lower()
    stripped_html = lowercased.replace("<br />", " ")
    return ''.join([' ' + char + ' ' if char in string.punctuation else char for char in stripped_html]).strip()

def build_vocab(texts, vocab_size):
    all_words = ' '.join(texts).split()
    word_counts = Counter(all_words)
    vocab = ['<PAD>', '<UNK>'] + [word for word, _ in word_counts.most_common(vocab_size - 2)]
    word_to_index = {word: index for index, word in enumerate(vocab)}
    return vocab, word_to_index

def tokenize(text, word_to_index, maxlen):
    words = text.split()
    tokens = [word_to_index.get(word, word_to_index['<UNK>']) for word in words]
    if len(tokens) < maxlen:
        tokens = tokens + [word_to_index['<PAD>']] * (maxlen - len(tokens))
    else:
        tokens = tokens[:maxlen]
    return tokens

def load_and_preprocess_data(filenames, batch_size, vocab_size, maxlen):
    data = []
    for filename in filenames:
        with open(filename, 'r', encoding='utf-8') as file:
            text = file.read()
            processed_text = custom_standardization(text)
            data.append(processed_text)

    vocab, word_to_index = build_vocab(data, vocab_size)
    tokenized_data = [tokenize(text, word_to_index, maxlen) for text in data]

    # Batch the data
    batched_data = [tokenized_data[i:i+batch_size] for i in range(0, len(tokenized_data), batch_size)]

    return batched_data, vocab, word_to_index

text_ds, vocab, word_to_index = load_and_preprocess_data(filenames, batch_size, vocab_size, maxlen)

50000 files


## Train the model

Define a helper function for generating text given a model and prompt.

In [None]:
def generate_text(model: MiniGPT, max_tokens: int, start_tokens: [int], index_to_word: [str], top_k=10):
    def sample_from(logits):
        logits, indices = jax.lax.top_k(logits, k=top_k)
        logits = nnx.softmax(logits)
        return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits)

    def generate_step(start_tokens):
        pad_len = maxlen - len(start_tokens)
        sample_index = len(start_tokens) - 1
        if pad_len < 0:
            x = jnp.array(start_tokens[:maxlen])
            sample_index = maxlen - 1
        elif pad_len > 0:
            x = jnp.array(start_tokens + [0] * pad_len)
        else:
            x = jnp.array(start_tokens)

        x = x[None, :]
        logits = model(x)
        next_token = sample_from(logits[0][sample_index])
        return next_token

    generated = []
    for _ in range(max_tokens):
        next_token = generate_step(start_tokens + generated)
        generated.append(int(next_token))
    return " ".join([index_to_word[token] for token in start_tokens + generated])

Define the loss function and training step function. The `train_step` is usually the most expensive function since it needs to compute the gradients and update the model parameters. We can use [JAX JIT compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compiling-a-function) to accelerate the execution of this function, but since we using NNX here, we annoate it with `@nnx.jit` instead of `@jax.jit`. JIT-compiled functions sometimes are tricky to debug; please refer to our [debugging documentation](https://jax.readthedocs.io/en/latest/debugging/print_breakpoint.html#compiled-prints-and-breakpoints) for help if you encouter such a situation.

In [None]:
def loss_fn(model, batch):
    logits = model(batch[0])
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
    return loss, logits

@nnx.jit
def train_step(model: MiniGPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, batch)
    metrics.update(loss=loss, logits=logits, lables=batch[1])
    optimizer.update(grads)

Start training.

In [None]:
model = create_model(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
metrics = nnx.MultiMetric(
  loss=nnx.metrics.Average('loss'),
  # You can add additional metrics for tracking
)
rng = jax.random.PRNGKey(0)

start_prompt = "this movie is"
start_tokens = [word_to_index.get(word, word_to_index['<UNK>']) for word in start_prompt.split()]
index_to_word = {i: word for word, i in word_to_index.items()}
generated_text = generate_text(
    model, 40, start_tokens, index_to_word
)
print(f"Initial generated text:\n{generated_text}\n")

num_epochs = 25
metrics_history = {
  'train_loss': [],
}

for epoch in range(num_epochs):
    for batch in text_ds:
        input_batch = jnp.array(batch)
        target_batch = jnp.array([tokens[1:] + [word_to_index['<PAD>']] for tokens in batch])
        train_step(model, optimizer, metrics, (input_batch, target_batch))

    for metric, value in metrics.compute().items():  # compute metrics
      metrics_history[f'train_{metric}'].append(value)  # record metrics
    metrics.reset()

    print(f"Epoch {epoch + 1}, Loss: {metrics_history['train_loss'][-1]}")
    start_prompt = "this movie is"
    start_tokens = [word_to_index.get(word, word_to_index['<UNK>']) for word in start_prompt.split()]
    generated_text = generate_text(
        model, 40, start_tokens, index_to_word
    )
    print(f"Generated text:\n{generated_text}\n")

# Final text generation
start_tokens = [word_to_index.get(word, word_to_index['<UNK>']) for word in start_prompt.split()]
generated_text = generate_text(
    model, 40, start_tokens, index_to_word
)
print(f"Final generated text:\n{generated_text}")

Initial generated text:
this movie is seem shameless celebrity clarity claudio xmas reunion drafted weed capability distrust perlman impaled nominal hesitate inside colleges wage supervision kerry thrillers celeste activists supporter partisan filled rookie sneak bona erase urban rowlands damp daria islanders english overshadowed overheard jazz cheezy

Epoch 1, Loss: 6.119418144226074
Generated text:
this movie is not that the story , the best , and it was not sure i think the first , and it ' t really , it is one of the first time , i have seen the story . it is

Epoch 2, Loss: 5.050351619720459
Generated text:
this movie is the movie that it ' s best , and the acting was very well - and i thought that the plot . the acting was just about a great acting is very bad movie was so much about a great

Epoch 3, Loss: 4.757399082183838
Generated text:
this movie is not only to the movie , it is the best , i have seen in a film , i was not to watch this film , and the only reason to watch th

As you can see, the model goes from generating completely random words at the beginning to generating sentences that look like sensible movie reviews at the end of the training. Of course the reviews are far from perfect because this model is really small and fundamentally lacks strong intelligence like modern LLMs. In our next tutorial, we are going to scale the model up and make it smarter.

## Save the model

We use [Orbax](https://github.com/google/orbax) to save the model checkpoint.

In [None]:
import orbax.checkpoint as orbax

state = nnx.state(model)

checkpointer = orbax.PyTreeCheckpointer()
checkpointer.save('/content/save', state)

# Make sure the files are there
!ls /content/save/

_CHECKPOINT_METADATA  d  manifest.ocdbt  _METADATA  ocdbt.process_0  _sharding
