# Installation
Install JAX with GPU CUDA support (in this case, CUDA 11 and cuDNN 8.6).

In [None]:
!pip install --upgrade jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install --upgrade tensorflow tensorflow-probability
!pip install datasets
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda]
  Downloading jax-0.5.2-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.2,>=0.5.1 (from jax[cuda])
  Downloading jaxlib-0.5.1-cp311-cp311-manylinux2014_x86_64.whl.metadata (978 bytes)
Collecting jax-cuda12-plugin<=0.5.2,>=0.5.1 (from jax-cuda12-plugin[with_cuda]<=0.5.2,>=0.5.1; extra == "cuda"->jax[cuda])
  Downloading jax_cuda12_plugin-0.5.1-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting jax-cuda12-pjrt==0.5.1 (from jax-cuda12-plugin<=0.5.2,>=0.5.1->jax-cuda12-plugin[with_cuda]<=0.5.2,>=0.5.1; extra == "cuda"->jax[cuda])
  Downloading jax_cuda12_pjrt-0.5.1-py3-none-manylinux2014_x86_64.whl.metadata (348 bytes)
Collecting nvidia-cuda-nvcc-cu12>=12.6.85 (from jax-cuda12-plugin[with_cuda]<=0.5.2,>=0.5.1; extra == "cuda"->jax[cuda])
  Downloading nvidia_cuda_nvcc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)
Download

TPU-equivalent installation code

In [None]:
!pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install --upgrade tensorflow tensorflow-probability
!pip install datasets
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

# Hyperparameters and setup

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import flax.linen as nn
import optax
from flax.training import train_state

batch_size = 64          # number of independent sequences processed in parallel
block_size = 256         # maximum context length for predictions
max_iters = 5000
eval_interval = 500
eval_iters = 200
learning_rate = 3e-4
n_embd = 384
n_head = 6
n_layer = 6
dropout_rate = 0.2
seed = 1337

# Set random seeds for reproducibility
np.random.seed(seed)
key = jax.random.PRNGKey(seed)


# Data Loading and Preprocessing

In [None]:

# Read the text (e.g. tiny Shakespeare)
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Create vocabulary and encoding/decoding functions
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Encode the full text as a numpy array of integers
data = np.array(encode(text), dtype=np.int32)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    """Generate a small batch of data as JAX arrays of inputs (x) and targets (y)."""
    data_split = train_data if split == 'train' else val_data
    ix = np.random.randint(0, len(data_split) - block_size, size=(batch_size,))
    x = np.stack([data_split[i : i + block_size] for i in ix])
    y = np.stack([data_split[i + 1 : i + block_size + 1] for i in ix])
    return jnp.array(x), jnp.array(y)


# Model Definition using Flax

In [None]:

class Head(nn.Module):
    head_size: int
    dropout_rate: float = dropout_rate

    @nn.compact
    def __call__(self, x, deterministic: bool = False):
        # x: (B, T, C)
        B, T, _ = x.shape
        # Compute key, query, and value projections
        k = nn.Dense(self.head_size, use_bias=False)(x)
        q = nn.Dense(self.head_size, use_bias=False)(x)
        v = nn.Dense(self.head_size, use_bias=False)(x)
        scale = 1.0 / jnp.sqrt(self.head_size)
        # Compute attention scores
        wei = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) * scale  # (B, T, T)
        # Create a lower-triangular mask
        mask = jnp.tril(jnp.ones((T, T), dtype=bool))
        wei = jnp.where(mask, wei, -1e10)
        wei = nn.softmax(wei, axis=-1)
        wei = nn.Dropout(rate=self.dropout_rate)(wei, deterministic=deterministic)
        # Weighted aggregation of the values
        out = jnp.matmul(wei, v)  # (B, T, head_size)
        return out

class MultiHeadAttention(nn.Module):
    num_heads: int
    head_size: int
    emb_dim: int
    dropout_rate: float = dropout_rate

    @nn.compact
    def __call__(self, x, deterministic: bool = False):
        heads = [
            Head(self.head_size, dropout_rate=self.dropout_rate)(x, deterministic=deterministic)
            for _ in range(self.num_heads)
        ]
        # Concatenate along the feature dimension
        out = jnp.concatenate(heads, axis=-1)
        out = nn.Dense(self.emb_dim)(out)
        out = nn.Dropout(rate=self.dropout_rate)(out, deterministic=deterministic)
        return out

