<a href="https://colab.research.google.com/github/sandeepnmenon/FlashAttention_tests/blob/master/FlashAttention_Hacker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install flash-attn --no-build-isolation

In [1]:
import torch
import torch.nn.functional as F
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
print(torch.__version__)

2.1.0+cu118


In [7]:
import time
import torch
import torch.nn.functional as F

# Set up the basic parameters for the model
batch_size = 32
sequence_length = 2048
dimensions = 64
number_of_heads = 8

# Creating query (q), key (k), and value (v) tensors
# These tensors are initialized with random values and moved to the GPU for faster processing
q = torch.randn(batch_size, number_of_heads, sequence_length, dimensions, dtype=torch.float16).cuda()
k = torch.randn(batch_size, number_of_heads, sequence_length, dimensions, dtype=torch.float16).cuda()
v = torch.randn(batch_size, number_of_heads, sequence_length, dimensions, dtype=torch.float16).cuda()

# Define the dropout rate and number of trials for benchmarking
dropout_rate = 0.2
num_trials = 10

# Standard Attention Computation
torch.cuda.synchronize()  # Synchronizes CPU and GPU to ensure accurate timing
start = time.time()  # Start timer
for i in range(num_trials):
    attn = q @ k.transpose(-2, -1)  # Compute attention scores
    attn = attn.softmax(dim=-1)  # Apply softmax to get probabilities
    attn = F.dropout(attn, p=dropout_rate, training=True)  # Apply dropout
    x = (attn @ v).transpose(1, 2)  # Apply attention to value and reshape
torch.cuda.synchronize()  # Ensure all GPU tasks are finished
end = time.time()  # End timer
print('Standard attention took {} seconds for {} trials'.format(end - start, num_trials))

# Flash Attention Computation
with torch.backends.cuda.sdp_kernel(
    enable_flash=True, enable_math=False, enable_mem_efficient=False
):
    torch.cuda.synchronize()  # Synchronizes CPU and GPU for accurate timing
    start = time.time()  # Start timer
    for i in range(num_trials):
        out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_rate)  # Compute attention using FlashAttention
    torch.cuda.synchronize()  # Ensure completion of all GPU tasks
    end = time.time()  # End timer
    print('Flash attention took {} seconds for {} trials'.format(end - start, num_trials))

# FlashAttention V2 Computation (add this to your existing code)
torch.cuda.synchronize()  # Synchronizes CPU and GPU for accurate timing
start = time.time()  # Start timer
for i in range(num_trials):
    # Replace 'flash_attn_func' with the actual FlashAttention V2 function if it has a different name
    out = flash_attn_func(q, k, v, dropout_p=dropout_rate)  # Compute attention using FlashAttention V2
torch.cuda.synchronize()  # Ensure completion of all GPU tasks
end = time.time()  # End timer
print('FlashAttention V2 took {} seconds for {} trials'.format(end - start, num_trials))

# FlashAttention QKV Packed Computation (add this to your existing code)
torch.cuda.synchronize()  # Synchronizes CPU and GPU for accurate timing
# Prepare the qkv tensor by stacking q, k, v along the third dimension
qkv = torch.stack((q, k, v), dim=2)
start = time.time()  # Start timer
for i in range(num_trials):
    # Call the flash_attn_qkvpacked_func with the stacked qkv tensor
    out= flash_attn_qkvpacked_func(qkv, dropout_p=dropout_rate)
torch.cuda.synchronize()  # Ensure completion of all GPU tasks
end = time.time()  # End timer
print('FlashAttention QKV Packed took {} seconds for {} trials'.format(end - start, num_trials))


Standard attention took 0.7057886123657227 seconds for 10 trials
Flash attention took 0.2725660800933838 seconds for 10 trials


RuntimeError: ignored