Similar to `kernel_perf.ipynb`, but after implementing KV-caching, so that only a single row of the attention matrix must be computed per iteration.

Test the performance of various Triton kernels, varying the configuration parameters.

In [3]:
import pickle

import jax
import jax.numpy as jnp
import flax.linen as nn

import jax_triton as jt
from functools import partial

In [46]:
from nimblegpt import get_config_for, param_shapes
from nimblegpt.params import get_flaxmodels_gpt2_params, make_gpt_param_dict
from nimblegpt.model import SingleHeadCausalSelfAttention
from nimblegpt.jmodel import JSingleHeadCausalSelfAttention
from nimblegpt.fast_model import FSingleHeadCausalSelfAttention, FGPT

from nimblegpt.kernels.kvcache_triton_kernels import SHCSABlock

In [8]:
config = get_config_for("gpt2")
n_cntx = config.block_size
n_feat = config.n_embd // config.n_head
rng = jax.random.PRNGKey(0)

In [48]:
gpt_params = make_gpt_param_dict(get_flaxmodels_gpt2_params(), config)

In [22]:
x = jax.random.normal(rng, (config.n_embd, ))

In [31]:
module = FSingleHeadCausalSelfAttention(n_feat, n_cntx)
params = module.init(rng, x, jnp.array(0))
_, vars = module.apply(
    params,
    x,
    jnp.array(0),
    mutable="cache",
)


In [40]:
jit_fshcsa = jax.jit(partial(FSingleHeadCausalSelfAttention(n_feat, n_cntx).apply, mutable="cache"))
jit_fshcsa({**params, "cache": vars["cache"]}, x, jnp.array(0))[0]

%timeit -n100 jit_fshcsa({**params, "cache": vars["cache"]}, x, jnp.array(0))[0].block_until_ready()

733 µs ± 44.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [43]:
jit_block = jax.jit(partial(SHCSABlock(n_feat, n_cntx).apply, mutable="cache"))
jit_block({**params, "cache": vars["cache"]}, x, jnp.array(0))[0]

%timeit -n100 jit_block({**params, "cache": vars["cache"]}, x, jnp.array(0))[0].block_until_ready()

769 µs ± 22.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [49]:
prompt_idxs = jax.random.randint(rng, (3, ), 0, config.vocab_size)

In [51]:
fmodule = FGPT.Make(config)
vars = fmodule.init_vars({"params": gpt_params})

fseq = fmodule.generate(
    rng,
    {
        "params": gpt_params,
        **vars
    },
    prompt_idxs,
    max_new_tokens=5,
)


In [58]:
%%timeit -n10
fmodule.generate(
    rng,
    {
        "params": gpt_params,
        **vars
    },
    prompt_idxs,
    max_new_tokens=10,
).block_until_ready()

36.6 ms ± 716 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [63]:
block_module = FGPT.MakeWithSHCSA(config, partial(SHCSABlock, subseq_size = 128))
vars = block_module.init_vars({"params": gpt_params})

fseq = block_module.generate(
    rng,
    {
        "params": gpt_params,
        **vars
    },
    prompt_idxs,
    max_new_tokens=5,
)

NotImplementedError: Batching rule for 'triton_kernel_call' not implemented