Skip to content

Commit

Permalink
Added device checks and test_pinv_errors_and_warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk committed Dec 16, 2020
1 parent e758afd commit 0800641
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
6 changes: 6 additions & 0 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) {
&& input.dim() >= 2,
"linalg_pinv(", input.scalar_type(), "{", input.sizes(), "}): expected a tensor with 2 or more dimensions "
"of float, double, cfloat or cdouble types");
TORCH_CHECK(rcond.device() == input.device(),
"Expected rcond and input to be on the same device, but found rcond on ",
rcond.device(), " and input on ", input.device(), " instead.");
if (input.numel() == 0) {
// The implementation below uses operations that do not work for zero numel tensors
// therefore we need this early return for 'input.numel() == 0' case
Expand Down Expand Up @@ -143,6 +146,9 @@ Tensor linalg_pinv(const Tensor& input, double rcond, bool hermitian) {
Tensor& linalg_pinv_out(Tensor& result, const Tensor& input, const Tensor& rcond, bool hermitian) {
TORCH_CHECK(result.scalar_type() == input.scalar_type(),
"result dtype ", result.scalar_type(), " does not match the expected dtype ", input.scalar_type());
TORCH_CHECK(result.device() == input.device(),
"Expected result and input to be on the same device, but found result on ",
result.device(), " and input on ", input.device(), " instead.");

Tensor result_tmp = at::linalg_pinv(input, rcond, hermitian);
at::native::resize_output(result, result_tmp.sizes());
Expand Down
36 changes: 36 additions & 0 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2149,6 +2149,42 @@ def run_test_numpy(A, hermitian):
run_test_main(A, hermitian)
run_test_numpy(A, hermitian)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_pinv_errors_and_warnings(self, device, dtype):
# pinv requires at least 2D tensor
a = torch.randn(1, device=device, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, "expected a tensor with 2 or more dimensions"):
torch.linalg.pinv(a)

# if non-empty out tensor with wrong shape is passed a warning is given
a = torch.randn(3, 3, dtype=dtype, device=device)
out = torch.empty(7, 7, dtype=dtype, device=device)
with warnings.catch_warnings(record=True) as w:
# Trigger warning
torch.linalg.pinv(a, out=out)
# Check warning occurs
self.assertEqual(len(w), 1)
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))

# dtypes of out and input should match
out = torch.empty_like(a).to(torch.int)
with self.assertRaisesRegex(RuntimeError, "dtype Int does not match the expected dtype"):
torch.linalg.pinv(a, out=out)

# device of out and input should match
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
out = torch.empty_like(a).to(wrong_device)
with self.assertRaisesRegex(RuntimeError, "Expected result and input to be on the same device"):
torch.linalg.pinv(a, out=out)

# device of rcond and input should match
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
rcond = torch.full((), 1e-2, device=wrong_device)
with self.assertRaisesRegex(RuntimeError, "Expected rcond and input to be on the same device"):
torch.linalg.pinv(a, rcond=rcond)

# TODO: RuntimeError: svd does not support automatic differentiation for outputs with complex dtype.
# See https://github.com/pytorch/pytorch/pull/47761
@unittest.expectedFailure
Expand Down

0 comments on commit 0800641

Please sign in to comment.