Skip to content

Commit

Permalink
Revert D25574962: [pytorch][PR] Updated derivative rules for complex …
Browse files Browse the repository at this point in the history
…svd and pinverse

Test Plan: revert-hammer

Differential Revision:
D25574962 (9955355)

Original commit changeset: 832b61303e88

fbshipit-source-id: d73f77f3e51b0f535dad6d21c5bebf8d41a6bfbd
  • Loading branch information
Mike Ruberry authored and facebook-github-bot committed Dec 17, 2020
1 parent c18af03 commit f5b68e7
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 170 deletions.
24 changes: 24 additions & 0 deletions test/test_autograd.py
Expand Up @@ -3000,6 +3000,30 @@ def test_igammac(self):
gradcheck(torch.igamma, (s, x))
gradgradcheck(torch.igamma, (s, x))

@skipIfNoLapack
def test_pinverse(self):
# Why is pinverse tested this way, and not ordinarily as other linear algebra methods?
# 1. Pseudo-inverses are not generally continuous, which means that they are not differentiable
# 2. Derivatives for pseudo-inverses exist typically for constant rank (Golub et al, 1973)
# 3. This method creates two orthogonal matrices, and a constructs a test case with large
# singular values (given by x to the function).
# 4. This will ensure that small perturbations don't affect the rank of matrix, in which case
# a derivative exists.
# 5. This test exists since pinverse is implemented using SVD, and is hence a backpropable method
m, n = 5, 10
U = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n
V = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n

def func(x):
S = torch.cat([x, torch.zeros(n - m)], 0)
M = U.mm(torch.diag(S)).mm(V.t())
return M.pinverse()

gradcheck(func, [torch.rand(m).add_(1).requires_grad_()])
gradcheck(func, [torch.rand(m).add_(10).requires_grad_()])
gradgradcheck(func, [torch.rand(m).add_(1).requires_grad_()])
gradgradcheck(func, [torch.rand(m).add_(10).requires_grad_()])

def test_chain_matmul(self):
def gen_matrices(p):
matrices = []
Expand Down
6 changes: 1 addition & 5 deletions test/test_jit.py
Expand Up @@ -15488,11 +15488,7 @@ def forward(self, x):
'test_slogdet_batched_pos_det',
'test_slogdet_batched_symmetric',
'test_slogdet_batched_symmetric_pd',
'test_slogdet_batched_distinct_singular_values',
'test_svd_check_grad_s',
'test_svd_check_grad_u',
'test_svd_check_grad_uv',
'test_svd_check_grad_v'
'test_slogdet_batched_distinct_singular_values'
}

# chunk returns a list in scripting and we don't unpack the list,
Expand Down
10 changes: 7 additions & 3 deletions test/test_linalg.py
Expand Up @@ -1331,9 +1331,6 @@ def gen_error_message(input_size, ord, keepdim, dim=None):
# TODO: Fix autograd for matrix orders 'nuc', 2, and -2 by adding complex
# support to svd's backward method. Once this is done, these ords
# should be added to `matrix_ords` above
# Update: svd's backward now works with https://github.com/pytorch/pytorch/pull/47761
# However run_test_case doesn't work for 'matrix_ords_unsupported' cases
# because singular values of 'x' and 'x_real' can be different and so is their norms based on singular values
matrix_ords_unsupported = ['nuc', 2, -2]

def run_test_case(x, ord, keepdim):
Expand All @@ -1360,6 +1357,13 @@ def run_test_case(x, ord, keepdim):
x = torch.randn(25, 25, dtype=dtype, device=device, requires_grad=True)
run_test_case(x, ord, keepdim)

for ord in matrix_ords_unsupported:
x = torch.randn(25, 25, dtype=dtype, device=device, requires_grad=True)
with self.assertRaisesRegex(
RuntimeError,
r'svd does not support automatic differentiation for outputs with complex dtype'):
res = torch.linalg.norm(x, ord, keepdim=keepdim)

