# Installation
Install JAX with GPU CUDA support (in this case, CUDA 11 and cuDNN 8.6).


In [None]:
!pip install --upgrade jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install --upgrade tensorflow tensorflow-probability
!pip install datasets

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[

TPU-equivalent installation code

In [None]:
!pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install --upgrade tensorflow tensorflow-probability
!pip install datasets

# Hyperparameters & Data Preparation

In [None]:
# ------------------
# Hyperparameters
# ------------------
batch_size = 16        # Number of sequences processed in parallel
block_size = 32        # Maximum context length (sequence length)
max_iters = 5000       # Total training iterations
eval_interval = 100    # How frequently to evaluate the model
learning_rate = 1e-3   # Optimizer learning rate
eval_iters = 200       # Number of batches to average loss over for evaluation
n_embd = 64            # Embedding dimension (hidden size)
n_head = 4             # Number of attention heads
n_layer = 4            # Number of transformer blocks
dropout = 0.0          # Dropout rate (set to 0 here)

# ------------------
# Data Preparation
# ------------------
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Build a sorted list of unique characters in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)  # Vocabulary size

# Create mappings from characters to integers and vice-versa
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

def encode(s: str):
    return [stoi[c] for c in s]

def decode(lst):
    return ''.join([itos[i] for i in lst])

# Convert the entire text to an array of integer indices
data = np.array(encode(text), dtype=np.int32)

# Split data into training (90%) and validation (10%) sets
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split: str, key: jax.random.PRNGKey):
    """
    Generate a small batch of data (input and target sequences).

    For each batch:
      - Randomly pick starting indices within the dataset.
      - Create an input sequence of length 'block_size'.
      - Create a target sequence shifted by one character.
    """
    data_source = train_data if split == 'train' else val_data
    # Random starting indices for each sequence in the batch
    ix = jax.random.randint(key, (batch_size,), minval=0, maxval=len(data_source) - block_size)

    def grab(idx):
        # x: input sequence; y: target sequence (next characters)
        x = data_source[idx: idx + block_size]
        y = data_source[idx + 1: idx + block_size + 1]
        return x, y

    x_list, y_list = [], []
    for i in ix:
        x, y = grab(i)
        x_list.append(x)
        y_list.append(y)

    # Stack individual sequences into batch arrays
    x_out = np.stack(x_list, axis=0)
    y_out = np.stack(y_list, axis=0)
    return x_out, y_out


# Model Definition
Single Attention Head

In [None]:
class Head(nn.Module):
    head_size: int

    @nn.compact
    def __call__(self, x, deterministic: bool = True):
        """
        Implements one head of self-attention.
        x: (B, T, C) where B=batch size, T=sequence length, C=channels.
        """
        B, T, C = x.shape

        # Linear projections for key, query, and value
        k = nn.Dense(self.head_size, use_bias=False, name='key')(x)    # (B, T, head_size)
        q = nn.Dense(self.head_size, use_bias=False, name='query')(x)    # (B, T, head_size)
        v = nn.Dense(self.head_size, use_bias=False, name='value')(x)    # (B, T, head_size)

        # Scaled dot-product attention calculation:
        scale = self.head_size ** -0.5
        wei = jnp.einsum('bth,bsh->bts', q, k) * scale

        # Create a causal mask so that each token can only attend to previous tokens
        mask = jnp.tril(jnp.ones((T, T), dtype=jnp.float32))
        neg_inf = -1e10  # Large negative number for masking
        wei = jnp.where(mask == 0, neg_inf, wei)

        # Normalize scores using softmax
        wei = nn.softmax(wei, axis=-1)

        # Optionally apply dropout to attention weights
        if dropout > 0 and not deterministic:
            wei = nn.Dropout(rate=dropout)(wei, deterministic=deterministic)

        # Compute weighted sum of values based on attention weights
        out = jnp.einsum('bts,bsh->bth', wei, v)
        return out


Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    num_heads: int

    @nn.compact
    def __call__(self, x, deterministic: bool = True):
        """
        Implements multiple self-attention heads in parallel.
        """
        B, T, C = x.shape
        head_size = C // self.num_heads

        # Process each head separately
        heads_out = []
        for i in range(self.num_heads):
            h = Head(head_size, name=f'head_{i}')(x, deterministic=deterministic)
            heads_out.append(h)

        # Concatenate outputs from all heads along the channel dimension
        out = jnp.concatenate(heads_out, axis=-1)

        # Final linear projection to mix the head outputs
        out = nn.Dense(C)(out)

        if dropout > 0 and not deterministic:
            out = nn.Dropout(rate=dropout)(out, deterministic=deterministic)

        return out


