In [None]:
# FASTA: Full Average Scaled Tiling Attention
# implement a sparse attention using triton using the following methods
# in the standard self attention, the attention weight is computed like this: attn_weight = query @ key.transpose(-2, -1) * scale_factor
# assume a function:
# def att_weight(Q,K_T):
#    return Q@K_T
# FASTA is a sparse approximation for the above function which works as follows:
# def att_weight(Q,K_T,n_chunks):
#    return Q@K_T # sparse approximation
# the Q and K are divided into equal sized chunks
# assume  QxK^T to be [Q0,Q1,....Qn-1]*[K0,K1,....Kn-1] where each of them are equal sized chunks from the initial embeddings.
# in the full product if Q0*K0 then you do the regular multiplication, but if Q0*K1 or whenever the indices are not same, do avg(Q0)*avg(K1) and then broadcast this value in the shape of that grid.
# create a triton kernel which implements the above operation if i==j then intra-index, if i!=j then inter-index
# generate code and test case for the kernels first before proceeding to the full implementation
# the overall time complexity should be O(n^2/c^2+n*d*c) where c is number of chunks

In [None]:
## standard torch self-attention
import torch

def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

In [None]:
################################################################################

In [3]:
import torch
import triton
import triton.language as tl
import math
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm  # For progress bar

# ------------------- Triton MatMul Kernel -------------------

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
        a_ptr, b_ptr, c_ptr,
        M, N, K,
        stride_am, stride_ak,
        stride_bk, stride_bn,
        stride_cm, stride_cn,
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
        GROUP_SIZE_M: tl.constexpr,
        ACTIVATION: tl.constexpr,
    ):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

    a_ptrs = a_ptr + offs_am[:, None] * stride_am + tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak
    b_ptrs = b_ptr + tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + offs_bn[None, :] * stride_bn

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = tl.load(a_ptrs, mask=(offs_am[:, None] < M)[:, None] & (tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K), other=0.0)
        b = tl.load(b_ptrs, mask=(tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K)[:, None] & (offs_bn[None, :] < N)[None, :], other=0.0)
        accumulator += tl.dot(a, b)
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk

    if ACTIVATION == "leaky_relu":
        accumulator = tl.where(accumulator >= 0, accumulator, 0.01 * accumulator)

    c = accumulator.to(tl.float16)

    c_ptrs = c_ptr + offs_am[:, None] * stride_cm + offs_bn[None, :] * stride_cn
    mask = (offs_am[:, None] < M)[:, None] & (offs_bn[None, :] < N)[None, :]
    tl.store(c_ptrs, c, mask=mask)

def matmul(a, b, activation=""):
    """
    Convenience wrapper for the Triton matmul kernel.

    Args:
        a (torch.Tensor): Tensor of shape (M, K)
        b (torch.Tensor): Tensor of shape (K, N)
        activation (str): Activation function to apply ('leaky_relu' or '')

    Returns:
        torch.Tensor: Result of the matrix multiplication (M, N)
    """
    M, K = a.shape
    K_b, N = b.shape
    assert K == K_b, "Incompatible dimensions for matmul"

    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_SIZE_M=128, BLOCK_SIZE_N=256, BLOCK_SIZE_K=64,
        GROUP_SIZE_M=8,
        ACTIVATION=activation,
    )
    return c

# ------------------- Triton Inter-Block Attention Kernel -------------------

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 32, 'GROUP_SIZE': 4}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE': 64, 'GROUP_SIZE': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE': 16, 'GROUP_SIZE': 2}, num_stages=4, num_warps=4),
    ],
    key=['BLOCK_SIZE'],
)
@triton.jit
def inter_block_attn_kernel(
    Q_ptr, K_ptr, attn_ptr,
    N, D: tl.constexpr, BLOCK_SIZE: tl.constexpr,
    sigma: tl.constexpr,
    stride_q0, stride_q1,
    stride_k0, stride_k1,
    stride_attn0, stride_attn1,
):
    # Program ID corresponds to (row_block, col_block)
    pid = tl.program_id(0)
    n_blocks = tl.cdiv(N, BLOCK_SIZE)
    row_block = pid // n_blocks
    col_block = pid % n_blocks

    # Skip intra-blocks
    if row_block == col_block:
        return

    row_start = row_block * BLOCK_SIZE
    col_start = col_block * BLOCK_SIZE

    # Calculate offsets for Q and K
    offs_q = row_start + tl.arange(0, BLOCK_SIZE)[:, None]  # Shape: (BLOCK_SIZE, 1)
    offs_k = col_start + tl.arange(0, BLOCK_SIZE)[None, :]  # Shape: (1, BLOCK_SIZE)

    # Masks
    q_mask = offs_q < N  # Shape: (BLOCK_SIZE, 1)
    k_mask = offs_k < N  # Shape: (1, BLOCK_SIZE)

    # Broadcast masks to match Q and K blocks
    q_mask_broadcast = q_mask & (tl.arange(0, D)[None, :] < D)  # Shape: (BLOCK_SIZE, D)
    k_mask_broadcast = k_mask & (tl.arange(0, D)[None, :] < D)  # Shape: (BLOCK_SIZE, D)

    # Load Q and K blocks
    Q_block = tl.load(Q_ptr + offs_q * stride_q0 + tl.arange(0, D)[None, :] * stride_q1, mask=q_mask_broadcast, other=0.0)  # Shape: (BLOCK_SIZE, D)
    K_block = tl.load(K_ptr + offs_k * stride_k0 + tl.arange(0, D)[None, :] * stride_k1, mask=k_mask_broadcast, other=0.0)  # Shape: (BLOCK_SIZE, D)

    # Compute vector averages
    avg_q = tl.sum(Q_block, axis=1) / D  # Shape: (BLOCK_SIZE,)
    avg_k = tl.sum(K_block, axis=1) / D  # Shape: (BLOCK_SIZE,)

    # Compute outer product
    outer = avg_q[:, None] * avg_k[None, :]  # Shape: (BLOCK_SIZE, BLOCK_SIZE)

    # Compute Gaussian weights
    center = (BLOCK_SIZE - 1) / 2.0
    i = tl.cast(tl.arange(0, BLOCK_SIZE)[:, None], tl.float32)  # Shape: (BLOCK_SIZE, 1)
    j = tl.cast(tl.arange(0, BLOCK_SIZE)[None, :], tl.float32)  # Shape: (1, BLOCK_SIZE)
    distance_sq = (i - center) * (i - center) + (j - center) * (j - center)  # Shape: (BLOCK_SIZE, BLOCK_SIZE)
    gaussian_weights = tl.exp(-distance_sq / (2.0 * sigma * sigma))  # Shape: (BLOCK_SIZE, BLOCK_SIZE)
    gaussian_weights /= tl.sum(gaussian_weights)  # Normalize

    # Apply Gaussian weights
    acc = outer * gaussian_weights  # Shape: (BLOCK_SIZE, BLOCK_SIZE)

    # Calculate attention pointers
    offs_attn_i = row_start + tl.arange(0, BLOCK_SIZE)[:, None]  # Shape: (BLOCK_SIZE, 1)
    offs_attn_j = col_start + tl.arange(0, BLOCK_SIZE)[None, :]  # Shape: (1, BLOCK_SIZE)

    attn_ptrs = attn_ptr + offs_attn_i * stride_attn0 + offs_attn_j * stride_attn1  # Shape: (BLOCK_SIZE, BLOCK_SIZE)

    # Mask for boundaries
    mask = (offs_attn_i < N) & (offs_attn_j < N)  # Shape: (BLOCK_SIZE, BLOCK_SIZE)

    # Store the attention weights with masking
    tl.store(attn_ptrs, acc, mask=mask)

def compute_inter_block_attn(Q, K, attn, block_size=32, sigma=1.5):
    """
    Computes inter-block attention using Triton kernel.

    Args:
        Q (torch.Tensor): Query tensor of shape (N, D)
        K (torch.Tensor): Key tensor of shape (N, D)
        attn (torch.Tensor): Attention tensor of shape (N, N) to store results
        block_size (int): Size of attention blocks
        sigma (float): Standard deviation for Gaussian spread
    """
    N, D = Q.shape
    n_blocks = triton.cdiv(N, block_size)
    grid = (n_blocks * n_blocks, )

    inter_block_attn_kernel[grid](
        Q, K, attn,
        N, D, block_size, sigma,
        Q.stride(0), Q.stride(1),
        K.stride(0), K.stride(1),
        attn.stride(0), attn.stride(1),
    )

