Skip to content

Commit

Permalink
move the svd tests from test_torch.py to test_linalg.py
Browse files Browse the repository at this point in the history
  • Loading branch information
antocuni committed Nov 26, 2020
1 parent 4b4076d commit 5fa4efe
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 208 deletions.
215 changes: 213 additions & 2 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings
from math import inf, nan, isnan
from random import randrange
from itertools import product

from torch.testing._internal.common_utils import \
(TestCase, run_tests, TEST_NUMPY, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, TEST_WITH_ASAN, make_tensor)
Expand Down Expand Up @@ -1019,10 +1020,220 @@ def test_nuclear_norm_exceptions_old(self, device):
self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0))
self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2))

# ~~~ tests for torch.svd ~~~

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
def test_svd(self, device, dtype):
def run_test(dims, some, compute_uv):
x = torch.randn(*dims, dtype=dtype, device=device)
outu = torch.tensor((), dtype=dtype, device=device)
outs = torch.tensor((), dtype=dtype, device=device)
outv = torch.tensor((), dtype=dtype, device=device)
torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv))

if compute_uv:
if some:
x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1)))
self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T')
else:
narrow_u = outu[..., :min(*dims[-2:])]
narrow_v = outv[..., :min(*dims[-2:])]
x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1)))
self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T')
else:
_, singvals, _ = torch.svd(x, compute_uv=True)
self.assertEqual(singvals, outs, msg='Singular values mismatch')
self.assertEqual(outu, torch.zeros_like(outu), msg='U not zero')
self.assertEqual(outv, torch.zeros_like(outv), msg='V not zero')

resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv)
self.assertEqual(resu, outu, msg='outputs of svd and svd with out differ')
self.assertEqual(ress, outs, msg='outputs of svd and svd with out differ')
self.assertEqual(resv, outv, msg='outputs of svd and svd with out differ')

# test non-contiguous
x = torch.randn(*dims, dtype=dtype, device=device)
n_dim = len(dims)
# Reverse the batch dimensions and the matrix dimensions and then concat them
x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2))
assert not x.is_contiguous(), "x is intentionally non-contiguous"
resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv)
if compute_uv:
if some:
x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1)))
self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T')
else:
narrow_u = resu[..., :min(*dims[-2:])]
narrow_v = resv[..., :min(*dims[-2:])]
x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1)))
self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T')
else:
_, singvals, _ = torch.svd(x, compute_uv=True)
self.assertEqual(singvals, ress, msg='Singular values mismatch')
self.assertEqual(resu, torch.zeros_like(resu), msg='U not zero')
self.assertEqual(resv, torch.zeros_like(resv), msg='V not zero')

shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices
(7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices
(3, 7), (5, 3, 7), (7, 5, 3, 7)] # thin matrices
for dims, some, compute_uv in product(shapes, [True, False], [True, False]):
run_test(dims, some, compute_uv)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
def test_svd_no_singularvectors(self, device):
for size in [(5, 5), (5, 20), (20, 5)]:
a = torch.randn(*size, device=device)
u, s_expect, v = torch.svd(a)
u, s_actual, v = torch.svd(a, compute_uv=False)
self.assertEqual(s_expect, s_actual, msg="Singular values don't match")

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
def test_svd_lowrank(self, device):
import torch
from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix

dtype = torch.double

def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options):
density = options.pop('density', 1)
if isinstance(matrix_size, int):
rows = columns = matrix_size
else:
rows, columns = matrix_size
if density == 1:
a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype)
a = a_input
else:
assert batches == ()
a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype)
a = a_input.to_dense()

q = min(*size)
u, s, v = svd_lowrank(a_input, q=q, **options)

# check if u, s, v is a SVD
u, s, v = u[..., :q], s[..., :q], v[..., :q]
A = u.matmul(s.diag_embed()).matmul(v.transpose(-2, -1))
self.assertEqual(A, a)

# check if svd_lowrank produces same singular values as torch.svd
U, S, V = torch.svd(a)
self.assertEqual(s.shape, S.shape)
self.assertEqual(u.shape, U.shape)
self.assertEqual(v.shape, V.shape)
self.assertEqual(s, S)

