Skip to content

Commit

Permalink
Enable dtype arg for torch.linalg.norm with order 'fro' and 'nuc'
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler committed Oct 21, 2020
1 parent e8fbe54 commit bb983df
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 16 deletions.
12 changes: 6 additions & 6 deletions aten/src/ATen/native/LinearAlgebra.cpp
Expand Up @@ -1391,9 +1391,8 @@ static std::vector<int64_t> make_dim_list(int64_t ndim) {
}

// Checks for valid arguments to linalg_norm when type(ord) == str
static void check_str_ord_valid(const std::string& str_ord, optional<IntArrayRef> opt_dim, int64_t ndim, optional<ScalarType> opt_dtype) {
static void check_str_ord_valid(const std::string& str_ord, optional<IntArrayRef> opt_dim, int64_t ndim) {
TORCH_CHECK((str_ord == "nuc") || (str_ord == "fro"), "Invalid norm order: ", str_ord);
TORCH_CHECK(!opt_dtype.has_value(), "ord=\'", str_ord, "\' does not yet support the dtype argument");
bool dims_valid = (ndim == 2 && !opt_dim.has_value()) || (opt_dim.has_value() && opt_dim.value().size() == 2);
TORCH_CHECK(dims_valid, "order \"", str_ord,
"\" can only be used if either len(dim) == 2 or (self.dim() == 2 and dim is None)");
Expand Down Expand Up @@ -1553,14 +1552,15 @@ static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional
if (opt_str_ord.has_value()) {
// 'ord' is string
auto str_ord = opt_str_ord.value();
check_str_ord_valid(str_ord, opt_dim, ndim, opt_dtype);
check_str_ord_valid(str_ord, opt_dim, ndim);
Tensor self_ = opt_dtype.has_value() ? self.to(opt_dtype.value()) : self;
if (str_ord == "fro") {
at::frobenius_norm_out(result, self, opt_dim.value_or(IntArrayRef({0, 1})), keepdim);
at::frobenius_norm_out(result, self_, opt_dim.value_or(IntArrayRef({0, 1})), keepdim);
} else if (str_ord == "nuc") {
if (opt_dim.has_value()) {
at::nuclear_norm_out(result, self, opt_dim.value(), keepdim);
at::nuclear_norm_out(result, self_, opt_dim.value(), keepdim);
} else {
at::nuclear_norm_out(result, self, keepdim);
at::nuclear_norm_out(result, self_, keepdim);
}
}
} else {
Expand Down
9 changes: 1 addition & 8 deletions test/test_linalg.py
Expand Up @@ -202,7 +202,7 @@ def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype)
self.assertEqual(result_converted, result_out_converted, msg=msg)

ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None]
ord_matrix = [1, -1, 2, -2, inf, -inf, None]
ord_matrix = ['fro', 'nuc', 1, -1, 2, -2, inf, -inf, None]
S = 10
test_cases = [
((S, ), ord_vector),
Expand Down Expand Up @@ -230,13 +230,6 @@ def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype)
with self.assertRaisesRegex(RuntimeError, r'provided dtype must match dtype of result'):
torch.linalg.norm(input, ord=ord, keepdim=keepdim, dtype=dtype, out=result)

# TODO: Once dtype arg is supported in nuclear and frobenius norms, remove the following test
# and add 'nuc' and 'fro' to ord_matrix above
for ord in ['nuc', 'fro']:
input = torch.randn(10, 10, device=device)
with self.assertRaisesRegex(RuntimeError, f"ord=\'{ord}\' does not yet support the dtype argument"):
torch.linalg.norm(input, ord, dtype=torch.float)

# This test compares torch.linalg.norm and numpy.linalg.norm to ensure that
# their vector norm results match
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
Expand Down
3 changes: 1 addition & 2 deletions torch/linalg/__init__.py
Expand Up @@ -68,8 +68,7 @@
:attr:`dtype` before performing the operation, and the returned tensor's type
will be :attr:`dtype`. If this argument is used in conjunction with the
:attr:`out` argument, the output tensor's type must match this argument or a
RuntimeError will be raised. This argument is not currently supported for
:attr:`ord='nuc'` or :attr:`ord='fro'`. Default: ``None``
RuntimeError will be raised. Default: ``None``
Examples::
Expand Down

0 comments on commit bb983df

Please sign in to comment.