Skip to content

Commit

Permalink
Removed xfailing test; batched complex matmul on cuda now works.
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk committed Nov 16, 2020
1 parent e053666 commit 59a41e3
Showing 1 changed file with 2 additions and 38 deletions.
40 changes: 2 additions & 38 deletions test/test_linalg.py
Expand Up @@ -547,9 +547,6 @@ def sub_test(pivot):
for k, n in zip([2, 3, 5], [3, 5, 7]):
b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n,), (n, k), pivot, device, dtype)
x = torch.lu_solve(b, LU_data, LU_pivots)
# TODO(@ivanyashchuk): remove this once 'norm_cuda' is avaiable for complex dtypes
if not self.device_type == 'cuda' and not dtype.is_complex:
self.assertLessEqual(abs(b.dist(A.mm(x), p=1)), self.precision)
self.assertEqual(b, A.mm(x))

sub_test(True)
Expand All @@ -571,17 +568,7 @@ def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
x_exp = torch.stack(x_exp_list) # Stacked output
x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output
self.assertEqual(x_exp, x_act) # Equality check
# TODO(@ivanyashchuk): remove this once batched matmul is avaiable on CUDA for complex dtypes
if self.device_type == 'cuda' and dtype.is_complex:
Ax_list = []
for A_i, x_i in zip(A, x_act):
Ax_list.append(torch.matmul(A_i, x_i))
Ax = torch.stack(Ax_list)
else:
Ax = torch.matmul(A, x_act)
self.assertLessEqual(abs(b.dist(Ax, p=1)), self.precision) # Correctness check
# In addition to the norm, check the individual entries
# 'norm_cuda' is not implemented for complex dtypes
Ax = torch.matmul(A, x_act)
self.assertEqual(b, Ax)

for batchsize in [1, 3, 4]:
Expand All @@ -605,35 +592,12 @@ def test_lu_solve_batched_many_batches(self, device, dtype):
def run_test(A_dims, b_dims):
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
x = torch.lu_solve(b, LU_data, LU_pivots)
# TODO(@ivanyashchuk): remove this once batched matmul is avaiable on CUDA for complex dtypes
if self.device_type == 'cuda' and dtype.is_complex:
Ax_list = []
for A_i, x_i in zip(A, x):
Ax_list.append(torch.matmul(A_i, x_i))
Ax = torch.stack(Ax_list)
else:
Ax = torch.matmul(A, x)
Ax = torch.matmul(A, x)
self.assertEqual(Ax, b.expand_as(Ax))

run_test((5, 65536), (65536, 5, 10))
run_test((5, 262144), (262144, 5, 10))

# TODO: once there is more support for complex dtypes on GPU, above tests should be updated
# particularly when RuntimeError: _th_bmm_out not supported on CUDAType for ComplexFloat
# and RuntimeError: "norm_cuda" not implemented for 'ComplexFloat' are fixed
@unittest.expectedFailure
@onlyCUDA
@skipCUDAIfNoMagma
@dtypes(torch.complex64, torch.complex128)
def test_lu_solve_batched_complex_xfailed(self, device, dtype):
A_dims = (3, 5)
b_dims = (5, 3, 2)
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
x = torch.lu_solve(b, LU_data, LU_pivots)
b_ = torch.matmul(A, x)
self.assertEqual(b_, b.expand_as(b_))
self.assertLessEqual(abs(b.dist(torch.matmul(A, x), p=1)), 1e-4)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
Expand Down

0 comments on commit 59a41e3

Please sign in to comment.