class FeedForward(nn.Module):
    emb_dim: int
    dropout_rate: float = dropout_rate

    @nn.compact
    def __call__(self, x, deterministic: bool = False):
        x = nn.Dense(4 * self.emb_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(self.emb_dim)(x)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
        return x

class Block(nn.Module):
    emb_dim: int
    num_heads: int
    dropout_rate: float = dropout_rate

    @nn.compact
    def __call__(self, x, deterministic: bool = False):
        # Self-attention with residual connection
        x = x + MultiHeadAttention(
            num_heads=self.num_heads,
            head_size=self.emb_dim // self.num_heads,
            emb_dim=self.emb_dim,
            dropout_rate=self.dropout_rate
        )(nn.LayerNorm()(x), deterministic=deterministic)
        # Feed-forward network with residual connection
        x = x + FeedForward(emb_dim=self.emb_dim, dropout_rate=self.dropout_rate)(
            nn.LayerNorm()(x), deterministic=deterministic
        )
        return x

class GPTLanguageModel(nn.Module):
    vocab_size: int
    emb_dim: int
    num_layers: int
    num_heads: int
    block_size: int
    dropout_rate: float = dropout_rate

    @nn.compact
    def __call__(self, idx, targets=None, deterministic: bool = False):
        # idx: (B, T) of token indices
        B, T = idx.shape
        # Token and positional embeddings
        tok_emb = nn.Embed(num_embeddings=self.vocab_size, features=self.emb_dim)(idx)
        pos_emb = nn.Embed(num_embeddings=self.block_size, features=self.emb_dim)(
            jnp.arange(T)
        )
        x = tok_emb + pos_emb[None, :, :]
        # Transformer blocks
        for _ in range(self.num_layers):
            x = Block(emb_dim=self.emb_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate)(
                x, deterministic=deterministic
            )
        x = nn.LayerNorm()(x)
        logits = nn.Dense(self.vocab_size)(x)  # (B, T, vocab_size)
        if targets is not None:
            # Flatten the logits and targets for computing the loss
            logits = logits.reshape(-1, self.vocab_size)
            targets = targets.reshape(-1)
            loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean()
        else:
            loss = None
        return logits, loss

    def generate(self, params, idx, max_new_tokens, rng):
        """Generate new text tokens given a starting context."""
        # idx: (B, T)
        for _ in range(max_new_tokens):
            # Crop to the last block_size tokens
            idx_cond = idx[:, -self.block_size :]
            logits, _ = self.apply({'params': params}, idx_cond, deterministic=True)
            logits = logits[:, -1, :]  # (B, vocab_size) for the last time step
            # Sample from the distribution
            next_token = jax.random.categorical(rng, logits)[:, None]
            idx = jnp.concatenate([idx, next_token], axis=1)
            rng, _ = jax.random.split(rng)
        return idx


# Training State and Step Functions

In [None]:
class TrainState(train_state.TrainState):
    pass

@jax.jit
def train_step(state, x, y, dropout_rng):
    def loss_fn(params):
        _, loss = model.apply({'params': params}, x, targets=y, deterministic=False, rngs={'dropout': dropout_rng})
        return loss
    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

@jax.jit
def eval_step(params, x, y):
    _, loss = model.apply({'params': params}, x, targets=y, deterministic=True)
    return loss

def estimate_loss(state):
    losses = {'train': [], 'val': []}
    for split in ['train', 'val']:
        for _ in range(eval_iters):
            x, y = get_batch(split)
            loss = eval_step(state.params, x, y)
            losses[split].append(loss)
        losses[split] = np.mean([l.item() for l in losses[split]])
    return losses


# Main Training Loop

In [None]:
if __name__ == '__main__':
    # Initialize the model and training state
    model = GPTLanguageModel(
        vocab_size=vocab_size,
        emb_dim=n_embd,
        num_layers=n_layer,
        num_heads=n_head,
        block_size=block_size,
        dropout_rate=dropout_rate
    )
    dummy_input = jnp.ones((batch_size, block_size), dtype=jnp.int32)
    initial_variables = model.init(key, dummy_input, targets=dummy_input, deterministic=False)
    params = initial_variables['params']
    tx = optax.adamw(learning_rate)
    state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

    # Training loop
    for iter in range(max_iters):
        if iter % eval_interval == 0 or iter == max_iters - 1:
            losses = estimate_loss(state)
            print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        x, y = get_batch('train')
        key, dropout_rng = jax.random.split(key)
        state = train_step(state, x, y, dropout_rng)

    # Generate text from the trained model
    context = jnp.zeros((1, 1), dtype=jnp.int32)  # starting with the token index 0
    key, gen_rng = jax.random.split(key)
    generated = model.generate(state.params, context, max_new_tokens=500, rng=gen_rng)
    print(decode(np.array(generated[0])))


# Full Code

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import flax.linen as nn
import optax
from flax.training import train_state

# -----------------------------------------------------------------------------
# Hyperparameters and setup
# -----------------------------------------------------------------------------
batch_size = 64          # number of independent sequences processed in parallel
block_size = 256         # maximum context length for predictions
max_iters = 5000
eval_interval = 500
eval_iters = 200
learning_rate = 3e-4
n_embd = 384
n_head = 6
n_layer = 6
dropout_rate = 0.2
seed = 1337

# Set random seeds for reproducibility
np.random.seed(seed)
key = jax.random.PRNGKey(seed)

# -----------------------------------------------------------------------------
# Data Loading and Preprocessing
# -----------------------------------------------------------------------------
# Read the text (e.g. tiny Shakespeare)
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Create vocabulary and encoding/decoding functions
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Encode the full text as a numpy array of integers
data = np.array(encode(text), dtype=np.int32)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    """Generate a small batch of data as JAX arrays of inputs (x) and targets (y)."""
    data_split = train_data if split == 'train' else val_data
    ix = np.random.randint(0, len(data_split) - block_size, size=(batch_size,))
    x = np.stack([data_split[i : i + block_size] for i in ix])
    y = np.stack([data_split[i + 1 : i + block_size + 1] for i in ix])
    return jnp.array(x), jnp.array(y)

# -----------------------------------------------------------------------------
# Model Definition using Flax
# -----------------------------------------------------------------------------
class Head(nn.Module):
    head_size: int
    dropout_rate: float = dropout_rate

    @nn.compact
    def __call__(self, x, deterministic: bool = False):
        # x: (B, T, C)
        B, T, _ = x.shape
        # Compute key, query, and value projections
        k = nn.Dense(self.head_size, use_bias=False)(x)
        q = nn.Dense(self.head_size, use_bias=False)(x)
        v = nn.Dense(self.head_size, use_bias=False)(x)
        scale = 1.0 / jnp.sqrt(self.head_size)
        # Compute attention scores
        wei = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) * scale  # (B, T, T)
        # Create a lower-triangular mask
        mask = jnp.tril(jnp.ones((T, T), dtype=bool))
        wei = jnp.where(mask, wei, -1e10)
        wei = nn.softmax(wei, axis=-1)
        wei = nn.Dropout(rate=self.dropout_rate)(wei, deterministic=deterministic)
        # Weighted aggregation of the values
        out = jnp.matmul(wei, v)  # (B, T, head_size)
        return out

