In [None]:
# FASTA: Full Average Scaled Tiling Attention
# implement a sparse attention using triton using the following methods
# in the standard self attention, the attention weight is computed like this: attn_weight = query @ key.transpose(-2, -1) * scale_factor
# assume a function:
# def att_weight(Q,K_T):
#    return Q@K_T
# FASTA is a sparse approximation for the above function which works as follows:
# def att_weight(Q,K_T,n_chunks):
#    return Q@K_T # sparse approximation
# the Q and K are divided into equal sized chunks
# assume  QxK^T to be [Q0,Q1,....Qn-1]*[K0,K1,....Kn-1] where each of them are equal sized chunks from the initial embeddings.
# in the full product if Q0*K0 then you do the regular multiplication, but if Q0*K1 or whenever the indices are not same, do avg(Q0)*avg(K1) and then broadcast this value in the shape of that grid.
# create a triton kernel which implements the above operation if i==j then intra-index, if i!=j then inter-index
# generate code and test case for the kernels first before proceeding to the full implementation
# the overall time complexity should be O(n^2/c^2+n*d*c) where c is number of chunks

In [None]:
################################################################################

In [None]:
import torch
import math
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from fasta import fasta_attn  # Ensure this is the Triton kernel
import numpy as np

def test_fasta_attention_benchmark():
    """
    Benchmark function for optimized FASTA attention implementation
    and standard self-attention. Includes batching for efficiency.
    """

    # Test parameters
    N = 4096  # Sequence length
    D = 64     # Hidden dimension
    block_size = 64
    device = 'cuda'
    num_iterations = 100  # Number of benchmarking iterations
    
    # Ensure CUDA is available
    assert torch.cuda.is_available(), "CUDA is not available. Please run on a CUDA-enabled device."
    
    # Generate random inputs
    torch.manual_seed(0)
    Q = torch.randn(N, D, device=device, dtype=torch.float32)
    K = torch.randn(N, D, device=device, dtype=torch.float32)

    # Reshape Q and K for FASTA attention
    Q_fasta = Q.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, N, D)
    K_fasta = K.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, N, D)

    # Precompute all standard outputs
    print("Computing standard self-attention for all iterations...")
    standard_times = []
    standard_outputs = []

    for _ in tqdm(range(num_iterations), desc="Standard Self-Attention"):
        start_std = torch.cuda.Event(enable_timing=True)
        end_std = torch.cuda.Event(enable_timing=True)
        
        start_std.record()
        attn_ref = Q @ K.T
        end_std.record()
        
        torch.cuda.synchronize()
        elapsed_std = start_std.elapsed_time(end_std)
        standard_times.append(elapsed_std)
        standard_outputs.append(attn_ref.detach().cpu().numpy())

    print("Computing optimized FASTA attention for all iterations...")
    # Precompute all FASTA outputs
    fasta_times = []
    fasta_outputs = []

    for _ in tqdm(range(num_iterations), desc="FASTA Attention"):
        start_fasta = torch.cuda.Event(enable_timing=True)
        end_fasta = torch.cuda.Event(enable_timing=True)
        
        start_fasta.record()
        attn_fasta = fasta_attn(Q_fasta, K_fasta, block_size=block_size).squeeze(0).squeeze(0)
        end_fasta.record()
        
        torch.cuda.synchronize()
        elapsed_fasta = start_fasta.elapsed_time(end_fasta)
        fasta_times.append(elapsed_fasta)
        fasta_outputs.append(attn_fasta.detach().cpu().numpy())

    # Convert timing lists to NumPy arrays
    standard_times = np.array(standard_times)
    fasta_times = np.array(fasta_times)

    # Compute errors after all computations
    print("Computing error metrics...")
    mae_list = []
    max_error_list = []
    relative_error_list = []

    for ref, fasta in zip(standard_outputs, fasta_outputs):
        abs_diff = np.abs(ref - fasta)
        mae_list.append(abs_diff.mean().item())
        max_error_list.append(abs_diff.max().item())
        relative_error_list.append((abs_diff / np.abs(ref + 1e-6)).mean().item())

    # Print results
    print("Benchmarking completed!")
    print("\nTiming Statistics:")
    print(f"Standard Self-Attention - Mean: {standard_times.mean():.4f} ms, Std: {standard_times.std():.4f} ms")
    print(f"FASTA Attention - Mean: {fasta_times.mean():.4f} ms, Std: {fasta_times.std():.4f} ms")
    print(f"Percentage Improvement: {100 * (1 - fasta_times.mean() / standard_times.mean()):.2f}%")

    print("\nError Metrics (Average over all iterations):")
    print(f"Mean Absolute Error (MAE): {np.mean(mae_list):.6f}")
    print(f"Maximum Absolute Error: {np.mean(max_error_list):.6f}")
    print(f"Relative Error: {np.mean(relative_error_list):.6f}")

    # Plot timing distributions
    plt.figure(figsize=(12, 6))
    sns.histplot(fasta_times, color='blue', label='FASTA Attention', kde=True, stat="density", bins=50, alpha=0.6)
    sns.histplot(standard_times, color='orange', label='Standard Attention', kde=True, stat="density", bins=50, alpha=0.6)
    plt.title('Timing Distributions')
    plt.xlabel('Time (ms)')
    plt.ylabel('Density')
    plt.legend()
    plt.grid(True)
    plt.show()

if __name__ == "__main__":
    test_fasta_attention_benchmark()