Feed-Forward Network

In [None]:
class FeedForward(nn.Module):
    n_embd: int

    @nn.compact
    def __call__(self, x, deterministic: bool = True):
        """
        A simple MLP with one hidden layer: Linear -> ReLU -> Linear.
        """
        hidden_size = 4 * self.n_embd
        x = nn.Dense(hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(self.n_embd)(x)
        if dropout > 0 and not deterministic:
            x = nn.Dropout(rate=dropout)(x, deterministic=deterministic)
        return x


Transformer Block

In [None]:
class Block(nn.Module):
    n_embd: int
    n_head: int

    @nn.compact
    def __call__(self, x, deterministic: bool = True):
        """
        A single transformer block combining self-attention and a feed-forward network.
        """
        # Pre-attention Layer Normalization
        x_ln = nn.LayerNorm()(x)
        # Self-attention sub-layer
        x_attn = MultiHeadAttention(self.n_head)(x_ln, deterministic=deterministic)
        x = x + x_attn  # Residual connection

        # Pre-FFN Layer Normalization
        x_ln = nn.LayerNorm()(x)
        # Feed-forward sub-layer
        x_ffwd = FeedForward(self.n_embd)(x_ln, deterministic=deterministic)
        x = x + x_ffwd  # Residual connection
        return x


Bigram Language Model

In [None]:
class BigramLanguageModel(nn.Module):
    """
    A language model that uses:
      - Token embeddings
      - Positional embeddings
      - Multiple transformer blocks
      - Final normalization and projection to logits
    """
    vocab_size: int
    n_embd: int
    n_head: int
    n_layer: int
    block_size: int

    @nn.compact
    def __call__(self, idx, targets=None, deterministic: bool = True):
        """
        Forward pass:
          - idx: (B, T) sequence of token indices.
          - targets: (B, T) optional target indices for loss computation.
        """
        B, T = idx.shape

        # Token embeddings
        token_emb = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embd)(idx)

        # Positional embeddings (learnable)
        pos_emb = self.param('pos_emb', nn.initializers.normal(stddev=0.02),
                             (self.block_size, self.n_embd))
        # Combine token and positional embeddings
        x = token_emb + pos_emb[:T]

        # Pass through multiple transformer blocks
        for i in range(self.n_layer):
            x = Block(self.n_embd, self.n_head, name=f'block_{i}')(x, deterministic=deterministic)

        # Final layer normalization
        x = nn.LayerNorm()(x)

        # Project to vocabulary dimension to produce logits for each token
        logits = nn.Dense(self.vocab_size)(x)

        loss = None
        if targets is not None:
            # Reshape logits and targets for cross-entropy loss computation
            logits_reshaped = logits.reshape((B * T, self.vocab_size))
            targets_reshaped = targets.reshape((B * T,))
            loss = optax.softmax_cross_entropy_with_integer_labels(logits_reshaped, targets_reshaped).mean()

        return logits, loss

    def generate(self, variables, idx, max_new_tokens, rng=None):
        """
        Autoregressive text generation.
          - Iteratively generates tokens by appending predictions to the input sequence.
        """
        for _ in range(max_new_tokens):
            # Use only the last block_size tokens as context
            idx_cond = idx[:, -self.block_size:]
            # Forward pass (inference mode)
            logits, _ = self.apply(
                variables,
                idx_cond,
                targets=None,
                deterministic=True,
                rngs={'dropout': rng} if rng is not None else None
            )
            # Consider only the last time step's logits
            logits_last = logits[:, -1, :]
            probs = nn.softmax(logits_last, axis=-1)
            # Sample the next token (or take argmax if rng is None)
            next_token = jax.random.categorical(rng, jnp.log(probs))[:, None] if rng is not None else jnp.argmax(probs, axis=-1)[:, None]
            # Append the predicted token to the sequence
            idx = jnp.concatenate([idx, next_token], axis=1)
        return idx


# Training Utilities

In [None]:
def create_train_state(rng, model: BigramLanguageModel):
    """
    Initializes model parameters and creates a train state that includes
    both the model parameters and the optimizer state.
    """
    # Create a dummy input for initialization
    x_dummy = jnp.zeros((batch_size, block_size), dtype=jnp.int32)
    variables = model.init(rng, x_dummy, targets=None, deterministic=True)
    params = variables['params']

    # Create an optimizer using AdamW
    tx = optax.adamw(learning_rate)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx
    ), variables

