Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Fix argmin/max with duplicate values #99920

Closed
wants to merge 15 commits into from

Conversation

peterbell10
Copy link
Collaborator

@peterbell10 peterbell10 commented Apr 24, 2023

Stack from ghstack (oldest at bottom):

Fixes #99879

This adds minimum_with_index helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a tl.reduce using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

As an example, this is the kernel generated for torch.argmin(x, 1):

def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 1028 # dynamic_shapes=False
    rnumel = 1028 # dynamic_shapes=False
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32)
    _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
        _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index(
            _tmp1, _tmp1_index, tmp0, rindex
        )
        _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1)
        _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index)
    _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1)
    tmp1 = tmp1_tmp[:, None]
    tl.store(out_ptr0 + x0, tmp1, xmask)

Or for a persistent reduction, it generates:

    tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0)
    tmp2 = tl.where(rmask & xmask, tmp0, float("inf"))
    tmp3 = tl.broadcast_to(rindex, tmp2.shape)
    _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1)
    tmp4 = tmp4_tmp[:, None]

cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 24, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99920

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure

As of commit 54312fc:

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Fixes #99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

[ghstack-poisoned]
Fixes #99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Apr 24, 2023
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: f80e18419a8e783e9a778f150aa911d30b0fe50c
Pull Request resolved: pytorch#99920
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Apr 24, 2023
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: f80e18419a8e783e9a778f150aa911d30b0fe50c
Pull Request resolved: pytorch#99920
Fixes #99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Apr 25, 2023
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: 73dc04c8f3a249b046f440f4f3df7c06d83f183e
Pull Request resolved: pytorch#99920
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Apr 25, 2023
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: 73dc04c8f3a249b046f440f4f3df7c06d83f183e
Pull Request resolved: pytorch#99920
Fixes #99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Apr 25, 2023
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: 73dc04c8f3a249b046f440f4f3df7c06d83f183e
Pull Request resolved: pytorch#99920
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Apr 25, 2023
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: 73dc04c8f3a249b046f440f4f3df7c06d83f183e
Pull Request resolved: pytorch#99920
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Apr 25, 2023
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: 73dc04c8f3a249b046f440f4f3df7c06d83f183e
Pull Request resolved: pytorch#99920
Fixes #99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
Fixes #99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@peterbell10 peterbell10 marked this pull request as ready for review April 25, 2023 15:42
@peterbell10 peterbell10 requested a review from ngimel April 25, 2023 15:42
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! I left minor comments, can you please post generated code in the PR description?

atol=1e-5,
rtol=0.5,
)
self.common(fn, (torch.randn([144, 144]),))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great that this test is fixed!

if self.device == "cpu":
raise unittest.SkipTest("broken on CPU")

t1 = torch.randn((10, 10))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of (or in addition to) 10, 10 I'd also have (8,8), because reductions of size 8 or smaller are unrolled iirc