# Test that linal.norm gives the same result as numpy when inputs
# contain extreme values (inf, -inf, nan)
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
Expand Down
23 changes: 7 additions & 16 deletions test/test_ops.py
Expand Up @@ -29,15 +29,13 @@ class TestOpInfo(TestCase):
@onlyOnCPUAndCUDA
@ops(op_db, dtypes=OpDTypes.unsupported)
def test_unsupported_dtypes(self, device, dtype, op):
# sample_inputs can have a function for generating the input that doesn't work for specified dtype
# https://github.com/pytorch/pytorch/issues/49024
with self.assertRaises(RuntimeError):
samples = op.sample_inputs(device, dtype)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
samples = op.sample_inputs(device, dtype)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")

# NOTE: only tests on first sample
sample = samples[0]
# NOTE: only tests on first sample
sample = samples[0]
with self.assertRaises(RuntimeError):
op(*sample.input, *sample.args, **sample.kwargs)

# Verifies that ops have their supported dtypes
Expand Down Expand Up @@ -76,14 +74,7 @@ def _check_helper(self, device, dtype, op, variant, check):

samples = op.sample_inputs(device, dtype, requires_grad=True)
for sample in samples:
if sample.output_process_fn_grad is not None:
out_fn = sample.output_process_fn_grad

def variant_out_fn(*args, **kwargs):
return out_fn(variant(*args, **kwargs))
else:
variant_out_fn = variant
partial_fn = partial(variant_out_fn, **sample.kwargs)
partial_fn = partial(variant, **sample.kwargs)
if check == 'gradcheck':
self.assertTrue(gradcheck(partial_fn, (*sample.input,) + sample.args,
check_grad_dtypes=True))
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Expand Up @@ -78,7 +78,7 @@
'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_',
'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv',
'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'qr', 'svd',
'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'qr',
'_fft_c2c', '_fft_r2c',
}

Expand Down
9 changes: 1 addition & 8 deletions torch/_torch_docs.py
Expand Up @@ -8010,7 +8010,7 @@ def merge_dicts(*dicts):
svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor)
This function returns a namedtuple ``(U, S, V)`` which is the singular value
decomposition of a input matrix or batches of matrices :attr:`input` such that
decomposition of a input real matrix or batches of real matrices :attr:`input` such that
:math:`input = U \times diag(S) \times V^T`.
If :attr:`some` is ``True`` (default), the method returns the reduced
Expand All @@ -8022,8 +8022,6 @@ def merge_dicts(*dicts):
If :attr:`compute_uv` is ``False``, the returned `U` and `V` matrices will be zero matrices
of shape :math:`(m \times m)` and :math:`(n \times n)` respectively. :attr:`some` will be ignored here.
Supports real-valued and complex-valued input.
.. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices,
then the singular values of each matrix in the batch is returned in descending order.
Expand All @@ -8048,9 +8046,6 @@ def merge_dicts(*dicts):
.. note:: When :attr:`compute_uv` = ``False``, backward cannot be performed since `U` and `V`
from the forward pass is required for the backward operation.
.. note:: With the complex-valued input the backward operation works correctly only
for gauge invariant loss functions. Please look at `Gauge problem in AD`_ for more details.
Args:
input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more
batch dimensions consisting of :math:`m \times n` matrices.
Expand Down Expand Up @@ -8088,8 +8083,6 @@ def merge_dicts(*dicts):
>>> u, s, v = torch.svd(a_big)
>>> torch.dist(a_big, torch.matmul(torch.matmul(u, torch.diag_embed(s)), v.transpose(-2, -1)))
tensor(2.6503e-06)
.. _Gauge problem in AD: https://re-ra.xyz/Gauge-Problem-in-Automatic-Differentiation/
""")

add_docstr(torch.symeig,
Expand Down
47 changes: 14 additions & 33 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -1824,35 +1824,28 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
auto gsigma = grads[1];

auto u = raw_u;
// Currently torch.svd for complex dtypes returns the conjugate of V,
// while the backward formula is derived with just V (without the conjugation)
// therefore here we need to conjugate the V output of SVD and grads[2].
// Once https://github.com/pytorch/pytorch/issues/45821 is resolved
// extra .conj(), that are marked below in the code, shall be removed.
auto v = raw_v.conj(); // TODO: remove .conj()
auto v = raw_v;
auto gu = grads[0];
auto gv = grads[2].conj(); // TODO: remove .conj()
auto gv = grads[2];

if (!some) {
// We ignore the free subspace here because possible base vectors cancel
// each other, e.g., both -v and +v are valid base for a dimension.
// Don't assume behavior of any particular implementation of svd.
u = raw_u.narrow(-1, 0, k);
v = raw_v.narrow(-1, 0, k).conj(); // TODO: remove .conj()
v = raw_v.narrow(-1, 0, k);
if (gu.defined()) {
gu = gu.narrow(-1, 0, k);
}
if (gv.defined()) {
gv = gv.narrow(-1, 0, k);
}
}
auto vh = v.conj().transpose(-2, -1);
auto vt = v.transpose(-2, -1);

Tensor sigma_term;
if (gsigma.defined()) {
gsigma = gsigma.to(self.dtype());
// computes u @ diag(gsigma) @ vh
sigma_term = at::matmul(u * gsigma.unsqueeze(-2), vh);
sigma_term = at::matmul(u, at::matmul(gsigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1), vt));
} else {
sigma_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
Expand All @@ -1862,11 +1855,11 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
return sigma_term;
}

auto uh = u.conj().transpose(-2, -1);
auto ut = u.transpose(-2, -1);
auto im = at::eye(m, self.options());
auto in = at::eye(n, self.options());
auto sigma_mat = sigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).to(self.dtype());
auto sigma_mat_inv = sigma.pow(-1).diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).to(self.dtype());
auto sigma_mat = sigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1);
auto sigma_mat_inv = sigma.pow(-1).diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1);
auto sigma_sq = sigma.pow(2);
auto F = sigma_sq.unsqueeze(-2) - sigma_sq.unsqueeze(-1);
// The following two lines invert values of F, and fills the diagonal with 0s.
Expand All @@ -1878,38 +1871,26 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
Tensor u_term, v_term;

if (gu.defined()) {
auto guh = gu.conj().transpose(-2, -1);
u_term = at::matmul(u, at::matmul(F.mul(at::matmul(uh, gu) - at::matmul(guh, u)), sigma_mat));
u_term = at::matmul(u, at::matmul(F.mul(at::matmul(ut, gu) - at::matmul(gu.transpose(-2, -1), u)), sigma_mat));
if (m > k) {
u_term = u_term + at::matmul(im - at::matmul(u, uh), at::matmul(gu, sigma_mat_inv));
u_term = u_term + at::matmul(im - at::matmul(u, ut), at::matmul(gu, sigma_mat_inv));
}
u_term = at::matmul(u_term, vh);
u_term = at::matmul(u_term, vt);
} else {
u_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}

if (gv.defined()) {
auto gvh = gv.conj().transpose(-2, -1);
v_term = at::matmul(sigma_mat, at::matmul(F.mul(at::matmul(vh, gv) - at::matmul(gvh, v)), vh));
auto gvt = gv.transpose(-2, -1);
v_term = at::matmul(sigma_mat, at::matmul(F.mul(at::matmul(vt, gv) - at::matmul(gvt, v)), vt));
if (n > k) {
v_term = v_term + at::matmul(sigma_mat_inv, at::matmul(gvh, in - at::matmul(v, vh)));
v_term = v_term + at::matmul(sigma_mat_inv, at::matmul(gvt, in - at::matmul(v, vt)));
}
v_term = at::matmul(u, v_term);
} else {
v_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}

// for complex-valued input there is an additional term
// https://giggleliu.github.io/2019/04/02/einsumbp.html
// https://arxiv.org/abs/1909.02659
if (self.is_complex() && gu.defined()) {
// computes L = Identity.mul(uh @ gu)
Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1).diag_embed(0, -2, -1);
L = L - L.conj().transpose(-2, -1);
Tensor imag_term = 0.5 * at::matmul(at::matmul(at::matmul(u, L), sigma_mat_inv), vh);
return u_term + sigma_term + v_term + imag_term;
}

return u_term + sigma_term + v_term;
}

Expand Down

0 comments on commit f5b68e7

Please sign in to comment.