-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Enhance diagonal (fixes #6479) #6718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch - adds Tensor.diagonal to complement torch.diagonal - implements diagonal natively in ATen - makes diagonal a view - implements taking arbitrary diagonals - implements diagonal backward instead of referring to the (more limited) diag
diag_size = std::min(self.size(dim1)+offset, self.size(dim2)); | ||
storage_offset -= offset * self.stride(dim1); | ||
} | ||
AT_ASSERT(diag_size > 0, "invalid diagonal offset %zd", offset); // the diagonal offset was too large in magnitude |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -1909,6 +1909,26 @@ def _test_diagonal(self, dtype, device): | |||
def test_diagonal(self): | |||
self._test_diagonal(self, dtype=torch.float32, device='cpu') | |||
|
|||
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found') | |||
def test_diagonal_multidim(self): | |||
x = torch.randn(10, 11, 12, 13) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Tensor diagonal_backward(const Tensor & grad, IntList input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { | ||
auto grad_input = at::zeros(grad.type(), input_sizes); | ||
auto diag = at::diagonal(grad_input, offset, dim1, dim2); | ||
diag.copy_(grad); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
result = torch.diagonal(x, 0, -2, -1) | ||
expected = xn.diagonal(0, -2, -1) | ||
self.assertEqual(expected.shape, result.shape) | ||
self.assertTrue(np.allclose(expected, result.numpy())) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
# test that the backward requires grad | ||
# we do this is because diagonal_backward uses inplace | ||
# operations and gradgradcheck does not catch whether | ||
# they works as expected |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
// copies the diagonal code in aten/src/ATen/native/TensorShape.cpp | ||
// that would be equivalent to | ||
// auto diag = grad_input.diagonal(offset, dim1, dim2); | ||
// when using diagonal, the output is not differentiable twice |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/_torch_docs.py
Outdated
@@ -1305,6 +1313,17 @@ | |||
-0.2239 | |||
[torch.FloatTensor of size 2] | |||
|
|||
>>> x = torch.randn(2,5,4,2) | |||
>>> torch.diagonal(x, -1, 1, 2) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@pytorchbot retest this please |
@ezyang is that my patch or the macos CI that has the bug? I didn't know that anything I did affected the fourier tests... |
Hi, is there anything I can do to move this forward? |
@pytorchbot retest this please I don't see how this PR could have triggered this. Let's try again. |
@pytorchbot retest this please |
>>> x = torch.randn(2, 5, 4, 2) | ||
>>> torch.diagonal(x, offset=-1, dim1=1, dim2=2) | ||
|
||
(0 ,.,.) = |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
* Enhance diagonal This patch - adds Tensor.diagonal to complement torch.diagonal - implements diagonal natively in ATen - makes diagonal a view - implements taking arbitrary diagonals - implements diagonal backward instead of referring to the (more limited) diag * add tests, copy diagonal code to backward for double differentiability * improve tests and doc comment. Thank you, Adam! * Mark diagonal as view function in gen_autograd.py, use simple backward.
* Enhance diagonal This patch - adds Tensor.diagonal to complement torch.diagonal - implements diagonal natively in ATen - makes diagonal a view - implements taking arbitrary diagonals - implements diagonal backward instead of referring to the (more limited) diag * add tests, copy diagonal code to backward for double differentiability * improve tests and doc comment. Thank you, Adam! * Mark diagonal as view function in gen_autograd.py, use simple backward.
This patch
to the (more limited) diag
There is some discussion in #6479 .