Re implement the pallas kernels from 'shcsa_kernels.py` in pure Triton, and call them
with `triton_call`. This allows us to use the 'mlir' branch of Triton, since pallas
support is still a work in progress.

In [2]:
import math
import pickle

import flax.linen as nn
import jax
import jax.numpy as jnp
import jax_triton as jt
import torch
import triton
import triton.language as tl

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from nimblegpt.model import SingleHeadCausalSelfAttention, softmax
from nimblegpt import get_config_for

# Vanilla Softmax

Implement basic (no padding) softmax in Triton - to start simply.

Softmax from the Triton examples:

In [4]:
@triton.jit
def softmax_kernel(
    input_ptr, output_ptr, input_row_stride, output_row_stride, n_cols,
    BLOCK_SIZE: tl.constexpr
):
    # The rows of the softmax are independent, so we parallelize across those
    row_idx = tl.program_id(0)
    # The stride represents how much we need to increase the pointer to advance 1 row
    row_start_ptr = input_ptr + row_idx * input_row_stride
    # The block size is the next power of two greater than n_cols, so we can fit each
    # row in a single block
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets
    # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
    row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
    # Substract maximum for numerical stability
    row_minus_max = row - tl.max(row, axis=0)
    # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    # Write back output to DRAM
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    output_ptrs = output_row_start_ptr + col_offsets
    tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)

In [5]:
def torch_softmax(x):
    n_rows, n_cols = x.shape
    # The block size is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 4
    if BLOCK_SIZE >= 2048:
        num_warps = 8
    if BLOCK_SIZE >= 4096:
        num_warps = 16
    # Allocate output
    y = torch.empty_like(x)
    # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
    # f the input matrix
    softmax_kernel[(n_rows,)](
        x,
        y,
        x.stride(0),
        y.stride(0),
        n_cols,
        num_warps=num_warps,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    return y

In [6]:
# torch.manual_seed(0)
# x = torch.randn(1823, 781, device='cuda')
# y_triton = torch_softmax(x)
# y_torch = torch.softmax(x, axis=1)
# assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

In [7]:
triton_dump_binary_path = "./triton-binary.pickle"

In [8]:
next_pow2 = lambda x: int(math.pow(2, math.ceil(math.log(x, 2))))


def jt_softmax(x: jnp.ndarray) -> jnp.ndarray:
    out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
    block_size = next_pow2(x.shape[1])
    strides = jt.strides_from_shape(x.shape)
    return jt.triton_call(
        x,
        kernel=softmax_kernel,
        out_shape=out_shape,
        input_row_stride=strides[0],
        output_row_stride=strides[0],
        n_cols=x.shape[1],
        grid=x.shape[0],
        BLOCK_SIZE=block_size,
        dump_binary_path=triton_dump_binary_path,
    )


In [9]:
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, (1823, 781))

In [10]:
y_jt = jt_softmax(x)
y_jax = jax.nn.softmax(x, axis=1)

In [11]:
(y_jt - y_jax).max()

Array(7.450581e-09, dtype=float32)

In [12]:
triton_dump = pickle.load(open(triton_dump_binary_path, "rb"))

In [13]:
triton_dump["asm"].keys()

dict_keys(['ttir', 'ttgir', 'llir', 'ptx', 'cubin'])

In [14]:
print(triton_dump["asm"]["ttir"])

module attributes {"triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 512 : i32} {
  llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
  llvm.func @softmax_kernel_0d1d(%arg0: !llvm.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg1: !llvm.ptr<f32, 1> {tt.divisibility = 16 : i32}) attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32, sym_visibility = "public"} {
    %0 = nvvm.read.ptx.sreg.tid.x : i32
    %1 = llvm.mlir.constant(32 : i32) : i32
    %2 = llvm.urem %0, %1  : i32
    %3 = llvm.udiv %0, %1  : i32
    %4 = llvm.urem %3, %1  : i32
    %5 = llvm.mlir.constant(1024 : i32) : i32
    %6 = llvm.urem %2, %5  : i32
    %7 = llvm.mlir.constant(1 : i32) : i32
    %8 = llvm.mul %4, %1  : i32
    %9 = llvm.add %6, %8  : i32
    %10 = llvm.mul %9, %7  : i32
    %11 = llvm.mlir.constant(0 : i32) : i32
    %12 = llvm.add %10, %11  : i32
    %13 = llvm.mlir.constant(128 : i32) : i32
    %14 = llvm.add %10, %13  : i32
    %15 = llvm.mlir.constant(25

## SHCSA using the Softmax

In [15]:
class SHCSATritonSoftmax(nn.Module):

    n_feat: int

    @nn.compact
    def __call__(self, x):
        T, C = x.shape  # sequence length, embedding dimensionality (n_embd)

        # [T, C] @ [C, 3 * n_feat] -> [T, 3 * n_feat] -> 3 * [T, n_feat]
        q, k, v = jnp.split(nn.Dense(features=3 * self.n_feat)(x), 3, axis=1)

        # [T, n_feat] @ [n_feat, T] -> [T, T].
        att = (q @ k.T) * (1.0 / jnp.sqrt(self.n_feat))
        causal_mask = jnp.tril(jnp.ones((T, T))).astype(bool)
        att = jnp.where(~causal_mask, -jnp.inf, att)
        att = jt_softmax(att)
        # att = jax.nn.softmax(att, axis=-1)

        y = att @ v  # [T, T] @ [T, n_feat] -> [T, n_feat]

        return y

### Testing

In [16]:
config = get_config_for('gpt2')
x = jax.random.normal(rng, (config.block_size, config.n_embd))
n_feat = config.n_embd // config.n_head

In [17]:
n_feat

64

In [18]:
y, _ = SingleHeadCausalSelfAttention(n_feat).init_with_output(rng, x)

In [19]:
ty, _ = SHCSATritonSoftmax(n_feat).init_with_output(rng, x)

In [20]:
(y - ty).max()

Array(1.7881393e-07, dtype=float32)

# Padded Softmax

In [21]:
@triton.jit
def padded_softmax_kernel(
    att_ptr,
    p_ptr,
    output_ptr,
    att_row_stride,
    output_row_stride,
    n_cols,
    BLOCK_SIZE: tl.constexpr,
):
    n_padd = tl.load(p_ptr)

    att_row_num = tl.program_id(0)
    att_row_start_ptr = att_ptr + att_row_num * att_row_stride

    att_col_idxs = tl.arange(0, BLOCK_SIZE)
    att_ptrs = att_row_start_ptr + att_col_idxs

    valid_mask = att_col_idxs < n_cols
    causal_mask = att_col_idxs <= att_row_num
    padd_mask = (att_col_idxs >= n_padd) & (att_row_num >= n_padd)
    read_mask = valid_mask & causal_mask & padd_mask 

    att_row = tl.load(att_ptrs, mask=read_mask, other=-float("inf"))

    numerator = tl.exp(att_row - tl.max(att_row, axis=0))
    sma_row = numerator / tl.sum(numerator, axis=0)

    output_row_start_ptr = output_ptr + att_row_num * output_row_stride
    output_ptrs = output_row_start_ptr + att_col_idxs

    tl.store(output_ptrs, sma_row, mask=att_col_idxs < n_cols)


In [22]:
def padded_softmax(att, n_padd):

    out_shape = jax.ShapeDtypeStruct(shape=att.shape, dtype=att.dtype)
    block_size = next_pow2(att.shape[1])
    strides = jt.strides_from_shape(att.shape)

    return jt.triton_call(
        att,
        jnp.array(n_padd),
        kernel=padded_softmax_kernel,
        out_shape=out_shape,
        att_row_stride=strides[0],
        output_row_stride=strides[0],
        n_cols=att.shape[1],
        grid=att.shape[0],
        BLOCK_SIZE=block_size,
    )

In [23]:
jt.strides_from_shape((3, 4))

(4, 1)

In [24]:
class SHCSATritonPaddedSoftmax(nn.Module):

    n_feat: int

    @nn.compact
    def __call__(self, x, n_padd: int = 0):
        T, C = x.shape  # sequence length, embedding dimensionality (n_embd)

        # [T, C] @ [C, 3 * n_feat] -> [T, 3 * n_feat] -> 3 * [T, n_feat]
        q, k, v = jnp.split(nn.Dense(features=3 * self.n_feat)(x), 3, axis=1)

        # [T, n_feat] @ [n_feat, T] -> [T, T].
        att = (q @ k.T) * (1.0 / jnp.sqrt(self.n_feat))
        att = padded_softmax(att, n_padd)

        y = att @ v  # [T, T] @ [T, n_feat] -> [T, n_feat]

        return y

In [25]:
config = get_config_for('gpt2')
x = jax.random.normal(jax.random.PRNGKey(0), (config.block_size, config.n_embd))
n_feat = config.n_embd // config.n_head

In [26]:
y, _ = SingleHeadCausalSelfAttention(n_feat).init_with_output(rng, x)

In [27]:
ty, _ = SHCSATritonPaddedSoftmax(n_feat).init_with_output(rng, x)

In [28]:
(y - ty).max()

Array(1.7881393e-07, dtype=float32)

# Padded Softmax + v multiplication

In [29]:
(torch.arange(0, 10) + 5)[:, None] * 10 + (torch.arange(0, 10) + 5)[None, :]

tensor([[ 55,  56,  57,  58,  59,  60,  61,  62,  63,  64],
        [ 65,  66,  67,  68,  69,  70,  71,  72,  73,  74],
        [ 75,  76,  77,  78,  79,  80,  81,  82,  83,  84],
        [ 85,  86,  87,  88,  89,  90,  91,  92,  93,  94],
        [ 95,  96,  97,  98,  99, 100, 101, 102, 103, 104],
        [105, 106, 107, 108, 109, 110, 111, 112, 113, 114],
        [115, 116, 117, 118, 119, 120, 121, 122, 123, 124],
        [125, 126, 127, 128, 129, 130, 131, 132, 133, 134],
        [135, 136, 137, 138, 139, 140, 141, 142, 143, 144],
        [145, 146, 147, 148, 149, 150, 151, 152, 153, 154]])

In [30]:
v_row_stride = 64
v_col_start = 3 * 4
v_col_idxs = torch.arange(0, 4) + v_col_start
seq_idxs = torch.arange(config.block_size)
v_block_start_ptr = v_col_start
v_col_idxs

tensor([12, 13, 14, 15])

In [31]:
v_block_start_ptr + (seq_idxs[:, None]*v_row_stride + v_col_idxs[None, :])

tensor([[   24,    25,    26,    27],
        [   88,    89,    90,    91],
        [  152,   153,   154,   155],
        ...,
        [65368, 65369, 65370, 65371],
        [65432, 65433, 65434, 65435],
        [65496, 65497, 65498, 65499]])

In [32]:
v_col_start = 3 * 4
v_block_start_ptr = 0 + v_col_start
v_ptrs = v_block_start_ptr + (
    seq_idxs[:, None] * v_row_stride + torch.arange(0, 4)[None, :]
)
v_ptrs

tensor([[   12,    13,    14,    15],
        [   76,    77,    78,    79],
        [  140,   141,   142,   143],
        ...,
        [65356, 65357, 65358, 65359],
        [65420, 65421, 65422, 65423],
        [65484, 65485, 65486, 65487]])

In [33]:
torch.ravel(v_ptrs)

tensor([   12,    13,    14,  ..., 65485, 65486, 65487])

In [34]:
jnp.arange(config.block_size * n_feat).reshape((config.block_size, n_feat))[:, 12:]

Array([[   12,    13,    14, ...,    61,    62,    63],
       [   76,    77,    78, ...,   125,   126,   127],
       [  140,   141,   142, ...,   189,   190,   191],
       ...,
       [65356, 65357, 65358, ..., 65405, 65406, 65407],
       [65420, 65421, 65422, ..., 65469, 65470, 65471],
       [65484, 65485, 65486, ..., 65533, 65534, 65535]], dtype=int32)

In [35]:
n_padd = 3

In [36]:
v_data_mask = (torch.arange(config.block_size) >= n_padd)[:, None] + torch.zeros(4, dtype=torch.uint8)[None, :]

In [37]:
v_data_mask[:10]

tensor([[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]], dtype=torch.uint8)

In [38]:
sma_row = torch.randn(config.block_size)
v_block = torch.randn(config.block_size, 4)

In [39]:
v_block

tensor([[-1.2196, -2.0542,  1.1874, -1.0706],
        [ 1.3401,  0.2011,  0.4685,  0.4783],
        [ 0.9115, -0.1027, -0.6169,  0.6463],
        ...,
        [-0.0428,  0.8290,  1.2708, -1.0070],
        [ 1.1806,  0.0735,  0.2074,  2.0656],
        [-0.7963, -1.2081,  1.2733, -1.4956]])

In [40]:
sma_mat = torch.broadcast_to(sma_row[:, None], v_block.shape)
sma_mat

tensor([[-0.0366, -0.0366, -0.0366, -0.0366],
        [-1.3061, -1.3061, -1.3061, -1.3061],
        [-1.1283, -1.1283, -1.1283, -1.1283],
        ...,
        [ 0.8469,  0.8469,  0.8469,  0.8469],
        [ 0.4464,  0.4464,  0.4464,  0.4464],
        [ 1.2998,  1.2998,  1.2998,  1.2998]])

In [41]:
torch.matmul(sma_row, v_block)

tensor([ 31.9814,  41.2079, -35.4196,   7.3629])

In [42]:
torch.sum(sma_mat * v_block, axis=0)

tensor([ 31.9814,  41.2079, -35.4195,   7.3629])

In [43]:
torch.sum(sma_row[:, None] * v_block, axis=0)

tensor([ 31.9814,  41.2079, -35.4195,   7.3629])

In [44]:
@triton.jit
def padded_softmax_v_kernel(
    att_ptr,
    v_ptr,
    p_ptr,
    output_ptr,
    SEQ_LEN: tl.constexpr,
    N_FEAT: tl.constexpr,
    N_OCOLS: tl.constexpr,
):
    """
    Triton kernel for computing the softmax of an attention matrix which may have padding
    tokens, and then multiplying the result by a value matrix `v`.

    Kernel cell with coordinates (i, j) computes out[i, j*N_OCOLS: (j+1)*N_OCOLS]. To
    do this, it loads att[i, :] and v[:, j*N_OCOLS: (j+1)*N_OCOLS].

    Inputs
    ------
    att_ptr: [SEQ_LEN, SEQ_LEN]
    v_ptr: [SEQ_LEN, N_FEAT]

    Output
    ------
    out: [SEQ_LEN, N_FEAT]
        The output of self attention (`att @ v`).

    Constants
    ---------
    SEQ_LEN: (1024 for GPT-2)
    N_FEAT: (64 for GPT-2)
    N_OCOLS: Number of elements of output matrix computed per kernel instance.

    NOTE: Assumes all tensor sizes are powers of 2.
    """
    n_padd = tl.load(p_ptr)

    ## Load att[i, :] - with masking to avoid reading non-causal or padding tokens. ###
    seq_idxs = tl.arange(0, SEQ_LEN)
    att_row_num = tl.program_id(0)
    att_row_start_ptr = att_ptr + att_row_num * SEQ_LEN
    att_ptrs = att_row_start_ptr + seq_idxs

    att_causal_mask = seq_idxs <= att_row_num  # 0 for non-causal tokens.
    att_data_mask = (seq_idxs >= n_padd) & (
        att_row_num >= n_padd
    )  # 0 for padding tokens.
    att_read_mask = att_causal_mask & att_data_mask

    print("att_ptrs", att_ptrs)
    print("att_read_mask", att_read_mask)
    att_row = tl.load(att_ptrs, mask=att_read_mask, other=-float("inf"))

    ### Compute attention row softmax. ###

    numerator = tl.exp(att_row - tl.max(att_row, axis=0))
    sma_row = numerator / tl.sum(numerator, axis=0) # [SEQ_LEN,]

    ### Load v[:, j*N_OCOLS: (j+1)*N_OCOLS] - with masking to avoid reading padding tokens. ###

    v_col_start = tl.program_id(1) * N_OCOLS
    v_block_start_ptr = v_ptr + v_col_start
    v_ptrs = v_block_start_ptr + (
        seq_idxs[:, None] * N_FEAT + tl.arange(0, N_OCOLS)[None, :]
    )

    # v[:n_padd, :] are values of padding tokens. The attention matrix already has zeros
    # for elements corresponding to these. We use a mask to avoid unnecessary reads.
    v_data_mask = (seq_idxs >= n_padd)[:, None] + tl.zeros((4,), dtype=tl.int1)[
        None, :
    ]

    print("v_ptrs", v_ptrs)
    print("v_data_mask", v_data_mask)
    v_block = tl.load(v_ptrs, mask=v_data_mask, other = 0.0)

    ### Compute output row-block. ###

    # We want to compute sma_row @ v_block, but Trition doesn't support doing this
    # directly, so we roll our own matrix-vector multiplication.
    # [SEQ_LEN,] -> [SEQ_LEN, N_OCOLS] (the same column copied N_OCOLS times).
    out = tl.sum(sma_row[:, None] * v_block, axis=0) # [N_OCOLS,]

    ### Write output row-block. ###

    output_start_ptr = output_ptr + att_row_num * N_FEAT + v_col_start
    output_ptrs = output_start_ptr + tl.arange(0, N_OCOLS)

    out_mask = tl.zeros((N_OCOLS,), dtype=tl.uint8) + att_row_num >= n_padd

    # Don't bother writing outputs for padding tokens - just leave uninitialized.
    tl.store(output_ptrs, out, mask=out_mask)


In [45]:
def padded_softmax_v(att, v, n_padd, n_ocols: int = 4):

    out_shape = jax.ShapeDtypeStruct(shape=v.shape, dtype=v.dtype)
    grid = (att.shape[0], v.shape[1] // n_ocols)
    assert grid[1] * n_ocols == v.shape[1]

    return jt.triton_call(
        att,
        v,
        jnp.array(n_padd),
        kernel=padded_softmax_v_kernel,
        out_shape=out_shape,
        grid=grid,
        SEQ_LEN=att.shape[0],
        N_FEAT=v.shape[1],
        N_OCOLS=n_ocols,
    )


In [46]:
class SHCSATritonPaddedSoftmaxV(nn.Module):

    n_feat: int

    @nn.compact
    def __call__(self, x, n_padd: int = 0):
        T, C = x.shape  # sequence length, embedding dimensionality (n_embd)

        # [T, C] @ [C, 3 * n_feat] -> [T, 3 * n_feat] -> 3 * [T, n_feat]
        q, k, v = jnp.split(nn.Dense(features=3 * self.n_feat)(x), 3, axis=1)

        # [T, n_feat] @ [n_feat, T] -> [T, T].
        att = (q @ k.T) * (1.0 / jnp.sqrt(self.n_feat))

        y = padded_softmax_v(att, v, n_padd)

        return y

In [47]:
config = get_config_for('gpt2')
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, (config.block_size, config.n_embd))
n_feat = config.n_embd // config.n_head

In [48]:
y, _ = SingleHeadCausalSelfAttention(n_feat).init_with_output(rng, x)

In [49]:
ty, _ = SHCSATritonPaddedSoftmaxV(n_feat).init_with_output(rng, x)

att_ptrs pointer<fp32>[constexpr[1024]]
att_read_mask int1[constexpr[1024]]
v_ptrs pointer<fp32>[constexpr[1024],constexpr[4]]
v_data_mask int1[constexpr[1024],constexpr[4]]


In [50]:
jax.make_jaxpr(SHCSATritonPaddedSoftmaxV(n_feat).init_with_output)(rng, x)

att_ptrs pointer<fp32>[constexpr[1024]]
att_read_mask int1[constexpr[1024]]
v_ptrs pointer<fp32>[constexpr[1024],constexpr[4]]
v_data_mask int1[constexpr[1024],constexpr[4]]


{ lambda ; a:u32[2] b:f32[1024,768]. let
    c:key<fry>[] = random_wrap[impl=fry] a
    d:key<fry>[] = random_fold_in c 2998342421
    e:u32[2] = random_unwrap d
    f:f32[] = sqrt 0.0013020833721384406
    g:f32[] = div f 0.879625678062439
    h:key<fry>[] = random_wrap[impl=fry] e
    i:f32[] = div -2.0 1.4142135381698608
    j:f32[] = erf i
    k:f32[] = div 2.0 1.4142135381698608
    l:f32[] = erf k
    m:f32[1,1] = broadcast_in_dim[broadcast_dimensions=() shape=(1, 1)] j
    n:f32[1,1] = broadcast_in_dim[broadcast_dimensions=() shape=(1, 1)] l
    o:u32[768,192] = random_bits[bit_width=32 shape=(768, 192)] h
    p:u32[768,192] = shift_right_logical o 9
    q:u32[768,192] = or p 1065353216
    r:f32[768,192] = bitcast_convert_type[new_dtype=float32] q
    s:f32[768,192] = sub r 1.0
    t:f32[1,1] = sub n m
    u:f32[768,192] = mul s t
    v:f32[768,192] = add u m
    w:f32[768,192] = max m v
    x:f32[768,192] = erf_inv w
    y:f32[768,192] = mul 1.4142135381698608 x
    z:f32[] = 

In [51]:
(y - ty).max()

Array(2.3841858e-07, dtype=float32)

# Padded Attention

In [52]:
SEQ_LEN = config.block_size
N_FEAT = config.n_embd // config.n_head
N_OCOLS = 4

In [53]:
n_padd = 5
out_row_num = 7
out_col_start = 3

seq_idxs = torch.arange(0, SEQ_LEN)

data_row_num_mask = out_row_num >= n_padd
data_seq_mask = seq_idxs >= n_padd
data_block_mask = torch.broadcast_to(data_seq_mask[:, None], (SEQ_LEN, N_OCOLS))

In [54]:
data_block_mask

tensor([[False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        ...,
        [ True,  True,  True,  True],
        [ True,  True,  True,  True],
        [ True,  True,  True,  True]])

In [90]:
k_mat = torch.randn((SEQ_LEN, N_FEAT))
k_mat

tensor([[ 0.2147,  0.6030,  1.1701,  ..., -0.7375, -1.4493,  0.8846],
        [-0.8668, -0.2372, -1.4559,  ...,  1.9100,  1.5226,  0.2802],
        [-0.3820, -0.8772,  0.2808,  ..., -0.4392,  0.5333,  0.2044],
        ...,
        [-0.0188,  1.5368,  0.0277,  ..., -0.3134, -0.7780,  1.0785],
        [ 0.5268,  0.5686,  1.2759,  ..., -0.7068, -2.0102, -0.6808],
        [-1.5466,  0.2807,  0.2595,  ...,  0.3604, -0.5766,  1.0186]])

In [88]:
q_row = torch.randn((N_FEAT,))
q_row

tensor([ 0.6096, -0.1009, -0.2483, -0.1537,  0.5777, -0.2142, -1.8165,  0.2700,
         0.9725, -0.2513, -0.7535, -0.2081, -0.9549, -0.7695, -1.8244,  0.4653,
         0.3364,  0.9533, -0.2115,  0.0030,  1.8236,  0.2614,  0.1525, -1.4166,
        -0.2983, -0.3724, -0.0177,  1.3408, -1.1383, -1.1147,  0.6211, -0.4133,
         0.0686,  0.6783, -0.0325,  1.6544,  0.2191, -0.2254, -0.2810, -0.0133,
         0.0300, -1.1392,  1.6929, -0.5713,  0.9627, -1.2250, -0.7900,  1.1811,
         1.1362,  0.0548,  0.5757, -0.3007,  1.0193, -1.1057, -0.0861, -1.3086,
        -0.8898, -0.1812, -0.2406, -0.4369, -2.0844, -0.3743,  1.7599,  0.7849])

In [91]:
k_mat.shape

torch.Size([1024, 64])

In [95]:
q_row[:, None].shape

torch.Size([64, 1])

In [97]:
k_mat * q_row[None, :]

tensor([[ 0.1309, -0.0609, -0.2906,  ...,  0.2760, -2.5506,  0.6944],
        [-0.5284,  0.0239,  0.3615,  ..., -0.7148,  2.6797,  0.2199],
        [-0.2328,  0.0885, -0.0697,  ...,  0.1644,  0.9386,  0.1604],
        ...,
        [-0.0115, -0.1551, -0.0069,  ...,  0.1173, -1.3692,  0.8466],
        [ 0.3211, -0.0574, -0.3168,  ...,  0.2645, -3.5378, -0.5344],
        [-0.9428, -0.0283, -0.0644,  ..., -0.1349, -1.0148,  0.7995]])

In [122]:
@triton.jit
def padded_attention_kernel(
    q_ptr,
    k_ptr,
    v_ptr,
    p_ptr,
    out_ptr,
    SM_SCALE: tl.constexpr,
    SEQ_LEN: tl.constexpr,
    N_FEAT: tl.constexpr,
    N_OCOLS: tl.constexpr,
):
    """
    Triton kernel implementing single-headed attention with causal masking, where some
    tokens may be padding.

    Kernel cell with coordinates i, j computes out[i, j*N_OCOLS: (j+1)*N_OCOLS] by
    multiplying att[i, :] by v[:, j*N_OCOLS: (j+1)*N_OCOLS].

    We compute att[i, :] by multiplying q[i, :] by k[:, :]^T.

    Inputs
    ------
    q_ptr: [SEQ_LEN, N_FEAT]
    k_ptr: [SEQ_LEN, N_FEAT]
    v_ptr: [SEQ_LEN, N_FEAT]

    Output
    ------
    out: [SEQ_LEN, N_FEAT]
        The output of self attention (`att @ v`).

    Constants
    ---------
    SEQ_LEN: (1024 for GPT-2)
    N_FEAT: (64 for GPT-2)
    N_OCOLS: Number of elements of output matrix computed per kernel instance.

    NOTE: Assumes all tensor sizes are powers of 2.
    """
    n_padd = tl.load(p_ptr)

    out_row_num = tl.program_id(0)
    out_col_start = tl.program_id(1) * N_OCOLS
    # This cell computes out[out_row_num, out_col_start: out_col_start + N_OCOLS].

    seq_idxs = tl.arange(0, SEQ_LEN)
    feat_idxs = tl.arange(0, N_FEAT)
    ocols_idxs = tl.arange(0, N_OCOLS)

    # Shape (1,) mask. 0 if this instance is computing only padding elements of `out`.
    data_row_num_mask = out_row_num >= n_padd
    # Shape (SEQ_LEN,) mask. 0 for tokens of the sequence which are padding.
    data_seq_mask = seq_idxs >= n_padd
    # Shape (SEQ_LEN, N_OCOLS) mask. 0 for features of v corresponding to padding tokens.
    data_v_block_mask = tl.broadcast_to(data_seq_mask[:, None], (SEQ_LEN, N_OCOLS))
    # Shape (SEQ_LEN, N_FEAT) mask into k. 0 for features corresponding to padding tokens.
    data_k_mat_mask = tl.broadcast_to(data_seq_mask[:, None], (SEQ_LEN, N_FEAT))
    # Shape (SEQ_LEN,). 0 for non-causal elements of `att`.
    causal_mask = seq_idxs <= out_row_num

    ### Compute the softmax of att[out_row_num, :]. ###
    # This requires loading one row of Q and all of K.

    q_row_start_ptr = q_ptr + out_row_num * N_FEAT
    q_ptrs = q_row_start_ptr + feat_idxs
    q_row = tl.load(q_ptrs, mask=data_row_num_mask, other=0.0) # [N_FEAT,]

    k_ptrs = k_ptr + (
        seq_idxs[:, None] * N_FEAT + feat_idxs[None, :]
    )
    k_mat = tl.load(k_ptrs, mask=data_k_mat_mask, other=0.0) # [SEQ_LEN, N_FEAT]

    # ([N_FEAT,] -> [1, N_FEAT]) * [SEQ_LEN, N_FEAT] -> [SEQ_LEN, N_FEAT].
    att_row = tl.sum(q_row[None, :] * k_mat, axis=1) * SM_SCALE
    # padding and non-causal elements currenly have value 0. We need to set them to -inf
    # for the softmax.

    print("q_row", q_row)
    print("k_mat", k_mat)
    print("att_row", att_row)

    causal_att_row = tl.where(causal_mask & data_seq_mask, att_row, float("-inf"))
    sm_numerator = tl.exp(causal_att_row - tl.max(causal_att_row, axis=0))
    sm_att_row = sm_numerator / tl.sum(sm_numerator, axis=0) # [seq_len,]

    ### Multiply att[out_row_num, :] by v[:, out_col_start: out_col_start + N_OCOLS]. ###

    v_block_start_ptr = v_ptr + out_col_start
    v_ptrs = v_block_start_ptr + (
        seq_idxs[:, None] * N_FEAT + ocols_idxs[None, :]
    )
    v_block = tl.load(v_ptrs, mask=data_v_block_mask, other=0.0) # [SEQ_LEN, N_OCOLS]

    out = tl.sum(sm_att_row[:, None] * v_block, axis=0) # [N_OCOLS,]

    ### Write output row-block. ###

    out_row_start_ptr = out_ptr + out_row_num * N_FEAT
    out_ptrs = out_row_start_ptr + out_col_start + ocols_idxs

    tl.store(out_ptrs, out, mask=data_row_num_mask)


In [123]:
def padded_attention(q, k, v, n_padd, n_ocols: int = 4):

    out_shape = jax.ShapeDtypeStruct(shape=v.shape, dtype=v.dtype)
    grid = (q.shape[0], q.shape[1] // n_ocols)
    assert grid[1] * n_ocols == q.shape[1]

    return jt.triton_call(
        q,
        k,
        v,
        jnp.array(n_padd),
        kernel=padded_attention_kernel,
        out_shape=out_shape,
        grid=grid,
        SM_SCALE = 1.0 / k.shape[1]**0.5,
        SEQ_LEN=q.shape[0],
        N_FEAT=q.shape[1],
        N_OCOLS=n_ocols,
    )

In [124]:
class SHCSATriton(nn.Module):

    n_feat: int

    @nn.compact
    def __call__(self, x, n_padd: int = 0):
        T, C = x.shape  # sequence length, embedding dimensionality (n_embd)

        # [T, C] @ [C, 3 * n_feat] -> [T, 3 * n_feat] -> 3 * [T, n_feat]
        q, k, v = jnp.split(nn.Dense(features=3 * self.n_feat)(x), 3, axis=1)

        y = padded_attention(q, k, v, n_padd)

        return y

In [116]:
config = get_config_for('gpt2')
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, (config.block_size, config.n_embd))
n_feat = config.n_embd // config.n_head

In [101]:
y, _ = SingleHeadCausalSelfAttention(n_feat).init_with_output(rng, x)

In [125]:
ty, _ = SHCSATriton(n_feat).init_with_output(rng, x)

q_row fp32[constexpr[64]]
k_mat fp32[constexpr[1024],constexpr[64]]
att_row fp32[constexpr[1024]]


In [128]:
(y - ty).max()

Array(2.3841858e-07, dtype=float32)