Skip to content

Gated delta net kernel correctness issue #1067

@yf225

Description

@yf225

Reported by @v0i0 :

Repro:

"""
Gated Delta Net Fwd H Kernel
============================
This code implements a fwd_h kernel as used in gated delta net
"""

# %%
# Imports
# -------
from __future__ import annotations
import functools
import math
import torch
import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl

# %%
# Helion Kernel Implementation
# ----------------------------
@helion.kernel()
def helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c):
    """
    Argument:
        k_c: (batch, nchunks, chunk_size, nheads, dhead)
        w_c: (batch, nchunks, chunk_size, nheads, dhead)
        u_c: (batch, nchunks, chunk_size, nheads, expand_v*dhead)
        g_c: (batch, nchunks, chunk_size, nheads)
    Return:
        h: (batch, nchunks, nheads, dhead, expand_v*dhead)
    """
    batch, nchunks, chunk_size, nheads, dhead = k_c.shape
    dhead = hl.specialize(dhead)
    chunk_size = hl.specialize(chunk_size)
    dstate = u_c.shape[-1]
    acc_dtype = torch.float32
    dtype = k_c.dtype
    h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k_c.device)
    block_v = hl.register_block_size(dstate)
    seqlen = chunk_size * nchunks
    for tile_b, tile_h, tile_v in hl.tile([batch, nheads, dstate], block_size=[1, 1, block_v]):
        b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
        for i_t in hl.grid(nchunks):
            h[tile_b.begin, i_t, tile_h.begin, :, tile_v] = b_h.to(dtype)
            print("b_h at the start", b_h)
            b_w = w_c[tile_b.begin, i_t, :, tile_h.begin, :]
            c_h = b_h.to(dtype)
            b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype)
            print("b_v at the start", b_v)
            p_v = u_c[tile_b.begin, i_t, :, tile_h.begin, tile_v].to(acc_dtype)
            b_v = p_v - b_v
            print("b_v after subtraction", b_v)
            last_idx = min((i_t + 1) * chunk_size, seqlen) - 1
            m_t = (i_t * chunk_size + hl.arange(0, chunk_size)) < seqlen
            b_g_last = g_c[tile_b.begin, last_idx // chunk_size, last_idx % chunk_size, tile_h.begin].to(acc_dtype)
            print("b_g_last", b_g_last)
            b_g = g_c[tile_b.begin, i_t, :, tile_h.begin].to(acc_dtype)
            print("b_g", b_g)
            b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
            print("b_v after multiplication", b_v)
            b_g_last = torch.exp(b_g_last)
            b_h *= b_g_last
            print("b_h after multiplication", b_h)
            b_v = b_v.to(dtype)
            p_k = k_c[tile_b.begin, i_t, :, tile_h.begin, :]
            b_h = hl.dot(p_k.T, b_v, acc=b_h)
            print("b_h after addition", b_h)
    return h

def helion_gdn_fwd_h(k, w, u, g, chunk_size):
    """
    Argument:
        k: (batch, seqlen, nheads, dhead)
        w: (batch, seqlen, nheads, dhead)
        u: (batch, seqlen, nheads, expand_v*dhead)
        g: (batch, seqlen, nheads)
        chunk_size: int
    Return:
        h: (batch, nchunks, nheads, dhead, expand_v*dhead)
    """
    batch, seqlen, nheads, dhead = k.shape
    dstate = u.shape[-1]
    nchunks = (seqlen + chunk_size - 1) // chunk_size
    k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead)
    w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead)
    u_c = u.reshape(batch, nchunks, chunk_size, nheads, dstate)
    g_c = g.reshape(batch, nchunks, chunk_size, nheads)
    h = helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c)
    print(h)
    return h

