384

In [7]:
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn import functional as F
import torch
B, N, D, H, d =  32,  8,  384, 6,  64
qkv = torch.randn(B, N, D * 3) 
qkv = qkv.reshape(B, N, 3, H, d).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)

q.is_contiguous()        # False (view)
q.stride(-1) == 1        # True  ✅ last-dim contiguous

with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
    F.scaled_dot_product_attention(q, k, v)  # raises if layout not supported by Fla


In [1]:
import sys, os, torch
os.environ["COMET_API_KEY"] = "R7OuT6FolA02VmQRI82xDN48O"
sys.path.insert(0, "/notebooks/pytorch-image-models")  # your fork
import train as timm_train
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(False)
torch.set_float32_matmul_precision("medium")
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

dynamo_config = torch._dynamo.config
dynamo_config.compiled_autograd = True
dynamo_config.capture_scalar_outputs = False
dynamo_config.cache_size_limit = 512


# Triton

In [1]:

import torch
from torch import Tensor
import triton
import triton.language as tl
from torch import nn
from torch.nn import functional as F

GROUP_NM_SWEEP = [4]
NUM_STAGES_SWEEP = [7]
NUM_WARPS_SWEEP = [8]
KEY_CACHE = ["BATCH_SIZE", "NUM_HEADS", "SEQ_LEN", "HEAD_DIM"]

def _sdpa_comp_dtype(x: torch.Tensor) -> torch.dtype:
    return torch.get_autocast_dtype('cuda') if torch.is_autocast_enabled() else x.dtype

def _triton_compute_dtype(dtype: torch.dtype):
    if dtype is torch.float16:
        return tl.float16
    if dtype is torch.bfloat16:
        return tl.bfloat16
    if dtype is torch.float32:
        return tl.float32
    raise ValueError(f"Unsupported compute dtype for Triton SDPA: {dtype}")

@triton.jit
def _attn_fwd_inner(
    O_block, l_i, m_i, Q_block,
    K_block_ptr, V_block_ptr,
    softmax_scale: tl.constexpr, BLOCK_KV: tl.constexpr,
    SEQ_LEN: tl.constexpr, DTYPE: tl.constexpr,
):
    s = tl.full([1], softmax_scale, dtype=DTYPE)
    Q_block = Q_block * s
    offs_kv = tl.arange(0, BLOCK_KV)
    for start_kv in range(0, SEQ_LEN, BLOCK_KV):
        K_block = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option="zero")
        S = tl.dot(Q_block, K_block) 

        kv_idx  = start_kv + offs_kv
        kv_valid = kv_idx < SEQ_LEN
        S = tl.where(kv_valid[None, :], S, -float("inf"))

        m_ij = tl.maximum(m_i, tl.max(S, axis=1))
        P_block = tl.exp(S - m_ij[:, None])
        l_ij = tl.sum(P_block, axis=1)

        alpha = tl.exp(m_i - m_ij)
        l_i = l_i * alpha + l_ij

        V_block = tl.load(V_block_ptr, boundary_check=(0, 1), padding_option="zero")
        P_block = P_block.to(DTYPE)
        V_block = V_block.to(DTYPE)

        O_block = O_block * alpha[:, None]
        O_block = tl.dot(P_block, V_block, O_block)

        m_i = m_ij
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_KV, 0))
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_KV))
    
    O_block = O_block / l_i[:, None]
    return O_block, l_i, m_i