if density == 1:
# actual_rank is known only for dense inputs
#
# check if pairs (u, U) and (v, V) span the same
# subspaces, respectively
u, s, v = u[..., :actual_rank], s[..., :actual_rank], v[..., :actual_rank]
U, S, V = U[..., :actual_rank], S[..., :actual_rank], V[..., :actual_rank]
self.assertEqual(u.transpose(-2, -1).matmul(U).det().abs(), torch.ones(batches, device=device, dtype=dtype))
self.assertEqual(v.transpose(-2, -1).matmul(V).det().abs(), torch.ones(batches, device=device, dtype=dtype))

all_batches = [(), (1,), (3,), (2, 3)]
for actual_rank, size, all_batches in [
(2, (17, 4), all_batches),
(4, (17, 4), all_batches),
(4, (17, 17), all_batches),
(10, (100, 40), all_batches),
(7, (1000, 1000), [()]),
]:
# dense input
for batches in all_batches:
run_subtest(actual_rank, size, batches, device, torch.svd_lowrank)
if size != size[::-1]:
run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank)

# sparse input
for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]:
for density in [0.005, 0.1]:
run_subtest(None, size, (), device, torch.svd_lowrank, density=density)

# jitting support
jitted = torch.jit.script(torch.svd_lowrank)
actual_rank, size, batches = 2, (17, 4), ()
run_subtest(actual_rank, size, batches, device, jitted)

@onlyCPU
@skipCPUIfNoLapack
@dtypes(torch.cfloat)
def test_svd_complex(self, device, dtype):
t = torch.randn((10, 10), dtype=dtype, device=device)
U, S, V = torch.svd(t, some=False)
# note: from the math point of view, it is weird that we need to use
# V.T instead of V.T.conj(): torch.svd has a buggy behavior for
# complex numbers and it's deprecated. You should use torch.linalg.svd
# instead.
t2 = U @ torch.diag(S).type(dtype) @ V.T
self.assertEqual(t, t2)

def _test_svd_helper(self, shape, some, col_maj, device, dtype):
cpu_tensor = torch.randn(shape, device='cpu').to(dtype)
device_tensor = cpu_tensor.to(device=device)
if col_maj:
cpu_tensor = cpu_tensor.t()
device_tensor = device_tensor.t()
cpu_result = torch.svd(cpu_tensor, some=some)
device_result = torch.svd(device_tensor, some=some)
m = min(cpu_tensor.shape[-2:])
# torch.svd returns torch.return_types.svd which is a tuple of (U, V, S).
# - When some==False, U[..., m:] can be arbitrary.
# - When some==True, U shape: [..., m], V shape: [m, m]
# - Signs are not deterministic. If the sign of a column of U is changed
# then the corresponding column of the V has to be changed.
# Thus here we only compare result[..., :m].abs() from CPU and device.
for x, y in zip(cpu_result, device_result):
self.assertEqual(x[..., :m].abs(), y[..., :m].abs(), atol=1e-5, rtol=0)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_svd_square(self, device, dtype):
self._test_svd_helper((10, 10), True, False, device, dtype)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float, torch.double)
def test_svd_square_col_maj(self, device, dtype):
self._test_svd_helper((10, 10), True, True, device, dtype)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float, torch.double)
def test_svd_tall_some(self, device, dtype):
self._test_svd_helper((20, 5), True, False, device, dtype)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float, torch.double)
def test_svd_tall_all(self, device, dtype):
self._test_svd_helper((20, 5), False, False, device, dtype)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float, torch.double)
def test_svd_tall_some_col_maj(self, device, dtype):
self._test_svd_helper((5, 20), True, True, device, dtype)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float, torch.double)
def test_svd_tall_all_col_maj(self, device, dtype):
self._test_svd_helper((5, 20), False, True, device, dtype)

# ~~~ tests for torch.linalg.svd ~~~

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_svd_compute_uv(self, device, dtype):
def test_linalg_svd_compute_uv(self, device, dtype):
"""
Test the default case, compute_uv=True. Here we have the very same behavior as
numpy
Expand All @@ -1045,7 +1256,7 @@ def test_svd_compute_uv(self, device, dtype):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_svd_no_compute_uv(self, device, dtype):
def test_linalg_svd_no_compute_uv(self, device, dtype):
"""
Test the compute_uv=False case. Here we have a different return type than
numpy: numpy returns S, we return (empty, S, empty)
Expand Down

0 comments on commit 5fa4efe

Please sign in to comment.