In [11]:
import functools

In [12]:
import jax.numpy as jnp
import jax
from flax import linen as nn
from jax import lax

from jax_triton import pallas as pl

In [13]:
from nimblegpt import get_config_for
from nimblegpt.jmodel import JSingleHeadCausalSelfAttention

In [14]:
config = get_config_for("gpt2")

In [15]:
rng = jax.random.PRNGKey(0)
att = jax.random.normal(rng, (config.block_size,)*2)

# Padded Softmax

In [16]:
def make_padded_softmax(block_size: int, num_warps: int = 1):
    """
    Returned kernel has signature:
    def padded_softmax_kernel(att: Array[block_size, block_size], n_padd: int):

    Such that `padded_softmax_kernel(att, n_padd)` only returns a correct result for
    rows which don't correspond to padding tokens.
    """
    # grid = block_size => one kernel instance per row of the input matrix.
    @functools.partial(
        pl.pallas_call,
        out_shape=jax.ShapeDtypeStruct((block_size, block_size), jnp.float32),
        grid=block_size,
        num_warps=num_warps,
        interpret=True,
        debug=False
    )
    def padded_softmax_kernel(x_ref, p_ref, o_ref):
        row_idx = pl.program_id(0)
        n_padd = p_ref[()]

        x_idx = jnp.arange(block_size)
        row_idxs = (row_idx, x_idx)

        # 1 for valid elements of `x_ref`, 0 elsewhere (i.e. out of bounds).
        valid_mask = x_idx < x_ref.shape[1]

        # Token i should only attend to tokens j <= i.
        causal_mask = x_idx <= row_idx

        # 1 in the bottom right corner of the matrix - where data tokens attend to data
        # tokens. 0 elsewhere.
        padd_mask = (x_idx >= n_padd) & (row_idx >= n_padd)

        read_mask = valid_mask & causal_mask & padd_mask
        row = pl.load(x_ref, row_idxs, mask=read_mask, other=-float("inf"))

        row_minus_max = row - jnp.max(row, axis=0)
        numerator = jnp.exp(row_minus_max)
        denominator = jnp.sum(numerator, axis=0)
        softmax_output = numerator / denominator

        # Only write back to rows corresponding to non-padding tokens. Padding tokens
        # may be uninitialized memory.
        pl.store(o_ref, row_idxs, softmax_output, mask=valid_mask & (row_idx >= n_padd))

    return padded_softmax_kernel


In [17]:
class TSingleHeadCausalSelfAttention(nn.Module):
    """
    Inference only (no dropout) single headed attention.

    minGPT docstring
    ----------------
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    n_feat: int

    @nn.compact
    def __call__(self, x, n_padd: int = 0):
        T, C = x.shape  # sequence length, embedding dimensionality (n_embd)

        # [T, C] @ [C, 3 * n_feat] -> [T, 3 * n_feat] -> 3 * [T, n_feat]
        q, k, v = jnp.split(nn.Dense(features=3 * self.n_feat)(x), 3, axis=1)

        # [T, n_feat] @ [n_feat, T] -> [T, T].
        # Row i of att tells us which tokens x[i] should attend to. att[i][j]
        # is high when token i should attend heavily to token j.
        att = (q @ k.T) * (1.0 / jnp.sqrt(self.n_feat))

        att = make_padded_softmax(T)(att, n_padd)

        y = att @ v  # [T, T] @ [T, n_feat] -> [T, n_feat]

        return y

In [18]:
x = jax.random.normal(rng, (config.block_size, config.n_embd))
n_padd = 2

In [19]:
y, _ = JSingleHeadCausalSelfAttention(config.n_embd).init_with_output(rng, x, n_padd=n_padd)

In [20]:
ty, _ = TSingleHeadCausalSelfAttention(config.n_embd).init_with_output(rng, x, n_padd=n_padd)

In [21]:
# We dont care about the embeddings for padding tokens.
(y[n_padd:] - ty[n_padd:]).max()

Array(0., dtype=float32)

In [22]:
y[n_padd:]

Array([[-1.1136408 ,  1.2468629 , -0.99344945, ...,  0.33121186,
         1.0703579 , -0.06721891],
       [-0.8053709 ,  0.7387072 , -0.53019696, ...,  0.30776542,
         0.9296782 , -0.15966392],
       [-0.9898438 ,  1.0326734 , -0.8366718 , ...,  0.26767203,
         0.93346596, -0.0699129 ],
       ...,
       [ 0.0416423 ,  0.06101406, -0.04523662, ..., -0.06422241,
        -0.04720704,  0.02136195],
       [ 0.03843816,  0.01985435,  0.04832299, ..., -0.09897571,
         0.04451201,  0.0089184 ],
       [ 0.00146818,  0.04245059,  0.07542939, ...,  0.01340979,
         0.08863425,  0.00376507]], dtype=float32)

In [23]:
ty[n_padd:]

Array([[-1.1136408 ,  1.2468629 , -0.99344945, ...,  0.33121186,
         1.0703579 , -0.06721891],
       [-0.8053709 ,  0.7387072 , -0.53019696, ...,  0.30776542,
         0.9296782 , -0.15966392],
       [-0.9898438 ,  1.0326734 , -0.8366718 , ...,  0.26767203,
         0.93346596, -0.0699129 ],
       ...,
       [ 0.0416423 ,  0.06101406, -0.04523662, ..., -0.06422241,
        -0.04720704,  0.02136195],
       [ 0.03843816,  0.01985435,  0.04832299, ..., -0.09897571,
         0.04451201,  0.0089184 ],
       [ 0.00146818,  0.04245059,  0.07542939, ...,  0.01340979,
         0.08863425,  0.00376507]], dtype=float32)

# V + Padded Softmax

In [24]:
def make_padded_softmax_v(seq_len: int, n_feat: int, n_ocols, num_warps: int = 1):
    """
    Parameters
    ----------
    seq_len
      GPT context length (1024)
    n_feat
      Number of features per q/k/v matrix, per attention head (64)
    n_ocols
      Number of columns of output to calculate per kernel instance. The full size
      of the output will be [block_size, n_feat], so n_ocols shuld divide n_feat.

    Returned kernel has signature:
    def padded_softmax_kernel(att: Array[block_size, block_size], v: Array[block_size, n_feat], n_padd: int):

    Such that `padded_softmax_kernel(att, n_padd)` only returns a correct result for
    rows which don't correspond to padding tokens.

    Kernel with grid index (i, j) is responsible for block out[i, (j:j+1)*n_ocols].
    To compute this is we must read att[i, :] and v[:, (j:j+1)*n_ocols].
    """
    # grid = (block_size, n_vblocks) => one kernel per [1, n_vblocks] block of the output matrix.
    @functools.partial(
        pl.pallas_call,
        out_shape=jax.ShapeDtypeStruct((seq_len, n_feat), jnp.float32),
        grid=(seq_len, n_feat // n_ocols),
        num_warps=num_warps,
        debug=False,
        interpret=True,
    )
    def padded_softmax_v_kernel(att_ref, v_ref, p_ref, o_ref):
        # Row of attention matrix that this kernel instance will process.
        att_row_num = pl.program_id(0)
        # Start of the block of columns of the `v` matrix that this kernel instance will process.
        v_col_start = pl.program_id(1) * n_ocols
        n_padd = p_ref[()]

        ### Create indicies for reading memory. ###
        seq_idxs = jnp.arange(seq_len)

        att_idxs = (att_row_num, pl.dslice(None))

        ## [seq_len,] mask.
        # Token i should only attend to tokens j <= i.
        causal_mask = seq_idxs <= att_row_num
        padd_from_mask = (
            seq_idxs >= n_padd
        )  # 0 when padding tokens are attending to anything.
        padd_to_mask = (
            att_row_num >= n_padd
        )  # 0 when anything is attending to padding tokens.
        padd_mask = padd_from_mask & padd_to_mask
        seq_mask = causal_mask & padd_mask

        ## Index for v[:, (j:j+1)*n_ocols].
        v_col_idxs = pl.dslice(v_col_start, n_ocols)
        v_row_idxs = pl.dslice(0, seq_len)
        v_idxs = (v_row_idxs, v_col_idxs)

        ## Only read elements of `v` which will be multipled by non-padding tokens.
        v_row_mask = padd_from_mask
        v_mask = lax.broadcast_in_dim(
            jnp.expand_dims(v_row_mask, 1), (seq_len, n_ocols), (0, 1)
        )

        out_idxs = (att_row_num, pl.dslice(v_col_start, n_ocols))

        ### Compute attn row softmax. ###
        att_row = pl.load(att_ref, att_idxs, mask=seq_mask, other=-float("inf"))

        numerator = jnp.exp(att_row - jnp.max(att_row, axis=0))
        sma_row = numerator / jnp.sum(numerator, axis=0)

        ### Multiply attention by `v`. ###
        v_block = pl.load(v_ref, v_idxs, mask=v_mask, other=0)

        # We want to do `out = sma_row @ v_block` ([seq_len,] @ [seq_len, n_ocols] => [n_ocols,])
        # But Triton doesn't support matrix multiplication for small matrices.

        # Poor man's matrix multiplication (may be slowing us down since it doesn't use tensor cores).
        sma_mat = jnp.expand_dims(sma_row, 1)  # [seq_len, 1]
        # [seq_len, 1] * [seq_len, n_ocols] -> [seq_len, n_ocols] -[sum]-> [n_ocols,]
        out = jnp.sum(sma_mat * v_block, axis=0)

        ### Write output. ###
        pl.store(o_ref, out_idxs, out)

    return padded_softmax_v_kernel


class VTSingleHeadCausalSelfAttention(nn.Module):
    """
    Inference only (no dropout) single headed attention.

    minGPT docstring
    ----------------
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    n_feat: int

    @nn.compact
    def __call__(self, x, n_padd: int = 0):
        T, C = x.shape  # sequence length, embedding dimensionality (n_embd)

        # [T, C] @ [C, 3 * n_feat] -> [T, 3 * n_feat] -> 3 * [T, n_feat]
        q, k, v = jnp.split(nn.Dense(features=3 * self.n_feat)(x), 3, axis=1)

        # [T, n_feat] @ [n_feat, T] -> [T, T].
        # Row i of att tells us which tokens x[i] should attend to. att[i][j]
        # is high when token i should attend heavily to token j.
        att = (q @ k.T) * (1.0 / jnp.sqrt(self.n_feat))

        y = make_padded_softmax_v(T, self.n_feat, n_ocols=4)(att, v, n_padd)

        return y


In [25]:
att = jax.random.normal(rng, (config.block_size, config.block_size))
v = jax.random.normal(rng, (config.block_size, config.n_embd))
n_padd = 2
x = jax.random.normal(rng, (config.block_size, config.n_embd))

In [26]:
y, _ = JSingleHeadCausalSelfAttention(config.n_embd).init_with_output(rng, x, n_padd=n_padd)

In [27]:
vy, _ = VTSingleHeadCausalSelfAttention(config.n_embd).init_with_output(rng, x, n_padd=n_padd)

In [28]:
(y[n_padd:] - vy[n_padd:]).max()

Array(4.7683716e-07, dtype=float32)

In [29]:
y[-1][:10]

Array([ 1.4681816e-03,  4.2450592e-02,  7.5429395e-02, -9.1652356e-02,
       -3.1681035e-02,  1.0414794e-04,  3.7798032e-02, -1.1124283e-02,
        1.1266769e-01, -1.3866793e-01], dtype=float32)

In [30]:
vy[-1][:10]

Array([ 1.46817416e-03,  4.24505323e-02,  7.54294395e-02, -9.16523561e-02,
       -3.16810384e-02,  1.04149804e-04,  3.77980322e-02, -1.11242868e-02,
        1.12667724e-01, -1.38667867e-01], dtype=float32)

In [31]:
jnp.tile(jnp.arange(5) > 3, (4, 1))

Array([[False, False, False, False,  True],
       [False, False, False, False,  True],
       [False, False, False, False,  True],
       [False, False, False, False,  True]], dtype=bool)

In [32]:
jax.lax.broadcast(jnp.arange(5) > 3, (4,))

Array([[False, False, False, False,  True],
       [False, False, False, False,  True],
       [False, False, False, False,  True],
       [False, False, False, False,  True]], dtype=bool)

In [33]:
lax.broadcast_in_dim(jnp.expand_dims(jnp.arange(5) >= 3, 1,), (5, 4), (0, 1))

Array([[False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [ True,  True,  True,  True],
       [ True,  True,  True,  True]], dtype=bool)

In [None]:
kernel = make_padded_softmax_v(config.block_size, config.n_embd, 4)

In [None]:
kernel(att, v, n_padd)

Array([[            nan,             nan,             nan, ...,
                    nan,             nan,             nan],
       [            nan,             nan,             nan, ...,
                    nan,             nan,             nan],
       [-6.50894493e-02,  1.23105273e-01,  1.24469054e+00, ...,
         6.53849840e-01,  6.03153765e-01,  1.89329430e-01],
       ...,
       [ 5.01023568e-02, -4.25007492e-02,  3.23878042e-02, ...,
        -1.46954395e-02, -7.23278970e-02, -5.75641170e-02],
       [-6.83210790e-04, -2.03592703e-03,  1.18623391e-01, ...,
        -3.67972404e-02,  4.68270928e-02,  1.81004982e-02],
       [ 1.67859476e-02, -3.88669893e-02,  3.94940078e-02, ...,
         4.67684865e-02, -5.01697585e-02,  9.15485024e-02]],      dtype=float32)