In [1]:
import torch
import torch.nn as nn

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy, json
from tqdm import tqdm

In [2]:
device = torch.device("cuda:0")

In [3]:
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func

In [4]:
#  qkv: (batch_size, seqlen, 3, nheads, headdim)
batch_size = 64
seqlen = 2**16
nheads = 8
headdim = 32
print("Seqlen", seqlen)

Seqlen 65536


In [5]:
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, 
                  device=device, dtype=torch.float16).to(device)

qkv.shape

torch.Size([64, 65536, 3, 8, 32])

In [6]:
%%timeit -r 4 -n 100 

flash_attn_qkvpacked_func(qkv)

60.3 µs ± 14.1 µs per loop (mean ± std. dev. of 4 runs, 100 loops each)


In [7]:
2**16, 2**8

(65536, 256)

In [8]:
## Reshape to block sparse
block_len = int(np.sqrt(seqlen))

qkv_ = qkv.view(batch_size*block_len, block_len, 3, nheads, headdim)
qkv_.shape

torch.Size([16384, 256, 3, 8, 32])

In [9]:
%%timeit -r 4 -n 100 

flash_attn_qkvpacked_func(qkv_)

44.2 µs ± 6.01 µs per loop (mean ± std. dev. of 4 runs, 100 loops each)


In [11]:
exit(0)

## Manual test

In [None]:
start = time.time()
flash_attn_qkvpacked_func(qkv)
start = time.time()-start
print(start*100, "ms")

In [None]:
start = time.time()
flash_attn_qkvpacked_func(qkv_)
start = time.time()-start
print(start*100, "ms")

In [None]:
ms = np.mean(time_taken)*100
print("Time (ms)", ms)

In [None]:
time_taken = []
for i in range(10):
    with torch.no_grad():
        start = time.time()
        flash_attn_qkvpacked_func(qkv)
        start = time.time()-start
        time_taken.append(start)

In [None]:
ms = np.mean(time_taken)*100
print("Time (ms)", ms)

In [None]:
## Test on model training -> (Inconclusive results)