This is a direct translation of [Text generation with a miniature GPT](https://keras.io/examples/generative/text_generation_with_miniature_gpt/) from Keras to JAX/Flax.

Install JAX and Flax

In [None]:
!pip install -U "jax[cuda12]" flax

Collecting flax
  Downloading flax-0.8.5-py3-none-any.whl.metadata (10 kB)
Collecting jax[cuda12]
  Downloading jax-0.4.31-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.31,>=0.4.30 (from jax[cuda12])
  Downloading jaxlib-0.4.31-cp310-cp310-manylinux2014_x86_64.whl.metadata (983 bytes)
Collecting jax-cuda12-plugin<=0.4.31,>=0.4.31 (from jax-cuda12-plugin[with_cuda]<=0.4.31,>=0.4.31; extra == "cuda12"->jax[cuda12])
  Downloading jax_cuda12_plugin-0.4.31-cp310-cp310-manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting jax-cuda12-pjrt==0.4.31 (from jax-cuda12-plugin<=0.4.31,>=0.4.31->jax-cuda12-plugin[with_cuda]<=0.4.31,>=0.4.31; extra == "cuda12"->jax[cuda12])
  Downloading jax_cuda12_pjrt-0.4.31-py3-none-manylinux2014_x86_64.whl.metadata (349 bytes)
Collecting nvidia-cublas-cu12>=12.1.3.1 (from jax-cuda12-plugin[with_cuda]<=0.4.31,>=0.4.31; extra == "cuda12"->jax[cuda12])
  Downloading nvidia_cublas_cu12-12.6.0.22-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting

Get the training data.

In [None]:
!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xf aclImdb_v1.tar.gz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 80.2M  100 80.2M    0     0  13.0M      0  0:00:06  0:00:06 --:--:-- 17.4M


Build and train the model.

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from typing import Any, Callable
import os
import string
import random
import tensorflow as tf
from flax.training import train_state
import keras


def causal_attention_mask(seq_len):
    """
    Generates a causal attention mask for self-attention.
    """
    return jnp.tril(jnp.ones((seq_len, seq_len)))

class TransformerBlock(nn.Module):
    embed_dim: int
    num_heads: int
    ff_dim: int
    rate: float = 0.1

    @nn.compact
    def __call__(self, inputs, training: bool = False):
        input_shape = inputs.shape
        batch_size, seq_len, _ = input_shape

        # Create causal mask
        mask = causal_attention_mask(seq_len)

        # Apply MultiHeadAttention with causal mask
        attention_output = nn.MultiHeadAttention(num_heads=self.num_heads)(
            inputs_q=inputs,
            inputs_kv=inputs,
            mask=mask
        )
        attention_output = nn.Dropout(rate=self.rate)(attention_output, deterministic=not training)
        out1 = nn.LayerNorm(epsilon=1e-6)(inputs + attention_output)

        # Feed-forward network
        ffn_output = nn.Dense(features=self.ff_dim)(out1)
        ffn_output = nn.relu(ffn_output)
        ffn_output = nn.Dense(features=self.embed_dim)(ffn_output)
        ffn_output = nn.Dropout(rate=self.rate)(ffn_output, deterministic=not training)

        return nn.LayerNorm(epsilon=1e-6)(out1 + ffn_output)


class TokenAndPositionEmbedding(nn.Module):
    maxlen: int
    vocab_size: int
    embed_dim: int

    @nn.compact
    def __call__(self, x):
        positions = jnp.arange(0, self.maxlen)[None, :]
        position_embedding = nn.Embed(self.maxlen, self.embed_dim)(positions)
        token_embedding = nn.Embed(int(self.vocab_size), self.embed_dim)(x)
        return token_embedding + position_embedding


class MiniGPT(nn.Module):
    maxlen: int
    vocab_size: int
    embed_dim: int
    num_heads: int
    feed_forward_dim: int

    @nn.compact
    def __call__(self, inputs, training: bool = False):
        embedding_layer = TokenAndPositionEmbedding(
            self.maxlen, self.vocab_size, self.embed_dim
        )
        x = embedding_layer(inputs)
        transformer_block = TransformerBlock(
            self.embed_dim, self.num_heads, self.feed_forward_dim
        )
        x = transformer_block(x, training=training)
        outputs = nn.Dense(features=self.vocab_size)(x)
        return outputs, x


vocab_size = 20000
maxlen = 80
embed_dim = 256
num_heads = 2
feed_forward_dim = 256
batch_size = 640


def create_model():
    return MiniGPT(maxlen, vocab_size, embed_dim, num_heads, feed_forward_dim)


# Data loading and preprocessing
filenames = []
directories = [
    "./aclImdb/train/pos",
    "./aclImdb/train/neg",
    "./aclImdb/test/pos",
    "./aclImdb/test/neg",
]
for dir in directories:
    for f in os.listdir(dir):
        filenames.append(os.path.join(dir, f))

print(f"{len(filenames)} files")

random.shuffle(filenames)
text_ds = tf.data.TextLineDataset(filenames)
text_ds = text_ds.shuffle(buffer_size=256)
text_ds = text_ds.batch(batch_size)


def custom_standardization(input_string):
    lowercased = tf.strings.lower(input_string)
    stripped_html = tf.strings.regex_replace(lowercased, "<br />", " ")
    return tf.strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1")


vectorize_layer = keras.layers.TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size - 1,
    output_mode="int",
    output_sequence_length=maxlen + 1,
)
vectorize_layer.adapt(text_ds)
vocab = vectorize_layer.get_vocabulary()  # To get words back from token indices


def prepare_lm_inputs_labels(text):
    text = tf.expand_dims(text, -1)
    tokenized_sentences = vectorize_layer(text)
    x = tokenized_sentences[:, :-1]
    y = tokenized_sentences[:, 1:]
    return x, y


text_ds = text_ds.map(prepare_lm_inputs_labels)
text_ds = text_ds.prefetch(tf.data.AUTOTUNE)


# JAX doesn't have a direct equivalent to Keras callbacks, so we'll implement the text generation as a separate function
def generate_text(params, max_tokens, start_tokens, index_to_word, top_k=10):
    model = create_model()

    def sample_from(logits):
        logits, indices = jax.lax.top_k(logits, k=top_k)
        logits = jax.nn.softmax(logits)
        return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits)

    def generate_step(start_tokens):
        pad_len = maxlen - len(start_tokens)
        sample_index = len(start_tokens) - 1
        if pad_len < 0:
            x = jnp.array(start_tokens[:maxlen])
            sample_index = maxlen - 1
        elif pad_len > 0:
            x = jnp.array(start_tokens + [0] * pad_len)
        else:
            x = jnp.array(start_tokens)

        x = x[None, :]
        logits, _ = model.apply({"params": params}, x)
        next_token = sample_from(logits[0][sample_index])
        return next_token

    generated = []
    for _ in range(max_tokens):
        next_token = generate_step(start_tokens + generated)
        generated.append(int(next_token))
    print(generated)
    return " ".join([index_to_word[token] for token in start_tokens + generated])


# Training loop
def create_train_state(rng):
    model = create_model()
    params = model.init(rng, jnp.ones((1, maxlen), dtype=jnp.int32))["params"]
    tx = optax.adam(learning_rate=1e-3)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)


@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits, _ = state.apply_fn({"params": params}, batch[0])
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch[1]).mean()
        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


rng = jax.random.PRNGKey(0)
state = create_train_state(rng)

num_epochs = 25
for epoch in range(num_epochs):
    for batch in text_ds:
        batch = (jnp.array(batch[0].numpy()), jnp.array(batch[1].numpy()))
        state, loss = train_step(state, batch)

    print(f"Epoch {epoch + 1}, Loss: {loss}")
    start_prompt = "this movie is"
    start_tokens = [
        vectorize_layer.get_vocabulary().index(word)
        for word in start_prompt.split()
    ]
    generated_text = generate_text(
        state.params, 40, start_tokens, vectorize_layer.get_vocabulary()
    )
    print(f"Generated text:\n{generated_text}\n")

# Final text generation
start_tokens = [
    vectorize_layer.get_vocabulary().index(word) for word in start_prompt.split()
]
generated_text = generate_text(
    state.params, 40, start_tokens, vectorize_layer.get_vocabulary()
)
print(f"Final generated text:\n{generated_text}")



50000 files
Epoch 1, Loss: 5.563542366027832
[28, 2, 71, 9, 28, 47, 17, 10, 3, 2, 71, 4, 12, 58, 33, 34, 7, 2, 71, 7, 2, 96, 22, 9, 28, 8, 114, 3, 3, 10, 3, 10, 16, 34, 7, 2, 122, 3, 3, 10]
Generated text:
this movie is not the story is not just as it . the story , i can be one of the story of the first film is not to watch . . it . it was one of the plot . . it

Epoch 2, Loss: 4.921992778778076
[2, 96, 18, 4, 21, 12, 218, 10, 15, 47, 95, 12, 156, 259, 4, 21, 10, 3, 2, 1, 27, 16, 43, 89, 4, 10, 9, 28, 2, 1, 27, 6, 10, 9, 34, 3, 10, 9, 34, 7]
Generated text:
this movie is the first movie , but i saw it 's just because i 'm sure , but it . the [UNK] " was so bad , it is not the [UNK] " and it is one . it is one of

Epoch 3, Loss: 4.750950336456299
[2, 18, 14, 52, 8, 33, 34, 7, 2, 257, 7, 2, 128, 18, 3, 2, 96, 12, 16, 28, 70, 57, 17, 5, 181, 216, 12, 245, 259, 53, 10, 9, 28, 68, 2, 96, 4, 12, 156, 5]
Generated text:
this movie is the movie that has to be one of the worst of the best movie