Skip to content

Commit

Permalink
Remove redundant code from test_old_cholesky_batched
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk committed Nov 13, 2020
1 parent 03b29a3 commit 88e23c3
Showing 1 changed file with 2 additions and 18 deletions.
20 changes: 2 additions & 18 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,26 +211,10 @@ def cholesky_test_helper(n, batchsize, device, upper):
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_old_cholesky_batched(self, device, dtype):
from torch.testing._internal.common_utils import \
(random_symmetric_pd_matrix,
random_fullrank_matrix_distinct_singular_value)
from torch.testing._internal.common_utils import random_hermitian_pd_matrix

def cholesky_test_helper(n, batch_dims, upper):
# This is a workaround while there is no support for batched complex random_symmetric_pd_matrix
if dtype.is_complex:
real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64
A_real = random_fullrank_matrix_distinct_singular_value(n, *batch_dims,
dtype=real_dtype, device=device)
A_imag = random_fullrank_matrix_distinct_singular_value(n, *batch_dims,
dtype=real_dtype, device=device)
A = A_real + 1j * A_imag
# There is no support for complex batched matmul yet
matmul_list = []
for mat in A.contiguous().view(-1, n, n):
matmul_list.append(mat @ mat.t().conj())
A = torch.stack(matmul_list).view(*batch_dims, n, n)
else:
A = random_symmetric_pd_matrix(n, *batch_dims, dtype=dtype, device=device)
A = random_hermitian_pd_matrix(n, *batch_dims, dtype=dtype, device=device)
cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)])
cholesky_exp = cholesky_exp.reshape_as(A)
self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper))
Expand Down

0 comments on commit 88e23c3

Please sign in to comment.