@jax.jit
def train_step(state, variables, x, y, rng):
    """
    Executes a single training step:
      - Computes the loss.
      - Computes gradients via backpropagation.
      - Updates the model parameters.
    """
    def loss_fn(params):
        logits, loss = state.apply_fn(
            {'params': params},
            x,
            targets=y,
            deterministic=False,
            rngs={'dropout': rng}
        )
        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

def estimate_loss(state, model, variables, rng):
    """
    Computes the average loss over a number of batches for both training and validation sets.
    """
    losses = {}
    for split in ['train', 'val']:
        avg_loss = 0.0
        key = rng
        for _ in range(eval_iters):
            key, subkey = jax.random.split(key)
            xb, yb = get_batch(split, subkey)
            xb = jnp.array(xb, dtype=jnp.int32)
            yb = jnp.array(yb, dtype=jnp.int32)

            # Forward pass in evaluation mode (deterministic)
            _, loss = model.apply(
                variables,
                xb,
                targets=yb,
                deterministic=True
            )
            avg_loss += loss.item()

        avg_loss /= eval_iters
        losses[split] = avg_loss

    return losses


# Main Loop

In [None]:
def main():
    # Initialize the random key for reproducibility
    main_key = jax.random.PRNGKey(1337)

    # Create the model instance with specified hyperparameters
    model = BigramLanguageModel(
        vocab_size=vocab_size,
        n_embd=n_embd,
        n_head=n_head,
        n_layer=n_layer,
        block_size=block_size
    )

    # Initialize the training state and model variables
    main_key, subkey = jax.random.split(main_key)
    state, variables = create_train_state(subkey, model)

    # Training loop
    for iter_i in range(max_iters):
        # Evaluate and print loss at regular intervals
        if iter_i % eval_interval == 0 or iter_i == max_iters - 1:
            main_key, sub_eval_key = jax.random.split(main_key)
            losses = estimate_loss(state, model, variables, sub_eval_key)
            print(f"step {iter_i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        # Generate a training batch
        main_key, subkey = jax.random.split(main_key)
        xb, yb = get_batch('train', subkey)
        xb = jnp.array(xb, dtype=jnp.int32)
        yb = jnp.array(yb, dtype=jnp.int32)

        # Perform a training step
        main_key, drop_key = jax.random.split(main_key)
        state, loss_val = train_step(state, variables, xb, yb, drop_key)

    # Generate text from the trained model
    context = jnp.zeros((1, 1), dtype=jnp.int32)  # Start with an empty context
    generated = model.generate(variables, context, max_new_tokens=200, rng=main_key)
    print(decode(np.array(generated[0])))

if __name__ == "__main__":
    main()


# Full Code

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

from flax import linen as nn
from flax.training import train_state
import optax

# ------------------
# Hyperparameters
# ------------------
batch_size = 16        # how many independent sequences will we process in parallel
block_size = 32        # maximum context length
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0

# Pick a platform (e.g. 'cpu', 'gpu', 'tpu')
# JAX will automatically pick the best available by default,
# but you can manually set via `jax.config.update('jax_platform_name', 'cpu')`
# or by using environment variables.
# We'll just rely on defaults here.

# ------------------
# Data Preparation
# ------------------
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

def encode(s: str):
    return [stoi[c] for c in s]

def decode(lst):
    return ''.join([itos[i] for i in lst])

data = np.array(encode(text), dtype=np.int32)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split: str, key: jax.random.PRNGKey):
    """
    Generate a small batch of data of inputs x and targets y.
    """
    data_source = train_data if split == 'train' else val_data
    # Randomly choose batch_size starting indices
    ix = jax.random.randint(key, (batch_size,), minval=0, maxval=len(data_source) - block_size)

    def grab(idx):
        x = data_source[idx: idx + block_size]
        y = data_source[idx + 1: idx + block_size + 1]
        return x, y

    x_list, y_list = [], []
    for i in ix:
        x, y = grab(i)
        x_list.append(x)
        y_list.append(y)

    x_out = np.stack(x_list, axis=0)
    y_out = np.stack(y_list, axis=0)
    return x_out, y_out

# ------------------
# Model Definition
# ------------------

