Skip to content

Commit

Permalink
Dispatch numpy.take_along_axis to torch.take_along_dim
Browse files Browse the repository at this point in the history
ghstack-source-id: 034d7f244bf1f60d6e82946bfcf95ceb9af605d6
Pull Request resolved: #108880
  • Loading branch information
guilhermeleobas committed Sep 8, 2023
1 parent 64800dc commit ddd08f9
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
7 changes: 6 additions & 1 deletion test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,12 @@ def sample_to_args(s):

samples = list(
sample_inputs_gather(
None, "cpu", torch.float32, requires_grad=False, include_0d=False
None,
"cpu",
torch.float32,
requires_grad=False,
include_0d=False,
include_empty=False,
)
)
cnts = torch._dynamo.testing.CompileCounter()
Expand Down
2 changes: 1 addition & 1 deletion torch/_numpy/_funcs_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def take(
def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
(arr,), axis = _util.axis_none_flatten(arr, axis=axis)
axis = _util.normalize_axis_index(axis, arr.ndim)
return torch.gather(arr, axis, indices)
return torch.take_along_dim(arr, indices, axis)


def put(
Expand Down
17 changes: 12 additions & 5 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2401,7 +2401,13 @@ def reference_unbind(t, dim):
"""A numpy implementation of torch.unbind"""
return tuple(s.squeeze(dim) for s in np.split(t, t.shape[dim], dim))

def sample_inputs_gather(op_info, device, dtype, requires_grad, include_0d=True, **kwargs):
def sample_inputs_gather(op_info,
device,
dtype,
requires_grad,
include_0d=True,
include_empty=True,
**kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
yield SampleInput(
make_arg((M, S)),
Expand All @@ -2412,10 +2418,11 @@ def sample_inputs_gather(op_info, device, dtype, requires_grad, include_0d=True,
1,
gather_variable((M, S // 2), 0, S, True, device=device))
# Empty index tensor case, see: https://github.com/pytorch/pytorch/pull/65006
yield SampleInput(
make_arg((S,)),
0,
torch.tensor([], dtype=torch.uint8, device=device))
if include_empty:
yield SampleInput(
make_arg((S,)),
0,
torch.tensor([], dtype=torch.uint8, device=device))
# 0D tensor case
if include_0d:
yield SampleInput(
Expand Down

0 comments on commit ddd08f9

Please sign in to comment.