-
Notifications
You must be signed in to change notification settings - Fork 64
Closed
Labels
Description
Reported by @v0i0 :
Repro:
"""
Gated Delta Net Fwd H Kernel
============================
This code implements a fwd_h kernel as used in gated delta net
"""
# %%
# Imports
# -------
from __future__ import annotations
import functools
import math
import torch
import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl
# %%
# Helion Kernel Implementation
# ----------------------------
@helion.kernel()
def helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c):
"""
Argument:
k_c: (batch, nchunks, chunk_size, nheads, dhead)
w_c: (batch, nchunks, chunk_size, nheads, dhead)
u_c: (batch, nchunks, chunk_size, nheads, expand_v*dhead)
g_c: (batch, nchunks, chunk_size, nheads)
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""
batch, nchunks, chunk_size, nheads, dhead = k_c.shape
dhead = hl.specialize(dhead)
chunk_size = hl.specialize(chunk_size)
dstate = u_c.shape[-1]
acc_dtype = torch.float32
dtype = k_c.dtype
h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k_c.device)
block_v = hl.register_block_size(dstate)
seqlen = chunk_size * nchunks
for tile_b, tile_h, tile_v in hl.tile([batch, nheads, dstate], block_size=[1, 1, block_v]):
b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
for i_t in hl.grid(nchunks):
h[tile_b.begin, i_t, tile_h.begin, :, tile_v] = b_h.to(dtype)
print("b_h at the start", b_h)
b_w = w_c[tile_b.begin, i_t, :, tile_h.begin, :]
c_h = b_h.to(dtype)
b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype)
print("b_v at the start", b_v)
p_v = u_c[tile_b.begin, i_t, :, tile_h.begin, tile_v].to(acc_dtype)
b_v = p_v - b_v
print("b_v after subtraction", b_v)
last_idx = min((i_t + 1) * chunk_size, seqlen) - 1
m_t = (i_t * chunk_size + hl.arange(0, chunk_size)) < seqlen
b_g_last = g_c[tile_b.begin, last_idx // chunk_size, last_idx % chunk_size, tile_h.begin].to(acc_dtype)
print("b_g_last", b_g_last)
b_g = g_c[tile_b.begin, i_t, :, tile_h.begin].to(acc_dtype)
print("b_g", b_g)
b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
print("b_v after multiplication", b_v)
b_g_last = torch.exp(b_g_last)
b_h *= b_g_last
print("b_h after multiplication", b_h)
b_v = b_v.to(dtype)
p_k = k_c[tile_b.begin, i_t, :, tile_h.begin, :]
b_h = hl.dot(p_k.T, b_v, acc=b_h)
print("b_h after addition", b_h)
return h
def helion_gdn_fwd_h(k, w, u, g, chunk_size):
"""
Argument:
k: (batch, seqlen, nheads, dhead)
w: (batch, seqlen, nheads, dhead)
u: (batch, seqlen, nheads, expand_v*dhead)
g: (batch, seqlen, nheads)
chunk_size: int
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""
batch, seqlen, nheads, dhead = k.shape
dstate = u.shape[-1]
nchunks = (seqlen + chunk_size - 1) // chunk_size
k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead)
w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead)
u_c = u.reshape(batch, nchunks, chunk_size, nheads, dstate)
g_c = g.reshape(batch, nchunks, chunk_size, nheads)
h = helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c)
print(h)
return h
# %%
# Reference Function
# -------------
def ref_gdn_fwd_h(k, w, u, g, chunk_size):
"""
Argument:
k: (batch, seqlen, nheads, dhead)
w: (batch, seqlen, nheads, dhead)
u: (batch, seqlen, nheads, expand_v*dhead)
g: (batch, seqlen, nheads)
chunk_size: int
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""
batch, seqlen, nheads, dhead = k.shape
expand_v = u.shape[-1] // dhead
nchunks = (seqlen + chunk_size - 1) // chunk_size
acc_dtype = torch.float32
dtype = k.dtype
h = torch.empty(batch, nchunks, nheads, dhead, expand_v*dhead, dtype=k.dtype, device=k.device)
b_h = torch.zeros(batch, nheads, dhead, expand_v*dhead, dtype=acc_dtype, device=k.device)
k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead)
w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead)
u_c = u.reshape(batch, nchunks, chunk_size, nheads, expand_v*dhead)
g_c = g.reshape(batch, nchunks, chunk_size, nheads)
for i_t in range(nchunks):
h[:, i_t, :, :, :] = b_h.to(dtype)
print("b_h at the start", b_h)
b_w = w_c[:, i_t, :, :, :].to(acc_dtype)
c_h = b_h.to(dtype).to(acc_dtype)
b_v = torch.einsum("bchk,bhkv->bchv", b_w, c_h)
print("b_v at the start", b_v)
p_v = u_c[:, i_t, :, :, :].to(acc_dtype)
b_v = p_v - b_v
print("b_v after subtraction", b_v)
last_idx = min((i_t + 1) * chunk_size, seqlen) - 1
m_t = (i_t * chunk_size + torch.arange(0, chunk_size, device=k.device)) < seqlen
b_g_last = g[:, last_idx, :].to(acc_dtype)
print("b_g_last", b_g_last)
b_g = g_c[:, i_t, :, :].to(acc_dtype) # batch, chunk, nheads
print("b_g", b_g)
b_v *= torch.where(m_t.unsqueeze(0).unsqueeze(-1), torch.exp(b_g_last.unsqueeze(1) - b_g), 0).unsqueeze(-1)
print("b_v after multiplication", b_v)
b_g_last = torch.exp(b_g_last)
b_h *= b_g_last.unsqueeze(-1).unsqueeze(-1)
print("b_h after multiplication", b_h)
b_v = b_v.to(dtype).to(acc_dtype)
p_k = k_c[:, i_t, :, :, :].to(acc_dtype)
b_h += torch.einsum("bchk,bchv->bhkv", p_k, b_v)
print("b_h after saddition", b_h)
print(h)
return h
# %%
# Testing Function
# -------------
def test(
batch: int,
nheads: int,
seqlen: int,
chunk_size: int,
dhead: int,
dstate: int,
dtype: torch.dtype = torch.float16,
) -> None:
k = torch.randn(batch, seqlen, nheads, dhead, dtype=torch.bfloat16, device="cuda")
k = torch.ones(batch, seqlen, nheads, dhead, dtype=torch.bfloat16, device="cuda") / dhead
w = torch.randn(batch, seqlen // chunk_size, chunk_size, nheads, dhead, dtype=torch.float32, device='cuda')
wu, ws, wv = torch.linalg.svd(w.permute(0, 1, 3, 2, 4), full_matrices=False)
w = torch.einsum("bnhik,bnhkj->bnhij", wu, wv)
w = w.permute(0, 1, 3, 2, 4).reshape(batch, seqlen, nheads, dhead).to(torch.bfloat16)
u = torch.randn(batch, seqlen, nheads, dstate, dtype=torch.bfloat16, device="cuda")
# u = torch.ones(batch, seqlen, nheads, dstate, dtype=torch.bfloat16, device="cuda") // dstate
g = torch.cumsum(0.5*math.log(1/dhead)*torch.rand(batch, seqlen, nheads, dtype=torch.float32, device='cuda'), dim=1)
# g = torch.zeros(batch, seqlen, nheads, device='cuda', dtype=torch.float32)
args = (k, w, u, g, chunk_size)
print([a.shape for a in args if isinstance(a, torch.Tensor)])
# run_example(helion_gdn_fwd_h, ref_gdn_fwd_h, args)
helion_gdn_fwd_h(*args)
ref_gdn_fwd_h(*args)
# %%
# Main Function
# -----------
def main() -> None:
"""
Main entry point that runs the attention kernel test with specific parameters.
"""
test(1, 1, 4, 2, 2, 2)
#test(8, 80, 4096, 256, 64, 128)
if __name__ == "__main__":
print(open(__file__).read())
print("======================")
main()The generated Triton kernel is:
@triton.jit
def _helion_helion_gdn_fwd_h_kernel(h, w_c, u_c, g_c, k_c, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_4: tl.constexpr):
# src[gdn_fwd_h.py:51]: for tile_b, tile_h, tile_v in hl.tile([batch, nheads, dstate], block_size=[1, 1, block_v]):
num_blocks_0 = 1
num_blocks_1 = 1
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
offset_0 = pid_2 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
indices_5 = tl.arange(0, _RDIM_SIZE_4).to(tl.int32)
# src[gdn_fwd_h.py:52]: b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
b_h = tl.full([2, _BLOCK_SIZE_0], 0.0, tl.float32)
# src[gdn_fwd_h.py:53]: for i_t in hl.grid(nchunks):
# src[gdn_fwd_h.py:54]: h[tile_b.begin, i_t, tile_h.begin, :, tile_v] = b_h.to(dtype)
# src[gdn_fwd_h.py:55]: print("b_h at the start", b_h)
# src[gdn_fwd_h.py:53-77]: ...
for offset_4 in tl.range(0, 2):
b_h_copy = b_h
b_h_copy_0 = b_h_copy
# src[gdn_fwd_h.py:54]: h[tile_b.begin, i_t, tile_h.begin, :, tile_v] = b_h.to(dtype)
v_0 = tl.cast(b_h_copy_0, tl.bfloat16)
tl.store(h + (offset_4 * 4 + indices_5[:, None] * 2 + indices_0[None, :] * 1), v_0, None)
# src[gdn_fwd_h.py:55]: print("b_h at the start", b_h)
tl.device_print('b_h at the start', b_h_copy_0)
# src[gdn_fwd_h.py:56]: b_w = w_c[tile_b.begin, i_t, :, tile_h.begin, :]
b_w = tl.load(w_c + (offset_4 * 4 + indices_5[:, None] * 2 + indices_5[None, :] * 1), None)
# src[gdn_fwd_h.py:57]: c_h = b_h.to(dtype)
v_1 = tl.cast(b_h_copy_0, tl.bfloat16)
# src[gdn_fwd_h.py:58]: b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype)
b_v = tl.dot(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(b_w, tl.bfloat16), tl.zeros_like(tl.cast(b_w, tl.bfloat16))), [0, 2, 1]), [2, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(b_w, tl.bfloat16), tl.zeros_like(tl.cast(b_w, tl.bfloat16))), [0, 2, 1]), [2, 4]))), [0, 2, 1]), [2, 8]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(b_w, tl.bfloat16), tl.zeros_like(tl.cast(b_w, tl.bfloat16))), [0, 2, 1]), [2, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(b_w, tl.bfloat16), tl.zeros_like(tl.cast(b_w, tl.bfloat16))), [0, 2, 1]), [2, 4]))), [0, 2, 1]), [2, 8]))), [0, 2, 1]), [2, 16]), tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(v_1, tl.bfloat16), tl.zeros_like(tl.cast(v_1, tl.bfloat16))), [2, 0, 1]), [4, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(v_1, tl.bfloat16), tl.zeros_like(tl.cast(v_1, tl.bfloat16))), [2, 0, 1]), [4, 2]))), [2, 0, 1]), [8, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(v_1, tl.bfloat16), tl.zeros_like(tl.cast(v_1, tl.bfloat16))), [2, 0, 1]), [4, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(v_1, tl.bfloat16), tl.zeros_like(tl.cast(v_1, tl.bfloat16))), [2, 0, 1]), [4, 2]))), [2, 0, 1]), [8, 2]))), [2, 0, 1]), [16, 2]), input_precision='tf32', out_dtype=tl.float32)
# src[gdn_fwd_h.py:59]: print("b_v at the start", b_v)
tl.device_print('b_v at the start', b_v)
# src[gdn_fwd_h.py:60]: p_v = u_c[tile_b.begin, i_t, :, tile_h.begin, tile_v].to(acc_dtype)
load_1 = tl.load(u_c + (offset_4 * 4 + indices_5[:, None] * 2 + indices_0[None, :] * 1), None)
v_2 = tl.cast(load_1, tl.float32)
# src[gdn_fwd_h.py:61]: b_v = p_v - b_v
v_3 = v_2 - b_v
# src[gdn_fwd_h.py:62]: print("b_v after subtraction", b_v)
tl.device_print('b_v after subtraction', v_3)
# src[gdn_fwd_h.py:64]: m_t = (i_t * chunk_size + hl.arange(0, chunk_size)) < seqlen
mul_1 = 2 * offset_4
iota = tl.arange(0, 2)
v_4 = tl.cast(mul_1, tl.int32)
v_5 = iota + v_4
v_6 = tl.full([], 4, tl.int32)
v_7 = v_5 < v_6
# src[gdn_fwd_h.py:65]: b_g_last = g_c[tile_b.begin, last_idx // chunk_size, last_idx % chunk_size, tile_h.begin].to(acc_dtype)
b_g_last = tl.load(g_c + (1 * 2 + 1 * 1), None) # <----- this line
# src[gdn_fwd_h.py:66]: print("b_g_last", b_g_last)
tl.device_print('b_g_last', b_g_last)
# src[gdn_fwd_h.py:67]: b_g = g_c[tile_b.begin, i_t, :, tile_h.begin].to(acc_dtype)
b_g = tl.load(g_c + (offset_4 * 2 + indices_5 * 1), None)
# src[gdn_fwd_h.py:68]: print("b_g", b_g)
tl.device_print('b_g', b_g)
# src[gdn_fwd_h.py:69]: b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
v_8 = b_g_last[None]
v_9 = v_8 - b_g
v_10 = libdevice.exp(v_9)
v_11 = 0.0
v_12 = v_11[None]
v_13 = tl.where(v_7, v_10, v_12)
subscript = v_13[:, None]
v_14 = v_3 * subscript
# src[gdn_fwd_h.py:70]: print("b_v after multiplication", b_v)
tl.device_print('b_v after multiplication', v_14)
# src[gdn_fwd_h.py:71]: b_g_last = torch.exp(b_g_last)
v_15 = libdevice.exp(b_g_last)
# src[gdn_fwd_h.py:72]: b_h *= b_g_last
v_16 = v_15[None, None]
v_17 = b_h_copy_0 * v_16
# src[gdn_fwd_h.py:73]: print("b_h after multiplication", b_h)
tl.device_print('b_h after multiplication', v_17)
# src[gdn_fwd_h.py:74]: b_v = b_v.to(dtype)
v_18 = tl.cast(v_14, tl.bfloat16)
# src[gdn_fwd_h.py:75]: p_k = k_c[tile_b.begin, i_t, :, tile_h.begin, :]
p_k = tl.load(k_c + (offset_4 * 4 + indices_5[:, None] * 2 + indices_5[None, :] * 1), None)
# src[gdn_fwd_h.py:76]: b_h = hl.dot(p_k.T, b_v, acc=b_h)
permute = tl.permute(p_k, [1, 0])
b_h = tl.dot(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(permute, tl.bfloat16), tl.zeros_like(tl.cast(permute, tl.bfloat16))), [0, 2, 1]), [2, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(permute, tl.bfloat16), tl.zeros_like(tl.cast(permute, tl.bfloat16))), [0, 2, 1]), [2, 4]))), [0, 2, 1]), [2, 8]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(permute, tl.bfloat16), tl.zeros_like(tl.cast(permute, tl.bfloat16))), [0, 2, 1]), [2, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(permute, tl.bfloat16), tl.zeros_like(tl.cast(permute, tl.bfloat16))), [0, 2, 1]), [2, 4]))), [0, 2, 1]), [2, 8]))), [0, 2, 1]), [2, 16]), tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(v_18, tl.bfloat16), tl.zeros_like(tl.cast(v_18, tl.bfloat16))), [2, 0, 1]), [4, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(v_18, tl.bfloat16), tl.zeros_like(tl.cast(v_18, tl.bfloat16))), [2, 0, 1]), [4, 2]))), [2, 0, 1]), [8, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(v_18, tl.bfloat16), tl.zeros_like(tl.cast(v_18, tl.bfloat16))), [2, 0, 1]), [4, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(v_18, tl.bfloat16), tl.zeros_like(tl.cast(v_18, tl.bfloat16))), [2, 0, 1]), [4, 2]))), [2, 0, 1]), [8, 2]))), [2, 0, 1]), [16, 2]), acc=v_17, input_precision='tf32', out_dtype=tl.float32)
# src[gdn_fwd_h.py:77]: print("b_h after addition", b_h)
tl.device_print('b_h after addition', b_h)
def helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c, *, _launcher=_default_launcher):
"""
Argument:
k_c: (batch, nchunks, chunk_size, nheads, dhead)
w_c: (batch, nchunks, chunk_size, nheads, dhead)
u_c: (batch, nchunks, chunk_size, nheads, expand_v*dhead)
g_c: (batch, nchunks, chunk_size, nheads)
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""
# src[gdn_fwd_h.py:39]: batch, nchunks, chunk_size, nheads, dhead = k_c.shape
batch, nchunks, chunk_size, nheads, dhead = k_c.shape
# src[gdn_fwd_h.py:40]: dhead = hl.specialize(dhead)
dhead = 2
# src[gdn_fwd_h.py:42]: dstate = u_c.shape[-1]
dstate = u_c.shape[-1]
# src[gdn_fwd_h.py:44]: acc_dtype = torch.float32
acc_dtype = torch.float32
# src[gdn_fwd_h.py:45]: dtype = k_c.dtype
dtype = k_c.dtype
# src[gdn_fwd_h.py:47]: h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k_c.device)
h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k_c.device)
# src[gdn_fwd_h.py:51]: for tile_b, tile_h, tile_v in hl.tile([batch, nheads, dstate], block_size=[1, 1, block_v]):
_BLOCK_SIZE_0 = 2
_RDIM_SIZE_4 = 2
# src[gdn_fwd_h.py:51]: for tile_b, tile_h, tile_v in hl.tile([batch, nheads, dstate], block_size=[1, 1, block_v]):
# src[gdn_fwd_h.py:52]: b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
# src[gdn_fwd_h.py:53]: for i_t in hl.grid(nchunks):
# src[gdn_fwd_h.py:51-77]: ...
_launcher(_helion_helion_gdn_fwd_h_kernel, (1 * 1 * triton.cdiv(2, _BLOCK_SIZE_0),), h, w_c, u_c, g_c, k_c, _BLOCK_SIZE_0, _RDIM_SIZE_4, num_warps=4, num_stages=1)
# src[gdn_fwd_h.py:78]: return h
return h
def call():
from torch._dynamo.testing import rand_strided
# src[gdn_fwd_h.py:28]: def helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c):
# src[gdn_fwd_h.py:29]: """
# src[gdn_fwd_h.py:30]: Argument:
# src[gdn_fwd_h.py:28-78]: ...
k_c = rand_strided(size=(1, 2, 2, 1, 2), stride=(8, 4, 2, 2, 1), dtype=torch.bfloat16, device='cuda:0')
w_c = rand_strided(size=(1, 2, 2, 1, 2), stride=(8, 4, 2, 2, 1), dtype=torch.bfloat16, device='cuda:0')
u_c = rand_strided(size=(1, 2, 2, 1, 2), stride=(8, 4, 2, 2, 1), dtype=torch.bfloat16, device='cuda:0')
g_c = rand_strided(size=(1, 2, 2, 1), stride=(4, 2, 1, 1), dtype=torch.float32, device='cuda:0')
helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c)
if __name__ == '__main__':
call()Notice how for # <----- this line there is no dependence on offset_4 or any loop variable, which is wrong. Looks like it hates the // and % ops, and manages to simplify them out.
But if we write it this way it will work:
"""
Gated Delta Net Fwd H Kernel
============================
This code implements a fwd_h kernel as used in gated delta net
"""
# %%
# Imports
# -------
from __future__ import annotations
import functools
import math
import torch
import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl
# %%
# Helion Kernel Implementation
# ----------------------------
@helion.kernel()
def helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c):
"""
Argument:
k_c: (batch, nchunks, chunk_size, nheads, dhead)
w_c: (batch, nchunks, chunk_size, nheads, dhead)
u_c: (batch, nchunks, chunk_size, nheads, expand_v*dhead)
g_c: (batch, nchunks, chunk_size, nheads)
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""
batch, nchunks, chunk_size, nheads, dhead = k_c.shape
dhead = hl.specialize(dhead)
chunk_size = hl.specialize(chunk_size)
dstate = u_c.shape[-1]
acc_dtype = torch.float32
dtype = k_c.dtype
h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k_c.device)
block_v = hl.register_block_size(dstate)
seqlen = chunk_size * nchunks
for tile_b, tile_h, tile_v in hl.tile([batch, nheads, dstate], block_size=[1, 1, block_v]):
b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
for i_t in range(nchunks):
h[tile_b.begin, i_t, tile_h.begin, :, tile_v] = b_h.to(dtype)
b_w = w_c[tile_b.begin, i_t, :, tile_h.begin, :]
c_h = b_h.to(dtype)
b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype)
p_v = u_c[tile_b.begin, i_t, :, tile_h.begin, tile_v].to(acc_dtype)
b_v = p_v - b_v
m_t = (i_t * chunk_size + hl.arange(0, chunk_size)) < seqlen
b_g_last = g_c[tile_b.begin, i_t, chunk_size - 1, tile_h.begin].to(acc_dtype)
b_g = g_c[tile_b.begin, i_t, :, tile_h.begin].to(acc_dtype)
b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
b_g_last = torch.exp(b_g_last)
b_h *= b_g_last
b_v = b_v.to(dtype)
p_k = k_c[tile_b.begin, i_t, :, tile_h.begin, :]
b_h = hl.dot(p_k.T, b_v, acc=b_h)
return h
def helion_gdn_fwd_h(k, w, u, g, chunk_size):
"""
Argument:
k: (batch, seqlen, nheads, dhead)
w: (batch, seqlen, nheads, dhead)
u: (batch, seqlen, nheads, expand_v*dhead)
g: (batch, seqlen, nheads)
chunk_size: int
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""
batch, seqlen, nheads, dhead = k.shape
dstate = u.shape[-1]
nchunks = (seqlen + chunk_size - 1) // chunk_size
k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead)
w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead)
u_c = u.reshape(batch, nchunks, chunk_size, nheads, dstate)
g_c = g.reshape(batch, nchunks, chunk_size, nheads)
h = helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c)
return h
# %%
# Reference Function
# -------------
def ref_gdn_fwd_h(k, w, u, g, chunk_size):
"""
Argument:
k: (batch, seqlen, nheads, dhead)
w: (batch, seqlen, nheads, dhead)
u: (batch, seqlen, nheads, expand_v*dhead)
g: (batch, seqlen, nheads)
chunk_size: int
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""
batch, seqlen, nheads, dhead = k.shape
expand_v = u.shape[-1] // dhead
nchunks = (seqlen + chunk_size - 1) // chunk_size
acc_dtype = torch.float32
dtype = k.dtype
h = torch.empty(batch, nchunks, nheads, dhead, expand_v*dhead, dtype=k.dtype, device=k.device)
b_h = torch.zeros(batch, nheads, dhead, expand_v*dhead, dtype=acc_dtype, device=k.device)
k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead)
w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead)
u_c = u.reshape(batch, nchunks, chunk_size, nheads, expand_v*dhead)
g_c = g.reshape(batch, nchunks, chunk_size, nheads)
for i_t in range(nchunks):
h[:, i_t, :, :, :] = b_h.to(dtype)
b_w = w_c[:, i_t, :, :, :].to(acc_dtype)
c_h = b_h.to(dtype).to(acc_dtype)
b_v = torch.einsum("bchk,bhkv->bchv", b_w, c_h)
p_v = u_c[:, i_t, :, :, :].to(acc_dtype)
b_v = p_v - b_v
last_idx = min((i_t + 1) * chunk_size, seqlen) - 1
m_t = (i_t * chunk_size + torch.arange(0, chunk_size, device=k.device)) < seqlen
b_g_last = g[:, last_idx, :].to(acc_dtype)
b_g = g_c[:, i_t, :, :].to(acc_dtype) # batch, chunk, nheads
print("b_g", b_g)
b_v *= torch.where(m_t.unsqueeze(0).unsqueeze(-1), torch.exp(b_g_last.unsqueeze(1) - b_g), 0).unsqueeze(-1)
b_g_last = torch.exp(b_g_last)
b_h *= b_g_last.unsqueeze(-1).unsqueeze(-1)
b_v = b_v.to(dtype).to(acc_dtype)
p_k = k_c[:, i_t, :, :, :].to(acc_dtype)
b_h += torch.einsum("bchk,bchv->bhkv", p_k, b_v)
return h
# %%
# Testing Function
# -------------
def test(
batch: int,
nheads: int,
seqlen: int,
chunk_size: int,
dhead: int,
dstate: int,
dtype: torch.dtype = torch.float16,
) -> None:
k = torch.randn(batch, seqlen, nheads, dhead, dtype=torch.bfloat16, device="cuda")
w = torch.randn(batch, seqlen // chunk_size, chunk_size, nheads, dhead, dtype=torch.float32, device='cuda')
w = torch.nn.functional.rms_norm(w, (dhead,))
# wu, ws, wv = torch.linalg.svd(w.permute(0, 1, 3, 2, 4), full_matrices=False)
# w = torch.einsum("bnhik,bnhkj->bnhij", wu, wv)
w = w.permute(0, 1, 3, 2, 4).reshape(batch, seqlen, nheads, dhead).to(torch.bfloat16)
u = torch.randn(batch, seqlen, nheads, dstate, dtype=torch.bfloat16, device="cuda")
u = torch.nn.functional.rms_norm(u, (dstate,))
g = torch.cumsum(0.5*math.log(1/dhead)*torch.rand(batch, seqlen, nheads, dtype=torch.float32, device='cuda'), dim=1)
args = (k, w, u, g, chunk_size)
run_example(helion_gdn_fwd_h, ref_gdn_fwd_h, args)
# %%
# Main Function
# -----------
def main() -> None:
"""
Main entry point that runs the attention kernel test with specific parameters.
"""
test(8, 80, 4096, 256, 64, 128)
if __name__ == "__main__":
main()Branch: fix_1067