From 59a41e35c72ff72f13acf08d264aa0b394a198a1 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 16 Nov 2020 13:09:11 -0600 Subject: [PATCH] Removed xfailing test; batched complex matmul on cuda now works. --- test/test_linalg.py | 40 ++-------------------------------------- 1 file changed, 2 insertions(+), 38 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 4d3b83d37660..313bfc6b3aa2 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -547,9 +547,6 @@ def sub_test(pivot): for k, n in zip([2, 3, 5], [3, 5, 7]): b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n,), (n, k), pivot, device, dtype) x = torch.lu_solve(b, LU_data, LU_pivots) - # TODO(@ivanyashchuk): remove this once 'norm_cuda' is avaiable for complex dtypes - if not self.device_type == 'cuda' and not dtype.is_complex: - self.assertLessEqual(abs(b.dist(A.mm(x), p=1)), self.precision) self.assertEqual(b, A.mm(x)) sub_test(True) @@ -571,17 +568,7 @@ def lu_solve_batch_test_helper(A_dims, b_dims, pivot): x_exp = torch.stack(x_exp_list) # Stacked output x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output self.assertEqual(x_exp, x_act) # Equality check - # TODO(@ivanyashchuk): remove this once batched matmul is avaiable on CUDA for complex dtypes - if self.device_type == 'cuda' and dtype.is_complex: - Ax_list = [] - for A_i, x_i in zip(A, x_act): - Ax_list.append(torch.matmul(A_i, x_i)) - Ax = torch.stack(Ax_list) - else: - Ax = torch.matmul(A, x_act) - self.assertLessEqual(abs(b.dist(Ax, p=1)), self.precision) # Correctness check - # In addition to the norm, check the individual entries - # 'norm_cuda' is not implemented for complex dtypes + Ax = torch.matmul(A, x_act) self.assertEqual(b, Ax) for batchsize in [1, 3, 4]: @@ -605,35 +592,12 @@ def test_lu_solve_batched_many_batches(self, device, dtype): def run_test(A_dims, b_dims): b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype) x = torch.lu_solve(b, LU_data, LU_pivots) - # TODO(@ivanyashchuk): remove this once batched matmul is avaiable on CUDA for complex dtypes - if self.device_type == 'cuda' and dtype.is_complex: - Ax_list = [] - for A_i, x_i in zip(A, x): - Ax_list.append(torch.matmul(A_i, x_i)) - Ax = torch.stack(Ax_list) - else: - Ax = torch.matmul(A, x) + Ax = torch.matmul(A, x) self.assertEqual(Ax, b.expand_as(Ax)) run_test((5, 65536), (65536, 5, 10)) run_test((5, 262144), (262144, 5, 10)) - # TODO: once there is more support for complex dtypes on GPU, above tests should be updated - # particularly when RuntimeError: _th_bmm_out not supported on CUDAType for ComplexFloat - # and RuntimeError: "norm_cuda" not implemented for 'ComplexFloat' are fixed - @unittest.expectedFailure - @onlyCUDA - @skipCUDAIfNoMagma - @dtypes(torch.complex64, torch.complex128) - def test_lu_solve_batched_complex_xfailed(self, device, dtype): - A_dims = (3, 5) - b_dims = (5, 3, 2) - b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype) - x = torch.lu_solve(b, LU_data, LU_pivots) - b_ = torch.matmul(A, x) - self.assertEqual(b_, b.expand_as(b_)) - self.assertLessEqual(abs(b.dist(torch.matmul(A, x), p=1)), 1e-4) - @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)