class MultiHeadAttention(nn.Module):
    num_heads: int
    head_size: int
    emb_dim: int
    dropout_rate: float = dropout_rate

    @nn.compact
    def __call__(self, x, deterministic: bool = False):
        heads = [
            Head(self.head_size, dropout_rate=self.dropout_rate)(x, deterministic=deterministic)
            for _ in range(self.num_heads)
        ]
        # Concatenate along the feature dimension
        out = jnp.concatenate(heads, axis=-1)
        out = nn.Dense(self.emb_dim)(out)
        out = nn.Dropout(rate=self.dropout_rate)(out, deterministic=deterministic)
        return out

class FeedForward(nn.Module):
    emb_dim: int
    dropout_rate: float = dropout_rate

    @nn.compact
    def __call__(self, x, deterministic: bool = False):
        x = nn.Dense(4 * self.emb_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(self.emb_dim)(x)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
        return x

class Block(nn.Module):
    emb_dim: int
    num_heads: int
    dropout_rate: float = dropout_rate

    @nn.compact
    def __call__(self, x, deterministic: bool = False):
        # Self-attention with residual connection
        x = x + MultiHeadAttention(
            num_heads=self.num_heads,
            head_size=self.emb_dim // self.num_heads,
            emb_dim=self.emb_dim,
            dropout_rate=self.dropout_rate
        )(nn.LayerNorm()(x), deterministic=deterministic)
        # Feed-forward network with residual connection
        x = x + FeedForward(emb_dim=self.emb_dim, dropout_rate=self.dropout_rate)(
            nn.LayerNorm()(x), deterministic=deterministic
        )
        return x

class GPTLanguageModel(nn.Module):
    vocab_size: int
    emb_dim: int
    num_layers: int
    num_heads: int
    block_size: int
    dropout_rate: float = dropout_rate

    @nn.compact
    def __call__(self, idx, targets=None, deterministic: bool = False):
        # idx: (B, T) of token indices
        B, T = idx.shape
        # Token and positional embeddings
        tok_emb = nn.Embed(num_embeddings=self.vocab_size, features=self.emb_dim)(idx)
        pos_emb = nn.Embed(num_embeddings=self.block_size, features=self.emb_dim)(
            jnp.arange(T)
        )
        x = tok_emb + pos_emb[None, :, :]
        # Transformer blocks
        for _ in range(self.num_layers):
            x = Block(emb_dim=self.emb_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate)(
                x, deterministic=deterministic
            )
        x = nn.LayerNorm()(x)
        logits = nn.Dense(self.vocab_size)(x)  # (B, T, vocab_size)
        if targets is not None:
            # Flatten the logits and targets for computing the loss
            logits = logits.reshape(-1, self.vocab_size)
            targets = targets.reshape(-1)
            loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean()
        else:
            loss = None
        return logits, loss

    def generate(self, params, idx, max_new_tokens, rng):
        """Generate new text tokens given a starting context."""
        # idx: (B, T)
        for _ in range(max_new_tokens):
            # Crop to the last block_size tokens
            idx_cond = idx[:, -self.block_size :]
            logits, _ = self.apply({'params': params}, idx_cond, deterministic=True)
            logits = logits[:, -1, :]  # (B, vocab_size) for the last time step
            # Sample from the distribution
            next_token = jax.random.categorical(rng, logits)[:, None]
            idx = jnp.concatenate([idx, next_token], axis=1)
            rng, _ = jax.random.split(rng)
        return idx

# -----------------------------------------------------------------------------
# Training State and Step Functions
# -----------------------------------------------------------------------------
class TrainState(train_state.TrainState):
    pass

@jax.jit
def train_step(state, x, y, dropout_rng):
    def loss_fn(params):
        _, loss = model.apply({'params': params}, x, targets=y, deterministic=False, rngs={'dropout': dropout_rng})
        return loss
    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

@jax.jit
def eval_step(params, x, y):
    _, loss = model.apply({'params': params}, x, targets=y, deterministic=True)
    return loss

def estimate_loss(state):
    losses = {'train': [], 'val': []}
    for split in ['train', 'val']:
        for _ in range(eval_iters):
            x, y = get_batch(split)
            loss = eval_step(state.params, x, y)
            losses[split].append(loss)
        losses[split] = np.mean([l.item() for l in losses[split]])
    return losses

# -----------------------------------------------------------------------------
# Main Training Loop
# -----------------------------------------------------------------------------
if __name__ == '__main__':
    # Initialize the model and training state
    model = GPTLanguageModel(
        vocab_size=vocab_size,
        emb_dim=n_embd,
        num_layers=n_layer,
        num_heads=n_head,
        block_size=block_size,
        dropout_rate=dropout_rate
    )
    dummy_input = jnp.ones((batch_size, block_size), dtype=jnp.int32)
    initial_variables = model.init(key, dummy_input, targets=dummy_input, deterministic=False)
    params = initial_variables['params']
    tx = optax.adamw(learning_rate)
    state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

    # Training loop
    for iter in range(max_iters):
        if iter % eval_interval == 0 or iter == max_iters - 1:
            losses = estimate_loss(state)
            print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        x, y = get_batch('train')
        key, dropout_rng = jax.random.split(key)
        state = train_step(state, x, y, dropout_rng)

    # Generate text from the trained model
    context = jnp.zeros((1, 1), dtype=jnp.int32)  # starting with the token index 0
    key, gen_rng = jax.random.split(key)
    generated = model.generate(state.params, context, max_new_tokens=500, rng=gen_rng)
    print(decode(np.array(generated[0])))


step 0: train loss 4.6297, val loss 4.6451
step 1: train loss 4.5821, val loss 4.6249
step 2: train loss 4.4477, val loss 4.4833
step 3: train loss 4.2935, val loss 4.3219
step 4: train loss 4.1183, val loss 4.1443
step 5: train loss 3.8985, val loss 3.9255
step 6: train loss 3.6706, val loss 3.7072
step 7: train loss 3.4779, val loss 3.5138
step 8: train loss 3.3642, val loss 3.4008
step 9: train loss 3.3840, val loss 3.4204
step 10: train loss 3.4494, val loss 3.4882
step 11: train loss 3.4496, val loss 3.4850
step 12: train loss 3.3940, val loss 3.4297
step 13: train loss 3.3479, val loss 3.3825
step 14: train loss 3.3302, val loss 3.3633


KeyboardInterrupt: 