In [1]:
import torch
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor

DEVICE = triton.runtime.driver.active.get_active_torch_device()

def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"

def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"

def supports_host_descriptor():
    return is_cuda() and torch.cuda.get_device_capability()[0] >= 9

def is_blackwell():
    return is_cuda() and torch.cuda.get_device_capability()[0] == 10

def is_hopper():
    return is_cuda() and torch.cuda.get_device_capability()[0] == 9

def get_num_stages_options():
    if is_hip():
        return [1]
    elif supports_host_descriptor():
        return [2, 3, 4]
    else:
        return [2, 3, 4]

def _host_descriptor_pre_hook(nargs):
    BLOCK_M = nargs["BLOCK_M"]
    BLOCK_N = nargs["BLOCK_N"]
    HEAD_DIM = nargs["HEAD_DIM"]
    if not isinstance(nargs["desc_q"], TensorDescriptor):
        return
    nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM]
    if nargs["FP8_OUTPUT"]:
        nargs["desc_v"].block_shape = [HEAD_DIM, BLOCK_N]
    else:
        nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM]
    nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM]
    nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM]
    
configs = [
    triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_host_descriptor_pre_hook) \
    for BM in [64, 128]\
    for BN in [32, 64, 128]\
    for s in get_num_stages_options() \
    for w in [4, 8]\
]

def keep(conf):
    BLOCK_M = conf.kwargs["BLOCK_M"]
    BLOCK_N = conf.kwargs["BLOCK_N"]
    if BLOCK_M % BLOCK_N != 0:
        return False
    return not (is_cuda() and torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128
                and conf.num_warps == 8)
    
def prune_invalid_configs(configs, named_args, **kwargs):
    S = kwargs["S"]

    # Filter out configs where BLOCK_M > N_CTX
    return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= S]

In [2]:
@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q,  #
                    desc_k, desc_v,  #
                    offset_y, dtype: tl.constexpr, start_m, qk_scale,  #
                    BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,  #
                    STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,  #
                    N_CTX: tl.constexpr, warp_specialize: tl.constexpr, IS_HOPPER: tl.constexpr):
    if STAGE == 1:
        # causal stage 1
        lo, hi = 0, start_m * BLOCK_M # all preceed tokens is another blockN
    elif STAGE == 2:
        # causal stage 2
        lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
        lo = tl.multiple_of(lo, BLOCK_M)
    else:
        # causal false
        lo, hi = 0, N_CTX

    offset_ky = offset_y + lo
    if dtype == tl.float8e5:
        offset_vy =  HEAD_DIM * offset_y + lo
    else:
        offset_vy = offset_y + lo

    # loop over k, v and update accumulator
    for start_n in range(lo, hi, BLOCK_N, warp_specialize=WARP_SPECIALIZE):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        k = desc_k.load([offset_ky,0]).T
        qk = tl.dot(q,k)
        if STAGE == 2:
            mask = offs_m[:, None] >= (start_n+offs_n[None,:])
            qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
            m_ij = tl.maximum(m_i, tl.max(qk, 1))
            qk -= m_ij[:,None]
        else:
            qk = qk * qk_scale
            m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
            qk -= m_ij[:, None]
        
        p = tl.math.exp2(qk)
        # -- compute correction factor
        alpha = tl.math.exp2(m_i - m_ij)
        l_ij = tl.sum(p, 1)
        
        acc = acc * alpha[:, None]
        # prepare p and v for the dot
        if dtype == tl.float8e5:
            v = desc_v.load([0, offsetv_y]).T
        else:
            v = desc_v.load([offsetv_y, 0])
        p = p.to(dtype)
        # note that this non transposed v for FP8 is only supported on Blackwell
        acc = tl.dot(p, v, acc)
        # update m_i and l_i
        # place this at the end of the loop to reduce register pressure
        l_i = l_i * alpha + l_ij
        m_i = m_ij
        offsetk_y += BLOCK_N
        offsetv_y += BLOCK_N
    return acc, l_i, m_i


@triton.jit
def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
    if isinstance(desc_or_ptr, tl.tensor_descriptor):
        return desc_or_ptr
    else:
        return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape)


