In [2]:
from typing import Literal

import flax
import jax
import jax.numpy as jnp
import optax
from flax import nnx

print(f"flax: {flax.__version__}")
print(f"jax: {jax.__version__}")
print(f"optax: {optax.__version__}")

flax: 0.12.0
jax: 0.8.0
optax: 0.2.6


In [44]:
!curl -O https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

  pid, fd = os.forkpty()


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1089k  100 1089k    0     0  4078k      0 --:--:-- --:--:-- --:--:-- 4079k


In [45]:
with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [46]:
print(f"Length of dataset in characters: {len(text):,}")

Length of dataset in characters: 1,115,394


In [47]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [48]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

print("".join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [49]:
s2i = {ch:i for i, ch in enumerate(chars)}
i2s = {i:ch for i, ch in enumerate(chars)}
encode = lambda s: [s2i[c] for c in s]
decode = lambda l: "".join([i2s[i] for i in l])

In [50]:
print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [51]:
data = jnp.array(encode(text), dtype=jnp.int32)
print(data.shape, data.dtype)
print(data[:1000])

(1115394,) int32
[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43
  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43
 39 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49
  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10
  0 37 53 59  1 39 56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39
 58 46 43 56  1 58 53  1 42 47 43  1 58 46 39 52  1 58 53  1 44 39 51 47
 57 46 12  0  0 13 50 50 10  0 30 43 57 53 50 60 43 42  8  1 56 43 57 53
 50 60 43 42  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 18 47
 56 57 58  6  1 63 53 59  1 49 52 53 61  1 15 39 47 59 57  1 25 39 56 41
 47 59 57  1 47 57  1 41 46 47 43 44  1 43 52 43 51 63  1 58 53  1 58 46
 43  1 54 43 53 54 50 43  8  0  0 13 50 50 10  0 35 43  1 49 52 53 61  5
 58  6  1 61 43  1 49 52 53 61  5 58  8  0  0 18 47 56 57 58  1 15 47 58
 47 64 43 52 10  0 24 43 58  1 59 57  1 49 47 50 50  1 46 47 51  6  1 39
 52 42  1 61 43  5 50 50  1 46 39 

In [52]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [53]:
train_data

Array([18, 47, 56, ..., 43, 56, 43], dtype=int32)

In [54]:
block_size = 8
train_data[:block_size+1]

Array([18, 47, 56, 57, 58,  1, 15, 47, 58], dtype=int32)

In [55]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is [18] the target: 47
when input is [18 47] the target: 56
when input is [18 47 56] the target: 57
when input is [18 47 56 57] the target: 58
when input is [18 47 56 57 58] the target: 1
when input is [18 47 56 57 58  1] the target: 15
when input is [18 47 56 57 58  1 15] the target: 47
when input is [18 47 56 57 58  1 15 47] the target: 58


In [56]:
rngs = nnx.Rngs(44)

In [57]:
def get_batch(
    rngs: nnx.Rngs,
    split: Literal["train", "val"],
    block_size: int = 8,
    batch_size: int = 4,
):
    # Generate a small batch of data of inputs x and targets y
    data = train_data if split == "train" else val_data

    maxval = len(data) - block_size
    start_indices = rngs.randint(shape=(batch_size,), minval=0, maxval=maxval)
    
    x = jnp.stack([data[i:i+block_size] for i in start_indices])
    y = jnp.stack([data[i+1:i+1+block_size] for i in start_indices])
    
    return x, y

In [58]:
BLOCK_SIZE = 8
BATCH_SIZE = 4

xb, yb = get_batch(rngs, "train", block_size=BLOCK_SIZE, batch_size=BATCH_SIZE)

In [59]:
print("inputs:")
print(xb.shape)
print(xb)
print("\ntargets:")
print(yb.shape)
print(yb)

inputs:
(4, 8)
[[57  1 40 56 53 58 46 43]
 [40 43  1 40 63  1 57 58]
 [ 1 52 53  1 46 39 56 51]
 [52 42  1 44 39 47 56 50]]

targets:
(4, 8)
[[ 1 40 56 53 58 46 43 56]
 [43  1 40 63  1 57 58 43]
 [52 53  1 46 39 56 51  1]
 [42  1 44 39 47 56 50 63]]


In [60]:
for b in range(BATCH_SIZE):  # Batch dimension
    for t in range(BLOCK_SIZE):  # Time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context} the target: {target}")
    print()

when input is [57] the target: 1
when input is [57  1] the target: 40
when input is [57  1 40] the target: 56
when input is [57  1 40 56] the target: 53
when input is [57  1 40 56 53] the target: 58
when input is [57  1 40 56 53 58] the target: 46
when input is [57  1 40 56 53 58 46] the target: 43
when input is [57  1 40 56 53 58 46 43] the target: 56

when input is [40] the target: 43
when input is [40 43] the target: 1
when input is [40 43  1] the target: 40
when input is [40 43  1 40] the target: 63
when input is [40 43  1 40 63] the target: 1
when input is [40 43  1 40 63  1] the target: 57
when input is [40 43  1 40 63  1 57] the target: 58
when input is [40 43  1 40 63  1 57 58] the target: 43

when input is [1] the target: 52
when input is [ 1 52] the target: 53
when input is [ 1 52 53] the target: 1
when input is [ 1 52 53  1] the target: 46
when input is [ 1 52 53  1 46] the target: 39
when input is [ 1 52 53  1 46 39] the target: 56
when input is [ 1 52 53  1 46 39 56] the t

In [61]:
class BigramLanguageModel(nnx.Module):

    def __init__(self, vocab_size: int, rngs: nnx.Rngs):
        self.token_embedding_table = nnx.Embed(
            num_embeddings=vocab_size, features=vocab_size, rngs=rngs
        )
    
    def __call__(
            self,
            idx: jnp.ndarray,
            targets: jnp.ndarray | None = None
    ) -> jnp.ndarray:
        # Think of logits as scores for the next char in the sequence.
        logits = self.token_embedding_table(idx)   # (Batch, Time, Channel) = (B, T, C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.reshape(B*T, C)
            targets = targets.reshape(B*T)

            loss = jnp.mean(
                # Can verify softmax using the average ~= -log(1/vocab_size) = 4.17.
                optax.softmax_cross_entropy_with_integer_labels(logits, targets)
            )
        
        # !!! Logits is dim (B, T, C) if targets is None else (B*T, C)
        return logits, loss
    
    def generate(self, idx: jnp.ndarray, max_new_tokens: int, rngs: nnx.Rngs) -> jnp.ndarray:
        # idx is (B, T) array of indices in current context.
        for _ in range(max_new_tokens):
            # Get the predictions
            logits, _ = self(idx)  # dim = (B, C)

            # Focus only on the last time step (get idx -1 on Time index)
            logits = logits[:, -1, :]  # dim = (B, C)

            # jax.random.categorical is more similar to torch.multinomial.
            # Notice we don't require apply softmax to logits since rngs.categorical
            # expects logits rather than probabilities.
            # Also notice, reshape (B, 1) because cannot concat (B, T) with (B,), require 
            # reshape to concat (B, T) with (B, 1) --> (B, T+1).
            idx_next = rngs.categorical(logits).reshape(logits.shape[0], 1)  # dim = (B,) -> (B, 1)

            # Append sampled index to the running idx sequence
            idx = jnp.concat([idx, idx_next], axis=1)  # dim = (B, T+1)
        
        return idx



In [62]:
rngs = nnx.Rngs(0)
m = BigramLanguageModel(vocab_size, rngs)
logits, loss = m(xb, yb)

print(logits.shape)
print(f"({BATCH_SIZE}, {BLOCK_SIZE}, {vocab_size})")

(32, 65)
(4, 8, 65)


In [63]:
idx = jnp.zeros((1, 1), dtype=jnp.int32)
next_idxs = m.generate(idx, max_new_tokens=100, rngs=rngs)
res = decode(next_idxs[0].tolist())
res

"\n3OovuAZ?&I?TR-ueXUHOf\nRB-Z?hDg'yssFGNIJt-3bQK!N?saRlq'sw: QB,:WQDg3lcDmwBHwXThpkZ:n3i.osSdnCEyHQ det"

In [64]:
# Train the function

learning_rate = 1e-3

optimizer = nnx.Optimizer(
    m, optax.adamw(learning_rate=learning_rate), wrt=nnx.Param
)

metrics = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss")
    # accuracy=nnx.metrics.Accuracy(),
)

nnx.display(optimizer)

In [65]:
def loss_fn(model, idx, targets):
    logits, loss = model(idx, targets)
    return loss, logits

    
@nnx.jit
def train_step(model, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, idx, targets):
    """Train for a single step."""
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, idx, targets)
    metrics.update(loss=loss)
    optimizer.update(model, grads)

In [None]:
rngs = nnx.Rngs(0)

BATCH_SIZE = 32

eval_every = 100

m = BigramLanguageModel(vocab_size, rngs)
for steps in range(10000):
        
    xb, yb = get_batch(rngs, "train", batch_size=BATCH_SIZE)
    m.train()
    train_step(m, optimizer, metrics, xb, yb)

    loss = metrics.compute()["loss"]
    
    if steps % eval_every == 0:
        print(f"{steps:<5} {loss}")

print(f"{steps:<5} {loss}")

0     4.140624523162842
100   4.135303974151611
200   4.101077079772949
300   4.0664825439453125
400   4.0348734855651855
500   4.00325870513916
600   3.9710533618927
700   3.942322254180908
800   3.911705493927002
900   3.8844358921051025
1000  3.8571066856384277
1100  3.8305211067199707
1200  3.8036048412323
1300  3.776442527770996
1400  3.7514498233795166
1500  3.7267446517944336
1600  3.702680826187134
1700  3.6792964935302734
1800  3.6570262908935547
1900  3.634251594543457
2000  3.613267183303833
2100  3.592505693435669
2200  3.572800636291504
2300  3.5528976917266846
2400  3.5337257385253906
2500  3.515170097351074
2600  3.4974186420440674
2700  3.479696035385132
2800  3.462641716003418
2900  3.445633888244629
3000  3.4287517070770264
3100  3.412339925765991
3200  3.396601915359497
3300  3.3805735111236572
3400  3.3655996322631836
3500  3.350949764251709
3600  3.3362441062927246
3700  3.321694850921631
3800  3.307173728942871
3900  3.294149398803711
4000  3.281064987182617
4100 

