New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix node provenance tracking #95901
Fix node provenance tracking #95901
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95901
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d26cdcf: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: f06e100fb139e907b4d4d59c065a92ca5e37c207 Pull Request resolved: #95901
cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 8e77cfaf2d2783020113bce9af45148e50443c37 Pull Request resolved: #95901
cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: b7a1ec16ccb3288063ba18087b71b6659635e489 Pull Request resolved: #95901
cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 31a4232eef40b9712b559badb75dbc38377fa9d8 Pull Request resolved: #95901
Before: ``` triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14 ``` After: ``` triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14 ``` For this kernel ``` persistent_reduction( size_hints=[512, 64], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]} ) triton.jit def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 512 rnumel = 49 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0) tmp8 = tl.load(in_ptr1 + (x0), xmask) tmp22 = tl.load(in_ptr2 + (x0), xmask) tmp24 = tl.load(in_ptr3 + (x0), xmask) tmp30 = tl.load(in_ptr4 + (x0), xmask) tmp2 = tl.where(rmask & xmask, tmp0, 0) tmp3 = tl.sum(tmp2, 1)[:, None] tmp4 = 49.0 tmp5 = tmp3 / tmp4 tmp6 = 0.1 tmp7 = tmp5 * tmp6 tmp9 = 0.9 tmp10 = tmp8 * tmp9 tmp11 = tmp7 + tmp10 tmp12 = tmp0 - tmp5 tmp13 = tmp12 * tmp12 tmp15 = tl.where(rmask & xmask, tmp13, 0) tmp16 = tl.sum(tmp15, 1)[:, None] tmp17 = tmp16 / tmp4 tmp18 = 1e-05 tmp19 = tmp17 + tmp18 tmp20 = tl.libdevice.rsqrt(tmp19) tmp21 = tmp12 * tmp20 tmp23 = tmp21 * tmp22 tmp25 = tmp23 + tmp24 tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25)) tmp27 = 1.0208333333333333 tmp28 = tmp17 * tmp27 tmp29 = tmp28 * tmp6 tmp31 = tmp30 * tmp9 tmp32 = tmp29 + tmp31 tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask) tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask) tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask) tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask) tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask) ``` Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions. cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 5a6f27f09248ec72e913daedbb5b3680d66883b3 Pull Request resolved: #95901
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot. |
@pytorchbot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
Before: ``` triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14 ``` After: ``` triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14 ``` For this kernel ``` persistent_reduction( size_hints=[512, 64], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]} ) triton.jit def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 512 rnumel = 49 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0) tmp8 = tl.load(in_ptr1 + (x0), xmask) tmp22 = tl.load(in_ptr2 + (x0), xmask) tmp24 = tl.load(in_ptr3 + (x0), xmask) tmp30 = tl.load(in_ptr4 + (x0), xmask) tmp2 = tl.where(rmask & xmask, tmp0, 0) tmp3 = tl.sum(tmp2, 1)[:, None] tmp4 = 49.0 tmp5 = tmp3 / tmp4 tmp6 = 0.1 tmp7 = tmp5 * tmp6 tmp9 = 0.9 tmp10 = tmp8 * tmp9 tmp11 = tmp7 + tmp10 tmp12 = tmp0 - tmp5 tmp13 = tmp12 * tmp12 tmp15 = tl.where(rmask & xmask, tmp13, 0) tmp16 = tl.sum(tmp15, 1)[:, None] tmp17 = tmp16 / tmp4 tmp18 = 1e-05 tmp19 = tmp17 + tmp18 tmp20 = tl.libdevice.rsqrt(tmp19) tmp21 = tmp12 * tmp20 tmp23 = tmp21 * tmp22 tmp25 = tmp23 + tmp24 tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25)) tmp27 = 1.0208333333333333 tmp28 = tmp17 * tmp27 tmp29 = tmp28 * tmp6 tmp31 = tmp30 * tmp9 tmp32 = tmp29 + tmp31 tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask) tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask) tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask) tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask) tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask) ``` Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions. cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Successfully rebased |
ghstack-source-id: c1face34a8dc02724840fb8202ff4525533e5de9 Pull Request resolved: #95901
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Before: ``` triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14 ``` After: ``` triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14 ``` For this kernel ``` @persistent_reduction( size_hints=[512, 64], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]} ) @triton.jit def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 512 rnumel = 49 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0) tmp8 = tl.load(in_ptr1 + (x0), xmask) tmp22 = tl.load(in_ptr2 + (x0), xmask) tmp24 = tl.load(in_ptr3 + (x0), xmask) tmp30 = tl.load(in_ptr4 + (x0), xmask) tmp2 = tl.where(rmask & xmask, tmp0, 0) tmp3 = tl.sum(tmp2, 1)[:, None] tmp4 = 49.0 tmp5 = tmp3 / tmp4 tmp6 = 0.1 tmp7 = tmp5 * tmp6 tmp9 = 0.9 tmp10 = tmp8 * tmp9 tmp11 = tmp7 + tmp10 tmp12 = tmp0 - tmp5 tmp13 = tmp12 * tmp12 tmp15 = tl.where(rmask & xmask, tmp13, 0) tmp16 = tl.sum(tmp15, 1)[:, None] tmp17 = tmp16 / tmp4 tmp18 = 1e-05 tmp19 = tmp17 + tmp18 tmp20 = tl.libdevice.rsqrt(tmp19) tmp21 = tmp12 * tmp20 tmp23 = tmp21 * tmp22 tmp25 = tmp23 + tmp24 tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25)) tmp27 = 1.0208333333333333 tmp28 = tmp17 * tmp27 tmp29 = tmp28 * tmp6 tmp31 = tmp30 * tmp9 tmp32 = tmp29 + tmp31 tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask) tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask) tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask) tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask) tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask) ``` Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions. Pull Request resolved: pytorch#95901 Approved by: https://github.com/jansel, https://github.com/mlazos
Before: ``` triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14 ``` After: ``` triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14 ``` For this kernel ``` @persistent_reduction( size_hints=[512, 64], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]} ) @triton.jit def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 512 rnumel = 49 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0) tmp8 = tl.load(in_ptr1 + (x0), xmask) tmp22 = tl.load(in_ptr2 + (x0), xmask) tmp24 = tl.load(in_ptr3 + (x0), xmask) tmp30 = tl.load(in_ptr4 + (x0), xmask) tmp2 = tl.where(rmask & xmask, tmp0, 0) tmp3 = tl.sum(tmp2, 1)[:, None] tmp4 = 49.0 tmp5 = tmp3 / tmp4 tmp6 = 0.1 tmp7 = tmp5 * tmp6 tmp9 = 0.9 tmp10 = tmp8 * tmp9 tmp11 = tmp7 + tmp10 tmp12 = tmp0 - tmp5 tmp13 = tmp12 * tmp12 tmp15 = tl.where(rmask & xmask, tmp13, 0) tmp16 = tl.sum(tmp15, 1)[:, None] tmp17 = tmp16 / tmp4 tmp18 = 1e-05 tmp19 = tmp17 + tmp18 tmp20 = tl.libdevice.rsqrt(tmp19) tmp21 = tmp12 * tmp20 tmp23 = tmp21 * tmp22 tmp25 = tmp23 + tmp24 tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25)) tmp27 = 1.0208333333333333 tmp28 = tmp17 * tmp27 tmp29 = tmp28 * tmp6 tmp31 = tmp30 * tmp9 tmp32 = tmp29 + tmp31 tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask) tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask) tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask) tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask) tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask) ``` Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions. Pull Request resolved: pytorch/pytorch#95901 Approved by: https://github.com/jansel, https://github.com/mlazos
Before: ``` triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14 ``` After: ``` triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14 ``` For this kernel ``` @persistent_reduction( size_hints=[512, 64], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]} ) @triton.jit def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 512 rnumel = 49 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0) tmp8 = tl.load(in_ptr1 + (x0), xmask) tmp22 = tl.load(in_ptr2 + (x0), xmask) tmp24 = tl.load(in_ptr3 + (x0), xmask) tmp30 = tl.load(in_ptr4 + (x0), xmask) tmp2 = tl.where(rmask & xmask, tmp0, 0) tmp3 = tl.sum(tmp2, 1)[:, None] tmp4 = 49.0 tmp5 = tmp3 / tmp4 tmp6 = 0.1 tmp7 = tmp5 * tmp6 tmp9 = 0.9 tmp10 = tmp8 * tmp9 tmp11 = tmp7 + tmp10 tmp12 = tmp0 - tmp5 tmp13 = tmp12 * tmp12 tmp15 = tl.where(rmask & xmask, tmp13, 0) tmp16 = tl.sum(tmp15, 1)[:, None] tmp17 = tmp16 / tmp4 tmp18 = 1e-05 tmp19 = tmp17 + tmp18 tmp20 = tl.libdevice.rsqrt(tmp19) tmp21 = tmp12 * tmp20 tmp23 = tmp21 * tmp22 tmp25 = tmp23 + tmp24 tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25)) tmp27 = 1.0208333333333333 tmp28 = tmp17 * tmp27 tmp29 = tmp28 * tmp6 tmp31 = tmp30 * tmp9 tmp32 = tmp29 + tmp31 tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask) tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask) tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask) tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask) tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask) ``` Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions. Pull Request resolved: pytorch/pytorch#95901 Approved by: https://github.com/jansel, https://github.com/mlazos
Before: ``` triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14 ``` After: ``` triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14 ``` For this kernel ``` @persistent_reduction( size_hints=[512, 64], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]} ) @triton.jit def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 512 rnumel = 49 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0) tmp8 = tl.load(in_ptr1 + (x0), xmask) tmp22 = tl.load(in_ptr2 + (x0), xmask) tmp24 = tl.load(in_ptr3 + (x0), xmask) tmp30 = tl.load(in_ptr4 + (x0), xmask) tmp2 = tl.where(rmask & xmask, tmp0, 0) tmp3 = tl.sum(tmp2, 1)[:, None] tmp4 = 49.0 tmp5 = tmp3 / tmp4 tmp6 = 0.1 tmp7 = tmp5 * tmp6 tmp9 = 0.9 tmp10 = tmp8 * tmp9 tmp11 = tmp7 + tmp10 tmp12 = tmp0 - tmp5 tmp13 = tmp12 * tmp12 tmp15 = tl.where(rmask & xmask, tmp13, 0) tmp16 = tl.sum(tmp15, 1)[:, None] tmp17 = tmp16 / tmp4 tmp18 = 1e-05 tmp19 = tmp17 + tmp18 tmp20 = tl.libdevice.rsqrt(tmp19) tmp21 = tmp12 * tmp20 tmp23 = tmp21 * tmp22 tmp25 = tmp23 + tmp24 tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25)) tmp27 = 1.0208333333333333 tmp28 = tmp17 * tmp27 tmp29 = tmp28 * tmp6 tmp31 = tmp30 * tmp9 tmp32 = tmp29 + tmp31 tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask) tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask) tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask) tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask) tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask) ``` Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions. Pull Request resolved: pytorch#95901 Approved by: https://github.com/jansel, https://github.com/mlazos
Removing from milestone since it hasn't been cherry picked and is not critical for 2.0.1 |
return is_unrealized_node(n.data) | ||
if isinstance(n, ir.StorageBox): | ||
return is_unrealized_node(n.data) | ||
return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we only test Pointwise here; shouldn't Reduction also count?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically, but Reductions are always materialized during lowering phase.
Stack from ghstack (oldest at bottom):
Before:
After:
For this kernel
Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions.
cc @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire