From bb983dfcb1529f9f64a242f33cd0b30713c095b3 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Tue, 20 Oct 2020 21:49:16 -0500 Subject: [PATCH] Enable dtype arg for torch.linalg.norm with order 'fro' and 'nuc' --- aten/src/ATen/native/LinearAlgebra.cpp | 12 ++++++------ test/test_linalg.py | 9 +-------- torch/linalg/__init__.py | 3 +-- 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index e3f50db11b6c..0dd727fb3197 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1391,9 +1391,8 @@ static std::vector make_dim_list(int64_t ndim) { } // Checks for valid arguments to linalg_norm when type(ord) == str -static void check_str_ord_valid(const std::string& str_ord, optional opt_dim, int64_t ndim, optional opt_dtype) { +static void check_str_ord_valid(const std::string& str_ord, optional opt_dim, int64_t ndim) { TORCH_CHECK((str_ord == "nuc") || (str_ord == "fro"), "Invalid norm order: ", str_ord); - TORCH_CHECK(!opt_dtype.has_value(), "ord=\'", str_ord, "\' does not yet support the dtype argument"); bool dims_valid = (ndim == 2 && !opt_dim.has_value()) || (opt_dim.has_value() && opt_dim.value().size() == 2); TORCH_CHECK(dims_valid, "order \"", str_ord, "\" can only be used if either len(dim) == 2 or (self.dim() == 2 and dim is None)"); @@ -1553,14 +1552,15 @@ static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional if (opt_str_ord.has_value()) { // 'ord' is string auto str_ord = opt_str_ord.value(); - check_str_ord_valid(str_ord, opt_dim, ndim, opt_dtype); + check_str_ord_valid(str_ord, opt_dim, ndim); + Tensor self_ = opt_dtype.has_value() ? self.to(opt_dtype.value()) : self; if (str_ord == "fro") { - at::frobenius_norm_out(result, self, opt_dim.value_or(IntArrayRef({0, 1})), keepdim); + at::frobenius_norm_out(result, self_, opt_dim.value_or(IntArrayRef({0, 1})), keepdim); } else if (str_ord == "nuc") { if (opt_dim.has_value()) { - at::nuclear_norm_out(result, self, opt_dim.value(), keepdim); + at::nuclear_norm_out(result, self_, opt_dim.value(), keepdim); } else { - at::nuclear_norm_out(result, self, keepdim); + at::nuclear_norm_out(result, self_, keepdim); } } } else { diff --git a/test/test_linalg.py b/test/test_linalg.py index 0220dff476c1..127c674e5b05 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -202,7 +202,7 @@ def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype) self.assertEqual(result_converted, result_out_converted, msg=msg) ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None] - ord_matrix = [1, -1, 2, -2, inf, -inf, None] + ord_matrix = ['fro', 'nuc', 1, -1, 2, -2, inf, -inf, None] S = 10 test_cases = [ ((S, ), ord_vector), @@ -230,13 +230,6 @@ def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype) with self.assertRaisesRegex(RuntimeError, r'provided dtype must match dtype of result'): torch.linalg.norm(input, ord=ord, keepdim=keepdim, dtype=dtype, out=result) - # TODO: Once dtype arg is supported in nuclear and frobenius norms, remove the following test - # and add 'nuc' and 'fro' to ord_matrix above - for ord in ['nuc', 'fro']: - input = torch.randn(10, 10, device=device) - with self.assertRaisesRegex(RuntimeError, f"ord=\'{ord}\' does not yet support the dtype argument"): - torch.linalg.norm(input, ord, dtype=torch.float) - # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that # their vector norm results match @unittest.skipIf(not TEST_NUMPY, "NumPy not found") diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 5e2b59c45c80..074bb47faaac 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -68,8 +68,7 @@ :attr:`dtype` before performing the operation, and the returned tensor's type will be :attr:`dtype`. If this argument is used in conjunction with the :attr:`out` argument, the output tensor's type must match this argument or a - RuntimeError will be raised. This argument is not currently supported for - :attr:`ord='nuc'` or :attr:`ord='fro'`. Default: ``None`` + RuntimeError will be raised. Default: ``None`` Examples::