## Findings
- jax.scan compiles faster, altough performs worse for 1 block
- For larger block sizes this seems to be no longer true and the two become similar in performance
- => I will employ the scan function since it'll either match performance (or even scale better) and

In [1]:
from __init__ import Config, ExpressiveRetNet
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import timeit

In [2]:
key = jax.random.PRNGKey(0)
seq_len = 4096
config = Config(
    n_heads=4,
    d_model=512,
    n_layers=4,
    d_mlp=1024,
)
csz = 128
model = ExpressiveRetNet(config, key)
X = jax.random.randint(key, (seq_len,), 0, config.n_vocab)
Y = jax.random.normal(key, (seq_len, config.n_vocab))

In [4]:
def loss_naive(model, X, Y, key):
    Y_chunk = []
    kvs = model._initial_kvs()
    for i in range(0, seq_len, csz):
        Y_pred, kvs = model(X[i : i + csz], kvs, key)
        Y_chunk.append(Y_pred)
    Y_chunk = jnp.concatenate(Y_chunk, axis=0)
    return optax.squared_error(Y_chunk, Y).mean()
loss_fn_naive = eqx.filter_jit(eqx.filter_value_and_grad(loss_naive))
_ = loss_fn_naive(model, X, Y, key) # warmup
timeit.timeit(lambda : loss_fn_naive(model, X, Y, None), number=5)

KeyboardInterrupt: 

In [5]:
def loss_scan(model, X, Y, key):
    def step(kvs, x):
        y, kvs = model(x, kvs, key)
        return kvs, y
        
    kvs = model._initial_kvs()
    kvs, Y_chunk = jax.lax.scan(step, kvs, X.reshape(-1, csz))
    Y_chunk = jnp.concatenate(Y_chunk, axis=0)
    return optax.squared_error(Y_chunk, Y).mean()
loss_fn_scan = eqx.filter_jit(eqx.filter_value_and_grad(loss_scan))
_ = loss_fn_scan(model, X, Y, key) # warmup
timeit.timeit(lambda : loss_fn_scan(model, X, Y, key), number=5)

9.750562791945413