<a href="https://colab.research.google.com/github/sandeepnmenon/FlashAttention_tests/blob/master/FlashAttention_Hacker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install flash-attn --no-build-isolation

Collecting flash-attn
  Downloading flash_attn-2.3.4.tar.gz (2.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m30.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting einops (from flash-attn)
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
Collecting ninja (from flash-attn)
  Downloading ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (307 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m307.2/307.2 kB[0m [31m38.6 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: flash-attn
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone
  Created wheel for flash-attn: filename=flash_attn-2.3.4-cp310-cp310-linux_x86_64.whl size=57449248 sha256=4ceff7f9ad6a6747975b568f3193b2416f598808b5e860fb09555dbfa32b42cf
  Sto

In [2]:
import torch
import torch.nn.functional as F
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
print(torch.__version__)

2.1.0+cu118


In [3]:
# Set up the basic parameters for the model
batch_size = 32
sequence_length = 2048
dimensions = 64
number_of_heads = 8

# Define the dropout rate and number of trials for benchmarking
dropout_rate = 0.0
num_trials = 10


# Time Benchmark

In [None]:
import time
import torch
import torch.nn.functional as F


# Creating query (q), key (k), and value (v) tensors
# These tensors are initialized with random values and moved to the GPU for faster processing
q = torch.randn(batch_size, number_of_heads, sequence_length, dimensions, dtype=torch.float16).cuda()
k = torch.randn(batch_size, number_of_heads, sequence_length, dimensions, dtype=torch.float16).cuda()
v = torch.randn(batch_size, number_of_heads, sequence_length, dimensions, dtype=torch.float16).cuda()


# Standard Attention Computation
torch.cuda.synchronize()  # Synchronizes CPU and GPU to ensure accurate timing
start = time.time()  # Start timer
for i in range(num_trials):
    attn = q @ k.transpose(-2, -1)  # Compute attention scores
    attn = attn.softmax(dim=-1)  # Apply softmax to get probabilities
    attn = F.dropout(attn, p=dropout_rate, training=True)  # Apply dropout
    x = (attn @ v).transpose(1, 2)  # Apply attention to value and reshape
torch.cuda.synchronize()  # Ensure all GPU tasks are finished
end = time.time()  # End timer
standard_attention_time = end - start
print('Standard attention took {} seconds for {} trials'.format(standard_attention_time, num_trials))

# Flash Attention Computation
with torch.backends.cuda.sdp_kernel(
    enable_flash=True, enable_math=False, enable_mem_efficient=False
):
    torch.cuda.synchronize()  # Synchronizes CPU and GPU for accurate timing
    start = time.time()  # Start timer
    for i in range(num_trials):
        out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_rate)  # Compute attention using FlashAttention
    torch.cuda.synchronize()  # Ensure completion of all GPU tasks
    end = time.time()  # End timer
    flash_attention_time = end - start
    print('Flash attention took {} seconds for {} trials'.format(end - start, num_trials))

# Speedup for Flash Attention
flash_attention_speedup = standard_attention_time / flash_attention_time
print('Speedup of Flash attention over standard attention: {:.2f}x'.format(flash_attention_speedup))

# FlashAttention V2 Computation (add this to your existing code)
torch.cuda.synchronize()  # Synchronizes CPU and GPU for accurate timing
start = time.time()  # Start timer
for i in range(num_trials):
    # Replace 'flash_attn_func' with the actual FlashAttention V2 function if it has a different name
    out = flash_attn_func(q, k, v, dropout_p=dropout_rate)  # Compute attention using FlashAttention V2
torch.cuda.synchronize()  # Ensure completion of all GPU tasks
end = time.time()  # End timer
flash_attention_v2_time = end - start
print('FlashAttention V2 took {} seconds for {} trials'.format(end - start, num_trials))

# Speedup for Flash Attention V2
flash_attention_v2_speedup = standard_attention_time / flash_attention_v2_time
print('Speedup of Flash attention over standard attention: {:.2f}x'.format(flash_attention_v2_speedup))


# FlashAttention QKV Packed Computation (add this to your existing code)
torch.cuda.synchronize()  # Synchronizes CPU and GPU for accurate timing
# Prepare the qkv tensor by stacking q, k, v along the third dimension
qkv = torch.stack((q, k, v), dim=2)
start = time.time()  # Start timer
for i in range(num_trials):
    # Call the flash_attn_qkvpacked_func with the stacked qkv tensor
    out= flash_attn_qkvpacked_func(qkv, dropout_p=dropout_rate)
torch.cuda.synchronize()  # Ensure completion of all GPU tasks
end = time.time()  # End timer
flash_attention_v2_time = end - start

print('FlashAttention QKV Packed took {} seconds for {} trials'.format(end - start, num_trials))

# Speedup for Flash Attention V2 qkv stacked
flash_attention_v2_speedup = standard_attention_time / flash_attention_v2_time
print('Speedup of Flash attention over standard attention: {:.2f}x'.format(flash_attention_v2_speedup))


Standard attention took 0.31433939933776855 seconds for 10 trials
Flash attention took 0.10668373107910156 seconds for 10 trials
FlashAttention V2 took 0.02718377113342285 seconds for 10 trials
FlashAttention QKV Packed took 0.02114129066467285 seconds for 10 trials


# GPT Inference

In [4]:
!git clone https://github.com/graykode/gpt-2-Pytorch
%cd gpt-2-Pytorch
!curl --output gpt2-pytorch_model.bin https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin
!pip install -r requirements.txt

Cloning into 'gpt-2-Pytorch'...
remote: Enumerating objects: 130, done.[K
remote: Total 130 (delta 0), reused 0 (delta 0), pack-reused 130[K
Receiving objects: 100% (130/130), 2.39 MiB | 27.48 MiB/s, done.
Resolving deltas: 100% (48/48), done.
/content/gpt-2-Pytorch
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  522M  100  522M    0     0  15.3M      0  0:00:34  0:00:34 --:--:-- 16.7M
Collecting regex==2017.4.5 (from -r requirements.txt (line 1))
  Downloading regex-2017.04.05.tar.gz (601 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m601.6/601.6 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: regex
  Building wheel for regex (setup.py) ... [?25l[?25hdone
  Created wheel for regex: filename=regex-2017.4.5-cp310-cp310-linux_x86_64.whl size=657352 sha256

In [11]:
!python main.py --text "Once when I was six years old I saw a magnificent picture in a book, called True Stories from Nature, about the primeval forest." --length 100


Namespace(text='Once when I was six years old I saw a magnificent picture in a book, called True Stories from Nature, about the primeval forest.', quiet=False, nsamples=1, unconditional=False, batch_size=-1, length=100, temperature=0.7, top_k=40)
Once when I was six years old I saw a magnificent picture in a book, called True Stories from Nature, about the primeval forest.
  0% 0/100 [00:00<?, ?it/s]qkv:  torch.Size([1, 12, 28, 64]) torch.Size([1, 12, 64, 28]) torch.Size([1, 12, 28, 64])
  0% 0/100 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/content/gpt-2-Pytorch/main.py", line 79, in <module>
    text_generator(state_dict)
  File "/content/gpt-2-Pytorch/main.py", line 61, in text_generator
    out = sample_sequence(
  File "/content/gpt-2-Pytorch/GPT2/sample.py", line 29, in sample_sequence
    logits, past = model(prev, past=past)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*

In [7]:
import time
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load the model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()
model.cuda()  # Move model to GPU

# Initial input text
input_text = "Once when I was six years old I saw a magnificent picture in a book, called True Stories from Nature, about the primeval forest."
input_ids = tokenizer.encode(input_text, return_tensors='pt').cuda()

# Function to generate 100 tokens autoregressively
def generate_100_tokens(model, input_ids):
    generated_text = input_ids
    for _ in range(100):
        outputs = model(generated_text)
        next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        generated_text = torch.cat([generated_text, next_token], dim=-1)
    return generated_text

# Measure inference time without FlashAttention
torch.cuda.synchronize()
start_time = time.time()
generated_text_normal = generate_100_tokens(model, input_ids)
torch.cuda.synchronize()
normal_inference_time = time.time() - start_time

print(f'Normal Inference Time: {normal_inference_time:.3f} seconds')

# Measure inference time with FlashAttention
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
    torch.cuda.synchronize()
    start_time = time.time()
    generated_text_flash = generate_100_tokens(model, input_ids)
    torch.cuda.synchronize()
    flash_inference_time = time.time() - start_time
print(f'FlashAttention Inference Time: {flash_inference_time:.3f} seconds')

# Decode and print the generated text
print("Generated Text without FlashAttention:")
print(tokenizer.decode(generated_text_normal[0]))

print("\nGenerated Text with FlashAttention:")
print(tokenizer.decode(generated_text_flash[0]))


vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Normal Inference Time: 2.037 seconds
FlashAttention Inference Time: 1.520 seconds
Generated Text without FlashAttention:
Once when I was six years old I saw a magnificent picture in a book, called True Stories from Nature, about the primeval forest. It was a beautiful picture, and I was so happy to see it. I was so happy to see it. I was so happy to see it. I was so happy to see it. I was so happy to see it. I was so happy to see it. I was so happy to see it. I was so happy to see it. I was so happy to see it. I was so happy to see it. I was so happy to see it. I was so happy to

Generated Text with FlashAttention:
Once when I was six years old I saw a magnificent picture in a book, called True Stories from Nature, about the primeval forest. It was a beautiful picture, and I was so happy to see it. I was so happy to see it. I was so happy to see it. I was so happy to see it. I was so happy to see it. I was so happy to see it. I was so happy to see it. I was so happy to see it. I was so