Skip to content

Commit

Permalink
Don't call tl.view in arg{min,max}
Browse files Browse the repository at this point in the history
A small oversight in triton-lang#1305, since `view` can rearrange elements it
should be avoided here. Instead I use indexing with `None` to create
new dimensions.
  • Loading branch information
peterbell10 committed Apr 13, 2023
1 parent e152183 commit f8baa4f
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,9 +1255,10 @@ def _argreduce(input, axis, combine_fn, _builder=None, _generator=None):
index = arange(0, n, _builder=_builder)

if len(input.shape) > 1:
new_shape = [constexpr(1)] * len(input.shape)
new_shape[axis] = constexpr(n)
index = view(index, new_shape, _builder=_builder)
# Broadcast index across the non-reduced axes
expand_dims_index = [None] * len(input.shape)
expand_dims_index[axis] = slice(None)
index = index.__getitem__(expand_dims_index, _builder=_builder)
index = broadcast_to(index, input.shape, _builder=_builder)

rvalue, rindices = reduction((input, index), axis, combine_fn,
Expand Down

0 comments on commit f8baa4f

Please sign in to comment.