# %%
# Reference Function
# -------------
def ref_gdn_fwd_h(k, w, u, g, chunk_size):
    """
    Argument:
        k: (batch, seqlen, nheads, dhead)
        w: (batch, seqlen, nheads, dhead)
        u: (batch, seqlen, nheads, expand_v*dhead)
        g: (batch, seqlen, nheads)
        chunk_size: int
    Return:
        h: (batch, nchunks, nheads, dhead, expand_v*dhead)
    """
    batch, seqlen, nheads, dhead = k.shape
    expand_v = u.shape[-1] // dhead
    nchunks = (seqlen + chunk_size - 1) // chunk_size
    acc_dtype = torch.float32
    dtype = k.dtype
    h = torch.empty(batch, nchunks, nheads, dhead, expand_v*dhead, dtype=k.dtype, device=k.device)
    b_h = torch.zeros(batch, nheads, dhead, expand_v*dhead, dtype=acc_dtype, device=k.device)
    k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead)
    w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead)
    u_c = u.reshape(batch, nchunks, chunk_size, nheads, expand_v*dhead)
    g_c = g.reshape(batch, nchunks, chunk_size, nheads)
    for i_t in range(nchunks):
        h[:, i_t, :, :, :] = b_h.to(dtype)
        print("b_h at the start", b_h)
        b_w = w_c[:, i_t, :, :, :].to(acc_dtype)
        c_h = b_h.to(dtype).to(acc_dtype)
        b_v = torch.einsum("bchk,bhkv->bchv", b_w, c_h)
        print("b_v at the start", b_v)
        p_v = u_c[:, i_t, :, :, :].to(acc_dtype)
        b_v = p_v - b_v
        print("b_v after subtraction", b_v)
        last_idx = min((i_t + 1) * chunk_size, seqlen) - 1
        m_t = (i_t * chunk_size + torch.arange(0, chunk_size, device=k.device)) < seqlen
        b_g_last = g[:, last_idx, :].to(acc_dtype)
        print("b_g_last", b_g_last)
        b_g = g_c[:, i_t, :, :].to(acc_dtype) # batch, chunk, nheads
        print("b_g", b_g)
        b_v *= torch.where(m_t.unsqueeze(0).unsqueeze(-1), torch.exp(b_g_last.unsqueeze(1) - b_g), 0).unsqueeze(-1)
        print("b_v after multiplication", b_v)
        b_g_last = torch.exp(b_g_last)
        b_h *= b_g_last.unsqueeze(-1).unsqueeze(-1)
        print("b_h after multiplication", b_h)
        b_v = b_v.to(dtype).to(acc_dtype)
        p_k = k_c[:, i_t, :, :, :].to(acc_dtype)
        b_h += torch.einsum("bchk,bchv->bhkv", p_k, b_v)
        print("b_h after saddition", b_h)
    print(h)
    return h

# %%
# Testing Function
# -------------
def test(
    batch: int,
    nheads: int,
    seqlen: int,
    chunk_size: int,
    dhead: int,
    dstate: int,
    dtype: torch.dtype = torch.float16,
) -> None:
    k = torch.randn(batch, seqlen, nheads, dhead, dtype=torch.bfloat16, device="cuda")
    k = torch.ones(batch, seqlen, nheads, dhead, dtype=torch.bfloat16, device="cuda") / dhead
    w = torch.randn(batch, seqlen // chunk_size, chunk_size, nheads, dhead, dtype=torch.float32, device='cuda')
    wu, ws, wv = torch.linalg.svd(w.permute(0, 1, 3, 2, 4), full_matrices=False)
    w = torch.einsum("bnhik,bnhkj->bnhij", wu, wv)
    w = w.permute(0, 1, 3, 2, 4).reshape(batch, seqlen, nheads, dhead).to(torch.bfloat16)
    u = torch.randn(batch, seqlen, nheads, dstate, dtype=torch.bfloat16, device="cuda")
    # u = torch.ones(batch, seqlen, nheads, dstate, dtype=torch.bfloat16, device="cuda") // dstate
    g = torch.cumsum(0.5*math.log(1/dhead)*torch.rand(batch, seqlen, nheads, dtype=torch.float32, device='cuda'), dim=1)
    # g = torch.zeros(batch, seqlen, nheads, device='cuda', dtype=torch.float32)
    args = (k, w, u, g, chunk_size)
    print([a.shape for a in args if isinstance(a, torch.Tensor)])
    # run_example(helion_gdn_fwd_h, ref_gdn_fwd_h, args)
    helion_gdn_fwd_h(*args)
    ref_gdn_fwd_h(*args)

# %%
# Main Function
# -----------
def main() -> None:
    """
    Main entry point that runs the attention kernel test with specific parameters.
    """
    test(1, 1, 4, 2, 2, 2)
    #test(8, 80, 4096, 256, 64, 128)
if __name__ == "__main__":
    print(open(__file__).read())
    print("======================")
    main()

The generated Triton kernel is:

@triton.jit
def _helion_helion_gdn_fwd_h_kernel(h, w_c, u_c, g_c, k_c, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_4: tl.constexpr):
    # src[gdn_fwd_h.py:51]: for tile_b, tile_h, tile_v in hl.tile([batch, nheads, dstate], block_size=[1, 1, block_v]):
    num_blocks_0 = 1
    num_blocks_1 = 1
    pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
    offset_0 = pid_2 * _BLOCK_SIZE_0
    indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
    indices_5 = tl.arange(0, _RDIM_SIZE_4).to(tl.int32)
    # src[gdn_fwd_h.py:52]: b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
    b_h = tl.full([2, _BLOCK_SIZE_0], 0.0, tl.float32)
    # src[gdn_fwd_h.py:53]: for i_t in hl.grid(nchunks):
    # src[gdn_fwd_h.py:54]:     h[tile_b.begin, i_t, tile_h.begin, :, tile_v] = b_h.to(dtype)
    # src[gdn_fwd_h.py:55]:     print("b_h at the start", b_h)
    # src[gdn_fwd_h.py:53-77]: ...
    for offset_4 in tl.range(0, 2):
        b_h_copy = b_h
        b_h_copy_0 = b_h_copy
        # src[gdn_fwd_h.py:54]: h[tile_b.begin, i_t, tile_h.begin, :, tile_v] = b_h.to(dtype)
        v_0 = tl.cast(b_h_copy_0, tl.bfloat16)
        tl.store(h + (offset_4 * 4 + indices_5[:, None] * 2 + indices_0[None, :] * 1), v_0, None)
        # src[gdn_fwd_h.py:55]: print("b_h at the start", b_h)
        tl.device_print('b_h at the start', b_h_copy_0)
        # src[gdn_fwd_h.py:56]: b_w = w_c[tile_b.begin, i_t, :, tile_h.begin, :]
        b_w = tl.load(w_c + (offset_4 * 4 + indices_5[:, None] * 2 + indices_5[None, :] * 1), None)
        # src[gdn_fwd_h.py:57]: c_h = b_h.to(dtype)
        v_1 = tl.cast(b_h_copy_0, tl.bfloat16)
        # src[gdn_fwd_h.py:58]: b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype)
        b_v = tl.dot(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(b_w, tl.bfloat16), tl.zeros_like(tl.cast(b_w, tl.bfloat16))), [0, 2, 1]), [2, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(b_w, tl.bfloat16), tl.zeros_like(tl.cast(b_w, tl.bfloat16))), [0, 2, 1]), [2, 4]))), [0, 2, 1]), [2, 8]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(b_w, tl.bfloat16), tl.zeros_like(tl.cast(b_w, tl.bfloat16))), [0, 2, 1]), [2, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(b_w, tl.bfloat16), tl.zeros_like(tl.cast(b_w, tl.bfloat16))), [0, 2, 1]), [2, 4]))), [0, 2, 1]), [2, 8]))), [0, 2, 1]), [2, 16]), tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(v_1, tl.bfloat16), tl.zeros_like(tl.cast(v_1, tl.bfloat16))), [2, 0, 1]), [4, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(v_1, tl.bfloat16), tl.zeros_like(tl.cast(v_1, tl.bfloat16))), [2, 0, 1]), [4, 2]))), [2, 0, 1]), [8, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(v_1, tl.bfloat16), tl.zeros_like(tl.cast(v_1, tl.bfloat16))), [2, 0, 1]), [4, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(v_1, tl.bfloat16), tl.zeros_like(tl.cast(v_1, tl.bfloat16))), [2, 0, 1]), [4, 2]))), [2, 0, 1]), [8, 2]))), [2, 0, 1]), [16, 2]), input_precision='tf32', out_dtype=tl.float32)
        # src[gdn_fwd_h.py:59]: print("b_v at the start", b_v)
        tl.device_print('b_v at the start', b_v)
        # src[gdn_fwd_h.py:60]: p_v = u_c[tile_b.begin, i_t, :, tile_h.begin, tile_v].to(acc_dtype)
        load_1 = tl.load(u_c + (offset_4 * 4 + indices_5[:, None] * 2 + indices_0[None, :] * 1), None)
        v_2 = tl.cast(load_1, tl.float32)
        # src[gdn_fwd_h.py:61]: b_v = p_v - b_v
        v_3 = v_2 - b_v
        # src[gdn_fwd_h.py:62]: print("b_v after subtraction", b_v)
        tl.device_print('b_v after subtraction', v_3)
        # src[gdn_fwd_h.py:64]: m_t = (i_t * chunk_size + hl.arange(0, chunk_size)) < seqlen
        mul_1 = 2 * offset_4
        iota = tl.arange(0, 2)
        v_4 = tl.cast(mul_1, tl.int32)
        v_5 = iota + v_4
        v_6 = tl.full([], 4, tl.int32)
        v_7 = v_5 < v_6
        # src[gdn_fwd_h.py:65]: b_g_last = g_c[tile_b.begin, last_idx // chunk_size, last_idx % chunk_size, tile_h.begin].to(acc_dtype)
        b_g_last = tl.load(g_c + (1 * 2 + 1 * 1), None)  # <----- this line
        # src[gdn_fwd_h.py:66]: print("b_g_last", b_g_last)
        tl.device_print('b_g_last', b_g_last)
        # src[gdn_fwd_h.py:67]: b_g = g_c[tile_b.begin, i_t, :, tile_h.begin].to(acc_dtype)
        b_g = tl.load(g_c + (offset_4 * 2 + indices_5 * 1), None)
        # src[gdn_fwd_h.py:68]: print("b_g", b_g)
        tl.device_print('b_g', b_g)
        # src[gdn_fwd_h.py:69]: b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
        v_8 = b_g_last[None]
        v_9 = v_8 - b_g
        v_10 = libdevice.exp(v_9)
        v_11 = 0.0
        v_12 = v_11[None]
        v_13 = tl.where(v_7, v_10, v_12)
        subscript = v_13[:, None]
        v_14 = v_3 * subscript
        # src[gdn_fwd_h.py:70]: print("b_v after multiplication", b_v)
        tl.device_print('b_v after multiplication', v_14)
        # src[gdn_fwd_h.py:71]: b_g_last = torch.exp(b_g_last)
        v_15 = libdevice.exp(b_g_last)
        # src[gdn_fwd_h.py:72]: b_h *= b_g_last
        v_16 = v_15[None, None]
        v_17 = b_h_copy_0 * v_16
        # src[gdn_fwd_h.py:73]: print("b_h after multiplication", b_h)
        tl.device_print('b_h after multiplication', v_17)
        # src[gdn_fwd_h.py:74]: b_v = b_v.to(dtype)
        v_18 = tl.cast(v_14, tl.bfloat16)
        # src[gdn_fwd_h.py:75]: p_k = k_c[tile_b.begin, i_t, :, tile_h.begin, :]
        p_k = tl.load(k_c + (offset_4 * 4 + indices_5[:, None] * 2 + indices_5[None, :] * 1), None)
        # src[gdn_fwd_h.py:76]: b_h = hl.dot(p_k.T, b_v, acc=b_h)
        permute = tl.permute(p_k, [1, 0])
        b_h = tl.dot(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(permute, tl.bfloat16), tl.zeros_like(tl.cast(permute, tl.bfloat16))), [0, 2, 1]), [2, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(permute, tl.bfloat16), tl.zeros_like(tl.cast(permute, tl.bfloat16))), [0, 2, 1]), [2, 4]))), [0, 2, 1]), [2, 8]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(permute, tl.bfloat16), tl.zeros_like(tl.cast(permute, tl.bfloat16))), [0, 2, 1]), [2, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(permute, tl.bfloat16), tl.zeros_like(tl.cast(permute, tl.bfloat16))), [0, 2, 1]), [2, 4]))), [0, 2, 1]), [2, 8]))), [0, 2, 1]), [2, 16]), tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(v_18, tl.bfloat16), tl.zeros_like(tl.cast(v_18, tl.bfloat16))), [2, 0, 1]), [4, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(v_18, tl.bfloat16), tl.zeros_like(tl.cast(v_18, tl.bfloat16))), [2, 0, 1]), [4, 2]))), [2, 0, 1]), [8, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(v_18, tl.bfloat16), tl.zeros_like(tl.cast(v_18, tl.bfloat16))), [2, 0, 1]), [4, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(v_18, tl.bfloat16), tl.zeros_like(tl.cast(v_18, tl.bfloat16))), [2, 0, 1]), [4, 2]))), [2, 0, 1]), [8, 2]))), [2, 0, 1]), [16, 2]), acc=v_17, input_precision='tf32', out_dtype=tl.float32)
        # src[gdn_fwd_h.py:77]: print("b_h after addition", b_h)
        tl.device_print('b_h after addition', b_h)

def helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c, *, _launcher=_default_launcher):
    """
    Argument:
        k_c: (batch, nchunks, chunk_size, nheads, dhead)
        w_c: (batch, nchunks, chunk_size, nheads, dhead)
        u_c: (batch, nchunks, chunk_size, nheads, expand_v*dhead)
        g_c: (batch, nchunks, chunk_size, nheads)
    Return:
        h: (batch, nchunks, nheads, dhead, expand_v*dhead)
    """
    # src[gdn_fwd_h.py:39]: batch, nchunks, chunk_size, nheads, dhead = k_c.shape
    batch, nchunks, chunk_size, nheads, dhead = k_c.shape
    # src[gdn_fwd_h.py:40]: dhead = hl.specialize(dhead)
    dhead = 2
    # src[gdn_fwd_h.py:42]: dstate = u_c.shape[-1]
    dstate = u_c.shape[-1]
    # src[gdn_fwd_h.py:44]: acc_dtype = torch.float32
    acc_dtype = torch.float32
    # src[gdn_fwd_h.py:45]: dtype = k_c.dtype
    dtype = k_c.dtype
    # src[gdn_fwd_h.py:47]: h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k_c.device)
    h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k_c.device)
    # src[gdn_fwd_h.py:51]: for tile_b, tile_h, tile_v in hl.tile([batch, nheads, dstate], block_size=[1, 1, block_v]):
    _BLOCK_SIZE_0 = 2
    _RDIM_SIZE_4 = 2
    # src[gdn_fwd_h.py:51]: for tile_b, tile_h, tile_v in hl.tile([batch, nheads, dstate], block_size=[1, 1, block_v]):
    # src[gdn_fwd_h.py:52]:     b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
    # src[gdn_fwd_h.py:53]:     for i_t in hl.grid(nchunks):
    # src[gdn_fwd_h.py:51-77]: ...
    _launcher(_helion_helion_gdn_fwd_h_kernel, (1 * 1 * triton.cdiv(2, _BLOCK_SIZE_0),), h, w_c, u_c, g_c, k_c, _BLOCK_SIZE_0, _RDIM_SIZE_4, num_warps=4, num_stages=1)
    # src[gdn_fwd_h.py:78]: return h
    return h

