<a href="https://colab.research.google.com/github/ubermenchh/jax-sandbox/blob/main/NanoGPT_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q jax flax

In [2]:
import jax
import jax.numpy as jnp
from flax import nnx
from flax.training import train_state
import optax

from typing import Tuple, Optional, Any
from dataclasses import dataclass

jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
import jax.profiler

In [51]:
@dataclass
class Config:
    vocab_size: int = 50257
    block_size: int = 128
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    lr: float = 3e-4
    max_iters: int = 10
    batch_size: int = 8
    grad_clip: float = 1.0

In [52]:
class MultiHeadAttention(nnx.Module):
    def __init__(self, config, rngs=nnx.Rngs(0)) -> None:
        self.config = config
        assert self.config.n_embd % self.config.n_head == 0

        self.c_attn = nnx.Linear(self.config.n_embd, 3 * self.config.n_embd, rngs=rngs)
        self.c_proj = nnx.Linear(self.config.n_embd, self.config.n_embd, rngs=rngs)

        self.n_head = self.config.n_head
        self.n_embd = self.config.n_embd
        self.head_dim = self.n_embd // self.n_head

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        B, T, C = x.shape # batch_size, seq_len, embed_dim

        qkv = self.c_attn(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)

        def reshape_heads(x):
            return x.reshape(B, T, self.n_head, self.head_dim).transpose(0, 2, 1, 3)

        q = reshape_heads(q)
        k = reshape_heads(k)
        v = reshape_heads(v)

        # Scaled Dot Product Attention
        scale = -1.0 / jnp.sqrt(self.head_dim)
        scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) * scale

        # Causal mask
        mask = jnp.tril(jnp.ones((T, T)))
        scores = jnp.where(mask == 0, float("-inf"), scores)

        # Attention Weights
        weights = jax.nn.softmax(scores, axis=-1)
        # Attention output
        out = jnp.matmul(weights, v)
        out = out.transpose(0, 2, 1, 3).reshape(B, T, C)

        # Output projection
        return self.c_proj(out)

In [53]:
class MLP(nnx.Module):
    def __init__(self, config: Config, rngs=nnx.Rngs(0)) -> None:
        super().__init__()
        self.config = config
        self.c_fc = nnx.Linear(config.n_embd, 4 * config.n_embd, rngs=rngs)
        self.c_proj = nnx.Linear(config.n_embd * 4, config.n_embd, rngs=rngs)

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = self.c_fc(x)
        x = jax.nn.gelu(x, approximate=True)
        x = self.c_proj(x)

        return x

In [54]:
class Block(nnx.Module):
    def __init__(self, config: Config, rngs=nnx.Rngs(0)) -> None:
        super().__init__()
        self.ln_1 = nnx.LayerNorm(num_features=config.n_embd, rngs=rngs)
        self.attn = MultiHeadAttention(config, rngs)
        self.ln_2 = nnx.LayerNorm(num_features=config.n_embd, rngs=rngs)
        self.mlp = MLP(config, rngs)

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

In [55]:
class GPT(nnx.Module):
    def __init__(self, config: Config, rngs=nnx.Rngs(0)) -> None:
        super().__init__()
        self.config = config
        self.wte = nnx.Embed(
            num_embeddings=config.vocab_size,
            features=config.n_embd,
            param_dtype=jnp.float32,
            rngs=rngs
        )
        self.wpe = nnx.Embed(
            num_embeddings=config.block_size,
            features=config.n_embd,
            param_dtype=jnp.float32,
            rngs=rngs
        )
        self.dropout = nnx.Dropout(rate=0.1)
        self.blocks = [Block(config, rngs) for _ in range(config.n_layer)]
        self.ln_f = nnx.LayerNorm(num_features=config.n_embd, rngs=rngs)
        self.lm_head = nnx.Linear(
            config.n_embd,
            config.vocab_size,
            use_bias=False,
            kernel_init=nnx.initializers.normal(stddev=0.02),
            rngs=rngs
        )

    def __call__(self,
                 idx: jnp.ndarray,
                 deterministic: bool=True,
                 targets: Optional[jnp.ndarray]=None) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
        B, T = idx.shape
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"

        pos = jnp.arange(0, T, dtype=jnp.int32)

        token_emb = self.wte(idx)
        pos_emb = self.wpe(pos)

        x = self.dropout(token_emb + pos_emb, deterministic=deterministic)

        for block in self.blocks:
            x = block(x)

        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is not None:
            loss = optax.softmax_cross_entropy_with_integer_labels(
                logits.reshape(-1, logits.shape[-1]),
                targets.reshape(-1)
            ).mean()
            return logits, loss

        return logits, None

In [56]:
config = Config()
rngs = nnx.Rngs(0)
gpt = GPT(config, rngs)

In [57]:
gpt

GPT(
  config=Config(vocab_size=50257, block_size=128, n_layer=12, n_head=12, n_embd=768, lr=0.0003, max_iters=10, batch_size=8, grad_clip=1.0),
  wte=Embed(
    embedding=Param(
      value=Array(shape=(50257, 768), dtype=float32)
    ),
    num_embeddings=50257,
    features=768,
    dtype=dtype('float32'),
    param_dtype=<class 'jax.numpy.float32'>,
    embedding_init=<function variance_scaling.<locals>.init at 0x791b0b813250>
  ),
  wpe=Embed(
    embedding=Param(
      value=Array(shape=(128, 768), dtype=float32)
    ),
    num_embeddings=128,
    features=768,
    dtype=dtype('float32'),
    param_dtype=<class 'jax.numpy.float32'>,
    embedding_init=<function variance_scaling.<locals>.init at 0x791b0b813250>
  ),
  dropout=Dropout(rate=0.1, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=None),
  blocks=[Block(
    ln_1=LayerNorm(
      scale=Param(
        value=Array(shape=(768,), dtype=float32)
      ),
      bias=Param(
        value=Array(shape=(768,

In [58]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-11-14 16:13:22--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2024-11-14 16:13:22 (17.9 MB/s) - ‘input.txt.1’ saved [1115394/1115394]



In [59]:
data = open("input.txt", "r").read()
print(data[:500])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


In [60]:
!pip install -q tiktoken

In [61]:
import tiktoken

tokenizer = tiktoken.get_encoding("gpt2")

In [62]:
import torch

tokens = tokenizer.encode(data)
n = len(tokens)
n

338025

In [63]:
train_data = tokens[:int(n * 0.9)]
valid_data = tokens[int(n * 0.9):int(n * 0.95)]
test_data = tokens[int(n * 0.95):]

len(train_data), len(valid_data), len(test_data)

(304222, 16901, 16902)

In [64]:
print(train_data[:10])
print(tokenizer.decode(train_data[:10]))

[5962, 22307, 25, 198, 8421, 356, 5120, 597, 2252, 11]
First Citizen:
Before we proceed any further,


In [65]:
import torch

class TextDataset(torch.utils.data.Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = block_size

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        chunk = self.data[idx: idx + self.block_size + 1]
        x = chunk[:-1]
        y = chunk[1:]
        return x, y

In [66]:
import numpy as np

def jax_collate(batch):
    max_len = max(len(item[0]) for item in batch)

    inp = [item[0] + [0] * (max_len - len(item[0])) for item in batch]
    trg = [item[1] + [0] * (max_len - len(item[1])) for item in batch]

    inp, trg = np.array(inp), np.array(trg)
    return inp, trg

In [67]:
train_dataset = TextDataset(train_data, config.block_size)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, drop_last=True, collate_fn=jax_collate, num_workers=4)

valid_dataset = TextDataset(valid_data, config.block_size)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, drop_last=True, collate_fn=jax_collate, num_workers=4)

test_dataset = TextDataset(test_data, config.block_size)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, drop_last=True, collate_fn=jax_collate, num_workers=4)

In [68]:
len(train_dataset) // config.batch_size

38027

In [69]:
b = next(iter(valid_dataloader))

  self.pid = os.fork()


In [70]:
b[0].shape, b[1].shape

((8, 128), (8, 128))

In [71]:
lr = 3e-4
weight_decay = 0.1

optimizer = nnx.Optimizer(gpt, optax.adamw(learning_rate=lr, b1=0.9, b2=0.95, eps=1e-8, weight_decay=weight_decay))
metrics = nnx.MultiMetric(loss=nnx.metrics.Average("loss"), perplexity=nnx.metrics.Average("perplexity"))

In [72]:
def loss_fn(model, batch):
    inp, trg = batch
    trg = trg.astype(int)
    logits, loss = model(inp, deterministic=True, targets=trg)
    return loss, logits

In [73]:
@nnx.jit
def train_step(model, optimizer, metrics, batch):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, batch)
    perplexity = jnp.exp(loss)
    metrics.update(loss=loss, perplexity=perplexity)
    optimizer.update(grads)

In [74]:
@nnx.jit
def eval_step(model: GPT, metrics: nnx.MultiMetric, batch):
    loss, logits = loss_fn(model, batch)
    perplexity = jnp.exp(loss)
    metrics.update(loss=loss, perplexity=perplexity)

In [75]:
import time

def generate(
    model: GPT,
    idx: jnp.ndarray,
    max_new_tokens: int,
    temperature: float = 1.0,
    top_k: Optional[int] = None
):
    """Generate text tokens autoregressively."""
    for _ in range(max_new_tokens):
        # Crop sequence if too long
        idx_cond = idx if idx.shape[1] <= model.config.block_size else idx[:, -model.config.block_size:]

        # Get predictions
        logits, _ = model(idx_cond, deterministic=True)
        logits = logits[:, -1, :] / temperature

        if top_k is not None:
            # Top-k sampling
            top_logits, top_indices = jax.lax.top_k(logits, min(top_k, logits.shape[-1]))
            probs = jax.nn.softmax(top_logits, axis=-1)
            idx_next = jax.random.choice(
                jax.random.PRNGKey(int(time.time())),
                top_indices.shape[-1],
                shape=(1,),
                p=probs[0]
            )
            idx_next = top_indices[0, idx_next]
        else:
            # Regular sampling
            probs = jax.nn.softmax(logits, axis=-1)
            idx_next = jax.random.categorical(
                jax.random.PRNGKey(int(time.time())),
                logits,
                axis=-1
            )

        # Append to sequence
        idx = jnp.concatenate((idx, idx_next.reshape(1, 1)), axis=1)

    return idx

In [76]:
metrics_history = {
    "train_loss": [],
    "train_perplexity": [],
    "val_loss": [],
    "val_perplexity": []
}

# Training configuration
num_epochs = 1
eval_every = 500  # Evaluate every 100 steps

# Main training loop
total_steps = 0
best_val_loss = float('inf')

In [77]:
for epoch in range(num_epochs):
    # Training
    for batch_idx, batch in enumerate(train_dataloader):
        train_step(gpt, optimizer, metrics, batch)
        total_steps += 1

        if total_steps > 0 and (total_steps % eval_every == 0):
            # Log training metrics
            train_metrics = metrics.compute()
            for metric, value in train_metrics.items():
                metrics_history[f"train_{metric}"].append(value)
            metrics.reset()

            # Validation
            for val_batch in valid_dataloader:
                eval_step(gpt, metrics, val_batch)

            # Log validation metrics
            val_metrics = metrics.compute()
            for metric, value in val_metrics.items():
                metrics_history[f"val_{metric}"].append(value)
            metrics.reset()

            # Print progress
            print(
                f"[Epoch {epoch}][Step {total_steps}] "
                f"Train Loss: {metrics_history['train_loss'][-1]:.4f} "
                f"Train PPL: {metrics_history['train_perplexity'][-1]:.4f} | "
                f"Val Loss: {metrics_history['val_loss'][-1]:.4f} "
                f"Val PPL: {metrics_history['val_perplexity'][-1]:.4f}"
            )

            # Save best model (you can implement saving logic here)
            if val_metrics['loss'] < best_val_loss:
                best_val_loss = val_metrics['loss']
                print(f"New best validation loss: {best_val_loss:.4f}")

  self.pid = os.fork()


[Epoch 0][Step 500] Train Loss: 5.5761 Train PPL: 547.7780 | Val Loss: 5.1558 Val PPL: 221.7115
New best validation loss: 5.1558


  self.pid = os.fork()


[Epoch 0][Step 1000] Train Loss: 4.6120 Train PPL: 104.0870 | Val Loss: 4.7349 Val PPL: 154.6977
New best validation loss: 4.7349
[Epoch 0][Step 1500] Train Loss: 4.2999 Train PPL: 76.1086 | Val Loss: 4.5557 Val PPL: 131.6437
New best validation loss: 4.5557
[Epoch 0][Step 2000] Train Loss: 4.0196 Train PPL: 57.4507 | Val Loss: 4.4569 Val PPL: 120.4873
New best validation loss: 4.4569
[Epoch 0][Step 2500] Train Loss: 3.7720 Train PPL: 44.7964 | Val Loss: 4.4164 Val PPL: 117.6189
New best validation loss: 4.4164
[Epoch 0][Step 3000] Train Loss: 3.4360 Train PPL: 32.0304 | Val Loss: 4.4019 Val PPL: 120.9069
New best validation loss: 4.4019
[Epoch 0][Step 3500] Train Loss: 3.0608 Train PPL: 22.1684 | Val Loss: 4.5267 Val PPL: 141.7111
[Epoch 0][Step 4000] Train Loss: 2.5977 Train PPL: 14.1584 | Val Loss: 4.6910 Val PPL: 171.0228
[Epoch 0][Step 4500] Train Loss: 2.0620 Train PPL: 8.2785 | Val Loss: 4.9290 Val PPL: 225.3462
[Epoch 0][Step 5000] Train Loss: 1.5335 Train PPL: 4.8242 | Val Los

KeyboardInterrupt: 

In [80]:
prompt = "First Citizen:\nYou "
prompt_tokens = jnp.array(tokenizer.encode(prompt))[None, :]

# print(prompt_tokens.shape)
pred_tokens = generate(gpt, idx=prompt_tokens, max_new_tokens=1024)

In [81]:
print(tokenizer.decode(pred_tokens.tolist()[0]))

First Citizen:
You orrowed well said,arer; no, so did gates.

Third Citizen:
I am five hundred, and thoure'st happy to be logs,
 great and thy great kingdom to my death.

Second Citizen:
And so did I; and so did I.

BRUTUS:
What then, that want wasiron.

CORIOLANUS:
as son.

MENENIUS:
The matter?

COMINIUS:
Ay, nine, call,kindenative; hearts!

CORIOLANUS:
 dimce have youSee how have youurrent made, sir?

Both that amWhen, sir.

MENENIUS:
How! Was it we? we loved him but, like beasts
And cowardly nobles, gave way unto your clusters,
Who did hoot him out o' the city.

COMINIUS:
But I fear
They'll roar him in again. Tullus Aufidius,
The second name of men, obeys his points
As if he were his officer: desperation
Is all the policy, strength and defence,
That Rome can make against them.

MENENIUS:
Here come the clusters.
And is Aufidius with him? You are they
That made the air unwh tellbalt with his
 true worse than trueoms.

MENENIUS:
I amShe's Go.

First Senator:
Now, as he made before, be