In [None]:
# FASTA: Full Average Scaled Tiling Attention
# implement a sparse attention using triton using the following methods
# in the standard self attention, the attention weight is computed like this: attn_weight = query @ key.transpose(-2, -1) * scale_factor
# assume a function:
# def att_weight(Q,K_T):
#    return Q@K_T
# FASTA is a sparse approximation for the above function which works as follows:
# the Q and K are divided into equal sized chunks
# assume  QxK^T to be [Q0,Q1,....Qn-1]*[K0,K1,....Kn-1] where each of them are equal sized chunks from the initial embeddings.
# in the full product if Q0*K0 then you do the regular multiplication, but if Q0*K1 or whenever the indices are not same, do avg(Q0)*avg(K1) and then broadcast this value in the shape of that grid.
# create a triton kernel which implements the above operation if i==j then intra-index, if i!=j then inter-index
# generate code and test case for the kernels first before proceeding to the full implementation

In [5]:
## standard torch self-attention
import torch

def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

In [6]:
################################################################################

In [14]:
import torch
import triton
import triton.language as tl

# Triton Kernel for FASTA Attention
@triton.jit
def fasta_attn_kernel(
    Q_ptr, K_ptr, attn_ptr,
    N, D, BLOCK_SIZE,
    stride_q0, stride_q1,
    stride_k0, stride_k1,
    stride_attn0, stride_attn1
):
    # Calculate program ID
    pid = tl.program_id(0)
    
    # Compute row using floor division and then compute col without using tl.mod
    row = tl.cast(tl.floor(tl.cast(pid, tl.float32) / tl.cast(N, tl.float32)), tl.int32)
    col = pid - row * N
    
    # Compute the base pointers for the current Q and K chunks
    q_base = Q_ptr + row * BLOCK_SIZE * stride_q0
    k_base = K_ptr + col * BLOCK_SIZE * stride_k0
    
    # Initialize shared memory for Q and K chunks
    # Shape: (BLOCK_SIZE, D)
    Q_chunk = tl.load(
        q_base + tl.arange(0, BLOCK_SIZE)[:, None] * stride_q0 + tl.arange(0, D)[None, :] * stride_q1,
        mask=True, 
        other=0.0
    )
    K_chunk = tl.load(
        k_base + tl.arange(0, BLOCK_SIZE)[:, None] * stride_k0 + tl.arange(0, D)[None, :] * stride_k1,
        mask=True, 
        other=0.0
    )
    
    if row == col:
        # Intra-chunk multiplication: Q_chunk @ K_chunk^T
        # Compute the dot product manually
        acc = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32)
        for k in range(D):
            q = Q_chunk[:, k]  # Shape: (BLOCK_SIZE,)
            k_ = K_chunk[:, k]  # Shape: (BLOCK_SIZE,)
            acc += q[:, None] * k_[None, :]
        attn = acc
    else:
        # Inter-chunk average multiplication
        # Compute the average of Q_chunk and K_chunk manually
        Q_sum = tl.zeros((D,), dtype=tl.float32)
        K_sum = tl.zeros((D,), dtype=tl.float32)
        for i in range(BLOCK_SIZE):
            Q_sum += Q_chunk[i]
            K_sum += K_chunk[i]
        Q_avg = Q_sum / BLOCK_SIZE
        K_avg = K_sum / BLOCK_SIZE
        # Compute the dot product of averages
        attn_val = tl.dot(Q_avg, K_avg)
        # Broadcast the scalar to a BLOCK_SIZE x BLOCK_SIZE matrix
        attn = tl.broadcast(attn_val, (BLOCK_SIZE, BLOCK_SIZE))
    
    # Compute the base pointer for the current attention block
    attn_base = attn_ptr + row * BLOCK_SIZE * stride_attn0 + col * BLOCK_SIZE * stride_attn1
    
    # Store the attention weights
    tl.store(
        attn_base + tl.arange(0, BLOCK_SIZE)[:, None] * stride_attn0 + tl.arange(0, BLOCK_SIZE)[None, :] * stride_attn1,
        attn, 
        mask=True
    )

# Attention Function
def fasta_attention(Q, K, BLOCK_SIZE=128):
    """
    Computes FASTA attention using Triton.

    Args:
        Q (torch.Tensor): Query tensor of shape (N*BLOCK_SIZE, D)
        K (torch.Tensor): Key tensor of shape (N*BLOCK_SIZE, D)
        BLOCK_SIZE (int): Size of each chunk

    Returns:
        torch.Tensor: Attention weights of shape (N*BLOCK_SIZE, N*BLOCK_SIZE)
    """
    assert Q.shape == K.shape, "Q and K must have the same shape"
    total_size, D = Q.shape
    N = total_size // BLOCK_SIZE
    assert N * BLOCK_SIZE == total_size, "Total size must be divisible by BLOCK_SIZE"
    
    attn = torch.empty((total_size, total_size), device=Q.device, dtype=Q.dtype)
    
    # Define strides
    stride_q0, stride_q1 = Q.stride()
    stride_k0, stride_k1 = K.stride()
    stride_attn0, stride_attn1 = attn.stride()
    
    # Launch Triton kernel
    grid = (N * N,)
    fasta_attn_kernel[grid](
        Q, K, attn,
        N, D, BLOCK_SIZE,
        stride_q0, stride_q1,
        stride_k0, stride_k1,
        stride_attn0, stride_attn1,
        num_warps=4,
        num_stages=3
    )
    return attn

def test_fasta_attention():
    """
    Tests the FASTA attention implementation by comparing it against a reference PyTorch implementation.
    """
    # Configuration
    N = 4           # Number of chunks
    BLOCK_SIZE = 16 # Chunk size
    D = 32          # Dimension
    torch.manual_seed(0)

    # Initialize Q and K with random values
    Q = torch.randn(N * BLOCK_SIZE, D, device='cuda', dtype=torch.float32)
    K = torch.randn(N * BLOCK_SIZE, D, device='cuda', dtype=torch.float32)

    # Compute attention using FASTA
    attn_fasta = fasta_attention(Q, K, BLOCK_SIZE=BLOCK_SIZE)

    # Reference computation using PyTorch
    attn_ref = torch.zeros((N * BLOCK_SIZE, N * BLOCK_SIZE), device='cuda', dtype=torch.float32)
    for i in range(N):
        for j in range(N):
            Q_i = Q[i*BLOCK_SIZE:(i+1)*BLOCK_SIZE]
            K_j = K[j*BLOCK_SIZE:(j+1)*BLOCK_SIZE]
            if i == j:
                # Intra-chunk multiplication
                ref = torch.matmul(Q_i, K_j.T)
            else:
                # Inter-chunk average multiplication
                Q_avg = Q_i.mean(dim=0, keepdim=True)
                K_avg = K_j.mean(dim=0, keepdim=True)
                ref = torch.matmul(Q_avg, K_avg.T).expand(BLOCK_SIZE, BLOCK_SIZE)
            attn_ref[i*BLOCK_SIZE:(i+1)*BLOCK_SIZE, j*BLOCK_SIZE:(j+1)*BLOCK_SIZE] = ref

    # Verify the results
    if torch.allclose(attn_fasta, attn_ref, atol=1e-4):
        print("Test passed! FASTA attention matches the reference implementation.")
    else:
        max_diff = (attn_fasta - attn_ref).abs().max()
        print(f"Test failed! Maximum difference: {max_diff}")

# Run the test case
if __name__ == "__main__":
    test_fasta_attention()


CompilationError: at 22:17:
    # Compute row using floor division and then compute col without using tl.mod
    row = tl.cast(tl.floor(tl.cast(pid, tl.float32) / tl.cast(N, tl.float32)), tl.int32)
    col = pid - row * N

    # Compute the base pointers for the current Q and K chunks
    q_base = Q_ptr + row * BLOCK_SIZE * stride_q0
    k_base = K_ptr + col * BLOCK_SIZE * stride_k0

    # Initialize shared memory for Q and K chunks
    # Shape: (BLOCK_SIZE, D)
    Q_chunk = tl.load(
        q_base + tl.arange(0, BLOCK_SIZE)[:, None] * stride_q0 + tl.arange(0, D)[None, :] * stride_q1,
                 ^