diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index bd59fe7d28b9..f0b36d0fdbac 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1411,7 +1411,7 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some if (compute_uv) { if (some) { - VT_working_copy = VT_working_copy.narrow(-1, 0, k); + VT_working_copy = VT_working_copy.narrow(-2, 0, k); } } else { VT_working_copy.zero_(); @@ -1421,24 +1421,71 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some U_working_copy.zero_(); VT_working_copy.zero_(); } + // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } std::tuple svd(const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); return at::_svd_helper(self, some, compute_uv); } -std::tuple svd_out(Tensor& U, Tensor& S, Tensor& VT, +std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); - Tensor U_tmp, S_tmp, VT_tmp; - std::tie(U_tmp, S_tmp, VT_tmp) = at::_svd_helper(self, some, compute_uv); + "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + Tensor U_tmp, S_tmp, V_tmp; + std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv); U.resize_as_(U_tmp).copy_(U_tmp); S.resize_as_(S_tmp).copy_(S_tmp); - VT.resize_as_(VT_tmp).copy_(VT_tmp); + V.resize_as_(V_tmp).copy_(V_tmp); + return std::tuple(U, S, V); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/* torch.linalg.svd, implemented in terms of torch.svd. There are two main + differences: + + 1. the 2nd parameter is bool some=True, which if effectively the opposite + of full_matrices=True + + 2. svd returns V, while linalg.svd returns VT. To accommodate the + difference, we transpose() V upon return +*/ + +std::tuple linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) { + TORCH_CHECK(self.dim() >= 2, + "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + + bool some = !full_matrices; + Tensor U, S, V; + std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv); + if (compute_uv) { + Tensor VT = V.transpose(-2, -1); + return std::make_tuple(U, S, VT); + } else { + Tensor empty_U = at::empty({0}, self.options()); + Tensor empty_VT = at::empty({0}, self.options()); + return std::make_tuple(empty_U, S, empty_VT); + } +} + +static void svd_resize_and_copy(const char *name, const Tensor& src, Tensor &dst) { + TORCH_CHECK(src.device() == dst.device(), "svd output tensor ", name, " is on the wrong device: expected ", src.device(), " got ", dst.device()); + at::native::resize_output(dst, src.sizes()); + dst.copy_(src); +} + +std::tuple linalg_svd_out(Tensor& U, Tensor& S, Tensor& VT, + const Tensor& self, bool full_matrices, bool compute_uv) { + Tensor U_tmp, S_tmp, VT_tmp; + std::tie(U_tmp, S_tmp, VT_tmp) = at::linalg_svd(self, full_matrices, compute_uv); + svd_resize_and_copy("U", U_tmp, U); + svd_resize_and_copy("S", S_tmp, S); + svd_resize_and_copy("V", VT_tmp, VT); return std::tuple(U, S, VT); } diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 607e201ebe8d..e97637d4c5e4 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -261,18 +261,21 @@ static inline std::tuple _create_U_S_VT(const Tensor& in U_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU)); } + // VT should be a column-major or a batch of column-major matrices sizes[input.dim() - 2] = n; sizes[input.dim() - 1] = n; - // VT should be a row-major or a batch of row-major matrices + strides = at::detail::defaultStrides(sizes); + strides[input.dim() - 1] = n; + strides[input.dim() - 2] = 1; Tensor VT_empty; if (!input.is_cuda()) { - VT_empty = at::empty(sizes, input.options()); + VT_empty = at::empty_strided(sizes, strides, input.options()); } else { // NB: VT_empty is an empty tensor created on the CPU intentionally, because magma_(d/s)gesdd // (which is the driver routine for the divide and conquer SVD operation) // takes in arrays on the CPU as input. This routine is a hybrid CPU-GPU routine that // moves the inputs between devices internally. - VT_empty = at::empty(sizes, input.options().device(at::kCPU)); + VT_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU)); } sizes.pop_back(); diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 3fbd693d17b1..379847d76ff4 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -2194,7 +2194,7 @@ std::tuple _svd_helper_cuda(const Tensor& self, bool som if (compute_uv) { if (some) { - VT_working_copy = VT_working_copy.narrow(-1, 0, k); + VT_working_copy = VT_working_copy.narrow(-2, 0, k); } } else { VT_working_copy.zero_(); @@ -2205,6 +2205,8 @@ std::tuple _svd_helper_cuda(const Tensor& self, bool som S_working_copy = same_stride_to(S_working_copy, S_working_copy.options().device(self.device())); VT_working_copy = same_stride_to(VT_working_copy, self.options()).zero_(); } + // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b474d435398c..e8e3efa307f8 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5820,14 +5820,14 @@ - func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: - DefaultBackend: svd_out + Math: svd_out - func: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) variants: method, function dispatch: - DefaultBackend: svd + Math: svd -- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor, Tensor, Tensor) +- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor U, Tensor S, Tensor V) variants: function dispatch: CPU: _svd_helper_cpu @@ -8962,6 +8962,15 @@ python_module: linalg variants: function +- func: linalg_svd.U(Tensor self, bool full_matrices=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: linalg + +- func: linalg_svd(Tensor self, bool full_matrices=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) + python_module: linalg + use_c10_dispatcher: full + variants: function + - func: linalg_cond(Tensor self, Scalar? p=None) -> Tensor python_module: linalg variants: function diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index d6de2373ad57..991205688df3 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -19,6 +19,7 @@ Functions .. autofunction:: eigvalsh .. autofunction:: matrix_rank .. autofunction:: norm +.. autofunction:: svd .. autofunction:: solve .. autofunction:: tensorinv .. autofunction:: tensorsolve diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index 2d5d50096c81..4332916fef6b 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -37,6 +37,7 @@ ("aten::ifft", datetime.date(2021, 1, 31)), ("aten::irfft", datetime.date(2021, 1, 31)), ("aten::rfft", datetime.date(2021, 1, 31)), + ("aten::_svd_helper", datetime.date(2021, 1, 31)), ("aten::_cudnn_rnn_flatten_weight", datetime.date(2020, 12, 31)), ("aten::_cudnn_rnn", datetime.date(2020, 12, 31)), ("aten::_cudnn_rnn_backward", datetime.date(2020, 12, 31)), diff --git a/test/test_linalg.py b/test/test_linalg.py index 6ebfb03256f3..f7cea014458c 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -11,6 +11,7 @@ from math import inf, nan, isnan import random from random import randrange +from itertools import product from functools import reduce from torch.testing._internal.common_utils import \ @@ -1864,6 +1865,277 @@ def test_nuclear_norm_exceptions_old(self, device): self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) + # ~~~ tests for torch.svd ~~~ + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_svd(self, device, dtype): + def run_test(dims, some, compute_uv): + x = torch.randn(*dims, dtype=dtype, device=device) + outu = torch.empty(0, dtype=dtype, device=device) + outs = torch.empty(0, dtype=dtype, device=device) + outv = torch.empty(0, dtype=dtype, device=device) + torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv)) + + if compute_uv: + if some: + x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + narrow_u = outu[..., :min(*dims[-2:])] + narrow_v = outv[..., :min(*dims[-2:])] + x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + _, singvals, _ = torch.svd(x, compute_uv=True) + self.assertEqual(singvals, outs, msg='Singular values mismatch') + self.assertEqual(outu, torch.zeros_like(outu), msg='U not zero') + self.assertEqual(outv, torch.zeros_like(outv), msg='V not zero') + + resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) + self.assertEqual(resu, outu, msg='outputs of svd and svd with out differ') + self.assertEqual(ress, outs, msg='outputs of svd and svd with out differ') + self.assertEqual(resv, outv, msg='outputs of svd and svd with out differ') + + # test non-contiguous + x = torch.randn(*dims, dtype=dtype, device=device) + n_dim = len(dims) + # Reverse the batch dimensions and the matrix dimensions and then concat them + x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) + assert not x.is_contiguous(), "x is intentionally non-contiguous" + resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) + if compute_uv: + if some: + x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + narrow_u = resu[..., :min(*dims[-2:])] + narrow_v = resv[..., :min(*dims[-2:])] + x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + _, singvals, _ = torch.svd(x, compute_uv=True) + self.assertEqual(singvals, ress, msg='Singular values mismatch') + self.assertEqual(resu, torch.zeros_like(resu), msg='U not zero') + self.assertEqual(resv, torch.zeros_like(resv), msg='V not zero') + + shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices + (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices + (3, 7), (5, 3, 7), (7, 5, 3, 7)] # thin matrices + for dims, some, compute_uv in product(shapes, [True, False], [True, False]): + run_test(dims, some, compute_uv) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float) + def test_svd_no_singularvectors(self, device, dtype): + for size in [(5, 5), (5, 20), (20, 5)]: + a = torch.randn(*size, device=device, dtype=dtype) + u, s_expect, v = torch.svd(a) + u, s_actual, v = torch.svd(a, compute_uv=False) + self.assertEqual(s_expect, s_actual, msg="Singular values don't match") + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_svd_lowrank(self, device, dtype): + from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix + + def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options): + density = options.pop('density', 1) + if isinstance(matrix_size, int): + rows = columns = matrix_size + else: + rows, columns = matrix_size + if density == 1: + a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) + a = a_input + else: + assert batches == () + a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) + a = a_input.to_dense() + + q = min(*size) + u, s, v = svd_lowrank(a_input, q=q, **options) + + # check if u, s, v is a SVD + u, s, v = u[..., :q], s[..., :q], v[..., :q] + A = u.matmul(s.diag_embed()).matmul(v.transpose(-2, -1)) + self.assertEqual(A, a) + + # check if svd_lowrank produces same singular values as torch.svd + U, S, V = torch.svd(a) + self.assertEqual(s.shape, S.shape) + self.assertEqual(u.shape, U.shape) + self.assertEqual(v.shape, V.shape) + self.assertEqual(s, S) + + if density == 1: + # actual_rank is known only for dense inputs + # + # check if pairs (u, U) and (v, V) span the same + # subspaces, respectively + u, s, v = u[..., :actual_rank], s[..., :actual_rank], v[..., :actual_rank] + U, S, V = U[..., :actual_rank], S[..., :actual_rank], V[..., :actual_rank] + self.assertEqual(u.transpose(-2, -1).matmul(U).det().abs(), torch.ones(batches, device=device, dtype=dtype)) + self.assertEqual(v.transpose(-2, -1).matmul(V).det().abs(), torch.ones(batches, device=device, dtype=dtype)) + + all_batches = [(), (1,), (3,), (2, 3)] + for actual_rank, size, all_batches in [ + (2, (17, 4), all_batches), + (4, (17, 4), all_batches), + (4, (17, 17), all_batches), + (10, (100, 40), all_batches), + (7, (1000, 1000), [()]), + ]: + # dense input + for batches in all_batches: + run_subtest(actual_rank, size, batches, device, torch.svd_lowrank) + if size != size[::-1]: + run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank) + + # sparse input + for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]: + for density in [0.005, 0.1]: + run_subtest(None, size, (), device, torch.svd_lowrank, density=density) + + # jitting support + jitted = torch.jit.script(torch.svd_lowrank) + actual_rank, size, batches = 2, (17, 4), () + run_subtest(actual_rank, size, batches, device, jitted) + + @onlyCPU + @skipCPUIfNoLapack + @dtypes(torch.cfloat) + def test_svd_complex(self, device, dtype): + t = torch.randn((10, 10), dtype=dtype, device=device) + U, S, V = torch.svd(t, some=False) + # note: from the math point of view, it is weird that we need to use + # V.T instead of V.T.conj(): torch.svd has a buggy behavior for + # complex numbers and it's deprecated. You should use torch.linalg.svd + # instead. + t2 = U @ torch.diag(S).type(dtype) @ V.T + self.assertEqual(t, t2) + + def _test_svd_helper(self, shape, some, col_maj, device, dtype): + cpu_tensor = torch.randn(shape, device='cpu').to(dtype) + device_tensor = cpu_tensor.to(device=device) + if col_maj: + cpu_tensor = cpu_tensor.t() + device_tensor = device_tensor.t() + cpu_result = torch.svd(cpu_tensor, some=some) + device_result = torch.svd(device_tensor, some=some) + m = min(cpu_tensor.shape[-2:]) + # torch.svd returns torch.return_types.svd which is a tuple of (U, V, S). + # - When some==False, U[..., m:] can be arbitrary. + # - When some==True, U shape: [..., m], V shape: [m, m] + # - Signs are not deterministic. If the sign of a column of U is changed + # then the corresponding column of the V has to be changed. + # Thus here we only compare result[..., :m].abs() from CPU and device. + for x, y in zip(cpu_result, device_result): + self.assertEqual(x[..., :m].abs(), y[..., :m].abs(), atol=1e-5, rtol=0) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_and_complex_types()) + def test_svd_square(self, device, dtype): + self._test_svd_helper((10, 10), True, False, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_square_col_maj(self, device, dtype): + self._test_svd_helper((10, 10), True, True, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_some(self, device, dtype): + self._test_svd_helper((20, 5), True, False, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_all(self, device, dtype): + self._test_svd_helper((20, 5), False, False, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_some_col_maj(self, device, dtype): + self._test_svd_helper((5, 20), True, True, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_all_col_maj(self, device, dtype): + self._test_svd_helper((5, 20), False, True, device, dtype) + + # ~~~ tests for torch.linalg.svd ~~~ + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_linalg_svd_compute_uv(self, device, dtype): + """ + Test the default case, compute_uv=True. Here we have the very same behavior as + numpy + """ + t = torch.randn((10, 11), device=device, dtype=dtype) + np_t = t.cpu().numpy() + for full_matrices in (True, False): + # check linalg.svd vs numpy + expected = np.linalg.svd(np_t, full_matrices, compute_uv=True) + actual = torch.linalg.svd(t, full_matrices, compute_uv=True) + self.assertEqual(actual, expected) + # check linalg.svd vs linalg.svd(out=...) + out = (torch.empty_like(actual[0]), + torch.empty_like(actual[1]), + torch.empty_like(actual[2])) + out2 = torch.linalg.svd(t, full_matrices, compute_uv=True, out=out) + self.assertEqual(actual, out) + self.assertEqual(actual, out2) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_linalg_svd_no_compute_uv(self, device, dtype): + """ + Test the compute_uv=False case. Here we have a different return type than + numpy: numpy returns S, we return (empty, S, empty) + """ + t = torch.randn((10, 11), device=device, dtype=dtype) + np_t = t.cpu().numpy() + + def is_empty(x): + return x.numel() == 0 and x.dtype == t.dtype and x.device == t.device + + for full_matrices in (True, False): + # check linalg.svd vs numpy + np_s = np.linalg.svd(np_t, full_matrices, compute_uv=False) + USV = torch.linalg.svd(t, full_matrices, compute_uv=False) + assert is_empty(USV.U) + self.assertEqual(USV.S, np_s) + assert is_empty(USV.V) + # check linalg.svd vs linalg.svd(out=...) + out = (torch.empty_like(USV.U), torch.empty_like(USV.S), torch.empty_like(USV.V)) + USV = torch.linalg.svd(t, full_matrices, compute_uv=False, out=out) + assert USV.U is out[0] + assert USV.S is out[1] + assert USV.V is out[2] + self.assertEqual(USV.S, np_s) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @onlyCUDA + @dtypes(torch.float) + def test_linalg_svd_out_different_device(self, device, dtype): + t = torch.randn(5, 7, device=device, dtype=dtype) # this is on cuda + u = torch.empty((5, 5), device='cpu', dtype=dtype) + s = torch.empty((5,), device='cpu', dtype=dtype) + v = torch.empty((7, 7), device='cpu', dtype=dtype) + with self.assertRaisesRegex(RuntimeError, 'svd output tensor U is on the wrong device: expected cuda:.* got cpu'): + torch.linalg.svd(t, out=(u, s, v)) + def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype): from torch.testing._internal.common_utils import random_hermitian_pd_matrix @@ -4602,60 +4874,6 @@ def test_solve_methods_arg_device(self, device): "Expected LU_pivots and LU_data to be on the same device"): torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=b_device).int()) - def _test_svd_helper(self, shape, some, col_maj, device, dtype): - cpu_tensor = torch.randn(shape, device='cpu').to(dtype) - device_tensor = cpu_tensor.to(device=device) - if col_maj: - cpu_tensor = cpu_tensor.t() - device_tensor = device_tensor.t() - cpu_result = torch.svd(cpu_tensor, some=some) - device_result = torch.svd(device_tensor, some=some) - m = min(cpu_tensor.shape[-2:]) - # torch.svd returns torch.return_types.svd which is a tuple of (U, V, S). - # - When some==False, U[..., m:] can be arbitrary. - # - When some==True, U shape: [..., m], V shape: [m, m] - # - Signs are not deterministic. If the sign of a column of U is changed - # then the corresponding column of the V has to be changed. - # Thus here we only compare result[..., :m].abs() from CPU and device. - for x, y in zip(cpu_result, device_result): - self.assertEqual(x[..., :m].abs(), y[..., :m].abs(), atol=1e-5, rtol=0) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_and_complex_types()) - def test_svd_square(self, device, dtype): - self._test_svd_helper((10, 10), True, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_square_col_maj(self, device, dtype): - self._test_svd_helper((10, 10), True, True, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_some(self, device, dtype): - self._test_svd_helper((20, 5), True, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_all(self, device, dtype): - self._test_svd_helper((20, 5), False, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_some_col_maj(self, device, dtype): - self._test_svd_helper((5, 20), True, True, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_all_col_maj(self, device, dtype): - self._test_svd_helper((5, 20), False, True, device, dtype) - @precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3}) @skipCUDAIfNoMagma @skipCPUIfNoLapack @@ -5597,145 +5815,6 @@ def run_test(dims, eigenvectors, upper): for batch_dims, eigenvectors, upper in itertools.product(batch_dims_set, (True, False), (True, False)): run_test((5,) + batch_dims, eigenvectors, upper) - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_svd(self, device, dtype): - def run_test(dims, some, compute_uv): - x = torch.randn(*dims, dtype=dtype, device=device) - outu = torch.tensor((), dtype=dtype, device=device) - outs = torch.tensor((), dtype=dtype, device=device) - outv = torch.tensor((), dtype=dtype, device=device) - torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv)) - - if compute_uv: - if some: - x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - narrow_u = outu[..., :min(*dims[-2:])] - narrow_v = outv[..., :min(*dims[-2:])] - x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - _, singvals, _ = torch.svd(x, compute_uv=True) - self.assertEqual(singvals, outs, msg='Singular values mismatch') - self.assertEqual(outu, torch.zeros_like(outu), msg='U not zero') - self.assertEqual(outv, torch.zeros_like(outv), msg='V not zero') - - resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) - self.assertEqual(resu, outu, msg='outputs of svd and svd with out differ') - self.assertEqual(ress, outs, msg='outputs of svd and svd with out differ') - self.assertEqual(resv, outv, msg='outputs of svd and svd with out differ') - - # test non-contiguous - x = torch.randn(*dims, dtype=dtype, device=device) - n_dim = len(dims) - # Reverse the batch dimensions and the matrix dimensions and then concat them - x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) - assert not x.is_contiguous(), "x is intentionally non-contiguous" - resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) - if compute_uv: - if some: - x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - narrow_u = resu[..., :min(*dims[-2:])] - narrow_v = resv[..., :min(*dims[-2:])] - x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - _, singvals, _ = torch.svd(x, compute_uv=True) - self.assertEqual(singvals, ress, msg='Singular values mismatch') - self.assertEqual(resu, torch.zeros_like(resu), msg='U not zero') - self.assertEqual(resv, torch.zeros_like(resv), msg='V not zero') - - shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices - (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices - (3, 7), (5, 3, 7), (7, 5, 3, 7)] # thin matrices - for dims, some, compute_uv in itertools.product(shapes, [True, False], [True, False]): - run_test(dims, some, compute_uv) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_svd_no_singularvectors(self, device): - for size in [(5, 5), (5, 20), (20, 5)]: - a = torch.randn(*size, device=device) - u, s_expect, v = torch.svd(a) - u, s_actual, v = torch.svd(a, compute_uv=False) - self.assertEqual(s_expect, s_actual, msg="Singular values don't match") - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_svd_lowrank(self, device): - import torch - from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix - - dtype = torch.double - - def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options): - density = options.pop('density', 1) - if isinstance(matrix_size, int): - rows = columns = matrix_size - else: - rows, columns = matrix_size - if density == 1: - a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) - a = a_input - else: - assert batches == () - a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) - a = a_input.to_dense() - - q = min(*size) - u, s, v = svd_lowrank(a_input, q=q, **options) - - # check if u, s, v is a SVD - u, s, v = u[..., :q], s[..., :q], v[..., :q] - A = u.matmul(s.diag_embed()).matmul(v.transpose(-2, -1)) - self.assertEqual(A, a) - - # check if svd_lowrank produces same singular values as torch.svd - U, S, V = torch.svd(a) - self.assertEqual(s.shape, S.shape) - self.assertEqual(u.shape, U.shape) - self.assertEqual(v.shape, V.shape) - self.assertEqual(s, S) - - if density == 1: - # actual_rank is known only for dense inputs - # - # check if pairs (u, U) and (v, V) span the same - # subspaces, respectively - u, s, v = u[..., :actual_rank], s[..., :actual_rank], v[..., :actual_rank] - U, S, V = U[..., :actual_rank], S[..., :actual_rank], V[..., :actual_rank] - self.assertEqual(u.transpose(-2, -1).matmul(U).det().abs(), torch.ones(batches, device=device, dtype=dtype)) - self.assertEqual(v.transpose(-2, -1).matmul(V).det().abs(), torch.ones(batches, device=device, dtype=dtype)) - - all_batches = [(), (1,), (3,), (2, 3)] - for actual_rank, size, all_batches in [ - (2, (17, 4), all_batches), - (4, (17, 4), all_batches), - (4, (17, 17), all_batches), - (10, (100, 40), all_batches), - (7, (1000, 1000), [()]), - ]: - # dense input - for batches in all_batches: - run_subtest(actual_rank, size, batches, device, torch.svd_lowrank) - if size != size[::-1]: - run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank) - - # sparse input - for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]: - for density in [0.005, 0.1]: - run_subtest(None, size, (), device, torch.svd_lowrank, density=density) - - # jitting support - jitted = torch.jit.script(torch.svd_lowrank) - actual_rank, size, batches = 2, (17, 4), () - run_subtest(actual_rank, size, batches, device, jitted) - @skipCUDAIfNoMagma @skipCPUIfNoLapack def test_pca_lowrank(self, device): diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index 1906b83ca8d6..a5d8b8179207 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -13,6 +13,7 @@ 'max', 'min', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig', 'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq', 'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "unpack_dual", 'linalg_qr', + '_svd_helper', 'linalg_svd', } @@ -56,7 +57,7 @@ def test_namedtuple_return(self): names=('values', 'indices'), hasout=True), op(operators=['kthvalue'], input=(1, 0), names=('values', 'indices'), hasout=True), - op(operators=['svd'], input=(), names=('U', 'S', 'V'), hasout=True), + op(operators=['svd', '_svd_helper', 'linalg_svd'], input=(), names=('U', 'S', 'V'), hasout=True), op(operators=['slogdet'], input=(), names=('sign', 'logabsdet'), hasout=False), op(operators=['qr', 'linalg_qr'], input=(), names=('Q', 'R'), hasout=True), op(operators=['solve'], input=(a,), names=('solution', 'LU'), hasout=True), @@ -65,26 +66,38 @@ def test_namedtuple_return(self): op(operators=['triangular_solve'], input=(a,), names=('solution', 'cloned_coefficient'), hasout=True), op(operators=['lstsq'], input=(a,), names=('solution', 'QR'), hasout=True), op(operators=['linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True), - op(operators=['unpack_dual'], input=(a, 0), names=('primal', 'tangent'), hasout=False), + op(operators=['unpack_dual'], input=(0,), names=('primal', 'tangent'), hasout=False), ] + def get_func(f): + "Return either torch.f or torch.linalg.f, where 'f' is a string" + if f.startswith('linalg_'): + return getattr(torch.linalg, f[7:]) + return getattr(torch, f, None) + + def check_namedtuple(tup, names): + "Check that the namedtuple 'tup' has the given names" + for i, name in enumerate(names): + self.assertIs(getattr(tup, name), tup[i]) + for op in operators: for f in op.operators: - if 'linalg_' in f: - ret = getattr(torch.linalg, f[7:])(a, *op.input) - ret1 = getattr(torch.linalg, f[7:])(a, *op.input, out=tuple(ret)) - for i, name in enumerate(op.names): - self.assertIs(getattr(ret, name), ret[i]) - else: - # Handle op that are not methods - func = getattr(a, f) if hasattr(a, f) else getattr(torch, f) - ret = func(*op.input) - for i, name in enumerate(op.names): - self.assertIs(getattr(ret, name), ret[i]) - if op.hasout: - ret1 = getattr(torch, f)(a, *op.input, out=tuple(ret)) - for i, name in enumerate(op.names): - self.assertIs(getattr(ret, name), ret[i]) + # 1. check the namedtuple returned by calling torch.f + func = get_func(f) + if func: + ret1 = func(a, *op.input) + check_namedtuple(ret1, op.names) + # + # 2. check the out= variant, if it exists + if func and op.hasout: + ret2 = func(a, *op.input, out=tuple(ret1)) + check_namedtuple(ret2, op.names) + # + # 3. check the Tensor.f method, if it exists + meth = getattr(a, f, None) + if meth: + ret3 = meth(*op.input) + check_namedtuple(ret3, op.names) all_covered_operators = set([x for y in operators for x in y.operators]) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 9bf266da394d..a59581d63bbd 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1055,7 +1055,7 @@ - name: nansum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) -- name: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) +- name: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor U, Tensor S, Tensor V) self: svd_backward(grads, self, some, compute_uv, U, S, V) - name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index c78e1e5f66cc..e4337e9de855 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -82,7 +82,7 @@ 'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_', 'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv', - 'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'linalg_qr', 'svd', '_fft_c2c', '_fft_r2c', + 'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'linalg_qr', '_svd_helper', '_fft_c2c', '_fft_r2c', 'linalg_solve', 'sqrt', 'stack', 'gather', 'index_select', 'index_add_' } diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 4a1c36df7497..d204afdb286e 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -8142,18 +8142,45 @@ def merge_dicts(*dicts): r""" svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) -This function returns a namedtuple ``(U, S, V)`` which is the singular value -decomposition of a input matrix or batches of matrices :attr:`input` such that -:math:`input = U \times diag(S) \times V^T`. - -If :attr:`some` is ``True`` (default), the method returns the reduced -singular value decomposition i.e., if the last two dimensions of -:attr:`input` are ``m`` and ``n``, then the returned `U` matrix will -contain only :math:`min(n, m)` orthonormal columns and the size of `V` -will be :math:`(*, n, n)`. - -If :attr:`compute_uv` is ``False``, the returned `U` and `V` matrices will be zero matrices -of shape :math:`(m \times m)` and :math:`(n \times n)` respectively. :attr:`some` will be ignored here. +Computes the singular value decomposition of either a matrix or batch of +matrices :attr:`input`." The singular value decomposition is represented as a +namedtuple ``(U, S, V)``, such that :math:`input = U \mathbin{@} diag(S) \times +V^T`, where :math:`V^T` is the transpose of ``V``. If :attr:`input` is a batch +of tensors, then ``U``, ``S``, and ``V`` are also batched with the same batch +dimensions as :attr:`input`. + +If :attr:`some` is ``True`` (default), the method returns the reduced singular +value decomposition i.e., if the last two dimensions of :attr:`input` are +``m`` and ``n``, then the returned `U` and `V` matrices will contain only +:math:`min(n, m)` orthonormal columns. + +If :attr:`compute_uv` is ``False``, the returned `U` and `V` will be +zero-filled matrices of shape :math:`(m \times m)` and :math:`(n \times n)` +respectively, and the same device as :attr:`input`. The :attr:`some` +argument has no effect when :attr:`compute_uv` is False. + +The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will +always be real-valued, even if :attr:`input` is complex. + +.. warning:: ``torch.svd`` is deprecated. Please use ``torch.linalg.`` + :func:`~torch.linalg.svd` instead, which is similar to NumPy's + ``numpy.linalg.svd``. + +.. note:: **Differences with** ``torch.linalg.`` :func:`~torch.linalg.svd`: + + * :attr:`some` is the opposite of ``torch.linalg.`` + :func:`~torch.linalg.svd`'s :attr:`full_matricies`. Note that + default value for both is ``True``, so the default behavior is + effectively the opposite. + + * it returns ``V``, whereas ``torch.linalg.`` + :func:`~torch.linalg.svd` returns ``Vh``. The result is that + when using ``svd`` you need to manually transpose + ``V`` in order to reconstruct the original matrix. + + * If :attr:`compute_uv=False`, it returns zero-filled tensors for + ``U`` and ``Vh``, whereas :meth:`~torch.linalg.svd` returns + empty tensors. Supports real-valued and complex-valued input. @@ -8164,22 +8191,18 @@ def merge_dicts(*dicts): algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the MAGMA routine `gesdd` as well. -.. note:: Irrespective of the original strides, the returned matrix `U` - will be transposed, i.e. with strides :code:`U.contiguous().transpose(-2, -1).stride()` +.. note:: The returned matrix `U` will be transposed, i.e. with strides + :code:`U.contiguous().transpose(-2, -1).stride()`. -.. note:: Extra care needs to be taken when backward through `U` and `V` - outputs. Such operation is really only stable when :attr:`input` is - full rank with all distinct singular values. Otherwise, ``NaN`` can - appear as the gradients are not properly defined. Also, notice that - double backward will usually do an additional backward through `U` and - `V` even if the original backward is only on `S`. +.. note:: Gradients computed using `U` and `V` may be unstable if + :attr:`input` is not full rank or has non-unique singular values. .. note:: When :attr:`some` = ``False``, the gradients on :code:`U[..., :, min(m, n):]` and :code:`V[..., :, min(m, n):]` will be ignored in backward as those vectors can be arbitrary bases of the subspaces. -.. note:: When :attr:`compute_uv` = ``False``, backward cannot be performed since `U` and `V` - from the forward pass is required for the backward operation. +.. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is True. + .. note:: With the complex-valued input the backward operation works correctly only for gauge invariant loss functions. Please look at `Gauge problem in AD`_ for more details. @@ -8187,8 +8210,9 @@ def merge_dicts(*dicts): Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more batch dimensions consisting of :math:`m \times n` matrices. - some (bool, optional): controls the shape of returned `U` and `V` - compute_uv (bool, optional): option whether to compute `U` and `V` or not + some (bool, optional): controls whether to compute the reduced or full decomposition, and + consequently the shape of returned ``U`` and ``V``. Defaults to True. + compute_uv (bool, optional): option whether to compute `U` and `V` or not. Defaults to True. Keyword args: out (tuple, optional): the output tuple of tensors diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 0f99def6c7fe..4c724b0b7e4c 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -403,6 +403,93 @@ (tensor(3.7417), tensor(11.2250)) """) +svd = _add_docstr(_linalg.linalg_svd, r""" +linalg.svd(input, full_matrices=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) + +Computes the singular value decomposition of either a matrix or batch of +matrices :attr:`input`." The singular value decomposition is represented as a +namedtuple ``(U, S, Vh)``, such that :math:`input = U \mathbin{@} diag(S) \times +Vh`. If :attr:`input` is a batch of tensors, then ``U``, ``S``, and ``Vh`` are +also batched with the same batch dimensions as :attr:`input`. + +If :attr:`full_matrices` is ``False`` (default), the method returns the reduced singular +value decomposition i.e., if the last two dimensions of :attr:`input` are +``m`` and ``n``, then the returned `U` and `V` matrices will contain only +:math:`min(n, m)` orthonormal columns. + +If :attr:`compute_uv` is ``False``, the returned `U` and `Vh` will be empy +tensors with no elements and the same device as :attr:`input`. The +:attr:`full_matrices` argument has no effect when :attr:`compute_uv` is False. + +The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will +always be real-valued, even if :attr:`input` is complex. + +.. note:: Unlike NumPy's ``linalg.svd``, this always returns a namedtuple of + three tensors, even when :attr:`compute_uv=False`. + +.. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, + then the singular values of each matrix in the batch is returned in descending order. + +.. note:: The implementation of SVD on CPU uses the LAPACK routine `?gesdd` (a divide-and-conquer + algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the MAGMA routine + `gesdd` as well. + +.. note:: The returned matrix `U` will be transposed, i.e. with strides + :code:`U.contiguous().transpose(-2, -1).stride()`. + +.. note:: Gradients computed using `U` and `Vh` may be unstable if + :attr:`input` is not full rank or has non-unique singular values. + +.. note:: When :attr:`full_matrices` = ``True``, the gradients on :code:`U[..., :, min(m, n):]` + and :code:`V[..., :, min(m, n):]` will be ignored in backward as those vectors + can be arbitrary bases of the subspaces. + +.. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is True. + + +Args: + input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more + batch dimensions consisting of :math:`m \times n` matrices. + full_matrices (bool, optional): controls whether to compute the full or reduced decomposition, and + consequently the shape of returned ``U`` and ``V``. Defaults to True. + compute_uv (bool, optional): whether to compute `U` and `V` or not. Defaults to True. + out (tuple, optional): a tuple of three tensors to use for the outputs. If compute_uv=False, + the 1st and 3rd arguments must be tensors, but they are ignored. E.g. you can + pass `(torch.Tensor(), out_S, torch.Tensor())` + +Example:: + + >>> import torch + >>> a = torch.randn(5, 3) + >>> a + tensor([[-0.3357, -0.2987, -1.1096], + [ 1.4894, 1.0016, -0.4572], + [-1.9401, 0.7437, 2.0968], + [ 0.1515, 1.3812, 1.5491], + [-1.8489, -0.5907, -2.5673]]) + >>> + >>> # reconstruction in the full_matrices=False case + >>> u, s, vh = torch.linalg.svd(a, full_matrices=False) + >>> u.shape, s.shape, vh.shape + (torch.Size([5, 3]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(a, u @ torch.diag(s) @ vh) + tensor(1.0486e-06) + >>> + >>> # reconstruction in the full_matrices=True case + >>> u, s, vh = torch.linalg.svd(a) + >>> u.shape, s.shape, vh.shape + (torch.Size([5, 5]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(a, u[:, :3] @ torch.diag(s) @ vh) + >>> torch.dist(a, u[:, :3] @ torch.diag(s) @ vh) + tensor(1.0486e-06) + >>> + >>> # extra dimensions + >>> a_big = torch.randn(7, 5, 3) + >>> u, s, vh = torch.linalg.svd(a_big, full_matrices=False) + >>> torch.dist(a_big, u @ torch.diag_embed(s) @ vh) + tensor(3.0957e-06) +""") + cond = _add_docstr(_linalg.linalg_cond, r""" linalg.cond(input, p=None, *, out=None) -> Tensor diff --git a/torch/overrides.py b/torch/overrides.py index f8d9f2e152f6..a6ec0de5ffb5 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -803,6 +803,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.nansum: lambda input, dim=None: -1, torch.svd: lambda input, some=True, compute_uv=True, out=None: -1, torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1, + torch.linalg.svd: lambda input, full_matrices=True, compute_uv=True, out=None: -1, torch.symeig: lambda input, eigenvectors=False, upper=True, out=None: -1, torch.swapaxes: lambda input, dim0, dim1: -1, torch.swapdims: lambda input, axis0, axis1: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 119750396f1e..c658dbef10e6 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -533,7 +533,7 @@ def sample_inputs_linalg_solve(op_info, device, dtype, requires_grad=False): return out -def sample_inputs_svd(op_info, device, dtype, requires_grad=False): +def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_svd=False): """ This function generates input for torch.svd with distinct singular values so that autograd is always stable. Matrices of different size: @@ -546,6 +546,16 @@ def sample_inputs_svd(op_info, device, dtype, requires_grad=False): """ from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + # svd and linalg.svd returns V and V.T, respectively. So we need to slice + # along different dimensions when needed (this is used by + # test_cases2:wide_all and wide_all_batched below) + if is_linalg_svd: + def slice_V(v): + return v[..., :(S - 2), :] + else: + def slice_V(v): + return v[..., :, :(S - 2)] + test_cases1 = ( # some=True (default) # loss functions for complex-valued svd have to be "gauge invariant", # i.e. loss functions shouldn't change when sigh of the singular vectors change. @@ -575,11 +585,11 @@ def sample_inputs_svd(op_info, device, dtype, requires_grad=False): ) test_cases2 = ( # some=False (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:(S - 2)], - lambda usv: (abs(usv[0]), usv[1], abs(usv[2][:, :(S - 2)]))), # 'wide_all' + lambda usv: (abs(usv[0]), usv[1], abs(slice_V(usv[2])))), # 'wide_all' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:, :(S - 2)], lambda usv: (abs(usv[0][:, :(S - 2)]), usv[1], abs(usv[2]))), # 'tall_all' (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :(S - 2), :], - lambda usv: (abs(usv[0]), usv[1], abs(usv[2][..., :, :(S - 2)]))), # 'wide_all_batched' + lambda usv: (abs(usv[0]), usv[1], abs(slice_V(usv[2])))), # 'wide_all_batched' (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :, :(S - 2)], lambda usv: (abs(usv[0][..., :, :(S - 2)]), usv[1], abs(usv[2]))), # 'tall_all_batched' ) @@ -587,15 +597,27 @@ def sample_inputs_svd(op_info, device, dtype, requires_grad=False): out = [] for a, out_fn in test_cases1: a.requires_grad = requires_grad - out.append(SampleInput(a, output_process_fn_grad=out_fn)) + if is_linalg_svd: + kwargs = {'full_matrices': False} + else: + kwargs = {'some': True} + out.append(SampleInput(a, kwargs=kwargs, output_process_fn_grad=out_fn)) for a, out_fn in test_cases2: a.requires_grad = requires_grad - kwargs = {'some': False} + if is_linalg_svd: + kwargs = {'full_matrices': True} + else: + kwargs = {'some': False} out.append(SampleInput(a, kwargs=kwargs, output_process_fn_grad=out_fn)) return out +def sample_inputs_svd(op_info, device, dtype, requires_grad=False): + return _sample_inputs_svd(op_info, device, dtype, requires_grad, is_linalg_svd=False) + +def sample_inputs_linalg_svd(op_info, device, dtype, requires_grad=False): + return _sample_inputs_svd(op_info, device, dtype, requires_grad, is_linalg_svd=True) def sample_inputs_pinverse(op_info, device, dtype, requires_grad=False): """ @@ -1175,6 +1197,20 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): # cuda gradchecks are very slow # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), + OpInfo('linalg.svd', + op=torch.linalg.svd, + aten_name='linalg_svd', + dtypes=floating_and_complex_types(), + test_inplace_grad=False, + supports_tensor_out=False, + sample_inputs_func=sample_inputs_linalg_svd, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + skips=( + # gradgrad checks are slow + SkipInfo('TestGradients', 'test_fn_gradgrad', active_if=(not TEST_WITH_SLOW)), + # cuda gradchecks are very slow + # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 + SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), OpInfo('pinverse', op=torch.pinverse, dtypes=floating_and_complex_types(),