In [1]:
# make sure to use the right version of pytorch when testing on nvidia/amd
import torch
import torch.nn.functional as F
torch.set_printoptions(sci_mode=False)

In [2]:
import math
M = 32


def flashattn(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
    assert Q.shape == K.shape == V.shape
    assert len(Q.shape) == 2

    K = K.T

    # seq length, inner dim of representations
    N, d = Q.shape

    # of kv vectors per tile
    bc = math.ceil(M/(4*d))
    tc = math.ceil(N/bc)
    K_shared = torch.empty((d, bc), dtype=Q.dtype)
    V_shared = torch.empty((bc, d), dtype=Q.dtype)

    # of q vectors per tile
    br = min(math.ceil(M/(4*d)), d)
    tr = math.ceil(N/br)
    Q_shared = torch.empty((br, d), dtype=Q.dtype)

    # print(f'bc={bc}, br={br}')
    # print(f'tc={tc}, tr={tr}')

    # output matrix
    O = torch.zeros_like(Q)
    O_shared = torch.empty_like(Q_shared)

    # intermediate rowmaxes
    m = torch.full((N,), -torch.inf, dtype=Q.dtype)
    m_shared = torch.empty((br, 1), dtype=Q.dtype)
    # intermediate normalization constants
    l = torch.full((N,), 0, dtype=Q.dtype)
    l_shared = torch.empty((br, 1), dtype=Q.dtype)

    for i in range(tc):
        # load k, v chunks
        # make sure we load in k as its transposed version
        K_shared[:, :] = K[:, i*bc:(i+1)*bc]
        V_shared[:, :] = V[i*bc:(i+1)*bc, :]

        for j in range(tr):
            # load in q, o, m, l
            Q_shared[:, :] = Q[j*br:(j+1)*br, :]
            # if i == 0: print(f'Q: {Q_shared}')
            O_shared[:, :] = O[j*br:(j+1)*br, :]
            m_shared[:, :] = m[j*br:(j+1)*br].unsqueeze(-1)
            l_shared[:, :] = l[j*br:(j+1)*br].unsqueeze(-1)
            
            S = Q_shared @ K_shared

            # get row-wise softmax statistics
            mt = S.max(dim=1).values.reshape(-1, 1)

            Pt = torch.exp(S - mt)
            lt = Pt.sum(dim=1).reshape(-1, 1)

            # compute new statistics
            m_new = torch.max(mt, m_shared)
            l_new = (torch.exp(m_shared - m_new) * l_shared) + (torch.exp(mt - m_new) * lt)


            # update chunk of output
            O_new = (l_shared * torch.exp(m_shared - m_new) * O_shared + torch.exp(mt - m_new) * Pt @ V_shared) / l_new
            O[j*br:(j+1)*br, :] = O_new
            
            m[j*br:(j+1)*br] = m_new.flatten()
            l[j*br:(j+1)*br] = l_new.flatten()

    return O






In [23]:
from math import sqrt
def dumb_attn(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
    """equivalent to F.scaled_dot_product_attention(Q, K, V, scale=1)"""
    # P = (Q @ K.T)
    # m = P.max(dim=1).values.reshape(-1, 1)
    # print((P - m).exp())
    # return torch.softmax(Q @ K.T, dim=1) @ V
    d = Q.shape[1]
    return F.scaled_dot_product_attention(Q, K, V, scale=d**-0.5)


# have to support up to like d=4096
# we're giving up and only supporting fp16 (like flashattn)
Q = torch.rand((8, 128), dtype=torch.float16, device='cuda') * 10
K = torch.rand((8, 128), dtype=torch.float16, device='cuda') * 10
V = torch.rand((8, 128), dtype=torch.float16, device='cuda') * 10


# Q = torch.randint(1, 9, (8, 4), device='cuda').to(torch.float16)
# K = torch.randint(1, 9, (8, 4), device='cuda').to(torch.float16)
# V = torch.randint(1, 9, (8, 4), device='cuda').to(torch.float16)

# Q = torch.tensor([[0., 1.],
#          [2., 3.]])

# K = torch.tensor([[0., 1.5],
#          [2., 3.]])

# V = torch.tensor([[1., 0.],
#          [0., 1.]])
Q, K, V


(tensor([[5.7188, 1.9053, 0.9224,  ..., 2.7441, 8.8125, 5.1797],
         [9.4453, 1.5811, 3.0957,  ..., 2.2012, 4.0469, 3.1641],
         [5.3516, 2.8887, 4.3047,  ..., 3.9160, 0.3467, 9.6953],
         ...,
         [1.6982, 7.3516, 5.7500,  ..., 5.4141, 8.0547, 2.2109],
         [7.9922, 7.8477, 9.7422,  ..., 1.3555, 9.0938, 5.0977],
         [5.7188, 6.3359, 8.4141,  ..., 6.4102, 0.7988, 9.5391]],
        device='cuda:0', dtype=torch.float16),
 tensor([[6.0898, 5.8828, 2.2285,  ..., 9.0312, 0.1387, 5.3438],
         [4.9258, 2.2617, 8.1641,  ..., 8.6641, 2.4609, 0.3540],
         [9.9297, 1.9629, 9.6797,  ..., 1.0459, 9.9609, 2.9980],
         ...,
         [1.0469, 7.7656, 8.8984,  ..., 1.7734, 1.9453, 0.4707],
         [8.2578, 5.8555, 5.2539,  ..., 4.5078, 8.3281, 8.9531],
         [5.5117, 4.9258, 3.1738,  ..., 9.5859, 6.6758, 2.5801]],
        device='cuda:0', dtype=torch.float16),
 tensor([[5.6250, 2.0410, 8.8438,  ..., 5.9688, 8.1719, 9.8750],
         [3.6934, 6.8398, 6.113

In [24]:
from torch.utils.cpp_extension import load
module = load(
    name='m',
    # sources=['cuda/main.cpp', 'cuda/flash_attention.cu'],
    sources=['rocm/flash_attention.hip',],
    extra_cflags=['--offload-arch="gfx1100"',],
    build_directory='build',
    verbose=True
)

/home/seb/Code/flash-attention/rocm/flash_attention.hip -> /home/seb/Code/flash-attention/rocm/flash_attention_hip.hip [ok]
Total number of unsupported CUDA function calls: 0


Total number of replaced kernel launches: 1


The input conditions for extension module m have changed. Bumping to version 2 and re-building as m_v2...
[92mSuccessfully preprocessed all matching files.[0m
Detected CUDA files, patching ldflags
Emitting ninja build file build/build.ninja...
Building extension module m_v2...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/2] /opt/rocm-6.1.3/bin/hipcc  -DWITH_HIP -DTORCH_EXTENSION_NAME=m_v2 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/TH -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/THC -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/THH -isystem /opt/rocm-6.1.3/include -isystem /home/seb/miniconda3/include/python3.12 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 --offload-arch="gfx1100" -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -fno-gpu-rdc -c /home/seb/Code/flash-attention/rocm/flash_at

Loading extension module m_v2...


In [25]:
# torch.isclose(dumb_attn(Q, K, V), flashattn(Q, K, V))
# dumb_attn(Q, K, V), flashattn(Q, K, V)

# F.scaled_dot_product_attention(Q, K, V, scale=1), flashattn(Q, K, V)
# module.flash_attn(Q, K, V)
r1 = dumb_attn(Q, K, V)
r2 = module.flash_attn(Q, K, V)
(r1-r2).abs().mean()
r1, r2

  Shared Memory per Block: 64 KB
  Shared Memory Banks: 65536
  Warp Size: 32
  Max Threads per Block: 1024
  Max Threads per Multiprocessor: 2048

tc=1, tr=1
scaling QK^T by 0.0883789
Launching kernel with smem size 8320
[[0.5054, 0.1683, 0.0815, 0.0861, 0.0204, 0.2272, 0.0088, 0.5264, 0.5474, 0.2737, 0.6104, 0.6304, 0.5884, 0.2969, 0.2302, 0.4695, 0.1288, 0.4771, 0.5322, 0.2292, 0.1032, 0.8672, 0.6909, 0.6782, 0.8740, 0.5117, 0.8760, 0.3018, 0.6792, 0.0432, 0.6548, 0.4841, 0.5815, 0.2654, 0.1848, 0.7524, 0.1581, 0.6875, 0.0547, 0.4585, 0.2417, 0.6685, 0.8350, 0.8706, 0.2649, 0.6147, 0.1885, 0.6489, 0.5967, 0.6831, 0.2634, 0.6802, 0.5073, 0.0037, 0.4426, 0.0249, 0.6099, 0.1449, 0.6792, 0.8193, 0.1809, 0.5127, 0.4836, 0.5000, 0.5903, 0.6416, 0.7749, 0.1964, 0.0405, 0.4458, 0.6196, 0.0450, 0.7866, 0.3191, 0.7090, 0.0061, 0.8452, 0.2219, 0.0727, 0.4958, 0.0664, 0.1964, 0.2603, 0.7485, 0.0652, 0.7759, 0.3967, 0.5562, 0.7466, 0.5923, 0.5088, 0.7300, 0.7759, 0.1827, 0.0218, 0.0089, 0.7554, 

(tensor([[0.7861, 4.6562, 2.3203,  ..., 4.3398, 7.5078, 9.5938],
         [0.7861, 4.6562, 2.3203,  ..., 4.3398, 7.5078, 9.5938],
         [0.7891, 4.6562, 2.3262,  ..., 4.3438, 7.5039, 9.5938],
         ...,
         [0.7861, 4.6562, 2.3203,  ..., 4.3398, 7.5078, 9.5938],
         [0.7861, 4.6562, 2.3203,  ..., 4.3398, 7.5078, 9.5938],
         [0.7861, 4.6562, 2.3203,  ..., 4.3398, 7.5078, 9.5938]],
        device='cuda:0', dtype=torch.float16),
 tensor([[0.7861, 4.6562, 2.3203,  ..., 4.3398, 7.5078, 9.5938],
         [0.7861, 4.6562, 2.3203,  ..., 4.3398, 7.5078, 9.5938],
         [0.7886, 4.6562, 2.3262,  ..., 4.3438, 7.5000, 9.5938],
         ...,
         [0.7861, 4.6562, 2.3203,  ..., 4.3398, 7.5078, 9.5938],
         [0.7861, 4.6562, 2.3203,  ..., 4.3398, 7.5078, 9.5938],
         [0.7861, 4.6562, 2.3203,  ..., 4.3398, 7.5078, 9.5938]],
        device='cuda:0', dtype=torch.float16))

In [6]:
a = torch.randint(1, 9, (5,)).float()
b = torch.hstack((a, torch.zeros((5,)).float()))
a, b

(tensor([4., 4., 3., 1., 5.]),
 tensor([4., 4., 3., 1., 5., 0., 0., 0., 0., 0.]))

In [7]:
a.softmax(0), b.softmax(0)

(tensor([0.1947, 0.1947, 0.0716, 0.0097, 0.5293]),
 tensor([0.1913, 0.1913, 0.0704, 0.0095, 0.5200, 0.0035, 0.0035, 0.0035, 0.0035,
         0.0035]))

In [8]:
a = torch.tensor([65.000000, 42.000000, 80.000000, 84.000000], dtype=torch.float32)
(a - a.max()).exp()

tensor([    0.0000,     0.0000,     0.0183,     1.0000])