Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Fix argmin/max with duplicate values #99920

Closed
wants to merge 15 commits into from

Commits on Apr 24, 2023

  1. [inductor] Fix argmin/max with duplicate values

    Fixes #99879
    
    This adds `minimum_with_index` helper functions to compute the minimum
    value and index simultaneously, with a preference for the smaller
    index which is required to match eager in case of duplicates.
    
    I also remove the mask-and-sum hack with a `tl.reduce` using
    the previously mentioned helper. This additionally fixes the indices
    being added together in the case of duplicates.
    
    [ghstack-poisoned]
    peterbell10 committed Apr 24, 2023
    Configuration menu
    Copy the full SHA
    cdf8cff View commit details
    Browse the repository at this point in the history
  2. Update on "[inductor] Fix argmin/max with duplicate values"

    Fixes #99879
    
    This adds `minimum_with_index` helper functions to compute the minimum
    value and index simultaneously, with a preference for the smaller
    index which is required to match eager in case of duplicates.
    
    I also remove the mask-and-sum hack with a `tl.reduce` using
    the previously mentioned helper. This additionally fixes the indices
    being added together in the case of duplicates.
    
    cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire
    
    [ghstack-poisoned]
    peterbell10 committed Apr 24, 2023
    Configuration menu
    Copy the full SHA
    a18bada View commit details
    Browse the repository at this point in the history
  3. Update on "[inductor] Fix argmin/max with duplicate values"

    Fixes #99879
    
    This adds `minimum_with_index` helper functions to compute the minimum
    value and index simultaneously, with a preference for the smaller
    index which is required to match eager in case of duplicates.
    
    I also remove the mask-and-sum hack with a `tl.reduce` using
    the previously mentioned helper. This additionally fixes the indices
    being added together in the case of duplicates.
    
    cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire
    
    [ghstack-poisoned]
    peterbell10 committed Apr 24, 2023
    Configuration menu
    Copy the full SHA
    8cdafd4 View commit details
    Browse the repository at this point in the history