else:
result_var = self.cse.generate(
self.compute, final_reduction(masked_value)
)
elif (src_dtype, reduction_type, value) not in self.cse.reduction_cache:
self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var
accumulator = f"_{result_var}"
default_value = f" + {default}" if default != 0 else ""
self.body.writeline(
f"{accumulator} = tl.zeros({self.dense_size_str()}, {triton_compute_type(src_dtype)}){default_value}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you replace this with tl.full?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out that doing this uncovers some bugs that I'd rather not dig into in this PR. Essentially some boolean ops actually return int32, but this was masked by tl.zeros() + x promoting the accumulator to integer anyway.

peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Apr 25, 2023
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: 47407b98133251392fbe329fa4b741eebc4000e2
Pull Request resolved: pytorch#99920
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Apr 25, 2023
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: 47407b98133251392fbe329fa4b741eebc4000e2
Pull Request resolved: pytorch#99920
…te values"

Fixes #99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Apr 25, 2023
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: dabfb041144543cdde5d5aa99806d2205858796f
Pull Request resolved: pytorch#99920
Fixes #99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

As an example, this is the kernel generated for `torch.argmin(x, 1)`:
```python
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 1028 # dynamic_shapes=False
    rnumel = 1028 # dynamic_shapes=False
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32)
    _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
        _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index(
            _tmp1, _tmp1_index, tmp0, rindex
        )
        _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1)
        _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index)
    _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1)
    tmp1 = tmp1_tmp[:, None]
    tl.store(out_ptr0 + x0, tmp1, xmask)
```

Or for a persistent reduction, it generates:
```python
    tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0)
    tmp2 = tl.where(rmask & xmask, tmp0, float("inf"))
    tmp3 = tl.broadcast_to(rindex, tmp2.shape)
    _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1)
    tmp4 = tmp4_tmp[:, None]
```


cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@peterbell10
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 27, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Apr 27, 2023
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: 03dce80229d6b0180f307a3835f110f93de14d30
Pull Request resolved: pytorch#99920
Fixes #99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

As an example, this is the kernel generated for `torch.argmin(x, 1)`:
```python
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 1028 # dynamic_shapes=False
    rnumel = 1028 # dynamic_shapes=False
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32)
    _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
        _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index(
            _tmp1, _tmp1_index, tmp0, rindex
        )
        _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1)
        _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index)
    _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1)
    tmp1 = tmp1_tmp[:, None]
    tl.store(out_ptr0 + x0, tmp1, xmask)
```

Or for a persistent reduction, it generates:
```python
    tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0)
    tmp2 = tl.where(rmask & xmask, tmp0, float("inf"))
    tmp3 = tl.broadcast_to(rindex, tmp2.shape)
    _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1)
    tmp4 = tmp4_tmp[:, None]
```


cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
Fixes #99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

As an example, this is the kernel generated for `torch.argmin(x, 1)`:
```python
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 1028 # dynamic_shapes=False
    rnumel = 1028 # dynamic_shapes=False
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32)
    _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
        _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index(
            _tmp1, _tmp1_index, tmp0, rindex
        )
        _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1)
        _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index)
    _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1)
    tmp1 = tmp1_tmp[:, None]
    tl.store(out_ptr0 + x0, tmp1, xmask)
```

Or for a persistent reduction, it generates:
```python
    tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0)
    tmp2 = tl.where(rmask & xmask, tmp0, float("inf"))
    tmp3 = tl.broadcast_to(rindex, tmp2.shape)
    _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1)
    tmp4 = tmp4_tmp[:, None]
```


cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@peterbell10
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Apr 27, 2023
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: eece9e1a560a36913f7622313a1e5fa315ecfb86
Pull Request resolved: pytorch#99920
@ezyang
Copy link
Contributor

ezyang commented May 3, 2023

Hey @peterbell10, sorry about this, but we need to revert some of the tl.reduce patches; internally we are still on the older version of Triton without tl.reduce, and rolling the Triton version forward is blocked on the new version of Triton regressing some numeric tests which need investigating.

It would be nice to still have this code in; how difficult do you think it would be to feature flag these improvements on Triton version? I read over the PRs and it wasn't immediately obvious that it would be easy, unfortunately.

@ezyang ezyang mentioned this pull request May 3, 2023
ezyang added a commit that referenced this pull request May 3, 2023
Revert "[inductor] Stop using `x + tl.zeros(...)` in generated triton (#100163)"

This reverts commit 5b98910.

Revert "[inductor] Fix argmin/max with duplicate values (#99920)"

This reverts commit 659dcc5.

Revert "[inductor] Fix nan-handling of max and min reductions (#99881)"

This reverts commit f9c3fcd.

[ghstack-poisoned]
ezyang added a commit that referenced this pull request May 3, 2023
Revert "[inductor] Stop using `x + tl.zeros(...)` in generated triton (#100163)"

This reverts commit 5b98910.

Revert "[inductor] Fix argmin/max with duplicate values (#99920)"

This reverts commit 659dcc5.

Revert "[inductor] Fix nan-handling of max and min reductions (#99881)"

This reverts commit f9c3fcd.

ghstack-source-id: 85531baedfb245e48512be97c0ed90eba1685664
Pull Request resolved: #100517
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request May 3, 2023
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: eece9e1a560a36913f7622313a1e5fa315ecfb86
Pull Request resolved: pytorch#99920
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request May 3, 2023
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: eece9e1a560a36913f7622313a1e5fa315ecfb86
Pull Request resolved: pytorch#99920
@peterbell10 peterbell10 reopened this May 3, 2023
Fixes #99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

As an example, this is the kernel generated for `torch.argmin(x, 1)`:
```python
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 1028 # dynamic_shapes=False
    rnumel = 1028 # dynamic_shapes=False
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32)
    _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
        _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index(
            _tmp1, _tmp1_index, tmp0, rindex
        )
        _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1)
        _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index)
    _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1)
    tmp1 = tmp1_tmp[:, None]
    tl.store(out_ptr0 + x0, tmp1, xmask)
```

Or for a persistent reduction, it generates:
```python
    tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0)
    tmp2 = tl.where(rmask & xmask, tmp0, float("inf"))
    tmp3 = tl.broadcast_to(rindex, tmp2.shape)
    _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1)
    tmp4 = tmp4_tmp[:, None]
```


cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
Fixes #99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

As an example, this is the kernel generated for `torch.argmin(x, 1)`:
```python
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 1028 # dynamic_shapes=False
    rnumel = 1028 # dynamic_shapes=False
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32)
    _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
        _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index(
            _tmp1, _tmp1_index, tmp0, rindex
        )
        _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1)
        _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index)
    _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1)
    tmp1 = tmp1_tmp[:, None]
    tl.store(out_ptr0 + x0, tmp1, xmask)
```

Or for a persistent reduction, it generates:
```python
    tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0)
    tmp2 = tl.where(rmask & xmask, tmp0, float("inf"))
    tmp3 = tl.broadcast_to(rindex, tmp2.shape)
    _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1)
    tmp4 = tmp4_tmp[:, None]
```


cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@peterbell10 peterbell10 closed this May 3, 2023
@facebook-github-bot facebook-github-bot deleted the gh/peterbell10/542/head branch June 8, 2023 18:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants