Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set proper output differentiability for unique function #47930

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 49 additions & 1 deletion test/test_autograd.py
Expand Up @@ -4879,12 +4879,60 @@ def test_integer_outputs(self):
self.assertFalse(out.dtype.is_floating_point)
self.assertFalse(out.requires_grad)

bins = torch.linspace(0, 1.0, requires_grad=True)
bins = torch.linspace(0, 1.0, steps=100, requires_grad=True)
vals = torch.rand(5, 5, requires_grad=True)
out = torch.bucketize(vals, bins)
self.assertFalse(out.dtype.is_floating_point)
self.assertFalse(out.requires_grad)

def assert_only_first_requires_grad(res):
if not isinstance(res, tuple):
res = (res,)
self.assertTrue(res[0].requires_grad)
for out in res[1:]:
if out is not None:
self.assertFalse(out.requires_grad)

for sort in [True, False]:
for return_inverse in [True, False]:
for return_counts in [True, False]:
res = torch.unique(inp, sorted=sort, return_inverse=return_inverse,
return_counts=return_counts)
assert_only_first_requires_grad(res)

res = torch.unique(inp, sorted=sort, return_inverse=return_inverse,
return_counts=return_counts, dim=0)
assert_only_first_requires_grad(res)

res = torch.unique_consecutive(inp, return_inverse=return_inverse,
return_counts=return_counts)
assert_only_first_requires_grad(res)

res = torch.unique_consecutive(inp, return_inverse=return_inverse,
return_counts=return_counts, dim=0)
assert_only_first_requires_grad(res)

# Here we test the internal functions to make sure all of them are
# covered on top of the public API
res = torch._unique(inp, sorted=sort, return_inverse=return_inverse)
assert_only_first_requires_grad(res)

# This looks public but is actually manually deleted from the
# torch namespace in torch/functional.py
res = torch._VF.unique_dim(inp, dim=0, sorted=sort, return_inverse=return_inverse,
return_counts=return_counts)
assert_only_first_requires_grad(res)

# We don't test `unique_dim_consecutive` here.
# It looks public but the python binding is actually manually disabled in
# tools/autograd/gen_python_functions.py

res = torch._unique2(inp, sorted=sort, return_inverse=return_inverse,
return_counts=return_counts)
assert_only_first_requires_grad(res)



def index_variable(shape, max_indices):
if not isinstance(shape, tuple):
shape = (shape,)
Expand Down
17 changes: 17 additions & 0 deletions tools/autograd/derivatives.yaml
Expand Up @@ -1106,8 +1106,25 @@
self: zeros_like(grad)

- name: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
output_differentiability: [True, False]
self: not_implemented("_unique")

- name: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
output_differentiability: [True, False, False]
self: not_implemented("unique_dim")

- name: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)
output_differentiability: [True, False, False]
self: not_implemented("unique_consecutive")

- name: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
output_differentiability: [True, False, False]
self: not_implemented("unique_dim_consecutive")

- name: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
output_differentiability: [True, False, False]
self: not_implemented("_unique2")

- name: _unsafe_view(Tensor self, int[] size) -> Tensor
self: grad.reshape(self.sizes())

Expand Down
1 change: 0 additions & 1 deletion tools/autograd/gen_variable_type.py
Expand Up @@ -840,7 +840,6 @@ def emit_increment_version():
# set_flags has to appear after version_counter, because rebase_history
# requires that the counter is incremented before it is called
body.append(emit_history())
if requires_derivative:
body.append(emit_save_outputs())
body.extend(emit_check_if_in_complex_autograd_allowlist())
if base_name in RESET_GRAD_ACCUMULATOR:
Expand Down