Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix node provenance tracking (#95901)
Before: ``` triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14 ``` After: ``` triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14 ``` For this kernel ``` @persistent_reduction( size_hints=[512, 64], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]} ) @triton.jit def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 512 rnumel = 49 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0) tmp8 = tl.load(in_ptr1 + (x0), xmask) tmp22 = tl.load(in_ptr2 + (x0), xmask) tmp24 = tl.load(in_ptr3 + (x0), xmask) tmp30 = tl.load(in_ptr4 + (x0), xmask) tmp2 = tl.where(rmask & xmask, tmp0, 0) tmp3 = tl.sum(tmp2, 1)[:, None] tmp4 = 49.0 tmp5 = tmp3 / tmp4 tmp6 = 0.1 tmp7 = tmp5 * tmp6 tmp9 = 0.9 tmp10 = tmp8 * tmp9 tmp11 = tmp7 + tmp10 tmp12 = tmp0 - tmp5 tmp13 = tmp12 * tmp12 tmp15 = tl.where(rmask & xmask, tmp13, 0) tmp16 = tl.sum(tmp15, 1)[:, None] tmp17 = tmp16 / tmp4 tmp18 = 1e-05 tmp19 = tmp17 + tmp18 tmp20 = tl.libdevice.rsqrt(tmp19) tmp21 = tmp12 * tmp20 tmp23 = tmp21 * tmp22 tmp25 = tmp23 + tmp24 tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25)) tmp27 = 1.0208333333333333 tmp28 = tmp17 * tmp27 tmp29 = tmp28 * tmp6 tmp31 = tmp30 * tmp9 tmp32 = tmp29 + tmp31 tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask) tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask) tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask) tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask) tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask) ``` Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions. Pull Request resolved: #95901 Approved by: https://github.com/jansel, https://github.com/mlazos
- Loading branch information