In [None]:
import torch
import torch.nn.functional as F
print(torch.__version__)

'2.1.0+cu118'

In [None]:
import time
import torch
import torch.nn.functional as F

# Set up the basic parameters for the model
batch_size = 32
sequence_length = 2048
dimensions = 64
number_of_heads = 8

# Creating query (q), key (k), and value (v) tensors
# These tensors are initialized with random values and moved to the GPU for faster processing
q = torch.randn(batch_size, number_of_heads, sequence_length, dimensions, dtype=torch.float16).cuda()
k = torch.randn(batch_size, number_of_heads, sequence_length, dimensions, dtype=torch.float16).cuda()
v = torch.randn(batch_size, number_of_heads, sequence_length, dimensions, dtype=torch.float16).cuda()

# Define the dropout rate and number of trials for benchmarking
dropout_rate = 0.2
num_trials = 10

# Standard Attention Computation
torch.cuda.synchronize()  # Synchronizes CPU and GPU to ensure accurate timing
start = time.time()  # Start timer
for i in range(num_trials):
    attn = q @ k.transpose(-2, -1)  # Compute attention scores
    attn = attn.softmax(dim=-1)  # Apply softmax to get probabilities
    attn = F.dropout(attn, p=dropout_rate, training=True)  # Apply dropout
    x = (attn @ v).transpose(1, 2)  # Apply attention to value and reshape
torch.cuda.synchronize()  # Ensure all GPU tasks are finished
end = time.time()  # End timer
print('Standard attention took {} seconds for {} trials'.format(end - start, num_trials))

# Flash Attention Computation
with torch.backends.cuda.sdp_kernel(
    enable_flash=True, enable_math=False, enable_mem_efficient=False
):
    torch.cuda.synchronize()  # Synchronizes CPU and GPU for accurate timing
    start = time.time()  # Start timer
    for i in range(num_trials):
        out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_rate)  # Compute attention using FlashAttention
    torch.cuda.synchronize()  # Ensure completion of all GPU tasks
    end = time.time()  # End timer
    print('Flash attention took {} seconds for {} trials'.format(end - start, num_trials))


Standard attention took 0.9052870273590088 seconds for 10 trials
Flash attention took 0.31307268142700195 seconds for 10 trials


In [None]:
!pip install flash-attn --no-build-isolation

Collecting flash-attn
  Downloading flash_attn-2.3.3.tar.gz (2.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting einops (from flash-attn)
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
Collecting ninja (from flash-attn)
  Downloading ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (307 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m307.2/307.2 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: flash-attn
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone
  Created wheel for flash-attn: filename=flash_attn-2.3.3-cp310-cp310-linux_x86_64.whl size=57075008 sha256=bcb63b64213ab61590b340b77de84e448a442e19c100480895194df39ad7673d
  Stor

In [None]:
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func


In [None]:
!git clone https://github.com/Dao-AILab/flash-attention.git

Cloning into 'flash-attention'...
remote: Enumerating objects: 4303, done.[K
remote: Counting objects: 100% (1785/1785), done.[K
remote: Compressing objects: 100% (191/191), done.[K
remote: Total 4303 (delta 1632), reused 1606 (delta 1590), pack-reused 2518[K
Receiving objects: 100% (4303/4303), 6.88 MiB | 17.61 MiB/s, done.
Resolving deltas: 100% (3007/3007), done.


In [None]:
%cd flash-attention/

/content/flash-attention


In [None]:
!pytest -q -s tests/test_flash_attn.py

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
        q, k, v = [maybe_contiguous(x) [94mfor[39;49;00m x [95min[39;49;00m (q, k, v)][90m[39;49;00m
>       out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd([90m[39;49;00m
            q,[90m[39;49;00m
            k,[90m[39;49;00m
            v,[90m[39;49;00m
            [94mNone[39;49;00m,[90m[39;49;00m
            dropout_p,[90m[39;49;00m
            softmax_scale,[90m[39;49;00m
            causal,[90m[39;49;00m
            window_size[[94m0[39;49;00m],[90m[39;49;00m
            window_size[[94m1[39;49;00m],[90m[39;49;00m
            return_softmax,[90m[39;49;00m
            [94mNone[39;49;00m,[90m[39;49;00m
        )[90m[39;49;00m
[1m[31mE       RuntimeError: FlashAttention only supports Ampere GPUs or newer.[0m

[1m[31m/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py[0m:47: RuntimeError
[31m[1m_____________________