In [21]:
# FASTA: Full Average Scaled Tiling Attention
# implement a sparse attention using triton using the following methods
# in the standard self attention, the attention weight is computed like this: attn_weight = query @ key.transpose(-2, -1) * scale_factor
# assume a function:
# def att_weight(Q,K_T):
#    return Q@K_T
# FASTA is a sparse approximation for the above function which works as follows:
# def att_weight(Q,K_T,n_chunks):
#    return Q@K_T # sparse approximation
# the Q and K are divided into equal sized chunks
# assume  QxK^T to be [Q0,Q1,....Qn-1]*[K0,K1,....Kn-1] where each of them are equal sized chunks from the initial embeddings.
# in the full product if Q0*K0 then you do the regular multiplication, but if Q0*K1 or whenever the indices are not same, do avg(Q0)*avg(K1) and then broadcast this value in the shape of that grid.
# create a triton kernel which implements the above operation if i==j then intra-index, if i!=j then inter-index
# generate code and test case for the kernels first before proceeding to the full implementation
# the overall time complexity should be O(n^2/c^2+n*d*c) where c is number of chunks

In [22]:
## standard torch self-attention
import torch

def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

In [6]:
################################################################################

In [109]:
import torch
import triton
import triton.language as tl

@triton.jit
def sparse_attn_kernel(
    Q_ptr, K_ptr, attn_ptr,
    N, D: tl.constexpr, BLOCK_SIZE: tl.constexpr,
    stride_q0, stride_q1,
    stride_k0, stride_k1,
    stride_attn0, stride_attn1,
):
    # Get the program ID and compute row and column indices
    pid = tl.program_id(0)
    n_blocks = tl.cdiv(N, BLOCK_SIZE)
    row_block_idx = pid // n_blocks
    col_block_idx = pid % n_blocks
    
    # Calculate offsets
    row_start = row_block_idx * BLOCK_SIZE
    col_start = col_block_idx * BLOCK_SIZE
    
    # Create block pointers
    offs_q = row_start + tl.arange(0, BLOCK_SIZE)
    offs_k = col_start + tl.arange(0, BLOCK_SIZE)
    offs_d = tl.arange(0, D)
    
    # Initialize accumulator
    acc = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32)
    
    # Load Q and K blocks
    q_ptrs = Q_ptr + offs_q[:, None] * stride_q0 + offs_d[None, :] * stride_q1
    k_ptrs = K_ptr + offs_k[:, None] * stride_k0 + offs_d[None, :] * stride_k1
    
    q_mask = offs_q[:, None] < N
    k_mask = offs_k[:, None] < N
    
    q_block = tl.load(q_ptrs, mask=q_mask, other=0.0)
    k_block = tl.load(k_ptrs, mask=k_mask, other=0.0)
    
    if row_block_idx == col_block_idx:
        # Intra-block: compute full attention
        for d in range(D):
            # Corrected: Include row_start in Q offsets
            q_offsets = (row_start + tl.arange(0, BLOCK_SIZE)) * stride_q0 + d * stride_q1
            # Corrected: Include col_start in K offsets
            k_offsets = (col_start + tl.arange(0, BLOCK_SIZE)) * stride_k0 + d * stride_k1
            
            # Load the d-th column into q_vals and k_vals
            q_vals = tl.load(Q_ptr + q_offsets, mask=(tl.arange(0, BLOCK_SIZE) < N), other=0.0)
            k_vals = tl.load(K_ptr + k_offsets, mask=(tl.arange(0, BLOCK_SIZE) < N), other=0.0)
            
            # Compute outer product and update accumulator
            acc += q_vals[:, None] * k_vals[None, :]
            
    if row_block_idx != col_block_idx:
        # Inter-block: compute average of entire Q and K blocks
        sum_q = tl.sum(q_block, axis=1)  # Shape: (BLOCK_SIZE,)
        sum_k = tl.sum(k_block, axis=1)  # Shape: (BLOCK_SIZE,)
        
        # Then, sum over BLOCK_SIZE to get a single scalar for each block
        total_sum_q = tl.sum(sum_q)  # Scalar
        total_sum_k = tl.sum(sum_k)  # Scalar
        
        # Compute the average
        avg_q = total_sum_q / (D * BLOCK_SIZE)
        avg_k = total_sum_k / (D * BLOCK_SIZE)
        
        # Compute the scalar outer product
        outer = avg_q * avg_k  # Scalar
        
        # Broadcast the scalar to the entire BLOCK_SIZE x BLOCK_SIZE matrix
        acc += outer    
    # Store the results
    offs_attn_i = row_start + tl.arange(0, BLOCK_SIZE)
    offs_attn_j = col_start + tl.arange(0, BLOCK_SIZE)
    
    attn_ptrs = attn_ptr + offs_attn_i[:, None] * stride_attn0 + offs_attn_j[None, :] * stride_attn1
    mask = (offs_attn_i[:, None] < N) & (offs_attn_j[None, :] < N)
    tl.store(attn_ptrs, acc, mask=mask)

def get_attn_weight(Q, K, block_size=128):
    """
    Computes FASTA attention using Triton.
    
    Args:
        Q (torch.Tensor): Query tensor of shape (N, D)
        K (torch.Tensor): Key tensor of shape (N, D)
        block_size (int): Size of attention blocks
    
    Returns:
        torch.Tensor: Attention weights of shape (N, N)
    """
    N, D = Q.shape
    # Ensure tensors are contiguous
    Q = Q.contiguous()
    K = K.contiguous()
    
    # Create output tensor
    attn = torch.empty((N, N), device=Q.device, dtype=Q.dtype)
    
    # Calculate grid size
    n_blocks = triton.cdiv(N, block_size)
    grid = (n_blocks * n_blocks,)
    print(f"grid: {grid}")
    # Launch kernel
    sparse_attn_kernel[grid](
        Q, K, attn,
        N, D, block_size,
        Q.stride(0), Q.stride(1),
        K.stride(0), K.stride(1),
        attn.stride(0), attn.stride(1),
    )
    
    return attn

def test_fasta_attention():
    """
    Test function for FASTA attention implementation
    """
    # Test parameters
    N = 64  # Sequence length
    D = 64   # Hidden dimension
    block_size = 8
    device = 'cuda'
    
    # Generate random inputs
    torch.manual_seed(0)
    Q = torch.randn(N, D, device=device, dtype=torch.float32)
    K = torch.randn(N, D, device=device, dtype=torch.float32)
    
    # Compute attention using FASTA
    attn_fasta = get_attn_weight(Q, K, block_size=block_size)
    
    # Compute reference attention
    attn_ref = Q @ K.T
    
    # Calculate grid parameters
    n_blocks = (N + block_size - 1) // block_size
    
    # Initialize difference trackers
    max_diff_intra = 0.0
    max_diff_inter = 0.0
    mean_diff_intra = 0.0
    mean_diff_inter = 0.0
    
    for i in range(n_blocks):
        for j in range(n_blocks):
            row_start = i * block_size
            col_start = j * block_size
            row_end = min(row_start + block_size, N)
            col_end = min(col_start + block_size, N)
            
            fasta_block = attn_fasta[row_start:row_end, col_start:col_end]
            ref_block = attn_ref[row_start:row_end, col_start:col_end]
            
            if i == j:
                # Intra-block: should match closely
                diff = (fasta_block - ref_block).abs()
                block_max_diff = diff.max().item()
                block_mean_diff = diff.mean().item()
                max_diff_intra = max(max_diff_intra, block_max_diff)
                mean_diff_intra += block_mean_diff
            else:
                # Inter-block: approximated
                # Compute the difference between approximated and actual
                # Since FASTA approximates the entire block with a scalar, compare each element
                diff = (fasta_block - ref_block).abs()
                block_max_diff = diff.max().item()
                block_mean_diff = diff.mean().item()
                max_diff_inter = max(max_diff_inter, block_max_diff)
                mean_diff_inter += block_mean_diff
    
    # Average the mean differences across all blocks
    mean_diff_intra /= n_blocks
    mean_diff_inter /= (n_blocks * (n_blocks - 1))
    
    print(f"Maximum intra-block difference: {max_diff_intra}")
    print(f"Mean intra-block difference: {mean_diff_intra}")
    print(f"Maximum inter-block difference: {max_diff_inter}")
    print(f"Mean inter-block difference: {mean_diff_inter}")
    
    # Intra-block differences should be very small (close to 0)
    # Inter-block differences will be higher due to approximation
    
    # Optional: Visual verification (commented out)
    # torch.set_printoptions(precision=4)
    # print("Difference matrix:\n", attn_fasta - attn_ref)
    
    # Assertions for intra-block differences only
    intra_block_tolerance = 1e-5  # Tight tolerance since intra-block is exact
    assert max_diff_intra < intra_block_tolerance, f"Intra-block differences exceed tolerance: {max_diff_intra}"
    
    print("Intra-block attention matches the reference within the acceptable tolerance.")
    print("Inter-block attention has been approximated, resulting in differences as expected.")
    print("Test completed successfully!")

if __name__ == "__main__":
    test_fasta_attention()


grid: (64,)
Maximum intra-block difference: 7.62939453125e-06
Mean intra-block difference: 8.333481673616916e-07
Maximum inter-block difference: 29.288122177124023
Mean inter-block difference: 6.261228357042585
Intra-block attention matches the reference within the acceptable tolerance.
Inter-block attention has been approximated, resulting in differences as expected.
Test completed successfully!
