-
Notifications
You must be signed in to change notification settings - Fork 39
Labels
Description
Take the RMS norm kernel as an example for the (2048, 2048) shape. Attached at the end is inductor's generated triton kernel, which is faster than the autotuned Helion kernel. The gap of the performance mostly come from eviction_policy. As a demonstration, here are the latency comparison of the inductor's generated kernel with or without eviction_policy.
(M, H) inductor_rms-latency (default) inductor_rms-latency (no eviction_policy) helion_rms_norm_tritonbench-latency
------------ --------------------------------- --------------------------------------------- ---------------------------------------------
(2048, 2048) 0.013513 (±0.10%) 0.016820 (±0.05%) 0.014535 (±0.06%)
@triton.jit
def triton_red_fused_add_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 2048
r0_numel = 2048
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_1 + 2048*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0)
tmp1 = tmp0 * tmp0
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
tmp4 = _tmp3 + tmp2
_tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3)
tmp3 = tl.sum(_tmp3, 1)[:, None]
tmp5 = 2048.0
tmp6 = (tmp3 / tmp5)
tmp7 = 1e-06
tmp8 = tmp6 + tmp7
tmp9 = libdevice.rsqrt(tmp8)
tl.debug_barrier()
tl.store(in_out_ptr0 + (x0), tmp9, xmask)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp10 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0)
tmp11 = tl.load(in_ptr0 + (r0_1 + 2048*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
tmp12 = tmp11 * tmp9
tmp13 = tmp10 * tmp12
tl.store(out_ptr0 + (r0_1 + 2048*x0), tmp13, r0_mask & xmask)