Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated derivative rules for complex svd and pinverse #47761

Closed
wants to merge 47 commits into from
Closed
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
ecc1f09
wip svd autograd update
IvanYashchuk Oct 20, 2020
eff6c22
Merge remote-tracking branch 'upstream/master' into autograd-svd
IvanYashchuk Oct 30, 2020
bbca9c6
Added missing imaginary term when input is complex
IvanYashchuk Oct 30, 2020
9179f91
Added tests; gradgradcheck doesn't work yet
IvanYashchuk Oct 30, 2020
546d7e5
Merge remote-tracking branch 'upstream/master' into autograd-svd
IvanYashchuk Nov 5, 2020
f6fcab7
Updated tests
IvanYashchuk Nov 5, 2020
f8cefe5
Added alternative imaginary term
IvanYashchuk Nov 5, 2020
bd0c997
Merge remote-tracking branch 'upstream/master' into autograd-svd
IvanYashchuk Nov 11, 2020
cff4acd
Fixed backward svd, finally all tests pass!
IvanYashchuk Nov 11, 2020
320f6b9
Remove unnecessary imag_term when only gv is defined
IvanYashchuk Nov 11, 2020
7417886
Fixed batched input case; chain_matmul works only for non-batched input
IvanYashchuk Nov 11, 2020
50c39ee
Enable complex common_methods tests
IvanYashchuk Nov 11, 2020
b606de2
Added svd test cases that check grad for each returned tensor separately
IvanYashchuk Nov 11, 2020
ca9607b
Removed tmp tests from test_autograd.py; they're part of common_metho…
IvanYashchuk Nov 11, 2020
22438a3
Enabled complex tests for pinverse
IvanYashchuk Nov 11, 2020
a3c88d7
flake8
IvanYashchuk Nov 11, 2020
d73359c
Added a note on gauge invariance requirement to docs
IvanYashchuk Nov 11, 2020
02d666b
Remove commented line of code
IvanYashchuk Nov 11, 2020
a030319
Exclude type checks for svd grad tests
IvanYashchuk Nov 11, 2020
e81b84c
Merge remote-tracking branch 'upstream/master' into autograd-svd
IvanYashchuk Nov 11, 2020
3d5efd7
Merge branch 'master' into autograd-svd
IvanYashchuk Nov 12, 2020
a112f30
Merge branch 'master' into autograd-svd
IvanYashchuk Nov 12, 2020
73a5f4e
Merge remote-tracking branch 'upstream/master' into autograd-svd
IvanYashchuk Nov 16, 2020
fd3353c
Don't skip batched tests; complex matmul on cuda now works
IvanYashchuk Nov 16, 2020
863658c
Merge branch 'master' into autograd-svd
IvanYashchuk Nov 23, 2020
ab2d8a6
Merge remote-tracking branch 'upstream/master' into autograd-svd
IvanYashchuk Dec 1, 2020
8732883
Updated comment on torch.svd returning conj(V)
IvanYashchuk Dec 1, 2020
7130485
Added a link for gauge invariance topic
IvanYashchuk Dec 1, 2020
6085d78
Added OpInfo-based tests
IvanYashchuk Dec 1, 2020
abd66ad
flake8
IvanYashchuk Dec 1, 2020
d4a5bcd
Fixed typo in test_ops.py
IvanYashchuk Dec 2, 2020
878336c
Merge remote-tracking branch 'upstream/master' into autograd-svd
IvanYashchuk Dec 8, 2020
c6e5ac0
Added descriptions to sample_inputs functions for svd and pinverse
IvanYashchuk Dec 8, 2020
5a7d67c
Added a link to github issue about test_unsupported_dtypes
IvanYashchuk Dec 8, 2020
9e2270b
Set supports_tensor_out=False
IvanYashchuk Dec 8, 2020
31cde02
Fixed error due to merge
IvanYashchuk Dec 8, 2020
861d8b9
Merge branch 'master' into autograd-svd
IvanYashchuk Dec 9, 2020
b07389f
Merge remote-tracking branch 'upstream/master' into autograd-svd
IvanYashchuk Dec 15, 2020
82a8007
Removed svd and pinverse entries from methods_tests;
IvanYashchuk Dec 15, 2020
1a446ff
Merge remote-tracking branch 'upstream/master' into autograd-svd
IvanYashchuk Dec 15, 2020
72ce395
Removed assertRaises for not implemented svd with complex dtype
IvanYashchuk Dec 15, 2020
d2b438d
Merge remote-tracking branch 'upstream/master' into autograd-svd
IvanYashchuk Dec 18, 2020
f301bfc
Skip cuda gradchecks
IvanYashchuk Dec 18, 2020
e6e3fa3
Mark test_fn_gradgrad as slowtest
IvanYashchuk Dec 18, 2020
18651b4
Use smaller arrays for tests
IvanYashchuk Dec 18, 2020
30b11ea
Undo changes to test/test_jit.py
IvanYashchuk Dec 20, 2020
cb7590e
Don't skip cuda grad checks
IvanYashchuk Dec 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 0 additions & 24 deletions test/test_autograd.py
Expand Up @@ -3000,30 +3000,6 @@ def test_igammac(self):
gradcheck(torch.igamma, (s, x))
gradgradcheck(torch.igamma, (s, x))

@skipIfNoLapack
def test_pinverse(self):
# Why is pinverse tested this way, and not ordinarily as other linear algebra methods?
# 1. Pseudo-inverses are not generally continuous, which means that they are not differentiable
# 2. Derivatives for pseudo-inverses exist typically for constant rank (Golub et al, 1973)
# 3. This method creates two orthogonal matrices, and a constructs a test case with large
# singular values (given by x to the function).
# 4. This will ensure that small perturbations don't affect the rank of matrix, in which case
# a derivative exists.
# 5. This test exists since pinverse is implemented using SVD, and is hence a backpropable method
m, n = 5, 10
U = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n
V = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n

def func(x):
S = torch.cat([x, torch.zeros(n - m)], 0)
M = U.mm(torch.diag(S)).mm(V.t())
return M.pinverse()

gradcheck(func, [torch.rand(m).add_(1).requires_grad_()])
gradcheck(func, [torch.rand(m).add_(10).requires_grad_()])
gradgradcheck(func, [torch.rand(m).add_(1).requires_grad_()])
gradgradcheck(func, [torch.rand(m).add_(10).requires_grad_()])

def test_chain_matmul(self):
def gen_matrices(p):
matrices = []
Expand Down
6 changes: 5 additions & 1 deletion test/test_jit.py
Expand Up @@ -15522,7 +15522,11 @@ def forward(self, x):
'test_slogdet_batched_pos_det',
'test_slogdet_batched_symmetric',
'test_slogdet_batched_symmetric_pd',
'test_slogdet_batched_distinct_singular_values'
'test_slogdet_batched_distinct_singular_values',
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
'test_svd_check_grad_s',
'test_svd_check_grad_u',
'test_svd_check_grad_uv',
'test_svd_check_grad_v'
}

# chunk returns a list in scripting and we don't unpack the list,
Expand Down
10 changes: 3 additions & 7 deletions test/test_linalg.py
Expand Up @@ -1331,6 +1331,9 @@ def gen_error_message(input_size, ord, keepdim, dim=None):
# TODO: Fix autograd for matrix orders 'nuc', 2, and -2 by adding complex
# support to svd's backward method. Once this is done, these ords
# should be added to `matrix_ords` above
# Update: svd's backward now works with https://github.com/pytorch/pytorch/pull/47761
# However run_test_case doesn't work for 'matrix_ords_unsupported' cases
# because singular values of 'x' and 'x_real' can be different and so is their norms based on singular values
matrix_ords_unsupported = ['nuc', 2, -2]

def run_test_case(x, ord, keepdim):
Expand All @@ -1357,13 +1360,6 @@ def run_test_case(x, ord, keepdim):
x = torch.randn(25, 25, dtype=dtype, device=device, requires_grad=True)
run_test_case(x, ord, keepdim)

for ord in matrix_ords_unsupported:
x = torch.randn(25, 25, dtype=dtype, device=device, requires_grad=True)
with self.assertRaisesRegex(
RuntimeError,
r'svd does not support automatic differentiation for outputs with complex dtype'):
res = torch.linalg.norm(x, ord, keepdim=keepdim)

# Test that linal.norm gives the same result as numpy when inputs
# contain extreme values (inf, -inf, nan)
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
Expand Down
23 changes: 16 additions & 7 deletions test/test_ops.py
Expand Up @@ -29,13 +29,15 @@ class TestOpInfo(TestCase):
@onlyOnCPUAndCUDA
@ops(op_db, dtypes=OpDTypes.unsupported)
def test_unsupported_dtypes(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")

# NOTE: only tests on first sample
sample = samples[0]
# sample_inputs can have a function for generating the input that doesn't work for specified dtype
# https://github.com/pytorch/pytorch/issues/49024
with self.assertRaises(RuntimeError):
samples = op.sample_inputs(device, dtype)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")

# NOTE: only tests on first sample
sample = samples[0]
op(*sample.input, *sample.args, **sample.kwargs)

# Verifies that ops have their supported dtypes
Expand Down Expand Up @@ -74,7 +76,14 @@ def _check_helper(self, device, dtype, op, variant, check):

samples = op.sample_inputs(device, dtype, requires_grad=True)
for sample in samples:
partial_fn = partial(variant, **sample.kwargs)
if sample.output_process_fn_grad is not None:
out_fn = sample.output_process_fn_grad

def variant_out_fn(*args, **kwargs):
return out_fn(variant(*args, **kwargs))
else:
variant_out_fn = variant
partial_fn = partial(variant_out_fn, **sample.kwargs)
if check == 'gradcheck':
self.assertTrue(gradcheck(partial_fn, (*sample.input,) + sample.args,
check_grad_dtypes=True))
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Expand Up @@ -78,7 +78,7 @@
'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_',
'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv',
'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'qr',
'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'qr', 'svd',
'_fft_c2c', '_fft_r2c',
}

Expand Down
9 changes: 8 additions & 1 deletion torch/_torch_docs.py
Expand Up @@ -8010,7 +8010,7 @@ def merge_dicts(*dicts):
svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor)

This function returns a namedtuple ``(U, S, V)`` which is the singular value
decomposition of a input real matrix or batches of real matrices :attr:`input` such that
decomposition of a input matrix or batches of matrices :attr:`input` such that
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future work: "of an input matrix or batch of matrices"

:math:`input = U \times diag(S) \times V^T`.

If :attr:`some` is ``True`` (default), the method returns the reduced
Expand All @@ -8022,6 +8022,8 @@ def merge_dicts(*dicts):
If :attr:`compute_uv` is ``False``, the returned `U` and `V` matrices will be zero matrices
of shape :math:`(m \times m)` and :math:`(n \times n)` respectively. :attr:`some` will be ignored here.

Supports real-valued and complex-valued input.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future work: specify the dtypes explicitly since "real-valued" would be every non-bool non-complex dtype


.. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices,
then the singular values of each matrix in the batch is returned in descending order.

Expand All @@ -8046,6 +8048,9 @@ def merge_dicts(*dicts):
.. note:: When :attr:`compute_uv` = ``False``, backward cannot be performed since `U` and `V`
from the forward pass is required for the backward operation.

.. note:: With the complex-valued input the backward operation works correctly only
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future work: this is too vague and referencing other documentation is never ideal

for gauge invariant loss functions. Please look at `Gauge problem in AD`_ for more details.

Args:
input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more
batch dimensions consisting of :math:`m \times n` matrices.
Expand Down Expand Up @@ -8083,6 +8088,8 @@ def merge_dicts(*dicts):
>>> u, s, v = torch.svd(a_big)
>>> torch.dist(a_big, torch.matmul(torch.matmul(u, torch.diag_embed(s)), v.transpose(-2, -1)))
tensor(2.6503e-06)

.. _Gauge problem in AD: https://re-ra.xyz/Gauge-Problem-in-Automatic-Differentiation/
""")

add_docstr(torch.symeig,
Expand Down
47 changes: 33 additions & 14 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -1824,28 +1824,35 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
auto gsigma = grads[1];

auto u = raw_u;
auto v = raw_v;
// Currently torch.svd for complex dtypes returns the conjugate of V,
// while the backward formula is derived with just V (without the conjugation)
// therefore here we need to conjugate the V output of SVD and grads[2].
// Once https://github.com/pytorch/pytorch/issues/45821 is resolved
// extra .conj(), that are marked below in the code, shall be removed.
auto v = raw_v.conj(); // TODO: remove .conj()
auto gu = grads[0];
auto gv = grads[2];
auto gv = grads[2].conj(); // TODO: remove .conj()

if (!some) {
// We ignore the free subspace here because possible base vectors cancel
// each other, e.g., both -v and +v are valid base for a dimension.
// Don't assume behavior of any particular implementation of svd.
u = raw_u.narrow(-1, 0, k);
v = raw_v.narrow(-1, 0, k);
v = raw_v.narrow(-1, 0, k).conj(); // TODO: remove .conj()
if (gu.defined()) {
gu = gu.narrow(-1, 0, k);
}
if (gv.defined()) {
gv = gv.narrow(-1, 0, k);
}
}
auto vt = v.transpose(-2, -1);
auto vh = v.conj().transpose(-2, -1);
albanD marked this conversation as resolved.
Show resolved Hide resolved

Tensor sigma_term;
if (gsigma.defined()) {
sigma_term = at::matmul(u, at::matmul(gsigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1), vt));
gsigma = gsigma.to(self.dtype());
// computes u @ diag(gsigma) @ vh
sigma_term = at::matmul(u * gsigma.unsqueeze(-2), vh);
} else {
sigma_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
Expand All @@ -1855,11 +1862,11 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
return sigma_term;
}

auto ut = u.transpose(-2, -1);
auto uh = u.conj().transpose(-2, -1);
auto im = at::eye(m, self.options());
auto in = at::eye(n, self.options());
auto sigma_mat = sigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1);
auto sigma_mat_inv = sigma.pow(-1).diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1);
auto sigma_mat = sigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).to(self.dtype());
auto sigma_mat_inv = sigma.pow(-1).diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).to(self.dtype());
auto sigma_sq = sigma.pow(2);
auto F = sigma_sq.unsqueeze(-2) - sigma_sq.unsqueeze(-1);
// The following two lines invert values of F, and fills the diagonal with 0s.
Expand All @@ -1871,26 +1878,38 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
Tensor u_term, v_term;

if (gu.defined()) {
u_term = at::matmul(u, at::matmul(F.mul(at::matmul(ut, gu) - at::matmul(gu.transpose(-2, -1), u)), sigma_mat));
auto guh = gu.conj().transpose(-2, -1);
u_term = at::matmul(u, at::matmul(F.mul(at::matmul(uh, gu) - at::matmul(guh, u)), sigma_mat));
if (m > k) {
u_term = u_term + at::matmul(im - at::matmul(u, ut), at::matmul(gu, sigma_mat_inv));
u_term = u_term + at::matmul(im - at::matmul(u, uh), at::matmul(gu, sigma_mat_inv));
}
u_term = at::matmul(u_term, vt);
u_term = at::matmul(u_term, vh);
} else {
u_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}

if (gv.defined()) {
auto gvt = gv.transpose(-2, -1);
v_term = at::matmul(sigma_mat, at::matmul(F.mul(at::matmul(vt, gv) - at::matmul(gvt, v)), vt));
auto gvh = gv.conj().transpose(-2, -1);
v_term = at::matmul(sigma_mat, at::matmul(F.mul(at::matmul(vh, gv) - at::matmul(gvh, v)), vh));
if (n > k) {
v_term = v_term + at::matmul(sigma_mat_inv, at::matmul(gvt, in - at::matmul(v, vt)));
v_term = v_term + at::matmul(sigma_mat_inv, at::matmul(gvh, in - at::matmul(v, vh)));
}
v_term = at::matmul(u, v_term);
} else {
v_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}

// for complex-valued input there is an additional term
// https://giggleliu.github.io/2019/04/02/einsumbp.html
// https://arxiv.org/abs/1909.02659
if (self.is_complex() && gu.defined()) {
// computes L = Identity.mul(uh @ gu)
Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1).diag_embed(0, -2, -1);
L = L - L.conj().transpose(-2, -1);
Tensor imag_term = 0.5 * at::matmul(at::matmul(at::matmul(u, L), sigma_mat_inv), vh);
return u_term + sigma_term + v_term + imag_term;
}

return u_term + sigma_term + v_term;
}

Expand Down