In [1]:
# make sure to use the right version of pytorch when testing on nvidia/amd
import torch
import torch.nn.functional as F

In [2]:
import math
M = 32


def flashattn(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
    assert Q.shape == K.shape == V.shape
    assert len(Q.shape) == 2

    K = K.T

    # seq length, inner dim of representations
    N, d = Q.shape

    # of kv vectors per tile
    bc = math.ceil(M/(4*d))
    tc = math.ceil(N/bc)
    K_shared = torch.empty((d, bc), dtype=Q.dtype)
    V_shared = torch.empty((bc, d), dtype=Q.dtype)

    # of q vectors per tile
    br = min(math.ceil(M/(4*d)), d)
    tr = math.ceil(N/br)
    Q_shared = torch.empty((br, d), dtype=Q.dtype)

    # print(f'bc={bc}, br={br}')
    # print(f'tc={tc}, tr={tr}')

    # output matrix
    O = torch.zeros_like(Q)
    O_shared = torch.empty_like(Q_shared)

    # intermediate rowmaxes
    m = torch.full((N,), -torch.inf, dtype=Q.dtype)
    m_shared = torch.empty((br, 1), dtype=Q.dtype)
    # intermediate normalization constants
    l = torch.full((N,), 0, dtype=Q.dtype)
    l_shared = torch.empty((br, 1), dtype=Q.dtype)

    for i in range(tc):
        # load k, v chunks
        # make sure we load in k as its transposed version
        K_shared[:, :] = K[:, i*bc:(i+1)*bc]
        V_shared[:, :] = V[i*bc:(i+1)*bc, :]

        for j in range(tr):
            # load in q, o, m, l
            Q_shared[:, :] = Q[j*br:(j+1)*br, :]
            # if i == 0: print(f'Q: {Q_shared}')
            O_shared[:, :] = O[j*br:(j+1)*br, :]
            m_shared[:, :] = m[j*br:(j+1)*br].unsqueeze(-1)
            l_shared[:, :] = l[j*br:(j+1)*br].unsqueeze(-1)
            
            S = Q_shared @ K_shared

            # get row-wise softmax statistics
            mt = S.max(dim=1).values.reshape(-1, 1)

            Pt = torch.exp(S - mt)
            lt = Pt.sum(dim=1).reshape(-1, 1)

            # compute new statistics
            m_new = torch.max(mt, m_shared)
            l_new = (torch.exp(m_shared - m_new) * l_shared) + (torch.exp(mt - m_new) * lt)


            # update chunk of output
            O_new = (l_shared * torch.exp(m_shared - m_new) * O_shared + torch.exp(mt - m_new) * Pt @ V_shared) / l_new
            O[j*br:(j+1)*br, :] = O_new
            
            m[j*br:(j+1)*br] = m_new.flatten()
            l[j*br:(j+1)*br] = l_new.flatten()

    return O






In [3]:
def dumb_attn(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
    """equivalent to F.scaled_dot_product_attention(Q, K, V, scale=1)"""
    # P = (Q @ K.T)
    # m = P.max(dim=1).values.reshape(-1, 1)
    # print((P - m).exp())
    # return torch.softmax(Q @ K.T, dim=1) @ V
    return F.scaled_dot_product_attention(Q, K, V, scale=1)


# Q = torch.rand((8, 2), dtype=torch.float32, device='cuda')
# K = torch.rand((8, 2), dtype=torch.float32, device='cuda')
# V = torch.rand((8, 2), dtype=torch.float32, device='cuda')

Q = torch.randint(1, 9, (4, 2), device='cuda').float()
K = torch.randint(1, 9, (4, 2), device='cuda').float()
V = torch.randint(1, 9, (4, 2), device='cuda').float()

# Q = torch.tensor([[0., 1.],
#          [2., 3.]])

# K = torch.tensor([[0., 1.5],
#          [2., 3.]])

# V = torch.tensor([[1., 0.],
#          [0., 1.]])
Q, K, V


(tensor([[3., 5.],
         [8., 8.],
         [6., 8.],
         [4., 1.]], device='cuda:0'),
 tensor([[2., 2.],
         [2., 6.],
         [6., 2.],
         [4., 7.]], device='cuda:0'),
 tensor([[8., 8.],
         [4., 1.],
         [4., 7.],
         [1., 3.]], device='cuda:0'))

In [8]:
from torch.utils.cpp_extension import load
module = load(
    name='m',
    # sources=['cuda/main.cpp', 'cuda/flash_attention.cu'],
    sources=['rocm/flash_attention.hip',],
    extra_cflags=['--offload-arch="gfx1100"',],
    verbose=True
)

/home/seb/Code/flash-attention/rocm/flash_attention.hip -> /home/seb/Code/flash-attention/rocm/flash_attention_hip.hip [skipped, already hipified]
Total number of unsupported CUDA function calls: 0


Total number of replaced kernel launches: 1


Using /home/seb/.cache/torch_extensions/py312_cpu as PyTorch extensions root...
The input conditions for extension module m have changed. Bumping to version 2 and re-building as m_v2...
[92mSuccessfully preprocessed all matching files.[0m
Detected CUDA files, patching ldflags
Emitting ninja build file /home/seb/.cache/torch_extensions/py312_cpu/m/build.ninja...
Building extension module m_v2...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/2] /opt/rocm-6.1.3/bin/hipcc  -DWITH_HIP -DTORCH_EXTENSION_NAME=m_v2 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/TH -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/THC -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/THH -isystem /opt/rocm-6.1.3/include -isystem /home/seb/miniconda3/include/python3.12 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 --offload-arch="gfx1100" -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -fno-gpu-rdc -c /home/seb/Code/flash-attention/rocm/flash_at

Loading extension module m_v2...


In [7]:
# torch.isclose(dumb_attn(Q, K, V), flashattn(Q, K, V))
# dumb_attn(Q, K, V), flashattn(Q, K, V)
dumb_attn(Q, K, V), module.flash_attn(Q, K, V)
# F.scaled_dot_product_attention(Q, K, V, scale=1), flashattn(Q, K, V)
# module.flash_attn(Q, K, V)

sum 16.000000
sum 36.000000
sum 28.000000
sum 47.000000
sum 32.000000
sum 64.000000
sum 64.000000
sum 88.000000
sum 28.000000
sum 60.000000
sum 52.000000
sum 80.000000
sum 10.000000
sum 14.000000
sum 26.000000
sum 23.000000
sum 1.000067
sum 3.000017
sum 1.000000
sum 3.000000
sum 1.000000
sum 3.000000
sum 4.049812
sum 7.149368


  return F.scaled_dot_product_attention(Q, K, V, scale=1)


(tensor([[1.0001, 3.0000],
         [1.0000, 3.0000],
         [1.0000, 3.0000],
         [3.8577, 6.8103]], device='cuda:0'),
 tensor([[1.0001, 3.0000],
         [1.0000, 3.0000],
         [1.0000, 3.0000],
         [3.8577, 6.8103]], device='cuda:0'))