In [67]:
idx = jnp.zeros((1, 1), dtype=jnp.int32)
print(decode( m.generate(idx, max_new_tokens=500, rngs=rngs)[0].tolist()))




DW:
TIxGAn s halll oris I?drsthee POSh fe atig ctharemo metiboncatothad t be ac?k d fe wadpy t

TOFreedeaspeaf s thirdy the ny sid t. w O habu te'SLI btod fathavCH'dry t sbove y'bowabeploith win ofr o prillld ber s muelernkiree t
RO TH:
LITA hichert aloun thLIEakian'Rindore, swithimin
TMir ld Ondy thr suY:
Nanoero wd as'd st sp'MNVeererelyino iWar ht s'Mates tt!
R s:
TCO cothastCUqu lo t dinnss IGe n

KxJalt orksengg tl qRSo RIAMos, t bes s, tth shyxyourou, I'LLOFrocerath's Carers

O!
Thint?qL


### The mathematical trick in self-attention

In [280]:
rngs = nnx.Rngs(1337)
B, T, C = 4, 8, 2  # Batch, Time, Channel
x = rngs.normal(shape=(B, T, C))
x.shape

(4, 8, 2)

In [281]:
# We want x[b, t] = mean_{i<=t} x[b,i]
xbow = jnp.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1]  # (t, C)
        xbow = xbow.at[b,t].set(jnp.mean(xprev, axis=0))

In [282]:
rngs = nnx.Rngs(42)
a = jnp.tril(jnp.ones((3, 3)))
a = a / a.sum(axis=1, keepdims=True)
b = rngs.randint(shape=(3,2), minval=0, maxval=10).astype(float)
c = a @ b
print("a=")
print(a)
print("--")
print("b=")
print(b)
print("--")
print("c=")
print(c)

a=
[[1.         0.         0.        ]
 [0.5        0.5        0.        ]
 [0.33333334 0.33333334 0.33333334]]
--
b=
[[4. 1.]
 [9. 1.]
 [0. 9.]]
--
c=
[[4.        1.       ]
 [6.5       1.       ]
 [4.3333335 3.6666667]]


In [283]:
# version 2
wei = jnp.tril(jnp.ones((T, T)))
wei = wei / wei.sum(axis=1, keepdims=True)
# Batch matrix multiply
xbow2 = wei @ x  # (T, T) @ (B, T, C) --> (B, T, T) @ (B, T, C) -> (B, T, C)
jnp.allclose(xbow, xbow2)

Array(True, dtype=bool)

In [284]:
# version 3: use Softmax
tril = jnp.tril(jnp.ones((T, T)))
wei = jnp.zeros((T, T))
wei = jnp.where(tril==0, float("-inf"), wei)
wei = nnx.softmax(wei, axis=-1)
xbow3 = wei @ x
jnp.allclose(xbow, xbow3)

Array(True, dtype=bool)

In [287]:
# version 4: self-attention!
rngs = nnx.Rngs(1337)
B, T, C = 4, 8, 32  # batch, time, channels
x = rngs.normal(shape=(B, T, C))

# let's see a single Head perform self-attention
head_size = 16  # this is a hyperparam
key = nnx.Linear(C, head_size, use_bias=False, rngs=rngs)
query = nnx.Linear(C, head_size, use_bias=False, rngs=rngs)
value = nnx.Linear(C, head_size, use_bias=False, rngs=rngs)
k = key(x)    # (B, T, head_size) = (B, T, 16)
q = query(x)  # (B, T, head_size) = (B, T, 16)
# Alt use jnp.matrix_transpose(k) (designed to handle exactly this use case!)
wei = q @ k.transpose(0, -1, -2) # (B, T, 16) @ (B, 16, T) --> (B, T, T)
print(wei.shape)

tril = jnp.tril(jnp.ones((T, T)))
# wei = jnp.zeros((T, T))
wei = jnp.where(tril==0, float("-inf"), wei)
wei = nnx.softmax(wei, axis=-1)

v = value(x)
out = wei @ v
# out = wei @ x

out.shape

(4, 8, 8)


(4, 8, 16)

In [286]:
wei[0]

Array([[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [3.8266862e-05, 9.9996173e-01, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [9.9973768e-01, 1.5942763e-04, 1.0290114e-04, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [1.9299363e-03, 2.9566420e-02, 4.3045271e-05, 9.6846056e-01,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [3.0582871e-03, 3.2653133e-03, 7.7314684e-07, 1.0017906e-01,
        8.9349657e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [3.9127679e-07, 8.0958921e-07, 1.6457305e-04, 1.9999688e-04,
        3.3143956e-02, 9.6649027e-01, 0.0000000e+00, 0.0000000e+00],
       [2.9228389e-04, 6.1300781e-02, 1.4398918e-05, 4.6946220e-02,
        8.8477163e-03, 5.5164182e-01, 3.3095676e-01, 0.0000000e+00],
       [3.4122226e-01, 2.2325837e-03, 5.6

In [None]:
a = nnx.Sequential(*[nnx.Linear(5, 5, rngs=rngs), nnx.Linear(3, 5, rngs=rngs)])
a


Sequential( # Param: 50 (200 B)
  layers=List([
    Linear( # Param: 30 (120 B)
      kernel=Param( # 25 (100 B)
        value=Array(shape=(5, 5), dtype=dtype('float32'))
      ),
      bias=Param( # 5 (20 B)
        value=Array(shape=(5,), dtype=dtype('float32'))
      ),
      in_features=5,
      out_features=5,
      use_bias=True,
      dtype=None,
      param_dtype=float32,
      precision=None,
      kernel_init=<function variance_scaling.<locals>.init at 0x1157ae3e0>,
      bias_init=<function zeros at 0x114730d60>,
      dot_general=<function dot_general at 0x1140cf240>,
      promote_dtype=<function promote_dtype at 0x1157ae700>,
      preferred_element_type=None
    ),
    Linear( # Param: 20 (80 B)
      kernel=Param( # 15 (60 B)
        value=Array(shape=(3, 5), dtype=dtype('float32'))
      ),
      bias=Param( # 5 (20 B)
        value=Array(shape=(5,), dtype=dtype('float32'))
      ),
      in_features=3,
      out_features=5,
      use_bias=True,
      dtype=None,
     

In [301]:
a.layers[1]

Linear( # Param: 20 (80 B)
  kernel=Param( # 15 (60 B)
    value=Array(shape=(3, 5), dtype=dtype('float32'))
  ),
  bias=Param( # 5 (20 B)
    value=Array(shape=(5,), dtype=dtype('float32'))
  ),
  in_features=3,
  out_features=5,
  use_bias=True,
  dtype=None,
  param_dtype=float32,
  precision=None,
  kernel_init=<function variance_scaling.<locals>.init at 0x1157ae3e0>,
  bias_init=<function zeros at 0x114730d60>,
  dot_general=<function dot_general at 0x1140cf240>,
  promote_dtype=<function promote_dtype at 0x1157ae700>,
  preferred_element_type=None
)

In [297]:
a.layers.layers[0]

AttributeError: 'List' object has no attribute 'layers'

In [317]:
a, b = get_batch(rngs, "train", 13, 30)
a

Array([[47, 50, 50,  1, 57, 53, 51, 43,  1, 61, 39, 52, 58],
       [63, 53, 59, 56,  1, 50, 47, 44, 43,  6,  0, 14, 59],
       [57, 53, 51, 43,  1, 54, 53, 57, 58,  1, 58, 53,  1],
       [61, 53, 59, 50, 42,  1, 52, 53, 58,  1, 56, 39, 58],
       [ 0, 39, 52, 42,  1, 57, 53,  1, 50, 53, 41, 49, 57],
       [ 0, 29, 33, 17, 17, 26,  1, 25, 13, 30, 19, 13, 30],
       [47, 53,  6,  1, 44, 53, 56,  1, 63, 53, 59, 56,  1],
       [52, 58,  6,  1, 39, 57,  1, 41, 46, 39, 57, 58, 43],
       [50, 47, 57, 46,  1, 44, 53, 56,  1, 58, 46, 47, 57],
       [58, 46, 63,  1, 40, 56, 53, 58, 46, 43, 56,  1, 51],
       [ 1, 47, 52,  1, 58, 46, 43,  1, 39, 40, 57, 43, 52],
       [10,  0, 32, 46, 47, 57,  1, 47, 57,  1, 39,  1, 46],
       [27, 33, 15, 17, 31, 32, 17, 30, 10,  0,  0, 23, 21],
       [53, 61, 43, 56, 10,  0, 21,  1, 61, 47, 50, 50,  1],
       [46, 43, 52,  1,  5, 58, 47, 57,  1, 58, 47, 51, 43],
       [42,  1, 46, 43,  1, 61, 43, 56, 43,  1, 61, 39, 57],
       [39,  1, 45, 53, 

In [344]:
a = rngs.choice(vocab_size, shape=(3, BLOCK_SIZE))
a

Array([[44, 18, 10,  7, 56,  7,  2, 28],
       [ 3, 13, 61, 49, 11, 13, 58, 26],
       [29, 48, 61, 60, 21, 30, 10,  8]], dtype=int32)

In [342]:
nnx.Embed(vocab_size, n_embd, rngs=rngs)(a)

Array([[[-0.72049576,  0.1807526 ,  0.5949409 ],
        [ 0.7831538 ,  0.01397121,  0.8691177 ],
        [ 0.67870635, -0.94205654,  1.4371607 ],
        [ 0.67870635, -0.94205654,  1.4371607 ],
        [-1.1601799 ,  0.03025232,  0.61355156],
        [ 0.24656233,  0.43664065, -0.5682739 ],
        [-0.13526548,  0.45045102, -0.5280581 ],
        [ 1.0046653 ,  0.63433415, -0.70020086],
        [ 0.1601707 , -0.01017623,  0.67510253],
        [-0.63270015, -0.25287867,  0.98393846],
        [-0.13526548,  0.45045102, -0.5280581 ],
        [-0.37615067, -0.19536205, -0.23777352],
        [ 0.03617128,  0.7040133 , -0.5876029 ],
        [-0.5348546 , -1.0864314 , -0.32516629],
        [-0.69427854,  0.16286016, -0.6004196 ]],

       [[ 0.58181065, -0.52031887, -0.5343059 ],
        [ 0.18912216,  0.53917867, -0.9692084 ],
        [ 0.01492365,  0.08259851,  0.33772236],
        [-0.13526548,  0.45045102, -0.5280581 ],
        [ 0.16752176, -0.17233528, -0.00625841],
        [ 0.264048

In [385]:
ARRAY_LEN = 10

INIT_POINT = 0

u = INIT_POINT
res = []
for i in range(ARRAY_LEN):
    u += 1
    res.append(u)

print(f"{u=}")
print(f"{res=}")

u=10
res=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]


In [386]:
@jax.jit
def scan_fn(carry, x):
    carry += 1
    return carry, carry

carry, y = jax.lax.scan(scan_fn, INIT_POINT, length=ARRAY_LEN)
print(f"{carry=}")
print(f"{y=}")

carry=Array(10, dtype=int32, weak_type=True)
y=Array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32, weak_type=True)


In [387]:
def get_batch_static(data, start_indices, block_size):
    """
    Static batch generation that can be JIT compiled.
    start_indices: (batch_size,) array of starting positions
    """
    # Use vmap to extract all sequences in parallel
    def extract_sequence(start_idx):
        x = jax.lax.dynamic_slice(data, (start_idx,), (block_size,))
        y = jax.lax.dynamic_slice(data, (start_idx + 1,), (block_size,))
        return x, y
    
    # Vectorize over all start indices
    x, y = jax.vmap(extract_sequence)(start_indices)
    return x, y


In [388]:
rngs = nnx.Rngs(0)

In [394]:
key = rngs()

In [None]:

rngs.randint(shape=(10, 3), minval=0, maxval=10)

Array([[1, 1, 7],
       [7, 6, 4],
       [7, 7, 3],
       [0, 5, 5],
       [1, 9, 3],
       [3, 9, 6],
       [6, 9, 4],
       [2, 6, 9],
       [9, 7, 3],
       [7, 4, 7]], dtype=int32)

In [403]:
@jax.jit
def scan_fn(_, x):
    x = x % 2
    return None, x

carry, y = jax.lax.scan(scan_fn, None, xs=jnp.arange(10))
print(f"{carry=}")
print(f"{y=}")

carry=None
y=Array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], dtype=int32)


In [404]:
# init ys = []
# Scan through xs, for each xi \in xs:
    # Apply _, yi = scan_fn(_, xi)
    # ys = [*[ys], yi]
# Return ys

In [None]:
def get_batch_static(
    data: jnp.ndarray, start_indices: jnp.ndarray, block_size: int
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Static batch generation that is JIT-compilable."""

    def extract_sequence(start_idx: int):
        x = jax.lax.dynamic_slice(data, (start_idx,), (block_size,))
        y = jax.lax.dynamic_slice(data, (start_idx + 1,), (block_size,))
        return x, y

    # Vectorize over all start indices
    x, y = jax.vmap(extract_sequence)(start_indices)
    return x, y

In [None]:
xs = jnp.arange(10)**2


Array([ True], dtype=bool)

In [6]:
output_tokens = jnp.zeros((2, 8), dtype=jnp.int32)

output_tokens = output_tokens.at[0, :3].set(jnp.array([11, 22, 33]))  # Sequence 0: [11, 22, 33]
output_tokens = output_tokens.at[1, :3].set(jnp.array([44, 55, 66]))  # Sequence 1: [44, 55, 66]

In [None]:
output_tokens

Array([[11, 22, 33,  0,  0,  0,  0,  0],
       [44, 55, 66,  0,  0,  0,  0,  0]], dtype=int32)

In [9]:
BLOCK_SIZE = 4

B, T = 2, 3 # 3 because we init with jnp.array([11, 22, 33])
step_idx = 0
current_pos = T + step_idx

In [10]:
effective_length = jnp.minimum(BLOCK_SIZE, current_pos)
start_idx = jnp.maximum(0, current_pos - BLOCK_SIZE)
print(effective_length, start_idx)

3 0


In [11]:
idx_cond = jax.lax.dynamic_slice(output_tokens, (0, start_idx), (B, BLOCK_SIZE))
idx_cond

Array([[11, 22, 33,  0],
       [44, 55, 66,  0]], dtype=int32)

In [16]:
BLOCK_SIZE - effective_length

Array(1, dtype=int32, weak_type=True)

In [15]:
jnp.arange(BLOCK_SIZE) < BLOCK_SIZE - effective_length

Array([ True, False, False, False], dtype=bool)

In [13]:
jnp.where(
    jnp.arange(BLOCK_SIZE) < BLOCK_SIZE - effective_length,
    0,
    idx_cond
)

Array([[ 0, 22, 33,  0],
       [ 0, 55, 66,  0]], dtype=int32)