Skip to content

Commit

Permalink
Updates nonzero's as_tuple behavior to no longer warn. (#45413)
Browse files Browse the repository at this point in the history
Summary:
Fixes #44284.

[torch.nonzero](https://pytorch.org/docs/master/generated/torch.nonzero.html?highlight=nonzero#torch.nonzero) is distinct from [numpy.nonzero](https://numpy.org/doc/1.18/reference/generated/numpy.nonzero.html?highlight=nonzero#numpy.nonzero). The latter returns a tensor by default, and the former returns a tuple of tensors. The `as_tuple` argument was added as part of an intended deprecation process to make torch.nonzero consistent with numpy.nonzero, but this was a confusing change for users. A better deprecation path would be to offer torch.argwhere consistent with [numpy.argwhere](https://numpy.org/doc/stable/reference/generated/numpy.argwhere.html?highlight=argwhere#numpy.argwhere), which is equivalent to the default torch.nonzero behavior. Once this is offered a change to torch.nonzero should be more straightforward with less user disruption, if we decided that's the correct change to pursue.

Pull Request resolved: #45413

Reviewed By: ngimel

Differential Revision: D23975015

Pulled By: mruberry

fbshipit-source-id: b59237d0d8c2df984e952b62d0a7c247b49d84dc
  • Loading branch information
Mike Ruberry authored and facebook-github-bot committed Sep 29, 2020
1 parent 0df99ad commit b66ac1e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 28 deletions.
44 changes: 31 additions & 13 deletions test/test_torch.py
Expand Up @@ -10795,15 +10795,6 @@ def assert_tuple_empty(tup, dim):
self.assertEqual(1, len(z))
self.assertEqual(torch.empty(0, dtype=torch.long), z[0])

@onlyOnCPUAndCUDA
def test_nonzero_deprecated(self, device):
x = torch.randn((2, 3), device=device)
with self.maybeWarnsRegex(UserWarning, "This overload of nonzero is deprecated"):
x.nonzero()

with self.maybeWarnsRegex(UserWarning, "This overload of nonzero is deprecated"):
torch.nonzero(x)

# TODO: add torch.complex64, torch.complex128
@dtypes(torch.float, torch.double)
def test_normal(self, device, dtype):
Expand Down Expand Up @@ -13070,10 +13061,6 @@ def gen_nontrivial_input(shape, dtype, device):
dst2 = tensor.nonzero(as_tuple=False)
dst3 = torch.empty([], dtype=torch.long, device=device)
torch.nonzero(tensor, out=dst3)
self.assertRaisesRegex(
TypeError,
"received an invalid combination of arguments",
lambda: torch.nonzero(tensor, as_tuple=True, out=dst3))
if self.device_type != 'xla':
# xla does not raise runtime error
self.assertRaisesRegex(
Expand All @@ -13099,6 +13086,37 @@ def gen_nontrivial_input(shape, dtype, device):
self.assertEqual(tup1, np_result, atol=0, rtol=0)
self.assertEqual(tup2, np_result, atol=0, rtol=0)

def test_nonzero_astuple_out(self, device):
t = torch.randn((3, 3, 3), device=device)
out = torch.empty_like(t, dtype=torch.long)

with self.assertRaises(RuntimeError):
torch.nonzero(t, as_tuple=True, out=out)

self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out))

# Verifies that JIT script cannot handle the as_tuple kwarg
# See Issue https://github.com/pytorch/pytorch/issues/45499.
def _foo(t):
tuple_result = torch.nonzero(t, as_tuple=True)
nontuple_result = torch.nonzero(t, as_tuple=False)
out = torch.empty_like(nontuple_result)
torch.nonzero(t, as_tuple=False, out=out)
return tuple_result, nontuple_result, out

with self.assertRaises(RuntimeError):
scripted_foo = torch.jit.script(_foo)

# Verifies that JIT tracing works fine
traced_foo = torch.jit.trace(_foo, t)
traced_tuple, traced_nontuple, traced_out = traced_foo(t)
expected_tuple = torch.nonzero(t, as_tuple=True)
expected_nontuple = torch.nonzero(t)

self.assertEqual(traced_tuple, expected_tuple)
self.assertEqual(traced_nontuple, expected_nontuple)
self.assertEqual(traced_out, expected_nontuple)

@onlyOnCPUAndCUDA
def test_nonzero_discontiguous(self, device):
shape = (4, 4)
Expand Down
30 changes: 15 additions & 15 deletions tools/autograd/templates/python_torch_functions.cpp
Expand Up @@ -583,29 +583,29 @@ static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject*
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"nonzero(Tensor input, *, Tensor out=None)|deprecated",
"nonzero(Tensor input, *, bool as_tuple)",
"nonzero(Tensor input, *, bool as_tuple=False, Tensor out=None)",
});
ParsedArgs<2> parsed_args;
ParsedArgs<3> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);

if(r.has_torch_function()){
return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch");
}

if (r.idx == 0) {
if (r.isNone(1)) {
return wrap(dispatch_nonzero(r.tensor(0)));
} else {
return wrap(dispatch_nonzero(r.tensor(0), r.tensor(1)));
}
} else {
if (r.toBool(1)) {
return wrap(dispatch_nonzero_numpy(r.tensor(0)));
} else {
return wrap(dispatch_nonzero(r.tensor(0)));
}
const auto as_tuple = r.toBool(1);
const auto has_out = !r.isNone(2);

if (as_tuple) {
TORCH_CHECK(!has_out, "nonzero does not support the out kwarg when as_tuple is True");
return wrap(dispatch_nonzero_numpy(r.tensor(0)));
}

if (has_out) {
return wrap(dispatch_nonzero(r.tensor(0), r.tensor(2)));
}

return wrap(dispatch_nonzero(r.tensor(0)));

END_HANDLE_TH_ERRORS
}

Expand Down

0 comments on commit b66ac1e

Please sign in to comment.