Commits on Apr 25, 2023

  1. Update on "[inductor] Fix argmin/max with duplicate values"

    Fixes #99879
    
    This adds `minimum_with_index` helper functions to compute the minimum
    value and index simultaneously, with a preference for the smaller
    index which is required to match eager in case of duplicates.
    
    I also remove the mask-and-sum hack with a `tl.reduce` using
    the previously mentioned helper. This additionally fixes the indices
    being added together in the case of duplicates.
    
    cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire
    
    [ghstack-poisoned]
    peterbell10 committed Apr 25, 2023
    Configuration menu
    Copy the full SHA
    ebe7b16 View commit details
    Browse the repository at this point in the history
  2. Update on "[inductor] Fix argmin/max with duplicate values"

    Fixes #99879
    
    This adds `minimum_with_index` helper functions to compute the minimum
    value and index simultaneously, with a preference for the smaller
    index which is required to match eager in case of duplicates.
    
    I also remove the mask-and-sum hack with a `tl.reduce` using
    the previously mentioned helper. This additionally fixes the indices
    being added together in the case of duplicates.
    
    cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire
    
    [ghstack-poisoned]
    peterbell10 committed Apr 25, 2023
    Configuration menu
    Copy the full SHA
    cef8a3f View commit details
    Browse the repository at this point in the history
  3. Update on "[inductor] Fix argmin/max with duplicate values"

    Fixes #99879
    
    This adds `minimum_with_index` helper functions to compute the minimum
    value and index simultaneously, with a preference for the smaller
    index which is required to match eager in case of duplicates.
    
    I also remove the mask-and-sum hack with a `tl.reduce` using
    the previously mentioned helper. This additionally fixes the indices
    being added together in the case of duplicates.
    
    cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire
    
    [ghstack-poisoned]
    peterbell10 committed Apr 25, 2023
    Configuration menu
    Copy the full SHA
    ed19f97 View commit details
    Browse the repository at this point in the history
  4. Fix unrolled argmin/argmax on "[inductor] Fix argmin/max with duplica…

    …te values"
    
    Fixes #99879
    
    This adds `minimum_with_index` helper functions to compute the minimum
    value and index simultaneously, with a preference for the smaller
    index which is required to match eager in case of duplicates.
    
    I also remove the mask-and-sum hack with a `tl.reduce` using
    the previously mentioned helper. This additionally fixes the indices
    being added together in the case of duplicates.
    
    cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire
    
    [ghstack-poisoned]
    peterbell10 committed Apr 25, 2023
    Configuration menu
    Copy the full SHA
    4e5c903 View commit details
    Browse the repository at this point in the history
  5. Update on "[inductor] Fix argmin/max with duplicate values"

    Fixes #99879
    
    This adds `minimum_with_index` helper functions to compute the minimum
    value and index simultaneously, with a preference for the smaller
    index which is required to match eager in case of duplicates.
    
    I also remove the mask-and-sum hack with a `tl.reduce` using
    the previously mentioned helper. This additionally fixes the indices
    being added together in the case of duplicates.
    
    cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire
    
    [ghstack-poisoned]
    peterbell10 committed Apr 25, 2023
    Configuration menu
    Copy the full SHA
    90356de View commit details
    Browse the repository at this point in the history
  6. Update on "[inductor] Fix argmin/max with duplicate values"

    Fixes #99879
    
    This adds `minimum_with_index` helper functions to compute the minimum
    value and index simultaneously, with a preference for the smaller
    index which is required to match eager in case of duplicates.
    
    I also remove the mask-and-sum hack with a `tl.reduce` using
    the previously mentioned helper. This additionally fixes the indices
    being added together in the case of duplicates.
    
    As an example, this is the kernel generated for `torch.argmin(x, 1)`:
    ```python
    def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
        xnumel = 1028 # dynamic_shapes=False
        rnumel = 1028 # dynamic_shapes=False
        xoffset = tl.program_id(0) * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
        xmask = xindex < xnumel
        rbase = tl.arange(0, RBLOCK)[None, :]
        x0 = xindex
        _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32)
        _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
        for roffset in range(0, rnumel, RBLOCK):
            rindex = roffset + rbase
            rmask = rindex < rnumel
            r1 = rindex
            tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
            _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index(
                _tmp1, _tmp1_index, tmp0, rindex
            )
            _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1)
            _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index)
        _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1)
        tmp1 = tmp1_tmp[:, None]
        tl.store(out_ptr0 + x0, tmp1, xmask)
    ```
    
    Or for a persistent reduction, it generates:
    ```python
        tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0)
        tmp2 = tl.where(rmask & xmask, tmp0, float("inf"))
        tmp3 = tl.broadcast_to(rindex, tmp2.shape)
        _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1)
        tmp4 = tmp4_tmp[:, None]
    ```
    
    
    cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire
    
    [ghstack-poisoned]
    peterbell10 committed Apr 25, 2023
    Configuration menu
    Copy the full SHA
    d83317d View commit details
    Browse the repository at this point in the history
  7. Revert tl.full usage on "[inductor] Fix argmin/max with duplicate val…

    …ues"
    
    
    Fixes #99879
    
    This adds `minimum_with_index` helper functions to compute the minimum
    value and index simultaneously, with a preference for the smaller
    index which is required to match eager in case of duplicates.
    
    I also remove the mask-and-sum hack with a `tl.reduce` using
    the previously mentioned helper. This additionally fixes the indices
    being added together in the case of duplicates.
    
    As an example, this is the kernel generated for `torch.argmin(x, 1)`:
    ```python
    def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
        xnumel = 1028 # dynamic_shapes=False
        rnumel = 1028 # dynamic_shapes=False
        xoffset = tl.program_id(0) * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
        xmask = xindex < xnumel
        rbase = tl.arange(0, RBLOCK)[None, :]
        x0 = xindex
        _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32)
        _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
        for roffset in range(0, rnumel, RBLOCK):
            rindex = roffset + rbase
            rmask = rindex < rnumel
            r1 = rindex
            tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
            _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index(
                _tmp1, _tmp1_index, tmp0, rindex
            )
            _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1)
            _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index)
        _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1)
        tmp1 = tmp1_tmp[:, None]
        tl.store(out_ptr0 + x0, tmp1, xmask)
    ```
    
    Or for a persistent reduction, it generates:
    ```python
        tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0)
        tmp2 = tl.where(rmask & xmask, tmp0, float("inf"))
        tmp3 = tl.broadcast_to(rindex, tmp2.shape)
        _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1)
        tmp4 = tmp4_tmp[:, None]
    ```
    
    
    cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire
    
    [ghstack-poisoned]
    peterbell10 committed Apr 25, 2023
    Configuration menu
    Copy the full SHA
    7ecfb27 View commit details
    Browse the repository at this point in the history

