Skip to content

Commit

Permalink
Add missing complex support for torch.norm and torch.linalg.norm (#48284
Browse files Browse the repository at this point in the history
)

Summary:
**BC-breaking note:**

Previously, when given a complex input, `torch.linalg.norm` and `torch.norm` would return a complex output. `torch.linalg.cond` would sometimes return a complex output and sometimes return a real output when given a complex input, depending on its `p` argument. This PR changes this behavior to match `numpy.linalg.norm` and `numpy.linalg.cond`, so that a complex input will result in the downgraded real number type, consistent with NumPy.

**PR Summary:**

The following cases were previously unsupported for complex inputs, and this commit adds support:

- Frobenius norm
- Norm order 2 (vector and matrix)
- CUDA vector norm

Part of #47833

Pull Request resolved: #48284

Reviewed By: H-Huang

Differential Revision: D25420880

Pulled By: mruberry

fbshipit-source-id: 11f6a2f3cad57d66476d30921c3f6ab8f3cd4017
  • Loading branch information
kurtamohler authored and facebook-github-bot committed Dec 10, 2020
1 parent 25a8397 commit 54f0556
Show file tree
Hide file tree
Showing 9 changed files with 318 additions and 213 deletions.
33 changes: 16 additions & 17 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1463,14 +1463,13 @@ Tensor matrix_power(const Tensor& a, int64_t n) {
}

Tensor frobenius_norm(const Tensor& self) {
TORCH_CHECK(!self.is_complex(), "frobenius norm not supported for complex tensors");
return at::norm(self);
}

Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
// NOTE: As frobenius_norm_out is currently implemented, it will always produce a
// strided tensor result, even if the input is sparse.
auto options = self.options().layout(c10::Layout::Strided);
auto options = self.options().layout(c10::Layout::Strided).dtype(toValueType(self.scalar_type()));
Tensor result = at::empty({0}, options);
return at::native::frobenius_norm_out(result, self, dim, keepdim);
}
Expand All @@ -1480,7 +1479,6 @@ Tensor &frobenius_norm_out(
const Tensor& self,
IntArrayRef dim,
bool keepdim) {
TORCH_CHECK(!self.is_complex(), "frobenius norm not supported for complex tensors");
TORCH_CHECK(
dim.size() <= 2,
"Expected at most 2 dimensions, but got ",
Expand Down Expand Up @@ -1524,7 +1522,7 @@ Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
}

Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
Tensor result = at::empty({0}, self.options());
Tensor result = at::empty({0}, self.options().dtype(toValueType(self.scalar_type())));
return at::native::nuclear_norm_out(result, self, dim, keepdim);
}

Expand Down Expand Up @@ -1679,7 +1677,7 @@ static Tensor& _linalg_norm_vector_out(Tensor& result, const Tensor& self, optio
// when the input contains extreme values (like nan or +/-inf) or if the input
// size is degenerate (like size(0), size(0, N), etc)
case_was_overridden = true;
self_ = self.abs();
self_ = self_.abs();
result_ = _norm_min_max(self_, ord, dim[0], keepdim);
} else if ((self_.numel() == 0) && (ord < 0)) {
// For negative orders with degenerate input sizes, at::norm's result does not
Expand All @@ -1698,7 +1696,7 @@ static Tensor& _linalg_norm_vector_out(Tensor& result, const Tensor& self, optio
}
if (!case_was_overridden) {
if (opt_dtype.has_value()) {
result_ = at::norm(self, opt_ord, dim, keepdim, opt_dtype.value());
result_ = at::norm(self.to(opt_dtype.value()), opt_ord, dim, keepdim);
} else {
result_ = at::norm(self, opt_ord, dim, keepdim);
}
Expand Down Expand Up @@ -1749,14 +1747,14 @@ static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional

// Numerical or None norms
Tensor linalg_norm(const Tensor& self, optional<Scalar> opt_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device());
auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type())).device(self.device());
Tensor result = at::empty({0}, options);
return at::native::linalg_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype);
}

// Frobenius and nuclear norms
Tensor linalg_norm(const Tensor& self, std::string ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device());
auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type())).device(self.device());
Tensor result = at::empty({0}, options);
return at::native::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
}
Expand All @@ -1781,7 +1779,8 @@ Tensor _linalg_cond_exception_helper(const Tensor& self) {
"linalg_cond does not support yet this case.");
}
auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2);
Tensor result = at::full(result_shape, INFINITY, self.options());
TensorOptions options = self.options().dtype(toValueType(self.scalar_type()));
Tensor result = at::full(result_shape, INFINITY, options);
return result;
}

Expand Down Expand Up @@ -1816,7 +1815,8 @@ Tensor _linalg_cond_helper(const Tensor& self, c10::variant<Scalar, std::string>
// Return zero for each matrix in the batch
Tensor _linalg_cond_empty_matrix(const Tensor& self, c10::ScalarType dtype) {
auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2);
return at::zeros(result_shape, self.options().dtype(dtype));
TensorOptions options = self.options().dtype(toValueType(self.scalar_type()));
return at::zeros(result_shape, options);
}

void _linalg_cond_check_ord(c10::variant<Scalar, std::string> ord_variant) {
Expand Down Expand Up @@ -1849,8 +1849,7 @@ Tensor linalg_cond(const Tensor& self, optional<Scalar> opt_ord) {
// NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input
if (self.numel() == 0) {
auto real_dtype = toValueType(typeMetaToScalarType(self.dtype()));
auto expected_dtype = std::abs(ord.toDouble()) == 2.0 ? real_dtype : self.scalar_type();
return _linalg_cond_empty_matrix(self, expected_dtype);
return _linalg_cond_empty_matrix(self, real_dtype);
}

// If ord == None or ord == ±2
Expand Down Expand Up @@ -1883,10 +1882,9 @@ Tensor& linalg_cond_out(Tensor& result, const Tensor& self, optional<Scalar> opt
// the result is always real-valued, for other cases it is complex-valued for the complex-valued input.
ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype()));
Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2;
auto expected_dtype = std::abs(ord.toDouble()) == 2.0 ? real_dtype : self.scalar_type();

TORCH_CHECK(result.scalar_type() == expected_dtype,
"result dtype ", result.scalar_type(), " does not match the expected dtype ", expected_dtype);
TORCH_CHECK(result.scalar_type() == real_dtype,
"result dtype ", result.scalar_type(), " does not match the expected dtype ", real_dtype);

Tensor result_tmp = at::linalg_cond(self, opt_ord);
at::native::resize_output(result, result_tmp.sizes());
Expand Down Expand Up @@ -1916,8 +1914,9 @@ Tensor linalg_cond(const Tensor& self, std::string ord) {

// TODO: implement _out variant avoiding copy and using already allocated storage directly
Tensor& linalg_cond_out(Tensor& result, const Tensor& self, std::string ord) {
TORCH_CHECK(result.scalar_type() == self.scalar_type(),
"result dtype ", result.scalar_type(), " does not match the expected dtype ", self.scalar_type());
ScalarType real_type = toValueType(self.scalar_type());
TORCH_CHECK(result.scalar_type() == real_type,
"result dtype ", result.scalar_type(), " does not match the expected dtype ", real_type);

Tensor result_tmp = at::linalg_cond(self, ord);
at::native::resize_output(result, result_tmp.sizes());
Expand Down
17 changes: 9 additions & 8 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,22 +660,23 @@ Tensor& logsumexp_out(Tensor& result, const Tensor& self, DimnameList dims, bool

static Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> opt_p,
IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
auto p = opt_p.value_or(2.0);
TORCH_CHECK(!(p.toDouble() == 2 && self.is_complex()), "norm with p=2 not supported for complex tensors");
auto p = opt_p.value_or(2.0).to<double>();
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
"norm only supports CPU AND CUDA device type, got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
"norm only supports strided layout, got: ", self.layout());

ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type();
ScalarType in_dtype = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type();
TORCH_CHECK(
at::isFloatingType(scalarType) || at::isComplexType(scalarType),
"Can only calculate the mean of floating types. Got ",
toString(scalarType),
at::isFloatingType(in_dtype) || at::isComplexType(in_dtype),
"Can only calculate the norm of floating point and complex dtypes. Got ",
toString(in_dtype),
" instead.");

ScalarType dtype = get_dtype(result, self, opt_dtype, true);
auto iter = make_reduction("norm", result, self, dim, keepdim, dtype);
ScalarType out_dtype = result.defined() ? result.scalar_type() : (opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type()));

auto iter = make_reduction("norm", result, self, dim, keepdim, in_dtype, out_dtype);

if (iter.numel() == 0) {
result.zero_();
} else {
Expand Down
108 changes: 77 additions & 31 deletions aten/src/ATen/native/SharedReduceOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Please note that this file is
// used across both CPU and GPU.

#include <type_traits>
#include <complex>
#include <c10/macros/Macros.h>
#include <ATen/detail/FunctionTraits.h>
#include <ATen/NumericUtils.h>
Expand Down Expand Up @@ -157,11 +159,15 @@ struct MeanOps {
}
};

template <typename acc_t>
// This accumulator template is used to calculate the minimum absolute value of
// a set of numbers.
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
// value. These types differ for complex number input support.
template <typename scalar_t, typename acc_t=scalar_t>
struct AbsMinOps {

inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const {
return MIN(acc, acc_t(std::abs(data)));
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
return MIN(acc, static_cast<acc_t>(std::abs(data)));
}

inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
Expand All @@ -177,17 +183,21 @@ struct AbsMinOps {
}

#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
return WARP_SHFL_DOWN(data, offset);
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
return WARP_SHFL_DOWN(acc, offset);
}
#endif
};

template <typename acc_t>
// This accumulator template is used to calculate the maximum absolute value of
// a set of numbers.
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
// value. These types differ for complex number input support.
template <typename scalar_t, typename acc_t=scalar_t>
struct AbsMaxOps {

inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const {
return MAX(acc, acc_t(std::abs(data)));
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
return MAX(acc, static_cast<acc_t>(std::abs(data)));
}

inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
Expand All @@ -203,46 +213,54 @@ struct AbsMaxOps {
}

#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
return WARP_SHFL_DOWN(data, offset);
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
return WARP_SHFL_DOWN(acc, offset);
}
#endif
};