@triton.autotune(
    [
        triton.Config(
            {"BLOCK_Q": BLOCK_Q, "BLOCK_KV": BLOCK_KV, "GROUP_M": GROUP_M},
            num_stages=num_stages,
            num_warps=num_warps,
        )
        for BLOCK_Q in [64, 128]
        for BLOCK_KV in [32, 64]
        for GROUP_M in GROUP_NM_SWEEP
        for num_stages in NUM_STAGES_SWEEP
        for num_warps in NUM_WARPS_SWEEP
    ],
    key=KEY_CACHE,
)
@triton.jit
def _attn_fwd(
    Q, K, V, M, O,
    # Q strides
    sqb, sqh, sqs, sqd,
    # K strides
    skb, skh, sks, skd,
    # V strides
    svb, svh, svs, svd,
    # O strides
    sob, soh, sos, sod,
    # dK strides
    NUM_HEADS: tl.constexpr, SEQ_LEN: tl.constexpr, HEAD_DIM: tl.constexpr,
    softmax_scale:tl.constexpr, BLOCK_Q: tl.constexpr, BLOCK_KV: tl.constexpr, 
    DTYPE: tl.constexpr, GROUP_M: tl.constexpr,
):
    tl.static_assert(BLOCK_KV <= HEAD_DIM)

    # --- program ids ---
    pid_m  = tl.program_id(0)
    pid_bh = tl.program_id(1)

    num_tiles_m   = tl.cdiv(SEQ_LEN, BLOCK_Q)                       # ceil_div
    group_id      = pid_m // GROUP_M
    tiles_in_this = tl.minimum(GROUP_M, num_tiles_m - group_id*GROUP_M)

    m_in_grp      = pid_m - group_id*GROUP_M                        # 0..GROUP_M-1
    m_in_grp_eff  = m_in_grp % tiles_in_this                        # clamp to tail size
    rot           = pid_bh % tiles_in_this
    m_swizzled    = group_id*GROUP_M + ((m_in_grp_eff + rot) % tiles_in_this)

    start_q       = m_swizzled * BLOCK_Q
    if start_q >= SEQ_LEN:
        return

    b = pid_bh // NUM_HEADS
    h  = pid_bh %  NUM_HEADS

    off_bh_k  = (b * skb   + h * skh  ).to(tl.int64)
    off_bh_v  = (b * svb   + h * svh  ).to(tl.int64)
    off_bh_q  = (b * sqb   + h * sqh  ).to(tl.int64)
    off_bh_o = (b * sob   + h * soh  ).to(tl.int64)
    
    # --- block pointers ---
    Q_block_ptr = tl.make_block_ptr(
        Q + off_bh_q, (SEQ_LEN, HEAD_DIM), (sqs, sqd), (start_q, 0), (BLOCK_Q, HEAD_DIM), (1, 0)
    )
    V_block_ptr = tl.make_block_ptr(
        V + off_bh_v, (SEQ_LEN, HEAD_DIM), (svs, svd), (0, 0), (BLOCK_KV, HEAD_DIM), (1, 0)
    )
    K_block_ptr = tl.make_block_ptr(
        K + off_bh_k, (HEAD_DIM, SEQ_LEN), (skd, sks), (0, 0), (HEAD_DIM, BLOCK_KV), (0, 1)
    )
    O_block_ptr = tl.make_block_ptr(
        O + off_bh_o, (SEQ_LEN, HEAD_DIM), (sos, sod), (start_q, 0), (BLOCK_Q, HEAD_DIM), (1, 0)
    )

    # --- per-row running stats + output tile ---
    m_i = tl.full((BLOCK_Q,), -float("inf"), dtype=tl.float32)
    l_i = tl.full((BLOCK_Q,),  1,          dtype=tl.float32)
    O_block = tl.zeros([BLOCK_Q, HEAD_DIM], dtype=tl.float32)
    Q_block = tl.load(Q_block_ptr, boundary_check=(0, 1), padding_option="zero")

    # --- inner loop over KV tiles (online softmax) ---
    O_block, l_i, m_i = _attn_fwd_inner(
        O_block, l_i, m_i, Q_block,
        K_block_ptr, V_block_ptr, softmax_scale,
        BLOCK_KV, SEQ_LEN, DTYPE
    )

    # --- write back: store log-sum-exp (for bwd) and O ---
    offs_q  = start_q + tl.arange(0, BLOCK_Q)
    m_i += tl.math.log(l_i + 1e-20)
    m_ptrs = M + pid_bh * SEQ_LEN + offs_q
    tl.store(m_ptrs, m_i, mask=offs_q < SEQ_LEN)
    tl.store(O_block_ptr, O_block.to(O.type.element_ty), boundary_check=(0, 1))

@triton.autotune(
    [triton.Config({"BLOCK_Q": bq}, num_stages=ns, num_warps=nw)
     for bq in [32, 64, 128]
     for ns in NUM_STAGES_SWEEP
     for nw in NUM_WARPS_SWEEP],
    key=KEY_CACHE,
)
@triton.jit
def _attn_bwd_preprocess(
    O, dO, D,
    sOb, sOh, sOs, sOd,          # O strides
    sdb, sdh, sds, sdd,          # dO strides
    NUM_HEADS: tl.constexpr, SEQ_LEN: tl.constexpr,
    BLOCK_Q: tl.constexpr, HEAD_DIM: tl.constexpr,
):
    pid_q  = tl.program_id(0)                          # Q-tile id
    pid_bh = tl.program_id(1)                          # packed (batch, head)
    start_q = pid_q * BLOCK_Q
    if start_q >= SEQ_LEN:
        return

    b = pid_bh // NUM_HEADS
    h = pid_bh %  NUM_HEADS
    off_bh_O  = (b * sOb  + h * sOh ).to(tl.int64)
    off_bh_dO = (b * sdb  + h * sdh ).to(tl.int64)

    # use block_ptr so arbitrary strides are OK
    O_blk = tl.make_block_ptr(
        O + off_bh_O, (SEQ_LEN, HEAD_DIM), (sOs, sOd),
        (start_q, 0), (BLOCK_Q, HEAD_DIM), (1, 0)
    )
    dO_blk = tl.make_block_ptr(
        dO + off_bh_dO, (SEQ_LEN, HEAD_DIM), (sds, sdd),
        (start_q, 0), (BLOCK_Q, HEAD_DIM), (1, 0)
    )

    O_block  = tl.load(O_blk,  boundary_check=(0, 1), padding_option="zero").to(tl.float32)
    dO_block = tl.load(dO_blk, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
    D_block  = tl.sum(dO_block * O_block, axis=1)

    offs_q = start_q + tl.arange(0, BLOCK_Q)
    tl.store(D + pid_bh * SEQ_LEN + offs_q, D_block, mask=offs_q < SEQ_LEN)


@triton.autotune(
    [
        triton.Config(
            {"BLOCK_Q": BLOCK_Q, "BLOCK_KV": BLOCK_KV, "GROUP_N": GROUP_N},
            num_stages=num_stages,
            num_warps=num_warps,
        )
        for BLOCK_Q in [32, 64]
        for BLOCK_KV in [64, 128]
        for GROUP_N in GROUP_NM_SWEEP
        for num_stages in NUM_STAGES_SWEEP
        for num_warps in NUM_WARPS_SWEEP
    ],
    key=KEY_CACHE,
)
@triton.jit
def _attn_bwd_dk_dv(
    Q, K, V, dO, dK, dV, M, D,
    # Q strides
    sqb, sqh, sqs, sqd,
    # K strides
    skb, skh, sks, skd,
    # V strides
    svb, svh, svs, svd,
    # dO strides
    sob, soh, sos, sod,
    # dK strides
    s_dkb, s_dkh, s_dks, s_dkd,
    # dV strides
    s_dvb, s_dvh, s_dvs, s_dvd,
    NUM_HEADS: tl.constexpr, SEQ_LEN: tl.constexpr,
    BLOCK_Q: tl.constexpr, BLOCK_KV: tl.constexpr, softmax_scale: tl.constexpr,
    HEAD_DIM: tl.constexpr, DTYPE: tl.constexpr, GROUP_N: tl.constexpr
):
    # --- program ids ---
    pid_kv = tl.program_id(0)                 # which KV block
    pid_bh = tl.program_id(1)                 # packed (batch, head)
    b = pid_bh // NUM_HEADS
    h = pid_bh %  NUM_HEADS

    # --- base offsets for this (batch, head) slice ---
    off_bh_seq = (pid_bh * SEQ_LEN).to(tl.int64)
    M  += off_bh_seq
    D  += off_bh_seq

    num_tiles_kv = tl.cdiv(SEQ_LEN, BLOCK_KV)
    group_id     = pid_kv // GROUP_N
    group_start  = group_id * GROUP_N
    if group_start >= num_tiles_kv:
        return
    
    tiles_in_this = tl.minimum(GROUP_N, num_tiles_kv - group_start)
    kv_in_grp     = pid_kv - group_start
    kv_eff        = kv_in_grp % tiles_in_this
    rot           = pid_bh % tiles_in_this
    kv_tile_id    = group_start + ((kv_eff + rot) % tiles_in_this)

    start_kv = kv_tile_id * BLOCK_KV
    if start_kv >= SEQ_LEN:
        return

    off_bh_k  = (b * skb   + h * skh  ).to(tl.int64)
    off_bh_v  = (b * svb   + h * svh  ).to(tl.int64)
    off_bh_dk = (b * s_dkb + h * s_dkh).to(tl.int64)
    off_bh_dv = (b * s_dvb + h * s_dvh).to(tl.int64)
    off_bh_q  = (b * sqb   + h * sqh  ).to(tl.int64)
    off_bh_do = (b * sob   + h * soh  ).to(tl.int64)
    
    K_blk = tl.make_block_ptr( 
        K + off_bh_k, (SEQ_LEN, HEAD_DIM), (sks, skd),(start_kv, 0),(BLOCK_KV, HEAD_DIM),(1, 0)
    ) # base,        shape,               strides,                  offsets,       block_shape,          order
    V_blk = tl.make_block_ptr( 
        V + off_bh_v,(SEQ_LEN, HEAD_DIM),(svs, svd),(start_kv, 0),(BLOCK_KV, HEAD_DIM),(1, 0)
    )
    dK_blk = tl.make_block_ptr( 
        dK + off_bh_dk, (SEQ_LEN, HEAD_DIM), (s_dks, s_dkd), (start_kv, 0), (BLOCK_KV, HEAD_DIM), (1, 0)
    )
    dV_blk = tl.make_block_ptr( 
        dV + off_bh_dv,(SEQ_LEN, HEAD_DIM),(s_dvs, s_dvd),(start_kv, 0),(BLOCK_KV, HEAD_DIM),(1, 0)
    )
    Q_T_blk = tl.make_block_ptr( 
        Q + off_bh_q,(HEAD_DIM, SEQ_LEN),(sqd, sqs),(0, 0),(HEAD_DIM, BLOCK_Q),(0, 1)
    )
    dO_blk = tl.make_block_ptr( 
        dO + off_bh_do,(SEQ_LEN, HEAD_DIM),(sos, sod),(0, 0),(BLOCK_Q, HEAD_DIM),(1, 0)
    )

    dV_acc = tl.zeros((BLOCK_KV, HEAD_DIM), dtype=tl.float32)
    dK_acc = tl.zeros((BLOCK_KV, HEAD_DIM), dtype=tl.float32)
    s = tl.full([1], softmax_scale, dtype=DTYPE)
    K_block = tl.load(K_blk, boundary_check=(0, 1), padding_option="zero").to(DTYPE) * s
    V_block = tl.load(V_blk, boundary_check=(0, 1), padding_option="zero").to(DTYPE)
    offs_kv  = start_kv + tl.arange(0, BLOCK_KV)
    
    # Loop over Q tiles
    num_steps = tl.cdiv(SEQ_LEN, BLOCK_Q)
    for qi in range(0, num_steps):
        qT_block = tl.load(Q_T_blk, boundary_check=(0, 1), padding_option="zero").to(DTYPE)
        dO_block = tl.load(dO_blk, boundary_check=(0, 1), padding_option="zero").to(DTYPE)
        
        start_q = qi * BLOCK_Q
        offs_q  = start_q + tl.arange(0, BLOCK_Q)
        m  = tl.load(M + offs_q, mask=offs_q < SEQ_LEN, other=0.0).to(tl.float32)
        Di = tl.load(D + offs_q, mask=offs_q < SEQ_LEN, other=0.0).to(tl.float32)

        QK_T = tl.dot(K_block, qT_block) 
        kv_valid = offs_kv < SEQ_LEN
        QK_T = tl.where(kv_valid[:, None], QK_T, -float("inf"))
        P_T = tl.exp(QK_T.to(tl.float32) - m[None, :])

        # --- dV += Pᵀ @ dO  (match operand dtypes) ---
        dV_acc += tl.dot(P_T.to(DTYPE), dO_block)

        # --- dpᵀ = V @ dOᵀ, then dSᵀ = Pᵀ * (dpᵀ - Di) ---
        dpT = tl.dot(V_block, tl.trans(dO_block)).to(tl.float32)
        dS_T = (P_T * (dpT - Di[None, :])).to(DTYPE)
        dK_acc = tl.dot(dS_T, tl.trans(qT_block), dK_acc)

        Q_T_blk = tl.advance(Q_T_blk, (0, BLOCK_Q))
        dO_blk = tl.advance(dO_blk, (BLOCK_Q, 0))

    # Tail-safe stores
    dK_acc *= s 
    tl.store(dV_blk, dV_acc.to(dV.type.element_ty), boundary_check=(0, 1))
    tl.store(dK_blk, dK_acc.to(dK.type.element_ty), boundary_check=(0, 1))
    

@triton.autotune(
    [
        triton.Config(
            {"BLOCK_Q": BLOCK_Q, "BLOCK_KV": BLOCK_KV, "GROUP_N": GROUP_N},
            num_stages=num_stages,
            num_warps=num_warps,
        )
        for BLOCK_Q in [64, 128]
        for BLOCK_KV in [32, 64]
        for GROUP_N in GROUP_NM_SWEEP
        for num_stages in NUM_STAGES_SWEEP
        for num_warps in NUM_WARPS_SWEEP
    ],
    key=KEY_CACHE,
)
@triton.jit
def _attn_bwd_dq(
    Q, K, V, dO, dQ, M, D,
    # Q strides
    sqb, sqh, sqs, sqd,
    # K strides
    skb, skh, sks, skd,
    # V strides
    svb, svh, svs, svd,
    # dO strides
    sob, soh, sos, sod,
    # dK strides
    s_dqb, s_dqh, s_dqs, s_dqd,
    NUM_HEADS: tl.constexpr , SEQ_LEN: tl.constexpr,
    BLOCK_Q: tl.constexpr, BLOCK_KV: tl.constexpr, 
    HEAD_DIM: tl.constexpr, DTYPE: tl.constexpr,
    GROUP_N: tl.constexpr, softmax_scale: tl.constexpr,
):
    pid_bh = tl.program_id(1)
    b = pid_bh // NUM_HEADS
    h = pid_bh %  NUM_HEADS
    
    off_bh_seq = (pid_bh * SEQ_LEN).to(tl.int64)
    M += off_bh_seq
    D += off_bh_seq

    # --- GROUP_M swizzle over Q tiles (tail-safe) ---
    pid_q = tl.program_id(0)
    num_tiles_m   = tl.cdiv(SEQ_LEN, BLOCK_Q)
    group_id      = pid_q // GROUP_N
    group_start   = group_id * GROUP_N
    # if this CTA's group starts past the last tile, exit early
    if group_start >= num_tiles_m:
        return
    tiles_in_this = tl.minimum(GROUP_N, num_tiles_m - group_start)
    m_in_grp      = pid_q - group_start
    m_eff         = m_in_grp % tiles_in_this
    rot           = pid_bh % tiles_in_this
    m_swizzled    = group_start + ((m_eff + rot) % tiles_in_this)

    start_q = m_swizzled * BLOCK_Q
    if start_q >= SEQ_LEN:
        return
    
    off_bh_k  = (b * skb   + h * skh  ).to(tl.int64)
    off_bh_v  = (b * svb   + h * svh  ).to(tl.int64)
    off_bh_dq = (b * s_dqb + h * s_dqh).to(tl.int64)
    off_bh_q  = (b * sqb   + h * sqh  ).to(tl.int64)
    off_bh_do = (b * sob   + h * soh  ).to(tl.int64)
    # ---------- block pointers ----------
    Q_blk = tl.make_block_ptr(
        Q + off_bh_q,(SEQ_LEN, HEAD_DIM),(sqs, sqd),(start_q, 0),(BLOCK_Q, HEAD_DIM),(1, 0),
    )
    dO_blk = tl.make_block_ptr(
        dO + off_bh_do,(SEQ_LEN, HEAD_DIM),(sos, sod),(start_q, 0),(BLOCK_Q, HEAD_DIM),(1, 0),
    )
    K_T_blk = tl.make_block_ptr(
        K + off_bh_k,(HEAD_DIM, SEQ_LEN),(skd, sks),(0, 0),(HEAD_DIM, BLOCK_KV),(0, 1),
    )
    V_T_blk = tl.make_block_ptr(
        V + off_bh_v,(HEAD_DIM, SEQ_LEN),(svd, svs),(0, 0),(HEAD_DIM, BLOCK_KV),(0, 1),
    )
    dQ_blk = tl.make_block_ptr(
        dQ + off_bh_dq,(SEQ_LEN, HEAD_DIM),(s_dqs, s_dqd),(start_q, 0),(BLOCK_Q, HEAD_DIM),(1, 0),
    )

    # ---------- indices & constants ----------
    offs_q = start_q + tl.arange(0, BLOCK_Q)
    offs_kv = tl.arange(0, BLOCK_KV)

    # row-wise scalars
    m  = tl.load(M + offs_q, mask=offs_q < SEQ_LEN, other=0.0)[:, None]  # [BLOCK_Q, 1]
    Di = tl.load(D + offs_q, mask=offs_q < SEQ_LEN, other=0.0)           # [BLOCK_Q]
    s = tl.full([1], softmax_scale, dtype=DTYPE)
    Q_block  = tl.load(Q_blk,  boundary_check=(0, 1), padding_option="zero") * s
    dO_block = tl.load(dO_blk, boundary_check=(0, 1), padding_option="zero")
    dQ_block = tl.zeros((BLOCK_Q, HEAD_DIM), dtype=tl.float32)

    # ---------- loop over KV tiles ----------
    num_steps = tl.cdiv(SEQ_LEN, BLOCK_KV)
    for step in range(num_steps):
        K_T_block = tl.load(K_T_blk, boundary_check=(0, 1), padding_option="zero")
        V_T_block = tl.load(V_T_blk, boundary_check=(0, 1), padding_option="zero")
        
        start_kv = step * BLOCK_KV
        kv_idx   = start_kv + offs_kv
        kv_valid = kv_idx < SEQ_LEN
        S = tl.dot(Q_block, K_T_block)                     # [BLOCK_Q, BLOCK_KV]
        S = tl.where(kv_valid[None, :], S, -float("inf"))
        P = tl.exp(S - m)                                  # [BLOCK_Q, BLOCK_KV]

        # dP = dO @ Vᵀ  (match dtypes for dot)
        dP = tl.dot(dO_block.to(DTYPE), V_T_block.to(DTYPE)).to(tl.float32)
        dS = (P * (dP - Di[:, None])).to(DTYPE)
        dQ_block = tl.dot(dS, tl.trans(K_T_block.to(DTYPE)), dQ_block)

        K_T_blk = tl.advance(K_T_blk, (0, BLOCK_KV))
        V_T_blk = tl.advance(V_T_blk, (0, BLOCK_KV))
    
    dQ_block *= s
    tl.store(dQ_blk, dQ_block.to(dQ.type.element_ty), boundary_check=(0, 1))



class TritonAttention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, Q, K, V):
        BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM = Q.size()
        comp_torch = _sdpa_comp_dtype(Q)
        comp_triton = _triton_compute_dtype(comp_torch)
        
        softmax_scale = 1 / (HEAD_DIM**0.5)
        O = torch.empty_like(Q)

        grid = lambda args: (
            triton.cdiv(SEQ_LEN, args["BLOCK_Q"]),
            BATCH_SIZE * NUM_HEADS,
        )
        # M is the logsumexp for the backward pass, one for each query
        M = torch.empty(
            (BATCH_SIZE, NUM_HEADS, SEQ_LEN), device=Q.device, dtype=torch.float32
        )
        _attn_fwd[grid](
            Q, K, V, M, O,
            *Q.stride(), *K.stride(), *V.stride(), *O.stride(),
            NUM_HEADS=Q.shape[1], SEQ_LEN=Q.shape[2], HEAD_DIM=HEAD_DIM, 
            softmax_scale=softmax_scale, DTYPE=comp_triton,
        )

        ctx.save_for_backward(Q, K, V, O, M)
        ctx.grid = grid
        ctx.softmax_scale = softmax_scale
        ctx.HEAD_DIM = HEAD_DIM
        ctx.comp_triton = comp_triton
        return O

    @staticmethod
    def backward(ctx, dO):
        Q, K, V, O, M = ctx.saved_tensors
        #dO = dO.contiguous()
        #assert dO.is_contiguous()
        #assert Q.stride() == K.stride() == V.stride() == O.stride() == dO.stride()
        dQ = torch.empty_like(Q)
        dK = torch.empty_like(K)
        dV = torch.empty_like(V)

        BATCH_SIZE, NUM_HEADS, SEQ_LEN, _ = Q.size()

        D = torch.empty_like(M) 
        pre_grid = lambda meta: (triton.cdiv(SEQ_LEN, meta["BLOCK_Q"]),
                         BATCH_SIZE * NUM_HEADS)
        _attn_bwd_preprocess[pre_grid](
            O, dO, D, 
            *O.stride(),
            *dO.stride(),
            NUM_HEADS=NUM_HEADS, SEQ_LEN=SEQ_LEN, HEAD_DIM=ctx.HEAD_DIM,
        )
        #assert torch.isnan(D).sum() == 0
        dkdv_grid = lambda meta: (triton.cdiv(SEQ_LEN, meta["BLOCK_KV"]),
                BATCH_SIZE * NUM_HEADS)
        # Fix KV and iterate through all the Q blocks
        _attn_bwd_dk_dv[dkdv_grid](
            Q, K, V, dO, dK, dV, M, D,
            *Q.stride(), *K.stride(), *V.stride(), *dO.stride(),
            *dK.stride(), *dV.stride(),
            NUM_HEADS=NUM_HEADS, SEQ_LEN=SEQ_LEN, HEAD_DIM=ctx.HEAD_DIM, DTYPE=ctx.comp_triton, 
            softmax_scale=ctx.softmax_scale
        )
        #assert torch.isnan(dK).sum() == 0
        #assert torch.isnan(dV).sum() == 0

        dq_grid = lambda meta: (triton.cdiv(SEQ_LEN, meta["BLOCK_Q"]),
                    BATCH_SIZE * NUM_HEADS)
        _attn_bwd_dq[dq_grid](
            Q, K, V, dO, dQ, M, D,
            *Q.stride(), *K.stride(), *V.stride(), *dO.stride(),
            *dQ.stride(), 
            NUM_HEADS=NUM_HEADS, SEQ_LEN=SEQ_LEN, HEAD_DIM=ctx.HEAD_DIM, DTYPE=ctx.comp_triton,
            softmax_scale=ctx.softmax_scale
        )
        #assert torch.isnan(dQ).sum() == 0
        return dQ, dK, dV
    
    
