Skip to content

Commit

Permalink
Fix nuclear norm with requires_grad=True (#26303)
Browse files Browse the repository at this point in the history
Summary:
Changelog:
- Selectively assign compute_uv in the at::svd used internally in the implementation of at::nuclear_norm
Pull Request resolved: #26303

Test Plan:
- Add tests in common_method_invocations.py

Refixes: #18275

Differential Revision: D17605357

Pulled By: ezyang

fbshipit-source-id: d87d60afe678e2546dca6992ea66f2daeb6b0346
  • Loading branch information
vishwakftw authored and facebook-github-bot committed Sep 26, 2019
1 parent 0e3389d commit 43b07ff
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
17 changes: 14 additions & 3 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <ATen/TensorUtils.h>
#include <ATen/Parallel.h>
#include <ATen/LegacyTHFunctionsCPU.h>
#include <ATen/core/grad_mode.h>
#include <functional>
#include <numeric>
#include <vector>
Expand Down Expand Up @@ -549,29 +550,39 @@ Tensor nuclear_norm(const Tensor& self, bool keepdim) {
self.dim() == 2,
"Expected a tensor with 2 dimensions, but got a tensor with ",
self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
return at::sum(std::get<1>(at::svd(self)), 0, keepdim);
// Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
// would end up throwing an error as a result if U and V aren't computed.
// Due to this, we have to compute U and V conditionally.
return at::sum(std::get<1>(at::svd(self, /*some=*/true,
/*compute_uv=*/at::GradMode::is_enabled() && self.is_variable() && self.requires_grad())), 0, keepdim);
}

Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
TORCH_CHECK(
self.dim() == 2,
"Expected a tensor with 2 dimensions, but got a tensor with ",
self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
return at::sum_out(result, std::get<1>(at::svd(self)), 0, keepdim);
return at::sum_out(result, std::get<1>(at::svd(self, /*some=*/true, /*compute_uv=*/false)), 0, keepdim);

}

Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");

Tensor p = _move_to_end(self, dim);
return at::sum(std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim);
// Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
// would end up throwing an error as a result if U and V aren't computed.
// Due to this, we have to compute U and V conditionally.
return at::sum(std::get<1>(at::svd(p, /*some=*/true,
/*compute_uv=*/at::GradMode::is_enabled() && self.is_variable() && self.requires_grad())), -1, keepdim);
}

Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) {
TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");

Tensor p = _move_to_end(self, dim);
return at::sum_out(result, std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim);

}

static inline Tensor _chain_matmul_general(TensorList matrices, std::vector<std::vector<int64_t>>& order, int64_t i, int64_t j) {
Expand Down
1 change: 1 addition & 0 deletions test/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def method_tests():
('norm', (S, S), ('fro',), 'fro_default'),
('norm', (S, S), ('fro', [0, 1],), 'fro'),
('norm', (S, S), ('nuc',), 'nuc', (), NO_ARGS, [skipIfNoLapack]),
('norm', (S, S, S), ('nuc', [1, 2]), 'nuc_batched', (), NO_ARGS, [skipIfNoLapack]),
('norm', (S, S), (-1,), 'neg_1'),
('norm', (S, S), (-2,), 'neg_2'),
('norm', (S, S), (-0.5,), 'neg_0_5'),
Expand Down
1 change: 1 addition & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16334,6 +16334,7 @@ def forward(self, x, y):
'test_norm_fro',
'test_norm_fro_default',
'test_norm_nuc',
'test_norm_nuc_batched',

# aten op has additional cudnn argument
'test_nn_unfold',
Expand Down

0 comments on commit 43b07ff

Please sign in to comment.