class Head(nn.Module):
    head_size: int

    @nn.compact
    def __call__(self, x, deterministic: bool = True):
        """
        One head of self-attention.
        x: (B, T, C)
        """
        B, T, C = x.shape

        # Learnable linear projections
        k = nn.Dense(self.head_size, use_bias=False, name='key')(x)    # (B, T, head_size)
        q = nn.Dense(self.head_size, use_bias=False, name='query')(x)  # (B, T, head_size)
        v = nn.Dense(self.head_size, use_bias=False, name='value')(x)  # (B, T, head_size)

        # Compute attention scores
        # scaled dot-product: q @ k^T
        # wei shape: (B, T, T)
        scale = self.head_size ** -0.5
        wei = jnp.einsum('bth,bsh->bts', q, k) * scale

        # Create a causal mask (triangular)
        mask = jnp.tril(jnp.ones((T, T), dtype=jnp.float32))
        # Convert [0,1] mask to [-inf, 0] for adding
        # A typical approach is to add a large negative number where mask=0
        # so that the softmax becomes zero there.
        neg_inf = -1e10
        # shape: (T, T) -> broadcast to (B, T, T)
        wei = jnp.where(mask == 0, neg_inf, wei)

        # softmax
        wei = nn.softmax(wei, axis=-1)  # (B, T, T)

        if dropout > 0 and not deterministic:
            wei = nn.Dropout(rate=dropout)(wei, deterministic=deterministic)

        # Weighted sum of values
        out = jnp.einsum('bts,bsh->bth', wei, v)  # (B, T, head_size)
        return out


class MultiHeadAttention(nn.Module):
    num_heads: int

    @nn.compact
    def __call__(self, x, deterministic: bool = True):
        """
        Multiple heads of self-attention in parallel.
        x: (B, T, C)
        """
        B, T, C = x.shape
        head_size = C // self.num_heads

        # Apply each head
        heads_out = []
        for i in range(self.num_heads):
            h = Head(head_size, name=f'head_{i}')(x, deterministic=deterministic)
            heads_out.append(h)

        # Concatenate along channel dimension
        out = jnp.concatenate(heads_out, axis=-1)  # (B, T, C)

        # Final linear projection
        out = nn.Dense(C)(out)

        if dropout > 0 and not deterministic:
            out = nn.Dropout(rate=dropout)(out, deterministic=deterministic)

        return out


class FeedForward(nn.Module):
    n_embd: int

    @nn.compact
    def __call__(self, x, deterministic: bool = True):
        """
        A simple MLP: Linear -> ReLU -> Linear
        """
        hidden_size = 4 * self.n_embd
        x = nn.Dense(hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(self.n_embd)(x)
        if dropout > 0 and not deterministic:
            x = nn.Dropout(rate=dropout)(x, deterministic=deterministic)
        return x


class Block(nn.Module):
    n_embd: int
    n_head: int

    @nn.compact
    def __call__(self, x, deterministic: bool = True):
        """
        Transformer block: self-attention + feedforward
        """
        # pre LN
        x_ln = nn.LayerNorm()(x)
        x_attn = MultiHeadAttention(self.n_head)(x_ln, deterministic=deterministic)
        x = x + x_attn

        # post-attn LN
        x_ln = nn.LayerNorm()(x)
        x_ffwd = FeedForward(self.n_embd)(x_ln, deterministic=deterministic)
        x = x + x_ffwd
        return x


class BigramLanguageModel(nn.Module):
    """
    Main language model:
    - token embedding
    - position embedding
    - N x Transformer blocks
    - final layer norm
    - linear head
    """
    vocab_size: int
    n_embd: int
    n_head: int
    n_layer: int
    block_size: int

    @nn.compact
    def __call__(self, idx, targets=None, deterministic: bool = True):
        """
        idx: (B, T) token indices
        targets: (B, T) optional, for computing cross-entropy
        """
        B, T = idx.shape

        # Token embeddings + position embeddings
        token_emb = nn.Embed(
            num_embeddings=self.vocab_size,
            features=self.n_embd
        )(idx)  # (B, T, C)

        pos_emb = self.param('pos_emb', nn.initializers.normal(stddev=0.02),
                             (self.block_size, self.n_embd))

        # Add positional embeddings
        # shape of pos_emb is (block_size, C).
        # We only use first T positions if T < block_size.
        x = token_emb + pos_emb[:T]

        # Transformer blocks
        for i in range(self.n_layer):
            x = Block(self.n_embd, self.n_head, name=f'block_{i}')(x, deterministic=deterministic)

        # Final layer norm
        x = nn.LayerNorm()(x)

        # Linear head to get logits
        logits = nn.Dense(self.vocab_size)(x)  # (B, T, vocab_size)

        loss = None
        if targets is not None:
            # Flatten
            logits_reshaped = logits.reshape((B * T, self.vocab_size))
            targets_reshaped = targets.reshape((B * T,))
            # Cross entropy
            loss = optax.softmax_cross_entropy_with_integer_labels(
                logits_reshaped, targets_reshaped
            ).mean()

        return logits, loss

    def generate(self, variables, idx, max_new_tokens, rng=None):
        """
        Autoregressive generation.
        idx: (B, T)
        """
        for _ in range(max_new_tokens):
            # Crop context to last block_size
            idx_cond = idx[:, -self.block_size:]

            # Forward pass (deterministic)
            logits, _ = self.apply(
                variables,
                idx_cond,
                targets=None,
                deterministic=True,
                rngs={'dropout': rng} if rng is not None else None
            )

            # Focus on last time step
            logits_last = logits[:, -1, :]  # (B, vocab_size)
            probs = nn.softmax(logits_last, axis=-1)  # (B, vocab_size)

            # Sample from distribution
            next_token = jax.random.categorical(rng, jnp.log(probs))[:, None] if rng is not None \
                         else jnp.argmax(probs, axis=-1)[:, None]

            # Append to running sequence
            idx = jnp.concatenate([idx, next_token], axis=1)
        return idx


# ------------------
# Training Utilities
# ------------------

def create_train_state(rng, model: BigramLanguageModel):
    """
    Initialize model parameters and create a train state.
    """
    # Dummy inputs for initialization
    x_dummy = jnp.zeros((batch_size, block_size), dtype=jnp.int32)

    variables = model.init(rng, x_dummy, targets=None, deterministic=True)
    params = variables['params']

    # Create an optimizer
    tx = optax.adamw(learning_rate)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx
    ), variables


