From ddd08f9843421592d00407bc670d4db3486b0cd2 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 8 Sep 2023 16:32:35 -0300 Subject: [PATCH] Dispatch `numpy.take_along_axis` to `torch.take_along_dim` ghstack-source-id: 034d7f244bf1f60d6e82946bfcf95ceb9af605d6 Pull Request resolved: https://github.com/pytorch/pytorch/pull/108880 --- test/dynamo/test_misc.py | 7 ++++++- torch/_numpy/_funcs_impl.py | 2 +- .../_internal/common_methods_invocations.py | 17 ++++++++++++----- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 2385319bb7fce..add2ddb501e9b 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1288,7 +1288,12 @@ def sample_to_args(s): samples = list( sample_inputs_gather( - None, "cpu", torch.float32, requires_grad=False, include_0d=False + None, + "cpu", + torch.float32, + requires_grad=False, + include_0d=False, + include_empty=False, ) ) cnts = torch._dynamo.testing.CompileCounter() diff --git a/torch/_numpy/_funcs_impl.py b/torch/_numpy/_funcs_impl.py index f3232548a0a6c..e0b8bbc0d0e0d 100644 --- a/torch/_numpy/_funcs_impl.py +++ b/torch/_numpy/_funcs_impl.py @@ -886,7 +886,7 @@ def take( def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis): (arr,), axis = _util.axis_none_flatten(arr, axis=axis) axis = _util.normalize_axis_index(axis, arr.ndim) - return torch.gather(arr, axis, indices) + return torch.take_along_dim(arr, indices, axis) def put( diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 75cc6ad57c6dd..58fcc834108df 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2401,7 +2401,13 @@ def reference_unbind(t, dim): """A numpy implementation of torch.unbind""" return tuple(s.squeeze(dim) for s in np.split(t, t.shape[dim], dim)) -def sample_inputs_gather(op_info, device, dtype, requires_grad, include_0d=True, **kwargs): +def sample_inputs_gather(op_info, + device, + dtype, + requires_grad, + include_0d=True, + include_empty=True, + **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) yield SampleInput( make_arg((M, S)), @@ -2412,10 +2418,11 @@ def sample_inputs_gather(op_info, device, dtype, requires_grad, include_0d=True, 1, gather_variable((M, S // 2), 0, S, True, device=device)) # Empty index tensor case, see: https://github.com/pytorch/pytorch/pull/65006 - yield SampleInput( - make_arg((S,)), - 0, - torch.tensor([], dtype=torch.uint8, device=device)) + if include_empty: + yield SampleInput( + make_arg((S,)), + 0, + torch.tensor([], dtype=torch.uint8, device=device)) # 0D tensor case if include_0d: yield SampleInput(