# ------------------- FASTA Attention Function -------------------

def fasta_attention(Q, K, block_size=128, sigma=1.5):
    """
    Computes FASTA attention using Triton for intra-blocks and Triton for inter-blocks.

    Args:
        Q (torch.Tensor): Query tensor of shape (N, D)
        K (torch.Tensor): Key tensor of shape (N, D)
        block_size (int): Size of attention blocks
        sigma (float): Standard deviation for Gaussian spread

    Returns:
        torch.Tensor: Attention weights of shape (N, N)
    """
    N, D = Q.shape
    attn = torch.zeros((N, N), device=Q.device, dtype=torch.float32)

    # Compute intra-block attention using Triton's matmul kernel
    n_blocks = triton.cdiv(N, block_size)
    for i in range(n_blocks):
        row_start = i * block_size
        row_end = min(row_start + block_size, N)
        Q_block = Q[row_start:row_end].half()  # Convert to float16 for the matmul kernel
        K_block = K[row_start:row_end].half()
        # Transpose K_block for correct matmul dimensions
        K_block_T = K_block.transpose(-2, -1)
        # Perform matmul and convert back to float32
        matmul_result = matmul(Q_block, K_block_T).float()
        attn[row_start:row_end, row_start:row_end] = matmul_result

    # Compute inter-block attention using Triton kernel
    compute_inter_block_attn(Q.half(), K.half(), attn, block_size=block_size, sigma=sigma)

    return attn

# ------------------- Standard Self-Attention Function -------------------

def standard_self_attention(Q, K):
    """
    Computes standard self-attention using PyTorch's optimized matrix multiplication.

    Args:
        Q (torch.Tensor): Query tensor of shape (N, D)
        K (torch.Tensor): Key tensor of shape (N, D)

    Returns:
        torch.Tensor: Attention weights of shape (N, N)
    """
    return Q @ K.T

# ------------------- Visualization Function -------------------

def visualize_attention_blocks(attn_fasta, attn_ref, block_size=32, block_i=0, block_j=1):
    """
    Visualizes specific attention blocks for comparison.

    Args:
        attn_fasta (torch.Tensor): FASTA attention weights.
        attn_ref (torch.Tensor): Reference (standard) attention weights.
        block_size (int): Size of the attention block to visualize.
        block_i (int): Block row index.
        block_j (int): Block column index.
    """
    row_start = block_i * block_size
    col_start = block_j * block_size
    row_end = min(row_start + block_size, attn_fasta.shape[0])
    col_end = min(col_start + block_size, attn_fasta.shape[1])

    fasta_block = attn_fasta[row_start:row_end, col_start:col_end].cpu().numpy()
    ref_block = attn_ref[row_start:row_end, col_start:col_end].cpu().numpy()
    diff_block = fasta_block - ref_block

    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    sns.heatmap(ref_block, annot=False, fmt=".2f", cmap="viridis")
    plt.title("Reference Attention Block")

    plt.subplot(1, 3, 2)
    sns.heatmap(fasta_block, annot=False, fmt=".2f", cmap="viridis")
    plt.title("FASTA Attention Block (Inter)")

    plt.subplot(1, 3, 3)
    sns.heatmap(diff_block, annot=False, fmt=".2f", cmap="coolwarm")
    plt.title("Difference Block")

    plt.show()

# ------------------- Benchmarking Function -------------------

