-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix node provenance tracking #95901
Fix node provenance tracking #95901
Commits on Mar 2, 2023
-
Configuration menu - View commit details
-
Copy full SHA for 0095ad2 - Browse repository at this point
Copy the full SHA 0095ad2View commit details -
Update on "Fix node provenance tracking"
cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Configuration menu - View commit details
-
Copy full SHA for 24dfafe - Browse repository at this point
Copy the full SHA 24dfafeView commit details -
Update on "Fix node provenance tracking"
cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Configuration menu - View commit details
-
Copy full SHA for b9a7407 - Browse repository at this point
Copy the full SHA b9a7407View commit details -
Update on "Fix node provenance tracking"
cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Configuration menu - View commit details
-
Copy full SHA for db335ad - Browse repository at this point
Copy the full SHA db335adView commit details
Commits on Mar 3, 2023
-
Update on "Fix node provenance tracking"
Before: ``` triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14 ``` After: ``` triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14 ``` For this kernel ``` persistent_reduction( size_hints=[512, 64], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]} ) triton.jit def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 512 rnumel = 49 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0) tmp8 = tl.load(in_ptr1 + (x0), xmask) tmp22 = tl.load(in_ptr2 + (x0), xmask) tmp24 = tl.load(in_ptr3 + (x0), xmask) tmp30 = tl.load(in_ptr4 + (x0), xmask) tmp2 = tl.where(rmask & xmask, tmp0, 0) tmp3 = tl.sum(tmp2, 1)[:, None] tmp4 = 49.0 tmp5 = tmp3 / tmp4 tmp6 = 0.1 tmp7 = tmp5 * tmp6 tmp9 = 0.9 tmp10 = tmp8 * tmp9 tmp11 = tmp7 + tmp10 tmp12 = tmp0 - tmp5 tmp13 = tmp12 * tmp12 tmp15 = tl.where(rmask & xmask, tmp13, 0) tmp16 = tl.sum(tmp15, 1)[:, None] tmp17 = tmp16 / tmp4 tmp18 = 1e-05 tmp19 = tmp17 + tmp18 tmp20 = tl.libdevice.rsqrt(tmp19) tmp21 = tmp12 * tmp20 tmp23 = tmp21 * tmp22 tmp25 = tmp23 + tmp24 tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25)) tmp27 = 1.0208333333333333 tmp28 = tmp17 * tmp27 tmp29 = tmp28 * tmp6 tmp31 = tmp30 * tmp9 tmp32 = tmp29 + tmp31 tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask) tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask) tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask) tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask) tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask) ``` Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions. cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Configuration menu - View commit details
-
Copy full SHA for 6f3c63d - Browse repository at this point
Copy the full SHA 6f3c63dView commit details
Commits on Mar 5, 2023
-
Update on "Fix node provenance tracking"
Before: ``` triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14 ``` After: ``` triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14 ``` For this kernel ``` persistent_reduction( size_hints=[512, 64], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]} ) triton.jit def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 512 rnumel = 49 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0) tmp8 = tl.load(in_ptr1 + (x0), xmask) tmp22 = tl.load(in_ptr2 + (x0), xmask) tmp24 = tl.load(in_ptr3 + (x0), xmask) tmp30 = tl.load(in_ptr4 + (x0), xmask) tmp2 = tl.where(rmask & xmask, tmp0, 0) tmp3 = tl.sum(tmp2, 1)[:, None] tmp4 = 49.0 tmp5 = tmp3 / tmp4 tmp6 = 0.1 tmp7 = tmp5 * tmp6 tmp9 = 0.9 tmp10 = tmp8 * tmp9 tmp11 = tmp7 + tmp10 tmp12 = tmp0 - tmp5 tmp13 = tmp12 * tmp12 tmp15 = tl.where(rmask & xmask, tmp13, 0) tmp16 = tl.sum(tmp15, 1)[:, None] tmp17 = tmp16 / tmp4 tmp18 = 1e-05 tmp19 = tmp17 + tmp18 tmp20 = tl.libdevice.rsqrt(tmp19) tmp21 = tmp12 * tmp20 tmp23 = tmp21 * tmp22 tmp25 = tmp23 + tmp24 tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25)) tmp27 = 1.0208333333333333 tmp28 = tmp17 * tmp27 tmp29 = tmp28 * tmp6 tmp31 = tmp30 * tmp9 tmp32 = tmp29 + tmp31 tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask) tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask) tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask) tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask) tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask) ``` Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions. cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Configuration menu - View commit details
-
Copy full SHA for d26cdcf - Browse repository at this point
Copy the full SHA d26cdcfView commit details