Notebook to explore caching computed keys and values. Problems to solve:
- How to use flax to store these as variables
- Will they fit in memory?

In [28]:
import flax.linen as nn
import jax.numpy as jnp
import jax
from jax import lax
from jax.tree_util import tree_map

In [10]:
from nimblegpt import get_config_for, make_gpt_param_dict, get_flaxmodels_gpt2_params, param_shapes

In [3]:
from nimblegpt.base_model import (
    BaseBlock,
    BaseCausalSelfAttention,
    BaseGPT,
    BaseSingleHeadCausalSelfAttention,
)
from nimblegpt.model import GELU, softmax

In [4]:
config = get_config_for("gpt2")

In [5]:
config

attn_pdrop: 0.1
block_size: 1024
embd_pdrop: 0.1
model_type: gpt2
n_embd: 768
n_head: 12
n_layer: 12
resid_pdrop: 0.1
vocab_size: 50257

In [6]:
n_feat = config.n_embd // config.n_head
# Q/K/V are [config.block_size, n_feat]
n_cache_params = config.block_size * n_feat * 2 * config.n_layer
f"{n_cache_params:,}"

'1,572,864'

K/V parameters for the entire context are on the order of 1 MB.

In [49]:
class SingleHeadQKV(nn.Module):
    """
    Compute Q, K, V matrices for a single head.

    This module processes a single token embedding at a time, and builds up a cache
    of K and V matrices for the entire sequence. The caching implementation is based on:
    https://flax.readthedocs.io/en/latest/_modules/flax/linen/attention.html#MultiHeadDotProductAttention
    """
    n_cntx: int
    n_feat: int

    @nn.compact
    def __call__(self, x: jax.Array):
        """
        Parameters
        ----------
        x : jax.Array
            Shape [n_embd]. The token embedding for the next token in the sequence.

        Returns
        -------
        Q, K, V, idx : Tuple
            Q : [n_feat] - The query vector for `x`.
            K, V : [n_cntx, n_feat] - The key and value matrices for the entire context.
            idx : int - The index of `x` in the context.
        """
        # Attention q, k, v vectors for token embedding `x`. Shape [n_feat].
        q, k, v = jnp.split(nn.Dense(features=3 * self.n_feat)(x), 3, axis=0)

        is_initialized = self.has_variable("cache", "cached_keys")

        # Cached K and V matrices. Shape [n_cntx, n_feat].
        cached_keys = self.variable("cache", "cached_keys", jnp.zeros,
                                    (self.n_cntx, self.n_feat))
        cached_values = self.variable("cache", "cached_values", jnp.zeros,
                                      (self.n_cntx, self.n_feat))

        cached_index = self.variable("cache", "cache_index",
                                     lambda: jnp.array(0, dtype=jnp.int32))
        cur_index = cached_index.value

        if is_initialized:

            K = lax.dynamic_update_slice(cached_keys.value,
                                         jnp.expand_dims(k, axis=0),
                                         (cur_index, 0))
            V = lax.dynamic_update_slice(cached_values.value,
                                         jnp.expand_dims(v, axis=0),
                                         (cur_index, 0))

            cached_keys.value = K
            cached_values.value = V
            cached_index.value = cur_index + 1

        return q, cached_keys.value, cached_values.value, cur_index

In [9]:
gpt_params = make_gpt_param_dict(get_flaxmodels_gpt2_params(), config)

In [50]:
qkv_module = SingleHeadQKV(n_cntx=config.block_size, n_feat=config.n_embd // config.n_head)

In [13]:
x = jnp.ones((config.n_embd,))

In [23]:
param_shapes(qkv_module.init(jax.random.PRNGKey(0), x))

{'params': {'Dense_0': {'kernel': '(768, 192)', 'bias': '(192)'}},
 'cache': {'cached_keys': '(1024, 64)',
  'cached_values': '(1024, 64)',
  'cache_index': '()'}}

In [26]:
sa_0_params = gpt_params["Block_0"]["CausalSelfAttention_0"]["VmapSingleHeadCausalSelfAttention_0"]

In [27]:
param_shapes(sa_0_params)

{'Dense_0': {'bias': '(12, 192)', 'kernel': '(12, 768, 192)'}}

In [30]:
sh_0_params = tree_map(lambda x: x[0], sa_0_params)

In [31]:
param_shapes(sh_0_params)

{'Dense_0': {'bias': '(192)', 'kernel': '(768, 192)'}}

In [32]:
rng = jax.random.PRNGKey(0)

In [55]:
X = jax.random.normal(rng, (10, config.n_embd))

In [56]:
sh_QKV = nn.Dense(features=3 * config.n_embd // config.n_head).apply({"params": sh_0_params["Dense_0"]}, X)

In [57]:
sh_QKV

Array([[  0.3254057 ,  -1.9289862 ,   6.3326793 , ...,  -0.500361  ,
         -0.35731804,  -1.1551355 ],
       [  4.2553654 ,   3.9190474 ,  -7.942846  , ...,  -0.02936717,
         -0.8493566 ,   1.0613071 ],
       [ -3.7252674 ,  -1.1315012 ,   8.182152  , ...,   1.7368824 ,
         -0.6001127 ,  -0.4603993 ],
       ...,
       [  0.7600452 ,  -3.0968392 ,   5.0037227 , ...,   2.1742795 ,
          4.5695148 ,  -0.9679753 ],
       [-10.481102  ,  -1.1025255 ,  -9.36252   , ...,   0.1816734 ,
          1.8518579 ,   0.6153525 ],
       [  1.979913  ,  11.513443  ,   2.6524346 , ...,  -2.272161  ,
         -0.8667084 ,  -4.240763  ]], dtype=float32)

In [68]:
vars = qkv_module.init(rng, X[0])

In [74]:
for i in range(10):
    qKVi, vars = qkv_module.apply({"cache": vars["cache"], "params": sh_0_params}, X[i], mutable="cache")

In [75]:
qKVi

(Array([ 1.9799105e+00,  1.1513445e+01,  2.6524339e+00, -1.3477521e+01,
        -1.8126870e+00, -6.0371763e-01, -3.0165085e-01,  4.2052898e+00,
        -3.4322548e+00,  5.6196337e+00, -5.6358824e+00, -6.2805634e+00,
        -1.5989894e+00, -7.7256999e+00, -7.7426491e+00, -1.1914519e+01,
         5.6751199e+00,  4.1370215e+00, -4.8818932e+00, -6.8391347e-01,
         7.4022107e+00, -4.2368451e-01, -5.6376481e+00,  1.1128722e+01,
        -7.7164817e-01, -2.1515315e+00, -6.5387106e-01,  1.2627184e+01,
        -7.0974302e+00, -1.0442138e+01,  4.7516134e-01, -1.4268667e+00,
         3.8699193e+00,  8.0294199e+00,  2.1324763e+00, -5.1904230e+00,
         9.3552160e+00, -1.3107698e+01, -1.2407379e+00, -3.3083718e+00,
        -1.4584163e+00, -1.2347947e+01, -4.5201941e+00,  4.4677892e+00,
        -8.2197313e+00,  8.2265444e+00, -8.2702935e-04, -5.0314231e+00,
        -1.6529402e+01, -8.2724028e+00, -7.8701715e+00,  3.3075047e+00,
         4.2626238e+00,  1.0886194e+01,  8.8508015e+00,  1.61220

In [76]:
q = qKVi[0]
k = qKVi[1][9]
v = qKVi[2][9]

In [80]:
(sh_QKV[9] - jnp.concatenate([q, k, v])).max()

Array(2.861023e-06, dtype=float32)

In [87]:
K = jnp.split(sh_QKV, 3, axis=1)[1]
V = jnp.split(sh_QKV, 3, axis=1)[2]

In [81]:
qKVi[1]

Array([[  0.04223603,   0.8799178 , -11.74339   , ...,  -9.586003  ,
         -3.4196584 ,   0.6275733 ],
       [ -2.908174  ,   1.6900772 ,   3.1685874 , ...,  -0.8879423 ,
          6.6386213 ,   1.1736691 ],
       [  7.2307673 ,   3.8176613 ,  -2.786204  , ...,   0.9499583 ,
        -11.866965  ,   4.592053  ],
       ...,
       [  0.        ,   0.        ,   0.        , ...,   0.        ,
          0.        ,   0.        ],
       [  0.        ,   0.        ,   0.        , ...,   0.        ,
          0.        ,   0.        ],
       [  0.        ,   0.        ,   0.        , ...,   0.        ,
          0.        ,   0.        ]], dtype=float32)

In [86]:
(K - qKVi[1][:10]).max()

Array(4.7683716e-06, dtype=float32)

In [90]:
(V - qKVi[2][:10]).max()

Array(1.3113022e-06, dtype=float32)

In [8]:
class JSingleHeadCausalSelfAttention(BaseSingleHeadCausalSelfAttention):
    n_feat: int
    n_cntx: int

    @nn.compact
    def __call__(self, x):
        C = x.shape  # Embedding dimensionality (n_embd).

        

        # [T, C] @ [C, 3 * n_feat] -> [T, 3 * n_feat] -> 3 * [T, n_feat]
        q, k, v = jnp.split(nn.Dense(features=3 * self.n_feat)(x), 3, axis=1)

        # [T, n_feat] @ [n_feat, T] -> [T, T].
        # Row i of att tells us which tokens x[i] should attend to. att[i][j]
        # is high when token i should attend heavily to token j.
        att = (q @ k.T) * (1.0 / jnp.sqrt(self.n_feat))

        # Token i should not attend to token j for any j > i. We set att to -inf
        # for any position above the diagonal - i.e. where j > i.
        # Note that this also prevents data tokens from attending to padding tokens.
        causal_mask = ~jnp.tril(jnp.ones((T, T))).astype(bool)
        att = jnp.where(causal_mask, -jnp.inf, att)

        att = softmax(att, axis=-1)

        y = att @ v  # [T, T] @ [T, n_feat] -> [T, n_feat]

        return y