Commits on Apr 26, 2023

  1. Update on "[inductor] Fix argmin/max with duplicate values"

    Fixes #99879
    
    This adds `minimum_with_index` helper functions to compute the minimum
    value and index simultaneously, with a preference for the smaller
    index which is required to match eager in case of duplicates.
    
    I also remove the mask-and-sum hack with a `tl.reduce` using
    the previously mentioned helper. This additionally fixes the indices
    being added together in the case of duplicates.
    
    As an example, this is the kernel generated for `torch.argmin(x, 1)`:
    ```python
    def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
        xnumel = 1028 # dynamic_shapes=False
        rnumel = 1028 # dynamic_shapes=False
        xoffset = tl.program_id(0) * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
        xmask = xindex < xnumel
        rbase = tl.arange(0, RBLOCK)[None, :]
        x0 = xindex
        _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32)
        _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
        for roffset in range(0, rnumel, RBLOCK):
            rindex = roffset + rbase
            rmask = rindex < rnumel
            r1 = rindex
            tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
            _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index(
                _tmp1, _tmp1_index, tmp0, rindex
            )
            _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1)
            _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index)
        _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1)
        tmp1 = tmp1_tmp[:, None]
        tl.store(out_ptr0 + x0, tmp1, xmask)
    ```
    
    Or for a persistent reduction, it generates:
    ```python
        tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0)
        tmp2 = tl.where(rmask & xmask, tmp0, float("inf"))
        tmp3 = tl.broadcast_to(rindex, tmp2.shape)
        _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1)
        tmp4 = tmp4_tmp[:, None]
    ```
    
    
    cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire
    
    [ghstack-poisoned]
    peterbell10 committed Apr 26, 2023
    Configuration menu
    Copy the full SHA
    ba7b2a1 View commit details
    Browse the repository at this point in the history

Commits on Apr 27, 2023

  1. Update on "[inductor] Fix argmin/max with duplicate values"

    Fixes #99879
    
    This adds `minimum_with_index` helper functions to compute the minimum
    value and index simultaneously, with a preference for the smaller
    index which is required to match eager in case of duplicates.
    
    I also remove the mask-and-sum hack with a `tl.reduce` using
    the previously mentioned helper. This additionally fixes the indices
    being added together in the case of duplicates.
    
    As an example, this is the kernel generated for `torch.argmin(x, 1)`:
    ```python
    def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
        xnumel = 1028 # dynamic_shapes=False
        rnumel = 1028 # dynamic_shapes=False
        xoffset = tl.program_id(0) * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
        xmask = xindex < xnumel
        rbase = tl.arange(0, RBLOCK)[None, :]
        x0 = xindex
        _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32)
        _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
        for roffset in range(0, rnumel, RBLOCK):
            rindex = roffset + rbase
            rmask = rindex < rnumel
            r1 = rindex
            tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
            _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index(
                _tmp1, _tmp1_index, tmp0, rindex
            )
            _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1)
            _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index)
        _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1)
        tmp1 = tmp1_tmp[:, None]
        tl.store(out_ptr0 + x0, tmp1, xmask)
    ```
    
    Or for a persistent reduction, it generates:
    ```python
        tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0)
        tmp2 = tl.where(rmask & xmask, tmp0, float("inf"))
        tmp3 = tl.broadcast_to(rindex, tmp2.shape)
        _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1)
        tmp4 = tmp4_tmp[:, None]
    ```
    
    
    cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire
    
    [ghstack-poisoned]
    peterbell10 committed Apr 27, 2023
    Configuration menu
    Copy the full SHA
    0ae1451 View commit details
    Browse the repository at this point in the history
  2. Update on "[inductor] Fix argmin/max with duplicate values"

    Fixes #99879
    
    This adds `minimum_with_index` helper functions to compute the minimum
    value and index simultaneously, with a preference for the smaller
    index which is required to match eager in case of duplicates.
    
    I also remove the mask-and-sum hack with a `tl.reduce` using
    the previously mentioned helper. This additionally fixes the indices
    being added together in the case of duplicates.
    
    As an example, this is the kernel generated for `torch.argmin(x, 1)`:
    ```python
    def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
        xnumel = 1028 # dynamic_shapes=False
        rnumel = 1028 # dynamic_shapes=False
        xoffset = tl.program_id(0) * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
        xmask = xindex < xnumel
        rbase = tl.arange(0, RBLOCK)[None, :]
        x0 = xindex
        _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32)
        _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
        for roffset in range(0, rnumel, RBLOCK):
            rindex = roffset + rbase
            rmask = rindex < rnumel
            r1 = rindex
            tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
            _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index(
                _tmp1, _tmp1_index, tmp0, rindex
            )
            _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1)
            _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index)
        _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1)
        tmp1 = tmp1_tmp[:, None]
        tl.store(out_ptr0 + x0, tmp1, xmask)
    ```
    
    Or for a persistent reduction, it generates:
    ```python
        tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0)
        tmp2 = tl.where(rmask & xmask, tmp0, float("inf"))
        tmp3 = tl.broadcast_to(rindex, tmp2.shape)
        _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1)
        tmp4 = tmp4_tmp[:, None]
    ```
    
    
    cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire
    
    [ghstack-poisoned]
    peterbell10 committed Apr 27, 2023
    Configuration menu
    Copy the full SHA
    902ea5e View commit details
    Browse the repository at this point in the history

