Similar to `kernel_perf.ipynb`, but after implementing KV-caching, so that only a single row of the attention matrix must be computed per iteration.

Test the performance of various Triton kernels, varying the configuration parameters.

In [1]:
import pickle

import jax
import jax.numpy as jnp
import flax.linen as nn

import jax_triton as jt
from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from nimblegpt import get_config_for, param_shapes
from nimblegpt.params import get_flaxmodels_gpt2_params, make_gpt_param_dict
from nimblegpt.model import SingleHeadCausalSelfAttention
from nimblegpt.jmodel import JSingleHeadCausalSelfAttention
from nimblegpt.fast_model import FSingleHeadCausalSelfAttention, FGPT

from nimblegpt.kernels.kvcache_triton_kernels import SHCSABlock

In [3]:
config = get_config_for("gpt2")
n_cntx = config.block_size
n_feat = config.n_embd // config.n_head
rng = jax.random.PRNGKey(0)

In [4]:
gpt_params = make_gpt_param_dict(get_flaxmodels_gpt2_params(), config)

In [5]:
x = jax.random.normal(rng, (config.n_embd, ))

In [6]:
module = FSingleHeadCausalSelfAttention(n_feat, n_cntx)
params = module.init(rng, x, jnp.array(0))
_, vars = module.apply(
    params,
    x,
    jnp.array(0),
    mutable="cache",
)


In [19]:
jit_fshcsa = jax.jit(partial(FSingleHeadCausalSelfAttention(n_feat, n_cntx).apply, mutable="cache"))
jit_fshcsa({**params, "cache": vars["cache"]}, x, jnp.array(0))[0]

%timeit -n100 jit_fshcsa({**params, "cache": vars["cache"]}, x, jnp.array(0))[0].block_until_ready()

705 µs ± 10.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [42]:
jit_block = jax.jit(partial(SHCSABlock(n_feat, n_cntx).apply, mutable="cache"))
jit_block({**params, "cache": vars["cache"]}, x, jnp.array(0))[0]

%timeit -n100 jit_block({**params, "cache": vars["cache"]}, x, jnp.array(0))[0].block_until_ready()

769 µs ± 15.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [44]:
jit_block = jax.jit(partial(SHCSABlock(n_feat, n_cntx, subseq_size = 128, subfeat_size=16).apply, mutable="cache"))
jit_block({**params, "cache": vars["cache"]}, x, jnp.array(0))[0]

%timeit -n100 jit_block({**params, "cache": vars["cache"]}, x, jnp.array(0))[0].block_until_ready()

691 µs ± 19.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
