<a href="https://colab.research.google.com/github/sandeepnmenon/FlashAttention_tests/blob/master/FlashAttention_Hacker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install flash-attn --no-build-isolation

Collecting flash-attn
  Downloading flash_attn-2.3.3.tar.gz (2.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m28.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting einops (from flash-attn)
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
Collecting ninja (from flash-attn)
  Downloading ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (307 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m307.2/307.2 kB[0m [31m35.8 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: flash-attn
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone
  Created wheel for flash-attn: filename=flash_attn-2.3.3-cp310-cp310-linux_x86_64.whl size=57075008 sha256=bcb63b64213ab61590b340b77de84e448a442e19c100480895194df39ad7673d
  Sto

In [2]:
import torch
import torch.nn.functional as F
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
print(torch.__version__)

2.1.0+cu118


In [3]:
# Set up the basic parameters for the model
batch_size = 32
sequence_length = 2048
dimensions = 64
number_of_heads = 8

# Define the dropout rate and number of trials for benchmarking
dropout_rate = 0.0
num_trials = 10


# FlashAttention might be flashy but it is not an Approximation

In [13]:
# Generate a single set of random q, k, v tensors
q_single = torch.randn(batch_size, number_of_heads, sequence_length, dimensions, dtype=torch.float16).cuda()
k_single = torch.randn(batch_size, number_of_heads, sequence_length, dimensions, dtype=torch.float16).cuda()
v_single = torch.randn(batch_size, number_of_heads, sequence_length, dimensions, dtype=torch.float16).cuda()
qkv_single = torch.stack((q_single, k_single, v_single), dim=2)

# Standard Attention Computation
attn = q_single @ k_single.transpose(-2, -1)  # Compute attention scores
attn = attn.softmax(dim=-1)  # Apply softmax to get probabilities
attn = F.dropout(attn, p=dropout_rate, training=True)  # Apply dropout
x_standard = (attn @ v_single)  # Apply attention to value and reshape

# Flash Attention Computation using scaled_dot_product_attention
with torch.backends.cuda.sdp_kernel(
    enable_flash=True, enable_math=False, enable_mem_efficient=False
):
    # Note: Assuming 'enable_flash' triggers FlashAttention internally in scaled_dot_product_attention
    out_flash_sdp = F.scaled_dot_product_attention(q_single, k_single, v_single, dropout_p=dropout_rate)

# # FlashAttention V2 Computation
# out_flash_v2, _, _ = flash_attn_func(q_single, k_single, v_single, dropout_p=dropout_rate, return_attn_probs=True)

# # FlashAttention QKV Packed Computation
# out_flash_qkv_packed, _, _ = flash_attn_qkvpacked_func(qkv_single, dropout_p=dropout_rate, return_attn_probs=True)

# Compare all the output4
tolerance = 1e-1  # Tolerance level for floating-point comparisons
if (torch.allclose(x_standard, out_flash_sdp, atol=tolerance)):
    print('All attention implementations produce close enough results.')
else:
    print('There is a discrepancy between the attention implementations.')

# if (torch.allclose(x_standard, out_flash_sdp, atol=tolerance) and
#     torch.allclose(out_flash_sdp, out_flash_v2, atol=tolerance) and
#     torch.allclose(out_flash_v2, out_flash_qkv_packed, atol=tolerance)):
#     print('All attention implementations produce close enough results.')
# else:
#     print('There is a discrepancy between the attention implementations.')

There is a discrepancy between the attention implementations.


# Time Benchmark

In [4]:
import time
import torch
import torch.nn.functional as F


# Creating query (q), key (k), and value (v) tensors
# These tensors are initialized with random values and moved to the GPU for faster processing
q = torch.randn(batch_size, number_of_heads, sequence_length, dimensions, dtype=torch.float16).cuda()
k = torch.randn(batch_size, number_of_heads, sequence_length, dimensions, dtype=torch.float16).cuda()
v = torch.randn(batch_size, number_of_heads, sequence_length, dimensions, dtype=torch.float16).cuda()


# Standard Attention Computation
torch.cuda.synchronize()  # Synchronizes CPU and GPU to ensure accurate timing
start = time.time()  # Start timer
for i in range(num_trials):
    attn = q @ k.transpose(-2, -1)  # Compute attention scores
    attn = attn.softmax(dim=-1)  # Apply softmax to get probabilities
    attn = F.dropout(attn, p=dropout_rate, training=True)  # Apply dropout
    x = (attn @ v).transpose(1, 2)  # Apply attention to value and reshape
torch.cuda.synchronize()  # Ensure all GPU tasks are finished
end = time.time()  # End timer
print('Standard attention took {} seconds for {} trials'.format(end - start, num_trials))

# Flash Attention Computation
with torch.backends.cuda.sdp_kernel(
    enable_flash=True, enable_math=False, enable_mem_efficient=False
):
    torch.cuda.synchronize()  # Synchronizes CPU and GPU for accurate timing
    start = time.time()  # Start timer
    for i in range(num_trials):
        out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_rate)  # Compute attention using FlashAttention
    torch.cuda.synchronize()  # Ensure completion of all GPU tasks
    end = time.time()  # End timer
    print('Flash attention took {} seconds for {} trials'.format(end - start, num_trials))

# FlashAttention V2 Computation (add this to your existing code)
torch.cuda.synchronize()  # Synchronizes CPU and GPU for accurate timing
start = time.time()  # Start timer
for i in range(num_trials):
    # Replace 'flash_attn_func' with the actual FlashAttention V2 function if it has a different name
    out = flash_attn_func(q, k, v, dropout_p=dropout_rate)  # Compute attention using FlashAttention V2
torch.cuda.synchronize()  # Ensure completion of all GPU tasks
end = time.time()  # End timer
print('FlashAttention V2 took {} seconds for {} trials'.format(end - start, num_trials))

# FlashAttention QKV Packed Computation (add this to your existing code)
torch.cuda.synchronize()  # Synchronizes CPU and GPU for accurate timing
# Prepare the qkv tensor by stacking q, k, v along the third dimension
qkv = torch.stack((q, k, v), dim=2)
start = time.time()  # Start timer
for i in range(num_trials):
    # Call the flash_attn_qkvpacked_func with the stacked qkv tensor
    out= flash_attn_qkvpacked_func(qkv, dropout_p=dropout_rate)
torch.cuda.synchronize()  # Ensure completion of all GPU tasks
end = time.time()  # End timer
print('FlashAttention QKV Packed took {} seconds for {} trials'.format(end - start, num_trials))


Standard attention took 0.31433939933776855 seconds for 10 trials
Flash attention took 0.10668373107910156 seconds for 10 trials
FlashAttention V2 took 0.02718377113342285 seconds for 10 trials
FlashAttention QKV Packed took 0.02114129066467285 seconds for 10 trials
