Attempt to implement flash attention, but save on memory acceses by ignoring padded data.

In [1]:
import functools

In [2]:
import jax.numpy as jnp
import jax
from flax import linen as nn
from jax import lax

from jax_triton import pallas as pl

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from nimblegpt import get_config_for
from nimblegpt.jmodel import JSingleHeadCausalSelfAttention

In [4]:
config = get_config_for("gpt2")

In [5]:
rng = jax.random.PRNGKey(0)
att = jax.random.normal(rng, (config.block_size,)*2)
q_key, k_key, v_key = jax.random.split(rng, 3)
q = jax.random.normal(q_key, (config.block_size, config.n_embd))
k = jax.random.normal(k_key, (config.block_size, config.n_embd))
v = jax.random.normal(v_key, (config.block_size, config.n_embd))

In [19]:
def padded_attn_kernel(q_ref, k_ref, v_ref, p_ref, o_ref, *, seq_len: int, n_feat: int, n_ocols):
    """
    Inputs
    ------
    q_ref, k_rev, v_ref: references to the Q, K, and V matrices.
        All have shape shape [seq_len, n_feat].

    Each kernel instances computes out[out_row_num, out_col_start: out_col_start + n_ocols].
    This requires multiplying att[out_row_num, :] by v[:, out_col_start: out_col_start + n_ocols].

    To compute att[out_row_num, :], we multiply q[out_row_num, :] by 
    k^T.
    """

    # Each instance computes out[out_row_num, out_col_start: out_col_start + n_ocols]
    out_row_num = pl.program_id(0)
    out_col_start = pl.program_id(1) * n_ocols
    n_padd = p_ref[()]

    seq_idxs = jnp.arange(seq_len)

    # Shape (1,) mask. 0 if this instance is computing a padding element of `out`.
    padd_row_mask = out_row_num >= n_padd
    # Shape (seq_len,) mask. 0 for tokens of the sequence that are padding.
    seq_mask = jnp.arange(seq_len) >= n_padd
    # Shape (seq_len, n_ocols) mask. 0 for elements corresponding to padding tokens.
    block_mask = lax.broadcast_in_dim(
        jnp.expand_dims(seq_mask, 1), (seq_len, n_ocols), (0, 1)
    )
    # Shape (seq_len, n_feat) mask. 0 for elements corresponding to padding tokens.
    mat_mask = lax.broadcast_in_dim(
        jnp.expand_dims(seq_mask, 1), (seq_len, n_feat), (0, 1)
    )
    # Token i should only atten to tokens j <= i. 0 for tokens j > i.
    causal_mask = seq_idxs <= out_row_num

    ### First we compute the softmax of row `out_row_num` of the attention matrix. ###
    # This requires loading one row of Q and all of K.

    q_idx = (out_row_num, pl.dslice(None))
    q_row = pl.load(q_ref, q_idx, mask=padd_row_mask, other=0) # [n_feat]
    q_row = jnp.expand_dims(q_row, 0) # [1, n_feat]

    k_idx = (pl.dslice(None), pl.dslice(None))
    k_mat = pl.load(k_ref, k_idx, mask=mat_mask, other=0) # [seq_len, n_feat]

    # Compute att[out_row_num, :] - a single row of the full attention matrix.
    # [1, n_feat] . ([seq_len, n_feat] -[T]-> [n_feat, seq_len]) = [1, seq_len]
    att_row = pl.dot(q_row, k_mat, trans_b = True)
    att_row /= jnp.sqrt(n_feat)
    att_row = jnp.where(causal_mask & seq_mask, att_row, -jnp.inf)
    sm_numerator = jnp.exp(att_row - jnp.max(att_row))
    sm_att = sm_numerator / jnp.sum(sm_numerator, keepdims=True) # [1, seq_len]

    v_idxs = (pl.dslice(None), pl.dslice(out_col_start, n_ocols))
    v_block = pl.load(v_ref, v_idxs, mask=block_mask, other=0) # [seq_len, n_ocols]

    # [1, seq_len] . [seq_len, n_ocols] = [1, n_ocols]
    # out = pl.dot(sm_att, v_block) # [1, n_ocols]
    out = sm_att @ v_block

    # Store the result.
    out_idxs = (out_row_num, pl.dslice(out_col_start, n_ocols))
    pl.store(o_ref, out_idxs, out[0], mask=padd_row_mask)

In [20]:
def padded_attn(q, k, v, n_padd, *, n_ocols: int):
    seq_len, n_feat = q.shape

    grid = (seq_len, n_feat // n_ocols)
    assert grid[1] * n_ocols == n_feat

    kernel = functools.partial(
        padded_attn_kernel, seq_len=seq_len, n_feat=n_feat, n_ocols=n_ocols
    )
    out_shape = jax.ShapeDtypeStruct((seq_len, n_feat), q.dtype)

    out = pl.pallas_call(kernel, grid=grid, out_shape=out_shape, interpret=True)(
        q, k, v, n_padd
    )

    return out


In [21]:
class TSingleHeadCausalSelfAttention(nn.Module):
    """
    Inference only (no dropout) single headed attention.

    minGPT docstring
    ----------------
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    n_feat: int

    @nn.compact
    def __call__(self, x, n_padd: int = 0):
        T, C = x.shape  # sequence length, embedding dimensionality (n_embd)

        # [T, C] @ [C, 3 * n_feat] -> [T, 3 * n_feat] -> 3 * [T, n_feat]
        q, k, v = jnp.split(nn.Dense(features=3 * self.n_feat)(x), 3, axis=1)

        y = padded_attn(q, k, v, n_padd, n_ocols=4)

        return y

In [26]:
x = jax.random.normal(rng, (config.block_size, config.n_embd))
n_padd = 2

In [27]:
y, _ = JSingleHeadCausalSelfAttention(config.n_embd).init_with_output(rng, x, n_padd=n_padd)

In [28]:
ty, _ = TSingleHeadCausalSelfAttention(config.n_embd).init_with_output(rng, x, n_padd=n_padd)

In [29]:
(y[n_padd:] - ty[n_padd:]).max()

Array(1.4305115e-06, dtype=float32)

In [13]:
padded_attn(q, k, v, 0, n_ocols = 4)

Array([[-1.01607285e-01,  3.23455304e-01, -1.14697017e-01, ...,
         3.42727363e-01, -1.50065219e+00,  4.98529613e-01],
       [-1.88820034e-01,  2.30207741e-01, -5.91809638e-02, ...,
         9.28151608e-02, -1.45710719e+00,  4.99784201e-01],
       [-1.72552958e-01, -4.69248593e-01, -1.03122219e-01, ...,
        -9.16234404e-02, -1.25068820e+00,  2.55948424e-01],
       ...,
       [ 4.17719148e-02, -6.57780915e-02, -6.61965087e-02, ...,
        -3.43253762e-02,  4.40334566e-02,  4.70811427e-02],
       [ 5.38661331e-02, -3.08343675e-02, -2.44537331e-02, ...,
         3.44744660e-02,  1.47196651e-03,  2.74988767e-02],
       [ 4.30060774e-02, -1.43105257e-03, -1.85891725e-02, ...,
         4.80973572e-02,  9.45648737e-03,  4.85833324e-02]],      dtype=float32)

In [14]:
x = 

SyntaxError: invalid syntax (3306000817.py, line 1)

In [None]:
q[0:1, :].shape

(1, 768)

In [None]:
jnp.expand_dims(q[0], 1).shape

(768, 1)

In [None]:
k.T.shape

(768, 1024)

In [None]:
k.shape

(1024, 768)

In [None]:
pl.dot(jnp.expand_dims(q[0], 1), k, trans_b=True)

TypeError: dot_general requires contracting dimensions to have the same shape, got (1,) and (768,).