Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix node provenance tracking #95901

Closed
wants to merge 6 commits into from
Closed

Conversation

Chillee
Copy link
Contributor

@Chillee Chillee commented Mar 2, 2023

Stack from ghstack (oldest at bottom):

Before:

triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14

After:

triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14

For this kernel

@persistent_reduction(
    size_hints=[512, 64],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 512
    rnumel = 49
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0)
    tmp8 = tl.load(in_ptr1 + (x0), xmask)
    tmp22 = tl.load(in_ptr2 + (x0), xmask)
    tmp24 = tl.load(in_ptr3 + (x0), xmask)
    tmp30 = tl.load(in_ptr4 + (x0), xmask)
    tmp2 = tl.where(rmask & xmask, tmp0, 0)
    tmp3 = tl.sum(tmp2, 1)[:, None]
    tmp4 = 49.0
    tmp5 = tmp3 / tmp4
    tmp6 = 0.1
    tmp7 = tmp5 * tmp6
    tmp9 = 0.9
    tmp10 = tmp8 * tmp9
    tmp11 = tmp7 + tmp10
    tmp12 = tmp0 - tmp5
    tmp13 = tmp12 * tmp12
    tmp15 = tl.where(rmask & xmask, tmp13, 0)
    tmp16 = tl.sum(tmp15, 1)[:, None]
    tmp17 = tmp16 / tmp4
    tmp18 = 1e-05
    tmp19 = tmp17 + tmp18
    tmp20 = tl.libdevice.rsqrt(tmp19)
    tmp21 = tmp12 * tmp20
    tmp23 = tmp21 * tmp22
    tmp25 = tmp23 + tmp24
    tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25))
    tmp27 = 1.0208333333333333
    tmp28 = tmp17 * tmp27
    tmp29 = tmp28 * tmp6
    tmp31 = tmp30 * tmp9
    tmp32 = tmp29 + tmp31
    tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask)
    tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask)
    tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask)
    tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask)
    tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask)

Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions.

cc @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 2, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95901

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d26cdcf:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Chillee added a commit that referenced this pull request Mar 2, 2023
ghstack-source-id: f06e100fb139e907b4d4d59c065a92ca5e37c207
Pull Request resolved: #95901
cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
Chillee added a commit that referenced this pull request Mar 2, 2023
ghstack-source-id: 8e77cfaf2d2783020113bce9af45148e50443c37
Pull Request resolved: #95901
@Chillee Chillee added the topic: not user facing topic category label Mar 2, 2023
@albanD albanD removed their request for review March 2, 2023 19:58
cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
Chillee added a commit that referenced this pull request Mar 2, 2023
ghstack-source-id: b7a1ec16ccb3288063ba18087b71b6659635e489
Pull Request resolved: #95901
cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
Chillee added a commit that referenced this pull request Mar 2, 2023
ghstack-source-id: 31a4232eef40b9712b559badb75dbc38377fa9d8
Pull Request resolved: #95901
Before:
```
triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14
```

After:
```
triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14
```

For this kernel
```
persistent_reduction(
    size_hints=[512, 64],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]}
)
triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 512
    rnumel = 49
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0)
    tmp8 = tl.load(in_ptr1 + (x0), xmask)
    tmp22 = tl.load(in_ptr2 + (x0), xmask)
    tmp24 = tl.load(in_ptr3 + (x0), xmask)
    tmp30 = tl.load(in_ptr4 + (x0), xmask)
    tmp2 = tl.where(rmask & xmask, tmp0, 0)
    tmp3 = tl.sum(tmp2, 1)[:, None]
    tmp4 = 49.0
    tmp5 = tmp3 / tmp4
    tmp6 = 0.1
    tmp7 = tmp5 * tmp6
    tmp9 = 0.9
    tmp10 = tmp8 * tmp9
    tmp11 = tmp7 + tmp10
    tmp12 = tmp0 - tmp5
    tmp13 = tmp12 * tmp12
    tmp15 = tl.where(rmask & xmask, tmp13, 0)
    tmp16 = tl.sum(tmp15, 1)[:, None]
    tmp17 = tmp16 / tmp4
    tmp18 = 1e-05
    tmp19 = tmp17 + tmp18
    tmp20 = tl.libdevice.rsqrt(tmp19)
    tmp21 = tmp12 * tmp20
    tmp23 = tmp21 * tmp22
    tmp25 = tmp23 + tmp24
    tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25))
    tmp27 = 1.0208333333333333
    tmp28 = tmp17 * tmp27
    tmp29 = tmp28 * tmp6
    tmp31 = tmp30 * tmp9
    tmp32 = tmp29 + tmp31
    tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask)
    tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask)
    tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask)
    tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask)
    tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask)
```

Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions.



cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
Chillee added a commit that referenced this pull request Mar 3, 2023
ghstack-source-id: 5a6f27f09248ec72e913daedbb5b3680d66883b3
Pull Request resolved: #95901
@Chillee
Copy link
Contributor Author

Chillee commented Mar 3, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 3, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

@Chillee
Copy link
Contributor Author

Chillee commented Mar 4, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

@Chillee
Copy link
Contributor Author

Chillee commented Mar 4, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

@Chillee
Copy link
Contributor Author

Chillee commented Mar 5, 2023

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

Before:
```
triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14
```

After:
```
triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14
```

For this kernel
```
persistent_reduction(
    size_hints=[512, 64],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]}
)
triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 512
    rnumel = 49
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0)
    tmp8 = tl.load(in_ptr1 + (x0), xmask)
    tmp22 = tl.load(in_ptr2 + (x0), xmask)
    tmp24 = tl.load(in_ptr3 + (x0), xmask)
    tmp30 = tl.load(in_ptr4 + (x0), xmask)
    tmp2 = tl.where(rmask & xmask, tmp0, 0)
    tmp3 = tl.sum(tmp2, 1)[:, None]
    tmp4 = 49.0
    tmp5 = tmp3 / tmp4
    tmp6 = 0.1
    tmp7 = tmp5 * tmp6
    tmp9 = 0.9
    tmp10 = tmp8 * tmp9
    tmp11 = tmp7 + tmp10
    tmp12 = tmp0 - tmp5
    tmp13 = tmp12 * tmp12
    tmp15 = tl.where(rmask & xmask, tmp13, 0)
    tmp16 = tl.sum(tmp15, 1)[:, None]
    tmp17 = tmp16 / tmp4
    tmp18 = 1e-05
    tmp19 = tmp17 + tmp18
    tmp20 = tl.libdevice.rsqrt(tmp19)
    tmp21 = tmp12 * tmp20
    tmp23 = tmp21 * tmp22
    tmp25 = tmp23 + tmp24
    tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25))
    tmp27 = 1.0208333333333333
    tmp28 = tmp17 * tmp27
    tmp29 = tmp28 * tmp6
    tmp31 = tmp30 * tmp9
    tmp32 = tmp29 + tmp31
    tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask)
    tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask)
    tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask)
    tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask)
    tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask)
```

Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions.



cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/chillee/191/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/95901)

pytorchmergebot pushed a commit that referenced this pull request Mar 5, 2023
ghstack-source-id: c1face34a8dc02724840fb8202ff4525533e5de9
Pull Request resolved: #95901
@Chillee
Copy link
Contributor Author

Chillee commented Mar 5, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

ydwu4 pushed a commit to ydwu4/pytorch that referenced this pull request Mar 10, 2023
Before:
```
triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14
```

After:
```
triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14
```

For this kernel
```
@persistent_reduction(
    size_hints=[512, 64],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 512
    rnumel = 49
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0)
    tmp8 = tl.load(in_ptr1 + (x0), xmask)
    tmp22 = tl.load(in_ptr2 + (x0), xmask)
    tmp24 = tl.load(in_ptr3 + (x0), xmask)
    tmp30 = tl.load(in_ptr4 + (x0), xmask)
    tmp2 = tl.where(rmask & xmask, tmp0, 0)
    tmp3 = tl.sum(tmp2, 1)[:, None]
    tmp4 = 49.0
    tmp5 = tmp3 / tmp4
    tmp6 = 0.1
    tmp7 = tmp5 * tmp6
    tmp9 = 0.9
    tmp10 = tmp8 * tmp9
    tmp11 = tmp7 + tmp10
    tmp12 = tmp0 - tmp5
    tmp13 = tmp12 * tmp12
    tmp15 = tl.where(rmask & xmask, tmp13, 0)
    tmp16 = tl.sum(tmp15, 1)[:, None]
    tmp17 = tmp16 / tmp4
    tmp18 = 1e-05
    tmp19 = tmp17 + tmp18
    tmp20 = tl.libdevice.rsqrt(tmp19)
    tmp21 = tmp12 * tmp20
    tmp23 = tmp21 * tmp22
    tmp25 = tmp23 + tmp24
    tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25))
    tmp27 = 1.0208333333333333
    tmp28 = tmp17 * tmp27
    tmp29 = tmp28 * tmp6
    tmp31 = tmp30 * tmp9
    tmp32 = tmp29 + tmp31
    tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask)
    tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask)
    tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask)
    tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask)
    tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask)
```

Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions.

Pull Request resolved: pytorch#95901
Approved by: https://github.com/jansel, https://github.com/mlazos
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 12, 2023
Before:
```
triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14
```

After:
```
triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14
```

For this kernel
```
@persistent_reduction(
    size_hints=[512, 64],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 512
    rnumel = 49
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0)
    tmp8 = tl.load(in_ptr1 + (x0), xmask)
    tmp22 = tl.load(in_ptr2 + (x0), xmask)
    tmp24 = tl.load(in_ptr3 + (x0), xmask)
    tmp30 = tl.load(in_ptr4 + (x0), xmask)
    tmp2 = tl.where(rmask & xmask, tmp0, 0)
    tmp3 = tl.sum(tmp2, 1)[:, None]
    tmp4 = 49.0
    tmp5 = tmp3 / tmp4
    tmp6 = 0.1
    tmp7 = tmp5 * tmp6
    tmp9 = 0.9
    tmp10 = tmp8 * tmp9
    tmp11 = tmp7 + tmp10
    tmp12 = tmp0 - tmp5
    tmp13 = tmp12 * tmp12
    tmp15 = tl.where(rmask & xmask, tmp13, 0)
    tmp16 = tl.sum(tmp15, 1)[:, None]
    tmp17 = tmp16 / tmp4
    tmp18 = 1e-05
    tmp19 = tmp17 + tmp18
    tmp20 = tl.libdevice.rsqrt(tmp19)
    tmp21 = tmp12 * tmp20
    tmp23 = tmp21 * tmp22
    tmp25 = tmp23 + tmp24
    tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25))
    tmp27 = 1.0208333333333333
    tmp28 = tmp17 * tmp27
    tmp29 = tmp28 * tmp6
    tmp31 = tmp30 * tmp9
    tmp32 = tmp29 + tmp31
    tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask)
    tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask)
    tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask)
    tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask)
    tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask)
```

Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions.

Pull Request resolved: pytorch/pytorch#95901
Approved by: https://github.com/jansel, https://github.com/mlazos
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 12, 2023
Before:
```
triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14
```

After:
```
triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14
```