def sdpa_triton_fa(Q: Tensor, K: Tensor, V: Tensor):
    """ViT-S-only autograd op (single-pass forward + exact backward)."""
    #Q = Q.contiguous()
    #K = K.contiguous()
    #V = V.contiguous()
    return TritonAttention.apply(Q, K, V)

class Attention(nn.Module):
    """Standard Multi-head Self Attention module with QKV projection.

    This module implements the standard multi-head attention mechanism used in transformers.
    It supports both the fused attention implementation (scaled_dot_product_attention) for
    efficiency when available, and a manual implementation otherwise. The module includes
    options for QK normalization, attention dropout, and projection dropout.
    """

    def __init__(
            self,
            dim:int = 384,
            num_heads: int = 6,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            scale_norm: bool = False,
            proj_bias: bool = True,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer = nn.LayerNorm,
            device="cuda",
            dtype=torch.float32,
            triton_kernel=True
            ) -> None:
        """Initialize the Attention module.

        Args:
            dim: Input dimension of the token embeddings
            num_heads: Number of attention heads
            qkv_bias: Whether to use bias in the query, key, value projections
            qk_norm: Whether to apply normalization to query and key vectors
            proj_bias: Whether to use bias in the output projection
            attn_drop: Dropout rate applied to the attention weights
            proj_drop: Dropout rate applied after the output projection
            norm_layer: Normalization layer constructor for QK normalization if enabled
        """
        super().__init__()
        dd = {'device': device, 'dtype': dtype}
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        if qk_norm or scale_norm:
            assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
        self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.norm = norm_layer(dim, **dd) if scale_norm else nn.Identity()
        self.proj = nn.Linear(dim, dim, bias=proj_bias, **dd)
        self.proj_drop = nn.Dropout(proj_drop)
        self.triton_kernel = triton_kernel

    def forward(
            self,
            x: torch.Tensor,
    ) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)
        
        if self.triton_kernel:
            x = sdpa_triton_fa(q, k, v)
        else:
            x = F.scaled_dot_product_attention(q, k, v)
        
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.norm(x)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class MLP(nn.Module):
    def __init__(self, dim: int, mlp_ratio: float = 4.0, drop: float = 0.0):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int = 384,
        num_heads: int = 6,
        mlp_ratio: float = 4.0,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        qkv_bias: bool = True,
        qk_norm: bool = False,
        scale_norm: bool = False,
        norm_layer=nn.LayerNorm,
        triton=True
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim=dim, num_heads=num_heads, qkv_bias=qkv_bias,
            qk_norm=qk_norm, scale_norm=scale_norm,
            attn_drop=attn_drop, proj_drop=proj_drop,
            norm_layer=norm_layer, triton_kernel=triton
        )
        self.norm2 = norm_layer(dim)
        self.mlp = MLP(dim, mlp_ratio=mlp_ratio, drop=proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(x) #self.norm1(x))
        assert torch.isnan(x).sum() == 0
        x = x + self.mlp(self.norm2(x))
        return x

class ToyTransformer(nn.Module):
    """
    Minimal Transformer encoder using your Attention block.
    Expects inputs shaped [B, N, C] and returns logits [B, num_classes].
    """
    def __init__(
        self,
        dim: int = 384,
        depth: int = 4,
        num_heads: int = 6,
        mlp_ratio: float = 4.0,
        num_classes: int = 1000,
        max_len: int = 197,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        qkv_bias: bool = True,
        qk_norm: bool = False,
        scale_norm: bool = False,
        norm_layer=nn.LayerNorm,
        cls_pool: str = "mean",  # "mean" or "first"
        triton=True
    ):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.cls_pool = cls_pool

        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                attn_drop=attn_drop, proj_drop=proj_drop,
                qkv_bias=qkv_bias, qk_norm=qk_norm, scale_norm=scale_norm,
                norm_layer=norm_layer, triton=triton
            )
            for _ in range(depth)
        ])
        self.norm = norm_layer(dim)
        self.head = nn.Linear(dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        for i, blk in enumerate(self.blocks):
            x = blk(x)
            assert torch.isnan(x).sum() == 0, f"block {i}"

        x = self.norm(x)
        if self.cls_pool == "first":
            x = x[:, 0]                 # [B, C]
        else:
            x = x.mean(dim=1)           # [B, C] mean-pool over N
        return self.head(x)

In [19]:
B, N, D = 32, 196, 384
from copy import deepcopy

model_triton = ToyTransformer(dim=D, depth=16, num_heads=6, num_classes=100).cuda()  # fp32 weights
model_torch = deepcopy(model_triton)
for b in model_torch.blocks:
    b.attn.triton_kernel = False


opt_triton = torch.optim.AdamW(model_triton.parameters(), lr=1e-3)
opt_torch = torch.optim.AdamW(model_torch.parameters(), lr=1e-3)

for i  in range(10):
    x = torch.randn(B, N, D, device="cuda")  # inputs can be bf16 or fp32
    x_triton = x.clone().detach().requires_grad_(True)
    x_torch = x.clone().detach().requires_grad_(True)
    with torch.autocast("cuda", dtype=torch.bfloat16):
        y_triton = model_triton(x_triton)
        y_torch = model_torch(x_torch)
        loss_triton = y_triton.mean()
        loss_torch = y_torch.mean()
        l_diff = torch.abs(loss_triton - loss_torch)
        print("l_diff", l_diff.max().item(), l_diff.mean().item())
        
        
    loss_triton.backward()
    loss_torch.backward()
    g_diff = torch.abs(x_triton.grad - x_torch.grad)
    print("g_diff", g_diff.max().item(), g_diff.mean().item())
    
    opt_triton.step()
    opt_torch.step()
    
    opt_triton.zero_grad()
    opt_torch.zero_grad()

l_diff 0.0 0.0
g_diff 1.865180365712149e-08 2.765358608769475e-09
l_diff 0.0 0.0
g_diff 3.820164238277357e-07 1.1808817035330321e-08
l_diff 0.0 0.0
g_diff 7.067016667861026e-06 8.927385408696864e-08
l_diff 0.0 0.0
g_diff 4.4637479732045904e-05 3.429436787882878e-07
l_diff 0.0 0.0
g_diff 0.00734396418556571 1.0270650818711147e-05
l_diff 0.0 0.0
g_diff 0.18744412064552307 0.00034679876989685
l_diff 0.015625 0.015625
g_diff 0.007186077069491148 1.8102227841154672e-05
l_diff 0.078125 0.078125
g_diff 0.0013323574094101787 5.336064987204736e-06
l_diff 0.109375 0.109375
g_diff 9.639918425818905e-05 7.131186521291966e-07
l_diff 0.078125 0.078125
g_diff 0.0001267079933313653 3.625587225997151e-07


# Test

In [1]:
import sys, os

os.environ["COMET_API_KEY"]="R7OuT6FolA02VmQRI82xDN48O"
os.environ["COMET_DISABLE_AUTO_LOGGING"]="1"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"]="1"
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"]="1"
os.environ["TRITON_PRINT_AUTOTUNING"]="1"

from pathlib import Path
import yaml
import torch
import timm
import train as timm_train
 
sys.path.insert(0, "/notebooks/pytorch-image-models")  # your fork

torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(False)
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

dynamo_config = torch._dynamo.config
dynamo_config.compiled_autograd = True
dynamo_config.capture_scalar_outputs = False
dynamo_config.cache_size_limit = 512

print("Python exe:", sys.executable)
print("Torch:", torch.__version__, "CUDA:", torch.version.cuda)
print(sys.executable)                             # your venv python
print(Path(timm.__file__).resolve())      # -> /notebooks/pytorch-image-models/timm/__init__.py

with open("/notebooks/params_timm.yaml", 'r', encoding='utf-8') as f:
    cfg = yaml.safe_load(f)




Python exe: /notebooks/venvs/pt27cu118/bin/python
Torch: 2.7.1+cu118 CUDA: 11.8
/notebooks/venvs/pt27cu118/bin/python
/notebooks/pytorch-image-models/timm/__init__.py


In [11]:
model = timm_train.main(cfg).cuda()

[1;38;5;39mCOMET INFO:[0m An experiment with the same configuration options is already running and will be reused.
Training with a single process on 1 device (cuda).
Training with a single process on 1 device (cuda).


Setting model to deit3_small_patch16_224
Setting num_classes to 1000
Setting img_size to 224
Setting in_chans to None
Setting dataset to hfds-disk:/notebooks/data/imagenet_1k_resized_256
Setting data_dir to notebooks/data/imagenet_1k_resized_256
Setting train_split to train
Setting val_split to val
Setting interpolation to bicubic
Setting train_interpolation to random
Setting crop_pct to None
Setting batch_size to 1024
Setting validation_batch_size to None
Setting workers to 10
Setting pin_mem to True
Setting channels_last to True
Setting amp to True
Setting amp_dtype to bfloat16
Setting amp_impl to native
Setting opt to adamw
Setting weight_decay to 0.05
Setting lr to None
Setting lr_base to 0.0005
Setting lr_base_size to 1024
Setting lr_base_scale to linear
Setting momentum to 0.9
Setting sched to cosine
Setting epochs to 400
Setting warmup_epochs to 5
Setting warmup_lr to 1e-05
Setting min_lr to 1e-06
Setting lr_k_decay to 1.0
Setting cooldown_epochs to 0
Setting lr_cycle_limit to 1

In [18]:
model

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=Fal

In [16]:
model_comp = torch.compile(
                model,
                backend="inductor",
                mode="max-autotune",
                fullgraph=True,
                dynamic=False
            )

In [17]:
B, C, H, W = 2, 3, 224, 224

opt = torch.optim.AdamW(model_comp.parameters(), lr=1e-3)

for i  in range(10):
    x = torch.randn(B, C, H, W, device="cuda", dtype=torch.float32)
    with torch.autocast("cuda", dtype=torch.bfloat16):
        y = model_comp(x)
        loss = y.mean()
        
    loss.backward()
    
    opt.step()    
    opt.zero_grad()

In [None]:
"""
Train: 54 [ 100/1251 (  8%)]  Loss: 0.00292 (0.00420)  Time: 0.300s, 3416.47/s  (0.581s, 1763.26/s)  LR: 4.779e-04  Data: 0.019 (0.141)
Train: 54 [ 150/1251 ( 12%)]  Loss: 0.00472 (0.00422)  Time: 0.297s, 3452.68/s  (0.526s, 1946.67/s)  LR: 4.779e-04  Data: 0.027 (0.141)
Train: 54 [ 200/1251 ( 16%)]  Loss: 0.00318 (0.00418)  Time: 0.300s, 3412.30/s  (0.502s, 2040.79/s)  LR: 4.779e-04  Data: 0.019 (0.144)
Train: 54 [ 250/1251 ( 20%)]  Loss: 0.00510 (0.00414)  Time: 0.282s, 3625.44/s  (0.482s, 2125.26/s)  LR: 4.779e-04  Data: 0.013 (0.141)
Train: 54 [ 300/1251 ( 24%)]  Loss: 0.00453 (0.00416)  Time: 0.292s, 3502.42/s  (0.470s, 2179.30/s)  LR: 4.779e-04  Data: 0.022 (0.140)
Train: 54 [ 350/1251 ( 28%)]  Loss: 0.00467 (0.00417)  Time: 0.283s, 3619.12/s  (0.463s, 2212.44/s)  LR: 4.779e-04  Data: 0.013 (0.141)
"""

# End