In [1]:
# make sure to use the right version of pytorch when testing on nvidia/amd
import torch
import torch.nn.functional as F
torch.set_printoptions(sci_mode=False)

In [2]:
import math
M = 32


def flashattn(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
    assert Q.shape == K.shape == V.shape
    assert len(Q.shape) == 2

    K = K.T

    # seq length, inner dim of representations
    N, d = Q.shape

    # of kv vectors per tile
    bc = math.ceil(M/(4*d))
    tc = math.ceil(N/bc)
    K_shared = torch.empty((d, bc), dtype=Q.dtype)
    V_shared = torch.empty((bc, d), dtype=Q.dtype)

    # of q vectors per tile
    br = min(math.ceil(M/(4*d)), d)
    tr = math.ceil(N/br)
    Q_shared = torch.empty((br, d), dtype=Q.dtype)

    # print(f'bc={bc}, br={br}')
    # print(f'tc={tc}, tr={tr}')

    # output matrix
    O = torch.zeros_like(Q)
    O_shared = torch.empty_like(Q_shared)

    # intermediate rowmaxes
    m = torch.full((N,), -torch.inf, dtype=Q.dtype)
    m_shared = torch.empty((br, 1), dtype=Q.dtype)
    # intermediate normalization constants
    l = torch.full((N,), 0, dtype=Q.dtype)
    l_shared = torch.empty((br, 1), dtype=Q.dtype)

    for i in range(tc):
        # load k, v chunks
        # make sure we load in k as its transposed version
        K_shared[:, :] = K[:, i*bc:(i+1)*bc]
        V_shared[:, :] = V[i*bc:(i+1)*bc, :]

        for j in range(tr):
            # load in q, o, m, l
            Q_shared[:, :] = Q[j*br:(j+1)*br, :]
            # if i == 0: print(f'Q: {Q_shared}')
            O_shared[:, :] = O[j*br:(j+1)*br, :]
            m_shared[:, :] = m[j*br:(j+1)*br].unsqueeze(-1)
            l_shared[:, :] = l[j*br:(j+1)*br].unsqueeze(-1)
            
            S = Q_shared @ K_shared

            # get row-wise softmax statistics
            mt = S.max(dim=1).values.reshape(-1, 1)

            Pt = torch.exp(S - mt)
            lt = Pt.sum(dim=1).reshape(-1, 1)

            # compute new statistics
            m_new = torch.max(mt, m_shared)
            l_new = (torch.exp(m_shared - m_new) * l_shared) + (torch.exp(mt - m_new) * lt)


            # update chunk of output
            O_new = (l_shared * torch.exp(m_shared - m_new) * O_shared + torch.exp(mt - m_new) * Pt @ V_shared) / l_new
            O[j*br:(j+1)*br, :] = O_new
            
            m[j*br:(j+1)*br] = m_new.flatten()
            l[j*br:(j+1)*br] = l_new.flatten()

    return O






In [125]:
from math import sqrt
def dumb_attn(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
    """equivalent to F.scaled_dot_product_attention(Q, K, V, scale=1)"""
    d = Q.shape[1]
    return torch.softmax((Q)@ K.T, dim=1) @ V
    # return F.scaled_dot_product_attention(Q, K, V, scale=1)


# have to support up to like d=4096
# we're giving up and only supporting fp16 (like flashattn)
Q = torch.rand((64, 256), dtype=torch.float16, device='cuda') 
K = torch.rand((64, 256), dtype=torch.float16, device='cuda') 
V = torch.rand((64, 256), dtype=torch.float16, device='cuda') 

# K[-1] = torch.zeros((2,))
# V[-1] = torch.zeros((2,))


# Q = torch.randint(1, 9, (8, 4), device='cuda').to(torch.float16)
# K = torch.randint(1, 9, (8, 4), device='cuda').to(torch.float16)
# V = torch.randint(1, 9, (8, 4), device='cuda').to(torch.float16)

# Q = torch.tensor([[0., 1.],
#          [2., 3.]])

# K = torch.tensor([[0., 1.5],
#          [2., 3.]])

# V = torch.tensor([[1., 0.],
#          [0., 1.]])
Q, K, V


(tensor([[0.9692, 0.8667, 0.2769,  ..., 0.4790, 0.0756, 0.5806],
         [0.3923, 0.5615, 0.7637,  ..., 0.9009, 0.4824, 0.7891],
         [0.6646, 0.3252, 0.7012,  ..., 0.2593, 0.7358, 0.0787],
         ...,
         [0.1360, 0.9292, 0.6895,  ..., 0.6885, 0.7935, 0.1575],
         [0.6240, 0.0032, 0.4893,  ..., 0.9106, 0.8154, 0.3074],
         [0.5620, 0.2218, 0.2180,  ..., 0.9136, 0.2301, 0.7876]],
        device='cuda:0', dtype=torch.float16),
 tensor([[0.0101, 0.1547, 0.5068,  ..., 0.3699, 0.1704, 0.1504],
         [0.0531, 0.0361, 0.3989,  ..., 0.4812, 0.9629, 0.5723],
         [0.8521, 0.2830, 0.4976,  ..., 0.5796, 0.7998, 0.8433],
         ...,
         [0.7920, 0.8970, 0.3716,  ..., 0.8301, 0.1196, 0.1761],
         [0.6099, 0.3625, 0.9199,  ..., 0.4976, 0.0604, 0.6001],
         [0.1112, 0.6943, 0.0289,  ..., 0.2869, 0.7617, 0.9707]],
        device='cuda:0', dtype=torch.float16),
 tensor([[0.5015, 0.7148, 0.1860,  ..., 0.1707, 0.5869, 0.5840],
         [0.9722, 0.1265, 0.362

In [141]:
from torch.utils.cpp_extension import load
module = load(
    name='m',
    # sources=['cuda/main.cpp', 'cuda/flash_attention.cu'],
    sources=['rocm/flash_attention.hip',],
    extra_cflags=['--offload-arch="gfx1100"',],
    build_directory='build',
    verbose=True
)

/home/seb/Code/flash-attention/rocm/flash_attention.hip -> /home/seb/Code/flash-attention/rocm/flash_attention_hip.hip [ok]
Total number of unsupported CUDA function calls: 0


Total number of replaced kernel launches: 1


The input conditions for extension module m have changed. Bumping to version 20 and re-building as m_v20...
[92mSuccessfully preprocessed all matching files.[0m
Detected CUDA files, patching ldflags
Emitting ninja build file build/build.ninja...
Building extension module m_v20...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/2] /opt/rocm-6.1.3/bin/hipcc  -DWITH_HIP -DTORCH_EXTENSION_NAME=m_v20 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/TH -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/THC -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/THH -isystem /opt/rocm-6.1.3/include -isystem /home/seb/miniconda3/include/python3.12 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 --offload-arch="gfx1100" -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -fno-gpu-rdc -c /home/seb/Code/flash-attention/rocm/flash_a

Loading extension module m_v20...


In [110]:
# torch.isclose(dumb_attn(Q, K, V), flashattn(Q, K, V))
# dumb_attn(Q, K, V), flashattn(Q, K, V)

# F.scaled_dot_product_attention(Q, K, V, scale=1), flashattn(Q, K, V)
# module.flash_attn(Q, K, V)
torch.set_printoptions(profile='full', sci_mode=False)
r1 = dumb_attn(Q, K, V)
r2 = module.flash_attn(Q, K, V)
(r1-r2).abs().mean()
# r1, r2


# # r1, dumb_attn(Q.float(), K.float(), V.float()), r2
# # r1, r2[:50]
# dumb_attn(Q, K, V), module.flash_attn(Q, K, V)

tc=8, tr=8
Q:
[[0.2263, 0.8062, 0.5610, 0.5864, 0.2240, 0.1256, 0.2905, 0.5786, 0.0806, 0.0186, 0.0041, 0.9883, 0.2656, 0.7109, 0.4702, 0.1790, 0.3147, 0.4700, 0.8726, 0.0129, 0.1383, 0.8765, 0.9248, 0.5820, 0.0030, 0.6089, 0.0879, 0.9785, 0.0206, 0.9170, 0.2590, 0.9111, 0.4058, 0.0702, 0.0619, 0.7021, 0.7764, 0.7290, 0.9497, 0.6519, 0.8887, 0.9346, 0.1448, 0.8726, 0.4983, 0.9072, 0.9863, 0.3220, 0.9624, 0.1345, 0.3250, 0.7100, 0.0387, 0.8130, 0.5396, 0.1399, 0.3662, 0.5425, 0.1506, 0.3972, 0.5220, 0.3684, 0.8857, 0.1910, 0.3687, 0.4619, 0.2306, 0.3833, 0.1346, 0.3931, 0.8198, 0.4258, 0.8188, 0.4353, 0.8340, 0.5137, 0.7925, 0.2195, 0.1945, 0.7764, 0.4858, 0.5781, 0.6514, 0.0022, 0.3240, 0.1368, 0.5532, 0.5005, 0.7114, 0.5161, 0.8296, 0.6011, 0.7104, 0.7715, 0.4373, 0.7983, 0.2413, 0.4258, 0.8740, 0.8105, 0.9375, 0.5332, 0.4858, 0.9829, 0.3726, 0.4617, 0.3958, 0.3335, 0.3958, 0.9741, 0.1256, 0.1310, 0.0270, 0.3623, 0.9990, 0.2546, 0.6738, 0.9146, 0.4150, 0.3020, 0.3286, 0.1879, 0.8955, 

tensor(    0.0000, device='cuda:0', dtype=torch.float16)

In [164]:

def test_flashattn():
    diff = lambda r1, r2: (r1-r2).abs().mean()
    dims = ((48, 32), (64, 256), (1024, 256), (4096, 1024))
    for N, d in dims:
        Q = torch.rand((N, d), dtype=torch.float16, device='cuda') 
        K = torch.rand((N, d), dtype=torch.float16, device='cuda') 
        V = torch.rand((N, d), dtype=torch.float16, device='cuda') 

        # print(f'Difference on ({N}, {d}) tensors: {diff(dumb_attn(Q, K, V), F.scaled_dot_product_attention(Q, K, V, scale=1))}\n\n')
        print(f'Difference on ({N}, {d}) tensors: {diff(module.flash_attn(Q, K, V), F.scaled_dot_product_attention(Q, K, V, scale=1))}\n\n')
        # print(f'Difference on ({N}, {d}) tensors: {diff(module.flash_attn(Q, K, V), dumb_attn(Q, K, V))}\n\n')

torch.set_printoptions(profile='short', sci_mode=False)
test_flashattn()
torch.set_printoptions(profile='default', sci_mode=False)

tc=12, tr=6
Launching kernel with smem size 1600
Difference on (48, 32) tensors: 0.00038886070251464844


tc=16, tr=8
Launching kernel with smem size 12352
Difference on (64, 256) tensors: 0.001415252685546875


tc=256, tr=128
Launching kernel with smem size 12352
Difference on (1024, 256) tensors: 0.002254486083984375


tc=1024, tr=512
Launching kernel with smem size 49216
Difference on (4096, 1024) tensors: 0.00667572021484375