For this kernel
```
@persistent_reduction(
    size_hints=[512, 64],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 512
    rnumel = 49
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0)
    tmp8 = tl.load(in_ptr1 + (x0), xmask)
    tmp22 = tl.load(in_ptr2 + (x0), xmask)
    tmp24 = tl.load(in_ptr3 + (x0), xmask)
    tmp30 = tl.load(in_ptr4 + (x0), xmask)
    tmp2 = tl.where(rmask & xmask, tmp0, 0)
    tmp3 = tl.sum(tmp2, 1)[:, None]
    tmp4 = 49.0
    tmp5 = tmp3 / tmp4
    tmp6 = 0.1
    tmp7 = tmp5 * tmp6
    tmp9 = 0.9
    tmp10 = tmp8 * tmp9
    tmp11 = tmp7 + tmp10
    tmp12 = tmp0 - tmp5
    tmp13 = tmp12 * tmp12
    tmp15 = tl.where(rmask & xmask, tmp13, 0)
    tmp16 = tl.sum(tmp15, 1)[:, None]
    tmp17 = tmp16 / tmp4
    tmp18 = 1e-05
    tmp19 = tmp17 + tmp18
    tmp20 = tl.libdevice.rsqrt(tmp19)
    tmp21 = tmp12 * tmp20
    tmp23 = tmp21 * tmp22
    tmp25 = tmp23 + tmp24
    tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25))
    tmp27 = 1.0208333333333333
    tmp28 = tmp17 * tmp27
    tmp29 = tmp28 * tmp6
    tmp31 = tmp30 * tmp9
    tmp32 = tmp29 + tmp31
    tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask)
    tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask)
    tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask)
    tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask)
    tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask)
```

Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions.

Pull Request resolved: pytorch/pytorch#95901
Approved by: https://github.com/jansel, https://github.com/mlazos
ydwu4 added a commit to ydwu4/pytorch that referenced this pull request Mar 13, 2023
Before:
```
triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14
```

After:
```
triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14
```

For this kernel
```
@persistent_reduction(
    size_hints=[512, 64],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 512
    rnumel = 49
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0)
    tmp8 = tl.load(in_ptr1 + (x0), xmask)
    tmp22 = tl.load(in_ptr2 + (x0), xmask)
    tmp24 = tl.load(in_ptr3 + (x0), xmask)
    tmp30 = tl.load(in_ptr4 + (x0), xmask)
    tmp2 = tl.where(rmask & xmask, tmp0, 0)
    tmp3 = tl.sum(tmp2, 1)[:, None]
    tmp4 = 49.0
    tmp5 = tmp3 / tmp4
    tmp6 = 0.1
    tmp7 = tmp5 * tmp6
    tmp9 = 0.9
    tmp10 = tmp8 * tmp9
    tmp11 = tmp7 + tmp10
    tmp12 = tmp0 - tmp5
    tmp13 = tmp12 * tmp12
    tmp15 = tl.where(rmask & xmask, tmp13, 0)
    tmp16 = tl.sum(tmp15, 1)[:, None]
    tmp17 = tmp16 / tmp4
    tmp18 = 1e-05
    tmp19 = tmp17 + tmp18
    tmp20 = tl.libdevice.rsqrt(tmp19)
    tmp21 = tmp12 * tmp20
    tmp23 = tmp21 * tmp22
    tmp25 = tmp23 + tmp24
    tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25))
    tmp27 = 1.0208333333333333
    tmp28 = tmp17 * tmp27
    tmp29 = tmp28 * tmp6
    tmp31 = tmp30 * tmp9
    tmp32 = tmp29 + tmp31
    tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask)
    tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask)
    tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask)
    tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask)
    tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask)
```

Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions.

Pull Request resolved: pytorch#95901
Approved by: https://github.com/jansel, https://github.com/mlazos
@ezyang ezyang added this to the 2.0.1 milestone Mar 20, 2023
@Chillee Chillee removed this from the 2.0.1 milestone Apr 24, 2023
@Chillee
Copy link
Contributor Author

Chillee commented Apr 24, 2023

Removing from milestone since it hasn't been cherry picked and is not critical for 2.0.1

return is_unrealized_node(n.data)
if isinstance(n, ir.StorageBox):
return is_unrealized_node(n.data)
return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we only test Pointwise here; shouldn't Reduction also count?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, but Reductions are always materialized during lowering phase.

@facebook-github-bot facebook-github-bot deleted the gh/chillee/191/head branch June 8, 2023 15:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants