In [None]:
# implement a sparse attention using triton using the following methods
# the Q and K are divided into equal sized chunks
# assume  QxK^T to be [Q0,Q1,....Qn-1]*[K0,K1,....Kn-1] where each of them are equal sized chunks from the initial embeddings.
# in the full product if Q0*K0 then you do the regular multiplication, but if Q0*K1 or whenever the indices are not same, do avg(Q0)*avg(K1) and then broadcast this value in the shape of that grid.
# create a triton kernel which implements the above operation if i==j then intra-index, if i!=j then inter-index
# generate code and test case for the kernels first before proceeding to the full implementation
# the argument and the return signature should exactly match that of the torch.nn.functional.scaled_dot_product_attention

In [None]:
# def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
#         is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
#     L, S = query.size(-2), key.size(-2)
#     scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
#     attn_bias = torch.zeros(L, S, dtype=query.dtype)
#     if is_causal:
#         assert attn_mask is None
#         temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
#         attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
#         attn_bias.to(query.dtype)

#     if attn_mask is not None:
#         if attn_mask.dtype == torch.bool:
#             attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
#         else:
#             attn_bias += attn_mask

#     if enable_gqa:
#         key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
#         value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

#     attn_weight = query @ key.transpose(-2, -1) * scale_factor
#     attn_weight += attn_bias
#     attn_weight = torch.softmax(attn_weight, dim=-1)
#     attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
#     return attn_weight @ value

In [7]:
import math
import torch
import triton
import triton.language as tl

@triton.jit
def chunked_attention_kernel(
    Q_ptr,                   # ptr to Q,  shape=[M, D]
    K_ptr,                   # ptr to K,  shape=[D, N]
    Out_ptr,                 # ptr to Out,shape=[M, N]
    # offsets for sub-chunk (i, j):
    q_row_offset,            # integer
    k_col_offset,            # integer
    out_row_offset,          # integer
    out_col_offset,          # integer
    
    # strides in memory for Q, K, Out:
    q_stride_m,              # typically D
    q_stride_d,              # typically 1
    k_stride_d,              # typically N
    k_stride_n,              # typically 1
    out_stride_m,            # typically N
    out_stride_n,            # typically 1

    M, N, D,                 # full matrix sizes
    chunk_size_m,            # sub-chunk M size
    chunk_size_n,            # sub-chunk N size
    
    avg_q_val,               # precomputed scalar avg(Q_i)
    avg_k_val,               # precomputed scalar avg(K_j)
    is_same_chunk,           # 1 => i == j (intra-chunk), 0 => i != j (inter-chunk)

    BLOCK_M: tl.constexpr,   # block size along M
    BLOCK_N: tl.constexpr,   # block size along N
    BLOCK_K: tl.constexpr    # block size along K (reduction dimension)
):
    """
    Single-block kernel for chunked attention sub-block:
      - If is_same_chunk == 1 (i == j): do a block-tiled matmul of shape [chunk_size_m, chunk_size_n].
      - If is_same_chunk == 0 (i != j): broadcast the product of avg_q_val * avg_k_val.
    """

    # 1) Compute row/col indices within this sub-chunk
    row_id = tl.arange(0, BLOCK_M)  # [0..BLOCK_M-1]
    col_id = tl.arange(0, BLOCK_N)  # [0..BLOCK_N-1]

    # 2) Mask to ensure we do not go out of the sub-chunk
    #    (for small or partial blocks)
    row_mask = row_id < chunk_size_m
    col_mask = col_id < chunk_size_n

    # Base pointers in Q, K, Out for this sub-chunk
    # Q base is Q_ptr + q_row_offset*D
    # K base is K_ptr + k_col_offset
    # Out base is Out_ptr + out_row_offset*N + out_col_offset
    # We apply them after we add the row/col offset inside the load/store macros.

    # ------------------------------------------------------------------
    # CASE 1: Intra-chunk => matmul with blocking along D
    # ------------------------------------------------------------------
    if is_same_chunk == 1:
        # Initialize accumulator
        acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

        # We do a loop over K dimension in BLOCK_K chunks
        # Typical approach: each iteration loads sub-block from Q and K
        # Then does a blockwise matmul-accumulate.
        for kk in range(0, D, BLOCK_K):
            # Current chunk size for this iteration
            kk_size = tl.minimum(BLOCK_K, D - kk)

            # -----------------
            # LOAD A SUB-BLOCK FROM Q => shape [BLOCK_M, BLOCK_K]
            # Q rows: (q_row_offset + row_id)
            # Q cols: (kk + [0..BLOCK_K-1])
            q_row_ptr = Q_ptr + (q_row_offset + row_id) * q_stride_m
            q_col_ptr = q_row_ptr + (kk) * q_stride_d  # offset in the D dimension
            # gather from Q
            # shape = [BLOCK_M, kk_size]
            # We'll do an outer loop over row_id + an inner loop over the partial K
            q_vals = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
            # load only the needed columns
            load_k_idx = tl.arange(0, BLOCK_K)
            col_mask_k = load_k_idx < kk_size  # valid columns
            # for each row in [BLOCK_M]:
            for r in range(BLOCK_M):
                if row_mask[r]:
                    # pointer for row r
                    row_ptr = q_col_ptr[r]
                    # load the columns
                    q_vals[r, :] = tl.load(
                        row_ptr + load_k_idx * q_stride_d,
                        mask=col_mask_k, 
                        other=0.0
                    )

            # -----------------
            # LOAD A SUB-BLOCK FROM K => shape [BLOCK_K, BLOCK_N]
            # K rows: (kk + [0..BLOCK_K-1])
            # K cols: (k_col_offset + col_id)
            k_col_ptr = K_ptr + (k_col_offset + col_id) * k_stride_n
            k_row_ptr = k_col_ptr + (kk) * k_stride_d
            k_vals = tl.zeros((BLOCK_K, BLOCK_N), dtype=tl.float32)
            row_mask_k = load_k_idx < kk_size
            for c in range(BLOCK_N):
                if col_mask[c]:
                    # pointer for col c
                    col_ptr = k_row_ptr + c * k_stride_n
                    k_vals[:, c] = tl.load(
                        col_ptr + load_k_idx * k_stride_d,
                        mask=row_mask_k,
                        other=0.0
                    )

            # -----------------
            # ACCUMULATE: (BLOCK_M x BLOCK_K) @ (BLOCK_K x BLOCK_N)
            # We do a naive triple nested loop in Triton to keep it simple.
            for red_i in range(BLOCK_K):
                # if red_i < kk_size
                red_mask = red_i < kk_size
                if red_mask:
                    q_slice = q_vals[:, red_i]  # shape [BLOCK_M]
                    k_slice = k_vals[red_i, :]  # shape [BLOCK_N]
                    acc += q_slice[:, None] * k_slice[None, :]

        # Write the result to Out
        # address for (row_id, col_id):
        out_ptr = Out_ptr + (out_row_offset + row_id) * out_stride_m + (out_col_offset) * out_stride_n
        for r in range(BLOCK_M):
            if row_mask[r]:
                row_out_ptr = out_ptr[r]
                for c in range(BLOCK_N):
                    if col_mask[c]:
                        tl.store(
                            row_out_ptr + (col_id[c] * out_stride_n),
                            acc[r, c]
                        )

    # ------------------------------------------------------------------
    # CASE 2: Inter-chunk => broadcast avg_q_val * avg_k_val
    # ------------------------------------------------------------------
    else:
        # Single scalar
        prod_scalar = avg_q_val * avg_k_val

        # Write to entire sub-block
        out_ptr = Out_ptr + (out_row_offset + row_id) * out_stride_m + (out_col_offset) * out_stride_n
        for r in range(BLOCK_M):
            if row_mask[r]:
                row_out_ptr = out_ptr[r]
                for c in range(BLOCK_N):
                    if col_mask[c]:
                        tl.store(
                            row_out_ptr + col_id[c] * out_stride_n,
                            prod_scalar
                        )