@triton.autotune(configs=list(filter(keep, configs)), key=["S", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"],
                 prune_configs_by={'early_config_prune': prune_invalid_configs})
@triton.jit
def _attn_fwd_kernel(
    sm_scale,
    M,
    Bs,
    NH,
    q_desc,
    k_desc,
    v_desc,
    o_desc,
    S,
    HEAD_DIM:tl.constexpr,
    BLOCK_M:tl.constexpr,
    BLOCK_N:tl.constexpr,
    FP8_OUTPUT:tl.constexpr,
    STAGE:tl.constexpr,
    warp_specialize:tl.constexpr,
    IS_HOPPER:tl.constexpr,
):
    dtype = tl.float8e5 if FP8_OUTPUT else tl.float16
    tl.static_assert(BLOCK_N <= HEAD_DIM)
    
    start_m = tl.program_id(axis=0)
    offset_bsh = tl.program_id(axis=1)
    offset_bs = offset_bsh // NH
    offset_head = offset_bsh % NH
    
    y_dim = Bs * NH * S
    q_desc = _maybe_make_tensor_desc(q_desc, shape=[y_dim, HEAD_DIM], strides=[y_dim, 1], block_shape=[BLOCK_M, HEAD_DIM])
    k_desc = _maybe_make_tensor_desc(k_desc, shape=[y_dim, HEAD_DIM], strides=[y_dim, 1], block_shape=[BLOCK_N, HEAD_DIM])
    o_desc = _maybe_make_tensor_desc(o_desc, shape=[y_dim, HEAD_DIM], strides=[y_dim, 1], block_shape=[BLOCK_M, HEAD_DIM])
    if FP8_OUTPUT:
        v_desc = _maybe_make_tensor_desc(v_desc, shape=[HEAD_DIM, y_dim], strides=[S, 1], block_shape=[HEAD_DIM, BLOCK_N])
    else:
        v_desc = _maybe_make_tensor_desc(v_desc, shape=[y_dim, HEAD_DIM], strides=[y_dim, 1], block_shape=[BLOCK_N, HEAD_DIM])
    
    offset_y = offset_bsh * S
    qo_offset_y = offset_y + start_m * BLOCK_M
    
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)

    qk_scale = sm_scale
    qk_scale *= 1.44269504  # 1/log(2)
    
    q = q_desc.load([qo_offset_y, 0])
    
    if STAGE & 1:
        acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q,  #
                                        desc_k, desc_v,  #
                                        offset_y, dtype, start_m, qk_scale,  #
                                        BLOCK_M, HEAD_DIM, BLOCK_N,  #
                                        4 - STAGE, offs_m, offs_n, N_CTX,  #
                                        warp_specialize, IS_HOPPER)
    # stage 2: on-band
    if STAGE & 2:
        acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q,  #
                                        desc_k, desc_v,  #
                                        offset_y, dtype, start_m, qk_scale,  #
                                        BLOCK_M, HEAD_DIM, BLOCK_N,  #
                                        2, offs_m, offs_n, N_CTX,  #
                                        warp_specialize, IS_HOPPER)
        
    # epilogue
    acc = acc / l_i[:, None]
    desc_o.store([qo_offset_y, 0], acc.to(dtype))
    

In [3]:
def attn_fwd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool, sm_scale: float, warp_specialize: bool = False):
    """
    q.shape: (Bs, NHead, SeqLen, HeadDim)
    """
    assert q.dim() == 4
    Bs, NH, S, HEAD_DIM = q.shape
    HEAD_DIM_K, HEAD_DIM_V = k.shape[-1], v.shape[-1]
    assert HEAD_DIM == HEAD_DIM_K == HEAD_DIM_V
    assert HEAD_DIM_K in {16, 32, 64, 128, 256}

    o = torch.empty_like(q)
    stage = 3 if causal else 1

    M = torch.empty((Bs, NH, S), device=q.device, dtype=torch.float32)
    if supports_host_descriptor() and not (is_hopper() and warp_specialize):
        y_dim = Bs * NH * S
        dummy_shape = [1, 1]
        q_desc = TensorDescriptor(q, shape=[y_dim, HEAD_DIM], strides=[y_dim, 1], block_shape=dummy_shape)
        k_desc = TensorDescriptor(k, shape=[y_dim, HEAD_DIM], strides=[y_dim, 1], block_shape=dummy_shape)
        o_desc = TensorDescriptor(o, shape=[y_dim, HEAD_DIM], strides=[y_dim, 1], block_shape=dummy_shape)
        if q.dtype == torch.float8_e5m2:
            v_desc = TensorDescriptor(v, shape=[HEAD_DIM, y_dim], strides=[S, 1], block_shape=dummy_shape)
        else:
            v_desc = TensorDescriptor(v, shape=[y_dim, HEAD_DIM], strides=[y_dim, 1], block_shape=dummy_shape)
    else:
        q_desc = q
        k_desc = k
        v_desc = v
        o_desc = o

    def alloc_fn(size: int, align: int, _):
        return torch.empty(size, dtype=torch.int8, device="cuda")
    triton.set_allocator(alloc_fn)
    
    grid = lambda META: (triton.cdiv(S, META["BLOCK_M"]), Bs * NH, 1)
    _attn_fwd_kernel[grid](
        sm_scale,
        M,
        Bs,
        NH,
        q_desc,
        k_desc,
        v_desc,
        o_desc,
        S=S,
        HEAD_DIM=HEAD_DIM,
        FP8_OUTPUT=q.dtype == torch.float8_e5m2,
        STAGE=stage,
        warp_specialize=warp_specialize,  #
        IS_HOPPER=is_hopper(),  #
    )
    return o