@jax.jit
def train_step(state, variables, x, y, rng):
    """
    Single training step: forward pass, compute loss, backprop, update.
    """
    def loss_fn(params):
        # We split the variables dict, but reuse batch stats if needed, etc.
        # In this code we only have 'params' because we aren't using batch_norm or similar.
        logits, loss = state.apply_fn(
            {'params': params},
            x,
            targets=y,
            deterministic=False,
            rngs={'dropout': rng}
        )
        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    # Update parameters
    state = state.apply_gradients(grads=grads)
    return state, loss


def estimate_loss(state, model, variables, rng):
    """
    Evaluate the average loss over `eval_iters` for both train and val.
    """
    losses = {}
    for split in ['train', 'val']:
        avg_loss = 0.0
        key = rng
        for _ in range(eval_iters):
            key, subkey = jax.random.split(key)
            xb, yb = get_batch(split, subkey)
            xb = jnp.array(xb, dtype=jnp.int32)
            yb = jnp.array(yb, dtype=jnp.int32)

            # Forward pass in eval mode
            _, loss = model.apply(
                variables,
                xb,
                targets=yb,
                deterministic=True
            )
            avg_loss += loss.item()

        avg_loss /= eval_iters
        losses[split] = avg_loss

    return losses


def main():
    # Seed
    main_key = jax.random.PRNGKey(1337)

    # Create the model
    model = BigramLanguageModel(
        vocab_size=vocab_size,
        n_embd=n_embd,
        n_head=n_head,
        n_layer=n_layer,
        block_size=block_size
    )

    # Initialize train state
    main_key, subkey = jax.random.split(main_key)
    state, variables = create_train_state(subkey, model)

    # Training loop
    for iter_i in range(max_iters):
        # Evaluate once in a while
        if iter_i % eval_interval == 0 or iter_i == max_iters - 1:
            main_key, sub_eval_key = jax.random.split(main_key)
            losses = estimate_loss(state, model, variables, sub_eval_key)
            print(f"step {iter_i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        # Get batch
        main_key, subkey = jax.random.split(main_key)
        xb, yb = get_batch('train', subkey)
        xb = jnp.array(xb, dtype=jnp.int32)
        yb = jnp.array(yb, dtype=jnp.int32)

        # Train step
        main_key, drop_key = jax.random.split(main_key)
        state, loss_val = train_step(state, variables, xb, yb, drop_key)

    # Generate text from the model
    # Start with a single batch of size 1, context=0
    context = jnp.zeros((1, 1), dtype=jnp.int32)
    generated = model.generate(variables, context, max_new_tokens=200, rng=main_key)
    print(decode(np.array(generated[0])))


if __name__ == "__main__":
    main()


step 0: train loss 4.8100, val loss 4.7877


KeyboardInterrupt: 