In [1]:
# make sure to use the right version of pytorch when testing on nvidia/amd
import torch
import torch.nn.functional as F

In [2]:
import math
M = 32


def flashattn(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
    assert Q.shape == K.shape == V.shape
    assert len(Q.shape) == 2

    K = K.T

    # seq length, inner dim of representations
    N, d = Q.shape

    # of kv vectors per tile
    bc = math.ceil(M/(4*d))
    tc = math.ceil(N/bc)
    K_shared = torch.empty((d, bc), dtype=Q.dtype)
    V_shared = torch.empty((bc, d), dtype=Q.dtype)

    # of q vectors per tile
    br = min(math.ceil(M/(4*d)), d)
    tr = math.ceil(N/br)
    Q_shared = torch.empty((br, d), dtype=Q.dtype)

    # print(f'bc={bc}, br={br}')
    # print(f'tc={tc}, tr={tr}')

    # output matrix
    O = torch.zeros_like(Q)
    O_shared = torch.empty_like(Q_shared)

    # intermediate rowmaxes
    m = torch.full((N,), -torch.inf, dtype=Q.dtype)
    m_shared = torch.empty((br, 1), dtype=Q.dtype)
    # intermediate normalization constants
    l = torch.full((N,), 0, dtype=Q.dtype)
    l_shared = torch.empty((br, 1), dtype=Q.dtype)

    for i in range(tc):
        # load k, v chunks
        # make sure we load in k as its transposed version
        K_shared[:, :] = K[:, i*bc:(i+1)*bc]
        V_shared[:, :] = V[i*bc:(i+1)*bc, :]

        for j in range(tr):
            # load in q, o, m, l
            Q_shared[:, :] = Q[j*br:(j+1)*br, :]
            # if i == 0: print(f'Q: {Q_shared}')
            O_shared[:, :] = O[j*br:(j+1)*br, :]
            m_shared[:, :] = m[j*br:(j+1)*br].unsqueeze(-1)
            l_shared[:, :] = l[j*br:(j+1)*br].unsqueeze(-1)
            
            S = Q_shared @ K_shared

            # get row-wise softmax statistics
            mt = S.max(dim=1).values.reshape(-1, 1)

            Pt = torch.exp(S - mt)
            lt = Pt.sum(dim=1).reshape(-1, 1)

            # compute new statistics
            m_new = torch.max(mt, m_shared)
            l_new = (torch.exp(m_shared - m_new) * l_shared) + (torch.exp(mt - m_new) * lt)


            # update chunk of output
            O_new = (l_shared * torch.exp(m_shared - m_new) * O_shared + torch.exp(mt - m_new) * Pt @ V_shared) / l_new
            O[j*br:(j+1)*br, :] = O_new
            
            m[j*br:(j+1)*br] = m_new.flatten()
            l[j*br:(j+1)*br] = l_new.flatten()

    return O






In [16]:
def dumb_attn(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
    """equivalent to F.scaled_dot_product_attention(Q, K, V, scale=1)"""
    # P = (Q @ K.T)
    # m = P.max(dim=1).values.reshape(-1, 1)
    # print((P - m).exp())
    # return torch.softmax(Q @ K.T, dim=1) @ V
    return F.scaled_dot_product_attention(Q, K, V, scale=1)


Q = torch.rand((1024, 32), dtype=torch.float32, device='cuda')
K = torch.rand((1024, 32), dtype=torch.float32, device='cuda')
V = torch.rand((1024, 32), dtype=torch.float32, device='cuda')

# Q = torch.randint(1, 9, (4, 2), device='cuda').float()
# K = torch.randint(1, 9, (4, 2), device='cuda').float()
# V = torch.randint(1, 9, (4, 2), device='cuda').float()

# Q = torch.tensor([[0., 1.],
#          [2., 3.]])

# K = torch.tensor([[0., 1.5],
#          [2., 3.]])

# V = torch.tensor([[1., 0.],
#          [0., 1.]])
Q, K, V


(tensor([[0.1744, 0.9305, 0.7961,  ..., 0.4479, 0.5641, 0.9814],
         [0.1761, 0.9618, 0.4083,  ..., 0.5819, 0.9369, 0.2934],
         [0.4902, 0.1542, 0.9067,  ..., 0.3008, 0.8327, 0.1166],
         ...,
         [0.0483, 0.5184, 0.7864,  ..., 0.6715, 0.1462, 0.0273],
         [0.9808, 0.8851, 0.3725,  ..., 0.0609, 0.7436, 0.6391],
         [0.1496, 0.4694, 0.3354,  ..., 0.6871, 0.9578, 0.6432]],
        device='cuda:0'),
 tensor([[0.4482, 0.2526, 0.1135,  ..., 0.2954, 0.4629, 0.0603],
         [0.7837, 0.0863, 0.2138,  ..., 0.3781, 0.7255, 0.4389],
         [0.7534, 0.3034, 0.3971,  ..., 0.3404, 0.4329, 0.4226],
         ...,
         [0.8353, 0.1152, 0.5016,  ..., 0.2093, 0.1920, 0.2022],
         [0.1042, 0.5667, 0.6792,  ..., 0.7182, 0.0785, 0.1686],
         [0.5082, 0.0505, 0.3705,  ..., 0.1122, 0.6319, 0.1609]],
        device='cuda:0'),
 tensor([[0.1417, 0.0555, 0.5136,  ..., 0.5614, 0.0462, 0.7551],
         [0.0668, 0.8369, 0.6660,  ..., 0.9265, 0.3773, 0.5001],
        

In [12]:
from torch.utils.cpp_extension import load
module = load(
    name='m',
    # sources=['cuda/main.cpp', 'cuda/flash_attention.cu'],
    sources=['rocm/flash_attention.hip',],
    extra_cflags=['--offload-arch="gfx1100"',],
    build_directory='build',
    verbose=True
)

/home/seb/Code/flash-attention/rocm/flash_attention.hip -> /home/seb/Code/flash-attention/rocm/flash_attention_hip.hip [ok]
Total number of unsupported CUDA function calls: 0


Total number of replaced kernel launches: 1


The input conditions for extension module m have changed. Bumping to version 2 and re-building as m_v2...
[92mSuccessfully preprocessed all matching files.[0m
Detected CUDA files, patching ldflags
Emitting ninja build file build/build.ninja...
Building extension module m_v2...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/2] /opt/rocm-6.1.3/bin/hipcc  -DWITH_HIP -DTORCH_EXTENSION_NAME=m_v2 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/TH -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/THC -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/THH -isystem /opt/rocm-6.1.3/include -isystem /home/seb/miniconda3/include/python3.12 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 --offload-arch="gfx1100" -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -fno-gpu-rdc -c /home/seb/Code/flash-attention/rocm/flash_at

Loading extension module m_v2...


In [18]:
# torch.isclose(dumb_attn(Q, K, V), flashattn(Q, K, V))
# dumb_attn(Q, K, V), flashattn(Q, K, V)
dumb_attn(Q, K, V), module.flash_attn(Q, K, V)
# F.scaled_dot_product_attention(Q, K, V, scale=1), flashattn(Q, K, V)
# module.flash_attn(Q, K, V)

Launching kernel with smem size20736


  return F.scaled_dot_product_attention(Q, K, V, scale=1)


(tensor([[0.4916, 0.5135, 0.5137,  ..., 0.4970, 0.4800, 0.5180],
         [0.5019, 0.5060, 0.5127,  ..., 0.5001, 0.4732, 0.4964],
         [0.4861, 0.5106, 0.5078,  ..., 0.4990, 0.4702, 0.5149],
         ...,
         [0.5016, 0.5114, 0.5027,  ..., 0.5043, 0.4689, 0.5047],
         [0.4975, 0.5129, 0.5025,  ..., 0.5051, 0.4840, 0.5110],
         [0.4961, 0.5070, 0.5096,  ..., 0.4908, 0.4708, 0.5107]],
        device='cuda:0'),
 tensor([[0.4916, 0.5135, 0.5137,  ..., 0.4970, 0.4800, 0.5180],
         [0.5019, 0.5060, 0.5127,  ..., 0.5001, 0.4732, 0.4964],
         [0.4861, 0.5106, 0.5078,  ..., 0.4990, 0.4702, 0.5149],
         ...,
         [0.5016, 0.5114, 0.5027,  ..., 0.5043, 0.4689, 0.5047],
         [0.4975, 0.5129, 0.5025,  ..., 0.5051, 0.4840, 0.5110],
         [0.4961, 0.5070, 0.5096,  ..., 0.4908, 0.4708, 0.5107]],
        device='cuda:0'))