From 6550c528b7ea57cd1d415a5bf0f2860a8a86e8f2 Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Thu, 13 Apr 2023 07:32:23 +0000 Subject: [PATCH] [FRONTEND] don't call `tl.view` in `arg{min,max}` (#1518) A small oversight in #1305, since `view` can rearrange elements it should be avoided here. Instead I use indexing with `None` to create new dimensions. Co-authored-by: Philippe Tillet --- python/triton/language/core.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 1431da509f42..7193cb6cc012 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1255,9 +1255,10 @@ def _argreduce(input, axis, combine_fn, _builder=None, _generator=None): index = arange(0, n, _builder=_builder) if len(input.shape) > 1: - new_shape = [constexpr(1)] * len(input.shape) - new_shape[axis] = constexpr(n) - index = view(index, new_shape, _builder=_builder) + # Broadcast index across the non-reduced axes + expand_dims_index = [None] * len(input.shape) + expand_dims_index[axis] = slice(None) + index = index.__getitem__(expand_dims_index, _builder=_builder) index = broadcast_to(index, input.shape, _builder=_builder) rvalue, rindices = reduction((input, index), axis, combine_fn,