In [2]:
# make sure to use the right version of pytorch when testing on nvidia/amd
import torch
import torch.nn.functional as F
torch.set_printoptions(sci_mode=False)
# from torch.nn.functional import scaled_dot_product_attention

In [2]:
import math
M = 32


def flashattn(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
    assert Q.shape == K.shape == V.shape
    assert len(Q.shape) == 2

    K = K.T

    # seq length, inner dim of representations
    N, d = Q.shape

    # of kv vectors per tile
    bc = math.ceil(M/(4*d))
    tc = math.ceil(N/bc)
    K_shared = torch.empty((d, bc), dtype=Q.dtype)
    V_shared = torch.empty((bc, d), dtype=Q.dtype)

    # of q vectors per tile
    br = min(math.ceil(M/(4*d)), d)
    tr = math.ceil(N/br)
    Q_shared = torch.empty((br, d), dtype=Q.dtype)

    # print(f'bc={bc}, br={br}')
    # print(f'tc={tc}, tr={tr}')

    # output matrix
    O = torch.zeros_like(Q)
    O_shared = torch.empty_like(Q_shared)

    # intermediate rowmaxes
    m = torch.full((N,), -torch.inf, dtype=Q.dtype)
    m_shared = torch.empty((br, 1), dtype=Q.dtype)
    # intermediate normalization constants
    l = torch.full((N,), 0, dtype=Q.dtype)
    l_shared = torch.empty((br, 1), dtype=Q.dtype)

    for i in range(tc):
        # load k, v chunks
        # make sure we load in k as its transposed version
        K_shared[:, :] = K[:, i*bc:(i+1)*bc]
        V_shared[:, :] = V[i*bc:(i+1)*bc, :]

        for j in range(tr):
            # load in q, o, m, l
            Q_shared[:, :] = Q[j*br:(j+1)*br, :]
            # if i == 0: print(f'Q: {Q_shared}')
            O_shared[:, :] = O[j*br:(j+1)*br, :]
            m_shared[:, :] = m[j*br:(j+1)*br].unsqueeze(-1)
            l_shared[:, :] = l[j*br:(j+1)*br].unsqueeze(-1)
            
            S = Q_shared @ K_shared

            # get row-wise softmax statistics
            mt = S.max(dim=1).values.reshape(-1, 1)

            Pt = torch.exp(S - mt)
            lt = Pt.sum(dim=1).reshape(-1, 1)

            # compute new statistics
            m_new = torch.max(mt, m_shared)
            l_new = (torch.exp(m_shared - m_new) * l_shared) + (torch.exp(mt - m_new) * lt)


            # update chunk of output
            O_new = (l_shared * torch.exp(m_shared - m_new) * O_shared + torch.exp(mt - m_new) * Pt @ V_shared) / l_new
            O[j*br:(j+1)*br, :] = O_new
            
            m[j*br:(j+1)*br] = m_new.flatten()
            l[j*br:(j+1)*br] = l_new.flatten()

    return O






In [6]:
from math import sqrt

def dumb_attn(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
    """equivalent to F.scaled_dot_product_attention(Q, K, V, scale=1)"""
    d = Q.shape[1]
    return torch.softmax((Q)@ K.T, dim=1) @ V
    # return F.scaled_dot_product_attention(Q, K, V, scale=1)


# have to support up to like d=128
# Model dim of 4096 is divided ACROSS HEADS
# we're giving up and only supporting fp16 (like flashattn)
# Q = torch.rand((8196, 128), dtype=torch.float16, device='cuda') 
# K = torch.rand((8196, 128), dtype=torch.float16, device='cuda') 
# V = torch.rand((8196, 128), dtype=torch.float16, device='cuda') 

# K[-1] = torch.zeros((2,))
# V[-1] = torch.zeros((2,))


# Q = torch.randint(1, 9, (32, 16), device='cuda').to(torch.float16)
# K = torch.randint(1, 9, (32, 16), device='cuda').to(torch.float16)
# V = torch.randint(1, 9, (32, 16), device='cuda').to(torch.float16)

Q = torch.rand((64, 128), device='cuda').to(torch.float16)
K = torch.rand((64, 128), device='cuda').to(torch.float16)
V = torch.rand((64, 128), device='cuda').to(torch.float16)

# Q = torch.tensor([[0., 1.],
#          [2., 3.]])

# K = torch.tensor([[0., 1.5],
#          [2., 3.]])

# V = torch.tensor([[1., 0.],
#          [0., 1.]])
Q, K, V


(tensor([[0.5146, 0.6387, 0.1764,  ..., 0.8188, 0.5225, 0.0366],
         [0.8057, 0.5679, 0.5073,  ..., 0.1008, 0.3245, 0.8804],
         [0.2388, 0.1415, 0.2081,  ..., 0.2712, 0.3889, 0.1394],
         ...,
         [0.6426, 0.1974, 0.2449,  ..., 0.7329, 0.0385, 0.7354],
         [0.4426, 0.7339, 0.6108,  ..., 0.8257, 0.0258, 0.1665],
         [0.7559, 0.3784, 0.3472,  ..., 0.0353, 0.7681, 0.1801]],
        device='cuda:0', dtype=torch.float16),
 tensor([[0.3005, 0.1859, 0.6689,  ..., 0.2739, 0.4033, 0.1785],
         [0.9087, 0.4207, 0.3491,  ..., 0.4517, 0.8340, 0.4834],
         [0.3782, 0.1959, 0.6440,  ..., 0.1592, 0.6025, 0.6606],
         ...,
         [0.8877, 0.2012, 0.2751,  ..., 0.4204, 0.0923, 0.9102],
         [0.2922, 0.1763, 0.1906,  ..., 0.6074, 0.1990, 0.7222],
         [0.5889, 0.9321, 0.4036,  ..., 0.7461, 0.0770, 0.1331]],
        device='cuda:0', dtype=torch.float16),
 tensor([[0.9624, 0.7954, 0.6157,  ..., 0.9609, 0.1188, 0.7988],
         [0.4856, 0.3457, 0.234

In [14]:
from torch.utils.cpp_extension import load
module = load(
    name='m',
    sources=['rocm/flash_attention.hip',],
    extra_cflags=['--offload-arch="gfx1100"'],
    build_directory='build',
    verbose=True
)

/home/seb/Code/flash-attention/rocm/flash_attention.hip -> /home/seb/Code/flash-attention/rocm/flash_attention_hip.hip [ok]
Total number of unsupported CUDA function calls: 0


Total number of replaced kernel launches: 3


The input conditions for extension module m have changed. Bumping to version 4 and re-building as m_v4...
[92mSuccessfully preprocessed all matching files.[0m
Detected CUDA files, patching ldflags
Emitting ninja build file build/build.ninja...
Building extension module m_v4...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/2] /opt/rocm-6.1.3/bin/hipcc  -DWITH_HIP -DTORCH_EXTENSION_NAME=m_v4 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/TH -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/THC -isystem /home/seb/Code/pyenvs/rocmenv/lib/python3.12/site-packages/torch/include/THH -isystem /opt/rocm-6.1.3/include -isystem /home/seb/miniconda3/include/python3.12 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 --offload-arch="gfx1100" -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -fno-gpu-rdc -c /home/seb/Code/flash-attention/rocm/flash_at

Loading extension module m_v4...


In [13]:
# torch.isclose(dumb_attn(Q, K, V), flashattn(Q, K, V))
# dumb_attn(Q, K, V), flashattn(Q, K, V)

# F.scaled_dot_product_attention(Q, K, V, scale=1), flashattn(Q, K, V)
# module.flash_attn(Q, K, V)
torch.set_printoptions(profile='default', sci_mode=False)
r1 = dumb_attn(Q, K, V)
r2 = module.flash_attn(Q, K, V)
(r1-r2).abs().mean()
# r1, r2
# Q @ K.T


# # r1, dumb_attn(Q.float(), K.float(), V.float()), r2
# # r1, r2[:50]
# dumb_attn(Q, K, V), module.flash_attn(Q, K, V)

  Shared Memory per Block: 65536
  Shared Memory Banks: 65536
  Warp Size: 32
  Max Threads per Block: 1024
  Constant Memory: 2147483647
  L2 Cache: 6291456

tc=1, tr=2
Launching kernel with smem size 62464
Q block 0:
[[0.5146, 0.6387, 0.1764, 0.5830, 0.3484, 0.3323, 0.6870, 0.2629, 0.4407, 0.1721, 0.9561, 0.4558, 0.4758, 0.7612, 0.5225, 0.7583, 0.3865, 0.2295, 0.2673, 0.5986, 0.7153, 0.7095, 0.6191, 0.7905, 0.6577, 0.9014, 0.6147, 0.5645, 0.8750, 0.2246, 0.8423, 0.2944, 0.9399, 0.3923, 0.3494, 0.7969, 0.2783, 0.0282, 0.9370, 0.0115, 0.8818, 0.0003, 0.9697, 0.1086, 0.3425, 0.1771, 0.3479, 0.1914, 0.5107, 0.8896, 0.4192, 0.4849, 0.8159, 0.3359, 0.5107, 0.4189, 0.3450, 0.9062, 0.7769, 0.6997, 0.8047, 0.2130, 0.1022, 0.3484, 0.5469, 0.5210, 0.7729, 0.5400, 0.4697, 0.7485, 0.0077, 0.8872, 0.7451, 0.1366, 0.3970, 0.5757, 0.9756, 0.3250, 0.0341, 0.4883, 0.9106, 0.9502, 0.8540, 0.6953, 0.8789, 0.6431, 0.4707, 0.6616, 0.4363, 0.4280, 0.6860, 0.4858, 0.8433, 0.3560, 0.8682, 0.4180, 0.7969, 0.5

tensor(0.0001, device='cuda:0', dtype=torch.float16)

In [15]:

def test_flashattn():
    diff = lambda r1, r2: (r1-r2).abs().mean()
    dims = ((48, 32), (64, 128), (1024, 128), (4096, 128))
    for N, d in dims:
        Q = torch.rand((N, d), dtype=torch.float16, device='cuda') 
        K = torch.rand((N, d), dtype=torch.float16, device='cuda') 
        V = torch.rand((N, d), dtype=torch.float16, device='cuda') 

        # print(f'Difference on ({N}, {d}) tensors: {diff(dumb_attn(Q, K, V), F.scaled_dot_product_attention(Q, K, V, scale=1))}\n\n')
        res = module.flash_attn(Q, K, V)
        # print(torch.isnan(res).any())
        # print(res, res.shape)
        print(f'Difference on ({N}, {d}) tensors: {diff(res, F.scaled_dot_product_attention(Q, K, V, scale=1))}\n\n')

        # print(f'Difference on ({N}, {d}) tensors: {diff(module.flash_attn(Q, K, V), dumb_attn(Q, K, V))}\n\n')

torch.set_printoptions(profile='full', sci_mode=False)
test_flashattn()
torch.set_printoptions(profile='default', sci_mode=False)

Difference on (48, 32) tensors: 0.0001538991928100586


Difference on (64, 128) tensors: 0.0005078315734863281


Difference on (1024, 128) tensors: 0.000492095947265625


Difference on (4096, 128) tensors: 0.0007801055908203125




In [41]:
# A = torch.rand((16, 32), dtype=torch.float16, device='cuda')
# B = torch.rand((32, 16), dtype=torch.float16, device='cuda')

A = torch.randint(0, 2, (16, 16), dtype=torch.float16, device='cuda')
B = torch.randint(0, 2, (16, 32), dtype=torch.float16, device='cuda')

# A = torch.eye(16, dtype=torch.float16, device='cuda') 
# B = torch.eye(16, dtype=torch.float16, device='cuda')

# A = torch.ones((16, 16), dtype=torch.float16, device='cuda') 
# B = torch.zeros((16, 16), dtype=torch.float16, device='cuda') 

A, B

(tensor([[1., 0., 0., 1., 1., 1., 1., 0., 1., 0., 1., 1., 1., 0., 0., 1.],
         [0., 0., 0., 0., 1., 1., 0., 1., 0., 1., 1., 1., 0., 1., 0., 0.],
         [0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 1., 1., 0., 1., 1., 0.],
         [1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 1., 0., 0., 1., 1.],
         [0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0., 0., 1., 1., 1.],
         [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1.],
         [1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 0., 1., 0., 1., 1., 0.],
         [1., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 1., 1., 1., 1., 0.],
         [0., 0., 0., 0., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 0.],
         [0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 1., 1., 0.],
         [0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1.],
         [1., 0., 0., 1., 1., 0., 1., 0., 1., 1., 0., 1., 0., 1., 1., 1.],
         [1., 1., 1., 0., 1., 1., 1., 0., 0., 1., 0., 1., 1., 0., 1., 0.],
         [1., 1., 0., 0.,

In [48]:
# print(module.matmul(A, B), (A @ B))
torch.set_printoptions(profile='full', sci_mode=False)
print(torch.allclose(module.matmul(A, B, False), A @ B))
module.matmul(A, B, False) - A @ B

Warp 0 computing tile 0 of matrix, coords (0, 0), will do 1 loops
Warp 0 loading from idx 0 of a and idx 0 of b
Warp 0 computing tile 1 of matrix, coords (0, 1), will do 1 loops
Warp 0 loading from idx 0 of a and idx 16 of b
True
Warp 0 computing tile 0 of matrix, coords (0, 0), will do 1 loops
Warp 0 loading from idx 0 of a and idx 0 of b
Warp 0 computing tile 1 of matrix, coords (0, 1), will do 1 loops
Warp 0 loading from idx 0 of a and idx 16 of b


tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 

In [45]:
Qp = F.pad(Q, (0, 16-4, 0, 32-5))
Kp = F.pad(K, (0, 16-4, 0, 48-5))

# Qp, Kp
print(Qp.shape, Kp.T.shape)
(module.matmul(Qp, Kp.T.contiguous(), False), (Qp @ Kp.T))
# torch.allclose(module.matmul(Qp, Kp.T, False), module.matmul(Qp, Kp, True))

torch.Size([32, 16]) torch.Size([16, 48])
Warp 0 computing tile 0 of matrix, coords (0, 0), will do 1 loops
Warp 0 loading from idx 0 of a and idx 0 of b
Warp 0 computing tile 1 of matrix, coords (0, 1), will do 1 loops
Warp 0 loading from idx 0 of a and idx 16 of b
Warp 0 computing tile 2 of matrix, coords (0, 2), will do 1 loops
Warp 0 loading from idx 0 of a and idx 32 of b
Warp 0 computing tile 3 of matrix, coords (1, 0), will do 1 loops
Warp 0 loading from idx 256 of a and idx 0 of b
Warp 0 computing tile 4 of matrix, coords (1, 1), will do 1 loops
Warp 0 loading from idx 256 of a and idx 16 of b
Warp 0 computing tile 5 of matrix, coords (1, 2), will do 1 loops
Warp 0 loading from idx 256 of a and idx 32 of b


(tensor([[138.,  45.,  47., 115.,  83.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
         [196.,  90.,  74., 156., 118.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
         [129.,  66.,  45., 106., 101.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
         [176

In [9]:
torch.allclose(module.matmul(Qp, Kp.T.contiguous(), False), module.matmul(Qp, Kp, True))

Warp 0 computing tile 0 of matrix, coords (0, 0), will do 1 loops
Warp 0 loading from idx 0 of a and idx 0 of b
Warp 0 computing tile 1 of matrix, coords (0, 1), will do 1 loops
Warp 0 loading from idx 0 of a and idx 16 of b
Warp 0 computing tile 2 of matrix, coords (0, 2), will do 1 loops
Warp 0 loading from idx 0 of a and idx 32 of b
Warp 0 computing tile 3 of matrix, coords (1, 0), will do 1 loops
Warp 0 loading from idx 256 of a and idx 0 of b
Warp 0 computing tile 4 of matrix, coords (1, 1), will do 1 loops
Warp 0 loading from idx 256 of a and idx 16 of b
Warp 0 computing tile 5 of matrix, coords (1, 2), will do 1 loops
Warp 0 loading from idx 256 of a and idx 32 of b
Warp 0 computing tile 0 of matrix, coords (0, 0), will do 1 loops
Warp 0 loading from idx 0 of a and idx 0 of b
Warp 0 computing tile 1 of matrix, coords (0, 1), will do 1 loops
Warp 0 loading from idx 0 of a and idx 16 of b
Warp 0 computing tile 2 of matrix, coords (0, 2), will do 1 loops
Warp 0 loading from idx 0 o

False

In [2]:
# module.matmul(B, Kp, False), module.matmul(B, Kp.T.contiguous(), False)
8192*4

32768