template <typename acc_t>
// This accumulator template is used to calculate the norm of the absolute value
// of a set of numbers.
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
// value. These types differ for complex number input support.
template <typename scalar_t, typename acc_t=scalar_t>
struct NormOps {
acc_t norm_;

inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const {
return acc + compat_pow(std::abs(data), norm_);
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
return acc + compat_pow(static_cast<acc_t>(std::abs(data)), norm_);
}

inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
return a + b;
}

inline C10_DEVICE acc_t project(acc_t a) const {
return compat_pow(a, acc_t(1.0)/norm_);
return compat_pow(a, static_cast<acc_t>(1.0) / norm_);
}

static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
return acc;
}

#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
return WARP_SHFL_DOWN(data, offset);
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
return WARP_SHFL_DOWN(acc, offset);
}
#endif

NormOps(acc_t norm_): norm_(norm_) {
}
};

template <typename acc_t>
// This accumulator template is used to calculate the order zero norm of the
// absolute value of a set of numbers.
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
// value. These types differ for complex number input support.
template <typename scalar_t, typename acc_t=scalar_t>
struct NormZeroOps {
inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const {
return acc + (data==acc_t(0) ? acc_t(0) : acc_t(1));
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
return acc + (data == static_cast<scalar_t>(0) ? static_cast<acc_t>(0) : static_cast<acc_t>(1));
}

inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
Expand All @@ -259,16 +277,20 @@ struct NormZeroOps {


#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
return WARP_SHFL_DOWN(data, offset);
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
return WARP_SHFL_DOWN(acc, offset);
}
#endif
};

template <typename acc_t>
// This accumulator template is used to calculate the order one norm of the
// absolute value of a set of numbers.
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
// value. These types differ for complex number input support.
template <typename scalar_t, typename acc_t=scalar_t>
struct NormOneOps {
inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const {
return acc + std::abs(data);
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
return acc + static_cast<acc_t>(std::abs(data));
}

inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
Expand All @@ -284,16 +306,40 @@ struct NormOneOps {
}

#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
return WARP_SHFL_DOWN(data, offset);
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
return WARP_SHFL_DOWN(acc, offset);
}
#endif
};

template <typename acc_t>

template<typename acc_t>
struct AbsSwitch {};

template<typename scalar_t, typename acc_t>
inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch<acc_t> s) {
return static_cast<acc_t>(data);
}

template<typename scalar_t, typename acc_t>
inline C10_DEVICE acc_t abs_if_complex(std::complex<scalar_t> data, AbsSwitch<acc_t> s) {
return static_cast<acc_t>(std::abs(data));
}

template<typename scalar_t, typename acc_t>
inline C10_DEVICE acc_t abs_if_complex(c10::complex<scalar_t> data, AbsSwitch<acc_t> s) {
return static_cast<acc_t>(std::abs(data));
}

// This accumulator template is used to calculate the order two norm of the
// absolute value of a set of numbers.
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
// value. These types differ for complex number input support.
template <typename scalar_t, typename acc_t=scalar_t>
struct NormTwoOps {
inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const {
return acc + data * data;
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
acc_t data_ = abs_if_complex(data, AbsSwitch<acc_t>());
return acc + data_ * data_;
}

inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
Expand All @@ -309,8 +355,8 @@ struct NormTwoOps {
}

#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
return WARP_SHFL_DOWN(data, offset);
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
return WARP_SHFL_DOWN(acc, offset);
}
#endif
};
Expand Down

0 comments on commit 54f0556

Please sign in to comment.