In [75]:
import torch
import triton.language as tl
import triton

In [76]:
sm_scale = 0.2
Z = 64
H = 48
N_CTX = 1024
D_HEAD = 64
dtype=torch.float16

In [77]:
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)

In [78]:
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
    for h in range(H):
        p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p)
ref_out = torch.matmul(p, v)

In [79]:
tri_out = triton.ops.attention(q, k, v, sm_scale)

In [80]:
(ref_out - tri_out).abs().max()

tensor(0.0005, device='cuda:0', dtype=torch.float16)

In [81]:
tri_out.shape

torch.Size([64, 48, 1024, 64])

In [89]:
@triton.jit
def fa_kernel(
    Q, K, V, sm_scale,
    # L, M,
    Out,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_oz, stride_oh, stride_om, stride_on,
    Z, H, N_CTX,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_DMODEL)
    off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
    off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
    off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
    # Initialize pointers to Q, K, V
    q_ptrs = Q + off_q
    k_ptrs = K + off_k
    v_ptrs = V + off_v
    # initialize pointer to m and l
    m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
    q = tl.load(q_ptrs)
    # loop over k, v and update accumulator
    for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
        # -- compute qk ----
        k = tl.load(k_ptrs + start_n * BLOCK_N * stride_kn)
        # qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        # qk += tl.dot(q, k)
        qk = tl.dot(q, k)
        qk *= sm_scale
        qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
        # compute new m
        m_curr = tl.maximum(tl.max(qk, 1), m_prev)
        # correct old l
        # l_prev *= tl.exp(m_prev - m_curr)
        # attention weights
        p = tl.exp(qk - m_curr[:, None])
        l_curr = tl.sum(p, 1) + l_prev
        # rescale operands of matmuls
        l_rcp = 1. / l_curr
        p *= l_rcp
        acc *= (l_prev * l_rcp)[:, None]
        # update acc
        p = p.to(tl.float16)
        v = tl.load(v_ptrs + start_n * BLOCK_N * stride_vk)
        acc += tl.dot(p, v)
        # update m_i and l_i
        # l_prev = l_curr
        # m_prev = m_curr
        # update pointers
        # k_ptrs += BLOCK_N * stride_kn
        # v_ptrs += BLOCK_N * stride_vk
    # rematerialize offsets to save registers
    # start_m = tl.program_id(0)
    # offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # write back l and m
    # l_ptrs = L + off_hz * N_CTX + offs_m
    # m_ptrs = M + off_hz * N_CTX + offs_m
    # tl.store(l_ptrs, l_prev)
    # tl.store(m_ptrs, m_prev)
    # initialize pointers to output
    offs_n = tl.arange(0, BLOCK_DMODEL)
    off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
    out_ptrs = Out + off_o
    tl.store(out_ptrs, acc)

In [90]:
def triton_fa(q, k, v, sm_scale):
    # only support for Ampere now
    capability = torch.cuda.get_device_capability()
    if capability[0] < 8:
        raise RuntimeError("Flash attention currently only supported for compute capability < 80")
    BLOCK = 128
    # shape constraints
    Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
    assert Lq == Lk and Lk == Lv
    # assert Lk in {16, 32, 64, 128}
    assert Lk in {64}  # TODO: fix other cases
    o = torch.empty_like(q)
    grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
    # L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
    # m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
    num_warps = 4 if Lk <= 64 else 8

    fa_kernel[grid](
        q, k, v, sm_scale,
        # L, m,
        o,
        q.stride(0), q.stride(1), q.stride(2), q.stride(3),
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        o.stride(0), o.stride(1), o.stride(2), o.stride(3),
        q.shape[0], q.shape[1], q.shape[2],
        BLOCK_M=BLOCK, BLOCK_N=BLOCK,
        BLOCK_DMODEL=Lk, num_warps=num_warps,
        num_stages=2,
    )

    return o

In [91]:
fa_out = triton_fa(q, k, v, sm_scale)

In [85]:
(ref_out - fa_out).abs().max()

tensor(1.0293, device='cuda:0', dtype=torch.float16)

# Benchmarking

In [86]:
from triton.testing import do_bench
from functools import partial

In [87]:
do_bench(partial(triton.ops.attention, q, k, v, sm_scale), warmup=100, rep=100)[0] * 1000 # us

4307.456016540527

In [93]:
do_bench(partial(triton_fa, q, k, v, sm_scale), warmup=100, rep=1000)[0] * 1000 # us

3777.535915374756