Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable dtype arg for torch.linalg.norm with order 'fro' and 'nuc' #46637

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 6 additions & 6 deletions aten/src/ATen/native/LinearAlgebra.cpp
Expand Up @@ -1391,9 +1391,8 @@ static std::vector<int64_t> make_dim_list(int64_t ndim) {
}

// Checks for valid arguments to linalg_norm when type(ord) == str
static void check_str_ord_valid(const std::string& str_ord, optional<IntArrayRef> opt_dim, int64_t ndim, optional<ScalarType> opt_dtype) {
static void check_str_ord_valid(const std::string& str_ord, optional<IntArrayRef> opt_dim, int64_t ndim) {
TORCH_CHECK((str_ord == "nuc") || (str_ord == "fro"), "Invalid norm order: ", str_ord);
TORCH_CHECK(!opt_dtype.has_value(), "ord=\'", str_ord, "\' does not yet support the dtype argument");
bool dims_valid = (ndim == 2 && !opt_dim.has_value()) || (opt_dim.has_value() && opt_dim.value().size() == 2);
TORCH_CHECK(dims_valid, "order \"", str_ord,
"\" can only be used if either len(dim) == 2 or (self.dim() == 2 and dim is None)");
Expand Down Expand Up @@ -1553,14 +1552,15 @@ static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional
if (opt_str_ord.has_value()) {
// 'ord' is string
auto str_ord = opt_str_ord.value();
check_str_ord_valid(str_ord, opt_dim, ndim, opt_dtype);
check_str_ord_valid(str_ord, opt_dim, ndim);
Tensor self_ = opt_dtype.has_value() ? self.to(opt_dtype.value()) : self;
if (str_ord == "fro") {
at::frobenius_norm_out(result, self, opt_dim.value_or(IntArrayRef({0, 1})), keepdim);
at::frobenius_norm_out(result, self_, opt_dim.value_or(IntArrayRef({0, 1})), keepdim);
} else if (str_ord == "nuc") {
if (opt_dim.has_value()) {
at::nuclear_norm_out(result, self, opt_dim.value(), keepdim);
at::nuclear_norm_out(result, self_, opt_dim.value(), keepdim);
} else {
at::nuclear_norm_out(result, self, keepdim);
at::nuclear_norm_out(result, self_, keepdim);
}
}
} else {
Expand Down
9 changes: 1 addition & 8 deletions test/test_linalg.py
Expand Up @@ -202,7 +202,7 @@ def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype)
self.assertEqual(result_converted, result_out_converted, msg=msg)

ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None]
ord_matrix = [1, -1, 2, -2, inf, -inf, None]
ord_matrix = ['fro', 'nuc', 1, -1, 2, -2, inf, -inf, None]
S = 10
test_cases = [
((S, ), ord_vector),
Expand Down Expand Up @@ -230,13 +230,6 @@ def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype)
with self.assertRaisesRegex(RuntimeError, r'provided dtype must match dtype of result'):
torch.linalg.norm(input, ord=ord, keepdim=keepdim, dtype=dtype, out=result)

# TODO: Once dtype arg is supported in nuclear and frobenius norms, remove the following test
# and add 'nuc' and 'fro' to ord_matrix above
for ord in ['nuc', 'fro']:
input = torch.randn(10, 10, device=device)
with self.assertRaisesRegex(RuntimeError, f"ord=\'{ord}\' does not yet support the dtype argument"):
torch.linalg.norm(input, ord, dtype=torch.float)

# This test compares torch.linalg.norm and numpy.linalg.norm to ensure that
# their vector norm results match
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
Expand Down
3 changes: 1 addition & 2 deletions torch/linalg/__init__.py
Expand Up @@ -68,8 +68,7 @@
:attr:`dtype` before performing the operation, and the returned tensor's type
will be :attr:`dtype`. If this argument is used in conjunction with the
:attr:`out` argument, the output tensor's type must match this argument or a
RuntimeError will be raised. This argument is not currently supported for
:attr:`ord='nuc'` or :attr:`ord='fro'`. Default: ``None``
RuntimeError will be raised. Default: ``None``

Examples::

Expand Down