New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] Fix argmin/max with duplicate values #99920
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99920
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New FailureAs of commit 54312fc: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Fixes #99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. [ghstack-poisoned]
Fixes #99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Fixes pytorch#99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. ghstack-source-id: f80e18419a8e783e9a778f150aa911d30b0fe50c Pull Request resolved: pytorch#99920
Fixes pytorch#99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. ghstack-source-id: f80e18419a8e783e9a778f150aa911d30b0fe50c Pull Request resolved: pytorch#99920
Fixes #99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Fixes pytorch#99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. ghstack-source-id: 73dc04c8f3a249b046f440f4f3df7c06d83f183e Pull Request resolved: pytorch#99920
Fixes pytorch#99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. ghstack-source-id: 73dc04c8f3a249b046f440f4f3df7c06d83f183e Pull Request resolved: pytorch#99920
Fixes #99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Fixes pytorch#99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. ghstack-source-id: 73dc04c8f3a249b046f440f4f3df7c06d83f183e Pull Request resolved: pytorch#99920
Fixes pytorch#99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. ghstack-source-id: 73dc04c8f3a249b046f440f4f3df7c06d83f183e Pull Request resolved: pytorch#99920
Fixes pytorch#99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. ghstack-source-id: 73dc04c8f3a249b046f440f4f3df7c06d83f183e Pull Request resolved: pytorch#99920
Fixes #99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Fixes #99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! I left minor comments, can you please post generated code in the PR description?
atol=1e-5, | ||
rtol=0.5, | ||
) | ||
self.common(fn, (torch.randn([144, 144]),)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great that this test is fixed!
test/inductor/test_torchinductor.py
Outdated
if self.device == "cpu": | ||
raise unittest.SkipTest("broken on CPU") | ||
|
||
t1 = torch.randn((10, 10)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of (or in addition to) 10, 10
I'd also have (8,8), because reductions of size 8 or smaller are unrolled iirc
else: | ||
result_var = self.cse.generate( | ||
self.compute, final_reduction(masked_value) | ||
) | ||
elif (src_dtype, reduction_type, value) not in self.cse.reduction_cache: | ||
self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var | ||
accumulator = f"_{result_var}" | ||
default_value = f" + {default}" if default != 0 else "" | ||
self.body.writeline( | ||
f"{accumulator} = tl.zeros({self.dense_size_str()}, {triton_compute_type(src_dtype)}){default_value}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you replace this with tl.full
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Turns out that doing this uncovers some bugs that I'd rather not dig into in this PR. Essentially some boolean ops
actually return int32
, but this was masked by tl.zeros() + x
promoting the accumulator to integer anyway.
Fixes pytorch#99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. ghstack-source-id: 47407b98133251392fbe329fa4b741eebc4000e2 Pull Request resolved: pytorch#99920
Fixes pytorch#99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. ghstack-source-id: 47407b98133251392fbe329fa4b741eebc4000e2 Pull Request resolved: pytorch#99920
…te values" Fixes #99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Fixes pytorch#99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. ghstack-source-id: dabfb041144543cdde5d5aa99806d2205858796f Pull Request resolved: pytorch#99920
Fixes #99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. As an example, this is the kernel generated for `torch.argmin(x, 1)`: ```python def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 1028 # dynamic_shapes=False rnumel = 1028 # dynamic_shapes=False xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32) _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64) for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0) _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index( _tmp1, _tmp1_index, tmp0, rindex ) _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1) _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index) _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1) tmp1 = tmp1_tmp[:, None] tl.store(out_ptr0 + x0, tmp1, xmask) ``` Or for a persistent reduction, it generates: ```python tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0) tmp2 = tl.where(rmask & xmask, tmp0, float("inf")) tmp3 = tl.broadcast_to(rindex, tmp2.shape) _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1) tmp4 = tmp4_tmp[:, None] ``` cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 jobs have failed, first few of them are: trunk / linux-bionic-cuda11.8-py3.10-gcc7 / test (nogpu_AVX512, 1, 1, linux.2xlarge), trunk / linux-bionic-cuda11.8-py3.10-gcc7 / test (nogpu_NO_AVX2, 1, 1, linux.2xlarge) Details for Dev Infra teamRaised by workflow job |
Fixes pytorch#99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. ghstack-source-id: 03dce80229d6b0180f307a3835f110f93de14d30 Pull Request resolved: pytorch#99920
Fixes #99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. As an example, this is the kernel generated for `torch.argmin(x, 1)`: ```python def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 1028 # dynamic_shapes=False rnumel = 1028 # dynamic_shapes=False xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32) _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64) for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0) _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index( _tmp1, _tmp1_index, tmp0, rindex ) _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1) _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index) _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1) tmp1 = tmp1_tmp[:, None] tl.store(out_ptr0 + x0, tmp1, xmask) ``` Or for a persistent reduction, it generates: ```python tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0) tmp2 = tl.where(rmask & xmask, tmp0, float("inf")) tmp3 = tl.broadcast_to(rindex, tmp2.shape) _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1) tmp4 = tmp4_tmp[:, None] ``` cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Fixes #99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. As an example, this is the kernel generated for `torch.argmin(x, 1)`: ```python def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 1028 # dynamic_shapes=False rnumel = 1028 # dynamic_shapes=False xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32) _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64) for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0) _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index( _tmp1, _tmp1_index, tmp0, rindex ) _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1) _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index) _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1) tmp1 = tmp1_tmp[:, None] tl.store(out_ptr0 + x0, tmp1, xmask) ``` Or for a persistent reduction, it generates: ```python tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0) tmp2 = tl.where(rmask & xmask, tmp0, float("inf")) tmp3 = tl.broadcast_to(rindex, tmp2.shape) _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1) tmp4 = tmp4_tmp[:, None] ``` cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes pytorch#99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. ghstack-source-id: eece9e1a560a36913f7622313a1e5fa315ecfb86 Pull Request resolved: pytorch#99920
Hey @peterbell10, sorry about this, but we need to revert some of the tl.reduce patches; internally we are still on the older version of Triton without tl.reduce, and rolling the Triton version forward is blocked on the new version of Triton regressing some numeric tests which need investigating. It would be nice to still have this code in; how difficult do you think it would be to feature flag these improvements on Triton version? I read over the PRs and it wasn't immediately obvious that it would be easy, unfortunately. |
Revert "[inductor] Stop using `x + tl.zeros(...)` in generated triton (#100163)" This reverts commit 5b98910. Revert "[inductor] Fix argmin/max with duplicate values (#99920)" This reverts commit 659dcc5. Revert "[inductor] Fix nan-handling of max and min reductions (#99881)" This reverts commit f9c3fcd. [ghstack-poisoned]
Revert "[inductor] Stop using `x + tl.zeros(...)` in generated triton (#100163)" This reverts commit 5b98910. Revert "[inductor] Fix argmin/max with duplicate values (#99920)" This reverts commit 659dcc5. Revert "[inductor] Fix nan-handling of max and min reductions (#99881)" This reverts commit f9c3fcd. ghstack-source-id: 85531baedfb245e48512be97c0ed90eba1685664 Pull Request resolved: #100517
Fixes pytorch#99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. ghstack-source-id: eece9e1a560a36913f7622313a1e5fa315ecfb86 Pull Request resolved: pytorch#99920
Fixes pytorch#99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. ghstack-source-id: eece9e1a560a36913f7622313a1e5fa315ecfb86 Pull Request resolved: pytorch#99920
Fixes #99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. As an example, this is the kernel generated for `torch.argmin(x, 1)`: ```python def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 1028 # dynamic_shapes=False rnumel = 1028 # dynamic_shapes=False xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32) _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64) for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0) _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index( _tmp1, _tmp1_index, tmp0, rindex ) _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1) _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index) _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1) tmp1 = tmp1_tmp[:, None] tl.store(out_ptr0 + x0, tmp1, xmask) ``` Or for a persistent reduction, it generates: ```python tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0) tmp2 = tl.where(rmask & xmask, tmp0, float("inf")) tmp3 = tl.broadcast_to(rindex, tmp2.shape) _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1) tmp4 = tmp4_tmp[:, None] ``` cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Fixes #99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. As an example, this is the kernel generated for `torch.argmin(x, 1)`: ```python def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 1028 # dynamic_shapes=False rnumel = 1028 # dynamic_shapes=False xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32) _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64) for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0) _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index( _tmp1, _tmp1_index, tmp0, rindex ) _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1) _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index) _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1) tmp1 = tmp1_tmp[:, None] tl.store(out_ptr0 + x0, tmp1, xmask) ``` Or for a persistent reduction, it generates: ```python tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0) tmp2 = tl.where(rmask & xmask, tmp0, float("inf")) tmp3 = tl.broadcast_to(rindex, tmp2.shape) _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1) tmp4 = tmp4_tmp[:, None] ``` cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
x + tl.zeros(...)
in generated triton #100163Fixes #99879
This adds
minimum_with_index
helper functions to compute the minimumvalue and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.
I also remove the mask-and-sum hack with a
tl.reduce
usingthe previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.
As an example, this is the kernel generated for
torch.argmin(x, 1)
:Or for a persistent reduction, it generates:
cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire