# nanoGPT-jax

In [1]:
with open("input.txt", "r", encoding="utf-8") as file:
    text = file.read()

In [2]:
print("length of text:", len(text))

length of text: 1115394


In [3]:
print("text:", text[:100])

text: First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [4]:
chars = sorted(set(text))
vocab_size = len(chars)
print("vocab_size:", vocab_size)
print(''.join(chars))

vocab_size: 65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [5]:
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda x: [stoi[ch] for ch in x]
decode = lambda x: ''.join([itos[i] for i in x])

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [6]:
import jax
from jax import numpy as jnp
import numpy as np

data = np.array(encode(text), dtype=np.int32) # no default int64, use numpy for underlying data
print(data.shape, data.dtype)
print(data[:100])

(1115394,) int32
[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43
  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43
 39 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49
  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10
  0 37 53 59]


In [7]:
# train val split
n = int(len(data) * 0.9)
train_data = data[:n]
val_data = data[n:]

In [8]:
block_size = 8
train_data[:block_size+1]

array([18, 47, 56, 57, 58,  1, 15, 47, 58], dtype=int32)

In [9]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context, target = x[:t+1], y[t]
    print(f"when input is {context} target is {target}")

when input is [18] target is 47
when input is [18 47] target is 56
when input is [18 47 56] target is 57
when input is [18 47 56 57] target is 58
when input is [18 47 56 57 58] target is 1
when input is [18 47 56 57 58  1] target is 15
when input is [18 47 56 57 58  1 15] target is 47
when input is [18 47 56 57 58  1 15 47] target is 58


In [10]:
rng = jax.random.PRNGKey(42)
key, subkey = jax.random.split(rng)

batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = np.random.randint(0, data.shape[0] - block_size, batch_size)
    x = np.stack([data[i:i+block_size] for i in ix])
    y = np.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch("train")
print("inputs:")
print(xb.shape)
print(xb)
print("targets:")
print(yb.shape)
print(yb)

print("-----")

for b in range(batch_size):
    for t in range(block_size):
        context, target = xb[b, :t+1], yb[b, t]
        print(f"when input is {context} target is {target}")

inputs:
(4, 8)
[[42  1 53  5 43 56  2  0]
 [46 39 58  1 57 43 43 51]
 [46 39 58  0 20 43  5 50]
 [47 52 44 39 52 58  6  1]]
targets:
(4, 8)
[[ 1 53  5 43 56  2  0  0]
 [39 58  1 57 43 43 51  1]
 [39 58  0 20 43  5 50 50]
 [52 44 39 52 58  6  1 44]]
-----
when input is [42] target is 1
when input is [42  1] target is 53
when input is [42  1 53] target is 5
when input is [42  1 53  5] target is 43
when input is [42  1 53  5 43] target is 56
when input is [42  1 53  5 43 56] target is 2
when input is [42  1 53  5 43 56  2] target is 0
when input is [42  1 53  5 43 56  2  0] target is 0
when input is [46] target is 39
when input is [46 39] target is 58
when input is [46 39 58] target is 1
when input is [46 39 58  1] target is 57
when input is [46 39 58  1 57] target is 43
when input is [46 39 58  1 57 43] target is 43
when input is [46 39 58  1 57 43 43] target is 51
when input is [46 39 58  1 57 43 43 51] target is 1
when input is [46] target is 39
when input is [46 39] target is 58
when 

In [11]:
import flax.linen as nn
import optax
from einops import rearrange
import numpy as np
from functools import partial

class BigramLanguageModel(nn.Module):
    vocab_size: int
    
    @nn.compact
    def __call__(self, idx: jnp.ndarray, targets: jnp.ndarray = None):
        logits = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)(idx) # (bs, block_size, 
        if targets is None:
            return logits, 0
        logits_reshaped = rearrange(logits, 'b t c -> (b t) c')
        loss = optax.softmax_cross_entropy_with_integer_labels(logits_reshaped, targets.flatten())
        return logits, jnp.mean(loss)

    def generate(self, params, subkey, idx, max_new_tokens: int = 100):
        # pad idx to max_new_tokens
        idx = jnp.pad(idx, ((0,0), (0, max_new_tokens)), mode='constant', constant_values=0)

        @jax.jit
        def step_fn(i: int, idx: jnp.ndarray, key: jnp.ndarray):
            logits, loss = self.apply(params, idx)
            logits = logits[:, i, :] # we are not using the logits from the rest of the sequence
            token = jax.random.categorical(key, logits)
            new_idx = idx.at[:, i+1].set(token)
            return new_idx
        for i in range(max_new_tokens):
            subkey, key = jax.random.split(subkey)
            idx = step_fn(i, idx, key)
        return jnp.array(idx)

m = BigramLanguageModel(vocab_size=vocab_size)
params = m.init(subkey, xb, yb)
out, loss = m.apply(params, xb, yb)
print(out.shape)
print(loss)

idx = jnp.zeros((1, 1), dtype=jnp.int32)
print(decode(np.array(m.generate(params, subkey, idx, max_new_tokens=100)[0])))

(4, 8, 65)
4.173657

yFlSnelZW$eod&mx&hF!OYNmyFf;Z;e-peU?JgFLPeuwoMmZaYH?wgihUj?ltK

KhqBwWPBEEuiYZnQM,WtQA-fjI3z ZH3.Wq:


In [18]:
from timeit import default_timer as timer

optimizer = optax.adamw(1e-3)
opt_state = optimizer.init(params)

eval_iters = 20
def estimate_loss():
    out = {}
    for split in ["train", "val"]:
        loss = 0
        for _ in range(eval_iters):
            xb, yb = get_batch(split)
            _, l = m.apply(params, xb, yb)
            loss += l
        out[split] = (loss / eval_iters).item()
    return out

@jax.jit
def train_step(params, opt_state, xb, yb):
    def loss_fn(params):
        logits, loss = m.apply(params, xb, yb)
        return loss
    
    loss, grad = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grad, opt_state, params)
    params = optax.apply_updates(params, updates)
    
    return params, opt_state, loss

start_time = timer()

batch_size = 32

for step in range(10000):
    xb, yb = get_batch("train")
    params, opt_state, loss = train_step(params, opt_state, xb, yb)
    
    if step % 2000 == 0:  # Print loss every 100 steps
        loss = estimate_loss()
        print(f"Step {step}, Loss: {loss}", "Elapsed time:", timer() - start_time)
        idx = jnp.zeros((1, 1), dtype=jnp.int32)

Step 0, Loss: {'train': 2.4908127784729004, 'val': 2.4694571495056152} Elapsed time: 0.4887924420181662
Step 2000, Loss: {'train': 2.462996244430542, 'val': 2.478917121887207} Elapsed time: 0.730645745061338
Step 4000, Loss: {'train': 2.4687507152557373, 'val': 2.492091655731201} Elapsed time: 0.961625972064212
Step 6000, Loss: {'train': 2.4685349464416504, 'val': 2.4656217098236084} Elapsed time: 1.198050731094554
Step 8000, Loss: {'train': 2.4678354263305664, 'val': 2.4940314292907715} Elapsed time: 1.431668707053177


In [19]:
idx = jnp.zeros((1, 1), dtype=jnp.int32)
subkey, key = jax.random.split(subkey)
print(decode(np.array(m.generate(params, key, idx, max_new_tokens=500)[0])))


Hou we h; mblld, hes hevete d, shed, nernge:
ROFousthereinore!

T:


My wne wimispowalo
Bus.

ME:
tuald nd.
IO e d my,
ADYo o thre ous chanere agastolerd So op' igonservewing amen tors bis Vighth s tusey ould tis me;
Wh ENout noon KIr KICle od IUCUSToulllooybede


d forsh ongnthore yel llthorace az wis st d?
HEN IUKINiun y tingure higeaving he mbothath sese I elll blar, o witore is owno, s
Bu aet tee He cern m uep, ps are ad, If y tyos,-horst f t senf! sefashee cheand t, tis PORIE puthecode t or