def call():
    from torch._dynamo.testing import rand_strided
    # src[gdn_fwd_h.py:28]: def helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c):
    # src[gdn_fwd_h.py:29]:     """
    # src[gdn_fwd_h.py:30]:     Argument:
    # src[gdn_fwd_h.py:28-78]: ...
    k_c = rand_strided(size=(1, 2, 2, 1, 2), stride=(8, 4, 2, 2, 1), dtype=torch.bfloat16, device='cuda:0')
    w_c = rand_strided(size=(1, 2, 2, 1, 2), stride=(8, 4, 2, 2, 1), dtype=torch.bfloat16, device='cuda:0')
    u_c = rand_strided(size=(1, 2, 2, 1, 2), stride=(8, 4, 2, 2, 1), dtype=torch.bfloat16, device='cuda:0')
    g_c = rand_strided(size=(1, 2, 2, 1), stride=(4, 2, 1, 1), dtype=torch.float32, device='cuda:0')
    helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c)

if __name__ == '__main__':
    call()

Notice how for # <----- this line there is no dependence on offset_4 or any loop variable, which is wrong. Looks like it hates the // and % ops, and manages to simplify them out.

But if we write it this way it will work:

"""
Gated Delta Net Fwd H Kernel
============================
This code implements a fwd_h kernel as used in gated delta net
"""
# %%
# Imports
# -------
from __future__ import annotations
import functools
import math
import torch
import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl

# %%
# Helion Kernel Implementation
# ----------------------------
@helion.kernel()
def helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c):
    """
    Argument:
        k_c: (batch, nchunks, chunk_size, nheads, dhead)
        w_c: (batch, nchunks, chunk_size, nheads, dhead)
        u_c: (batch, nchunks, chunk_size, nheads, expand_v*dhead)
        g_c: (batch, nchunks, chunk_size, nheads)
    Return:
        h: (batch, nchunks, nheads, dhead, expand_v*dhead)
    """
    batch, nchunks, chunk_size, nheads, dhead = k_c.shape
    dhead = hl.specialize(dhead)
    chunk_size = hl.specialize(chunk_size)
    dstate = u_c.shape[-1]
    acc_dtype = torch.float32
    dtype = k_c.dtype
    h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k_c.device)
    block_v = hl.register_block_size(dstate)
    seqlen = chunk_size * nchunks
    for tile_b, tile_h, tile_v in hl.tile([batch, nheads, dstate], block_size=[1, 1, block_v]):
        b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
        for i_t in range(nchunks):
            h[tile_b.begin, i_t, tile_h.begin, :, tile_v] = b_h.to(dtype)
            b_w = w_c[tile_b.begin, i_t, :, tile_h.begin, :]
            c_h = b_h.to(dtype)
            b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype)
            p_v = u_c[tile_b.begin, i_t, :, tile_h.begin, tile_v].to(acc_dtype)
            b_v = p_v - b_v
            m_t = (i_t * chunk_size + hl.arange(0, chunk_size)) < seqlen
            b_g_last = g_c[tile_b.begin, i_t, chunk_size - 1, tile_h.begin].to(acc_dtype)
            b_g = g_c[tile_b.begin, i_t, :, tile_h.begin].to(acc_dtype)
            b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
            b_g_last = torch.exp(b_g_last)
            b_h *= b_g_last
            b_v = b_v.to(dtype)
            p_k = k_c[tile_b.begin, i_t, :, tile_h.begin, :]
            b_h = hl.dot(p_k.T, b_v, acc=b_h)
    return h

def helion_gdn_fwd_h(k, w, u, g, chunk_size):
    """
    Argument:
        k: (batch, seqlen, nheads, dhead)
        w: (batch, seqlen, nheads, dhead)
        u: (batch, seqlen, nheads, expand_v*dhead)
        g: (batch, seqlen, nheads)
        chunk_size: int
    Return:
        h: (batch, nchunks, nheads, dhead, expand_v*dhead)
    """
    batch, seqlen, nheads, dhead = k.shape
    dstate = u.shape[-1]
    nchunks = (seqlen + chunk_size - 1) // chunk_size
    k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead)
    w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead)
    u_c = u.reshape(batch, nchunks, chunk_size, nheads, dstate)
    g_c = g.reshape(batch, nchunks, chunk_size, nheads)
    h = helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c)
    return h

# %%
# Reference Function
# -------------
def ref_gdn_fwd_h(k, w, u, g, chunk_size):
    """
    Argument:
        k: (batch, seqlen, nheads, dhead)
        w: (batch, seqlen, nheads, dhead)
        u: (batch, seqlen, nheads, expand_v*dhead)
        g: (batch, seqlen, nheads)
        chunk_size: int
    Return:
        h: (batch, nchunks, nheads, dhead, expand_v*dhead)
    """
    batch, seqlen, nheads, dhead = k.shape
    expand_v = u.shape[-1] // dhead
    nchunks = (seqlen + chunk_size - 1) // chunk_size
    acc_dtype = torch.float32
    dtype = k.dtype
    h = torch.empty(batch, nchunks, nheads, dhead, expand_v*dhead, dtype=k.dtype, device=k.device)
    b_h = torch.zeros(batch, nheads, dhead, expand_v*dhead, dtype=acc_dtype, device=k.device)
    k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead)
    w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead)
    u_c = u.reshape(batch, nchunks, chunk_size, nheads, expand_v*dhead)
    g_c = g.reshape(batch, nchunks, chunk_size, nheads)
    for i_t in range(nchunks):
        h[:, i_t, :, :, :] = b_h.to(dtype)
        b_w = w_c[:, i_t, :, :, :].to(acc_dtype)
        c_h = b_h.to(dtype).to(acc_dtype)
        b_v = torch.einsum("bchk,bhkv->bchv", b_w, c_h)
        p_v = u_c[:, i_t, :, :, :].to(acc_dtype)
        b_v = p_v - b_v
        last_idx = min((i_t + 1) * chunk_size, seqlen) - 1
        m_t = (i_t * chunk_size + torch.arange(0, chunk_size, device=k.device)) < seqlen
        b_g_last = g[:, last_idx, :].to(acc_dtype)
        b_g = g_c[:, i_t, :, :].to(acc_dtype) # batch, chunk, nheads
        print("b_g", b_g)
        b_v *= torch.where(m_t.unsqueeze(0).unsqueeze(-1), torch.exp(b_g_last.unsqueeze(1) - b_g), 0).unsqueeze(-1)
        b_g_last = torch.exp(b_g_last)
        b_h *= b_g_last.unsqueeze(-1).unsqueeze(-1)
        b_v = b_v.to(dtype).to(acc_dtype)
        p_k = k_c[:, i_t, :, :, :].to(acc_dtype)
        b_h += torch.einsum("bchk,bchv->bhkv", p_k, b_v)
    return h

# %%
# Testing Function
# -------------
def test(
    batch: int,
    nheads: int,
    seqlen: int,
    chunk_size: int,
    dhead: int,
    dstate: int,
    dtype: torch.dtype = torch.float16,
) -> None:
    k = torch.randn(batch, seqlen, nheads, dhead, dtype=torch.bfloat16, device="cuda")
    w = torch.randn(batch, seqlen // chunk_size, chunk_size, nheads, dhead, dtype=torch.float32, device='cuda')
    w = torch.nn.functional.rms_norm(w, (dhead,))
    # wu, ws, wv = torch.linalg.svd(w.permute(0, 1, 3, 2, 4), full_matrices=False)
    # w = torch.einsum("bnhik,bnhkj->bnhij", wu, wv)
    w = w.permute(0, 1, 3, 2, 4).reshape(batch, seqlen, nheads, dhead).to(torch.bfloat16)
    u = torch.randn(batch, seqlen, nheads, dstate, dtype=torch.bfloat16, device="cuda")
    u = torch.nn.functional.rms_norm(u, (dstate,))
    g = torch.cumsum(0.5*math.log(1/dhead)*torch.rand(batch, seqlen, nheads, dtype=torch.float32, device='cuda'), dim=1)
    args = (k, w, u, g, chunk_size)
    run_example(helion_gdn_fwd_h, ref_gdn_fwd_h, args)

# %%
# Main Function
# -----------
def main() -> None:
    """
    Main entry point that runs the attention kernel test with specific parameters.
    """
    test(8, 80, 4096, 256, 64, 128)

if __name__ == "__main__":
    main()

Branch: fix_1067

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions