Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3739,6 +3739,7 @@ def test_vmap_exhaustive(self, device, dtype, op):
# but found at least two devices, cuda:0 and cpu!
xfail('ge', device_type='cuda'),
xfail('_upsample_bilinear2d_aa'),
xfail('argsort'), # aten::argsort.stable hit the vmap fallback which is currently disabled
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpInfo now tests with the stable kwarg, so this got uncovered.

}))
def test_op_has_batch_rule(self, device, dtype, op):
# needs to be fixed
Expand Down
5 changes: 5 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4716,6 +4716,11 @@ def meta_sort(self, stable=None, dim=-1, descending=False, values=None, indices=
return v, i


@register_meta(aten.argsort.stable)
def meta_argsort(self, *, stable, dim=-1, descending=False):
return meta_sort(self, stable=stable, dim=dim, descending=descending)[1]


def rnn_cell_checkSizes(
input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden
):
Expand Down
6 changes: 1 addition & 5 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3127,10 +3127,6 @@ def sample_inputs_threshold(op_info, device, dtype, requires_grad, **kwargs):
# threshold and values args must be numbers
yield SampleInput(make_arg(x_size), make_arg(()).item(), make_arg(()).item())

def sample_inputs_argsort(*args, **kwargs):
return (sample_input for sample_input in sample_inputs_sort(*args, **kwargs)
if "stable" not in sample_input.kwargs)

def sample_inputs_unique(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S))
Expand Down Expand Up @@ -18269,7 +18265,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
"argsort",
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_argsort,
sample_inputs_func=sample_inputs_sort,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

argsort is the second value returned by torch.sort(). Same API as sort, not sure why samples weren't reused before.

torch.argsort(input, dim=- 1, descending=False, stable=False) → Tensor
   torch.sort(input, dim=- 1, descending=False, stable=False, *, out=None)

Note: on the aten side, stable has no default:

func: argsort.stable(Tensor self, *, bool stable, int dim=-1, bool descending=False) -> Tensor

supports_out=False,
supports_autograd=False,
skips=(
Expand Down