def test_fasta_attention_gaussian_benchmark():
    """
    Benchmark function for FASTA attention implementation using Gaussian-like spread
    and standard self-attention. Runs each attention computation 100 times and
    plots the time distributions.
    """
    # Test parameters
    N = 1024  # Sequence length
    D = 128   # Hidden dimension
    block_size = 128
    sigma = 1.5  # Standard deviation for Gaussian spread
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    print(f"Using device: {device}")
    
    # Generate random inputs
    torch.manual_seed(0)
    Q = torch.randn(N, D, device=device, dtype=torch.float32)
    K = torch.randn(N, D, device=device, dtype=torch.float32)
    
    # Warm-up runs to stabilize GPU performance
    print("Warming up the GPU...")
    for _ in range(10):
        attn_fasta = fasta_attention(Q, K, block_size=block_size, sigma=sigma)
        attn_ref = standard_self_attention(Q, K)
    torch.cuda.synchronize() if device == 'cuda' else None
    
    # Number of benchmarking iterations
    num_iterations = 100
    
    # Initialize lists to store timing data
    fasta_times = []
    standard_times = []
    
    print("Starting benchmarking...")
    # Use tqdm for progress visualization
    for _ in tqdm(range(num_iterations), desc="Benchmarking"):
        # Benchmark FASTA attention
        if device == 'cuda':
            torch.cuda.synchronize()
            start_fasta = torch.cuda.Event(enable_timing=True)
            end_fasta = torch.cuda.Event(enable_timing=True)

            start_fasta.record()
            attn_fasta = fasta_attention(Q, K, block_size=block_size, sigma=sigma)
            end_fasta.record()

            torch.cuda.synchronize()
            elapsed_fasta = start_fasta.elapsed_time(end_fasta)  # Time in milliseconds
        else:
            import time
            start_fasta = time.time()
            attn_fasta = fasta_attention(Q, K, block_size=block_size, sigma=sigma)
            end_fasta = time.time()
            elapsed_fasta = (end_fasta - start_fasta) * 1000  # Convert to milliseconds

        fasta_times.append(elapsed_fasta)

        # Benchmark standard self-attention
        if device == 'cuda':
            torch.cuda.synchronize()
            start_std = torch.cuda.Event(enable_timing=True)
            end_std = torch.cuda.Event(enable_timing=True)

            start_std.record()
            attn_ref = standard_self_attention(Q, K)
            end_std.record()

            torch.cuda.synchronize()
            elapsed_std = start_std.elapsed_time(end_std)  # Time in milliseconds
        else:
            import time
            start_std = time.time()
            attn_ref = standard_self_attention(Q, K)
            end_std = time.time()
            elapsed_std = (end_std - start_std) * 1000  # Convert to milliseconds

        standard_times.append(elapsed_std)
    
    print("Benchmarking completed!")
    
    # Convert timing lists to NumPy arrays for easier handling
    fasta_times = torch.tensor(fasta_times).cpu().numpy()
    standard_times = torch.tensor(standard_times).cpu().numpy()
    
    # Plotting the time distributions
    plt.figure(figsize=(12, 6))
    
    sns.histplot(fasta_times, color='blue', label='FASTA Attention (Gaussian Spread)', kde=True, stat="density", bins=50, alpha=0.6)
    sns.histplot(standard_times, color='orange', label='Standard Self-Attention', kde=True, stat="density", bins=50, alpha=0.6)
    
    plt.title('Time Distribution of FASTA vs Standard Self-Attention')
    plt.xlabel('Time (ms)')
    plt.ylabel('Density')
    plt.legend()
    plt.grid(True)
    plt.show()
    
    # Optional: Print summary statistics
    print("Summary Statistics:")
    print(f"FASTA Attention - Mean: {fasta_times.mean():.4f} ms, Std: {fasta_times.std():.4f} ms")
    print(f"Standard Self-Attention - Mean: {standard_times.mean():.4f} ms, Std: {standard_times.std():.4f} ms")
    
    # Optional: Visualize a specific attention block
    # visualize_attention_blocks(attn_fasta, attn_ref, block_size=32, block_i=0, block_j=1)

# ------------------- Main Execution -------------------

if __name__ == "__main__":
    test_fasta_attention_gaussian_benchmark()


Using device: cuda
Warming up the GPU...


ValueError: Conflicting meta-parameters: BLOCK_SIZE_K, GROUP_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_M. Make sure that you don't re-define auto-tuned symbols.