Commits on May 3, 2023

  1. Update on "[inductor] Fix argmin/max with duplicate values"

    Fixes #99879
    
    This adds `minimum_with_index` helper functions to compute the minimum
    value and index simultaneously, with a preference for the smaller
    index which is required to match eager in case of duplicates.
    
    I also remove the mask-and-sum hack with a `tl.reduce` using
    the previously mentioned helper. This additionally fixes the indices
    being added together in the case of duplicates.
    
    As an example, this is the kernel generated for `torch.argmin(x, 1)`:
    ```python
    def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
        xnumel = 1028 # dynamic_shapes=False
        rnumel = 1028 # dynamic_shapes=False
        xoffset = tl.program_id(0) * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
        xmask = xindex < xnumel
        rbase = tl.arange(0, RBLOCK)[None, :]
        x0 = xindex
        _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32)
        _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
        for roffset in range(0, rnumel, RBLOCK):
            rindex = roffset + rbase
            rmask = rindex < rnumel
            r1 = rindex
            tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
            _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index(
                _tmp1, _tmp1_index, tmp0, rindex
            )
            _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1)
            _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index)
        _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1)
        tmp1 = tmp1_tmp[:, None]
        tl.store(out_ptr0 + x0, tmp1, xmask)
    ```
    
    Or for a persistent reduction, it generates:
    ```python
        tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0)
        tmp2 = tl.where(rmask & xmask, tmp0, float("inf"))
        tmp3 = tl.broadcast_to(rindex, tmp2.shape)
        _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1)
        tmp4 = tmp4_tmp[:, None]
    ```
    
    
    cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire
    
    [ghstack-poisoned]
    peterbell10 committed May 3, 2023
    Configuration menu
    Copy the full SHA
    1a0025b View commit details
    Browse the repository at this point in the history
  2. Update on "[inductor] Fix argmin/max with duplicate values"

    Fixes #99879
    
    This adds `minimum_with_index` helper functions to compute the minimum
    value and index simultaneously, with a preference for the smaller
    index which is required to match eager in case of duplicates.
    
    I also remove the mask-and-sum hack with a `tl.reduce` using
    the previously mentioned helper. This additionally fixes the indices
    being added together in the case of duplicates.
    
    As an example, this is the kernel generated for `torch.argmin(x, 1)`:
    ```python
    def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
        xnumel = 1028 # dynamic_shapes=False
        rnumel = 1028 # dynamic_shapes=False
        xoffset = tl.program_id(0) * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
        xmask = xindex < xnumel
        rbase = tl.arange(0, RBLOCK)[None, :]
        x0 = xindex
        _tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32)
        _tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
        for roffset in range(0, rnumel, RBLOCK):
            rindex = roffset + rbase
            rmask = rindex < rnumel
            r1 = rindex
            tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
            _tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index(
                _tmp1, _tmp1_index, tmp0, rindex
            )
            _tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1)
            _tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index)
        _, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1)
        tmp1 = tmp1_tmp[:, None]
        tl.store(out_ptr0 + x0, tmp1, xmask)
    ```
    
    Or for a persistent reduction, it generates:
    ```python
        tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0)
        tmp2 = tl.where(rmask & xmask, tmp0, float("inf"))
        tmp3 = tl.broadcast_to(rindex, tmp2.shape)
        _, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1)
        tmp4 = tmp4_tmp[:, None]
    ```
    
    
    cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire
    
    [ghstack-poisoned]
    peterbell10 committed May 3, 2023
    Configuration menu
    Copy the full SHA
    54312fc View commit details
    Browse the repository at this point in the history