Skip to content

Commit

Permalink
[numpy] torch.angle: promote integer inputs to float (#49163)
Browse files Browse the repository at this point in the history
Summary:
**BC-Breaking Note:**

This PR updates PyTorch's angle operator to be consistent with NumPy's. Previously angle would return zero for all floating point values (including NaN). Now angle returns `pi` for negative floating point values, zero for non-negative floating point values, and propagates NaNs.

**PR Summary:**

Reference: #42515

TODO:

* [x] Add BC-Breaking Note (Prev all real numbers returned `0` (even `nan`)) -> Fixed to match the correct behavior of NumPy.

Pull Request resolved: #49163

Reviewed By: ngimel

Differential Revision: D25681758

Pulled By: mruberry

fbshipit-source-id: 54143fe6bccbae044427ff15d8daaed3596f9685
  • Loading branch information
kshitij12345 authored and facebook-github-bot committed Dec 23, 2020
1 parent 46b8321 commit 461aafe
Show file tree
Hide file tree
Showing 13 changed files with 94 additions and 32 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ struct Vec256 {
Vec256<T> angle() const {
// other_t_angle is for SFINAE and clarity. Make sure it is not changed.
static_assert(std::is_same<other_t_angle, T>::value, "other_t_angle must be T");
return Vec256(0);
return map(at::native::angle_impl<T>); // compiler is unable to resolve the overload without <T>
}
template <typename complex_t_angle = T,
typename std::enable_if<c10::is_complex<complex_t_angle>::value, int>::type = 0>
Expand Down
18 changes: 17 additions & 1 deletion aten/src/ATen/cpu/vec256/vec256_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,23 @@ template <> class Vec256<BFloat16> {
return cvtfp32_bf16(o1, o2);
}
Vec256<BFloat16> angle() const {
return _mm256_set1_epi16(0);
__m256 lo, hi;
cvtbf16_fp32(values, lo, hi);
auto angle_lambda = [](__m256 values) {
const auto zero_vec = _mm256_set1_ps(0.f);
const auto nan_vec = _mm256_set1_ps(NAN);
const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ);
const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ);
const auto pi = _mm256_set1_ps(M_PI);

const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask);
angle = _mm256_blendv_ps(angle, nan_vec, nan_mask);
return angle;
};
auto o1 = angle_lambda(lo);
auto o2 = angle_lambda(hi);
return cvtfp32_bf16(o1, o2);
}
Vec256<BFloat16> real() const {
return *this;
Expand Down
11 changes: 10 additions & 1 deletion aten/src/ATen/cpu/vec256/vec256_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,16 @@ template <> class Vec256<double> {
return _mm256_andnot_pd(mask, values);
}
Vec256<double> angle() const {
return _mm256_set1_pd(0);
const auto zero_vec = _mm256_set1_pd(0.f);
const auto nan_vec = _mm256_set1_pd(NAN);
const auto not_nan_mask = _mm256_cmp_pd(values, values, _CMP_EQ_OQ);
const auto nan_mask = _mm256_cmp_pd(not_nan_mask, zero_vec, _CMP_EQ_OQ);
const auto pi = _mm256_set1_pd(M_PI);

const auto neg_mask = _mm256_cmp_pd(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm256_blendv_pd(zero_vec, pi, neg_mask);
angle = _mm256_blendv_pd(angle, nan_vec, nan_mask);
return angle;
}
Vec256<double> real() const {
return *this;
Expand Down
11 changes: 10 additions & 1 deletion aten/src/ATen/cpu/vec256/vec256_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,16 @@ template <> class Vec256<float> {
return _mm256_andnot_ps(mask, values);
}
Vec256<float> angle() const {
return _mm256_set1_ps(0);
const auto zero_vec = _mm256_set1_ps(0.f);
const auto nan_vec = _mm256_set1_ps(NAN);
const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ);
const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ);
const auto pi = _mm256_set1_ps(M_PI);

const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask);
angle = _mm256_blendv_ps(angle, nan_vec, nan_mask);
return angle;
}
Vec256<float> real() const {
return *this;
Expand Down
12 changes: 0 additions & 12 deletions aten/src/ATen/cpu/vec256/vec256_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,6 @@ class Vec256<int64_t> : public Vec256i {
auto inverse = _mm256_xor_si256(values, is_larger);
return _mm256_sub_epi64(inverse, is_larger);
}
Vec256<int64_t> angle() const {
return _mm256_set1_epi64x(0);
}
Vec256<int64_t> real() const {
return *this;
}
Expand Down Expand Up @@ -250,9 +247,6 @@ class Vec256<int32_t> : public Vec256i {
Vec256<int32_t> abs() const {
return _mm256_abs_epi32(values);
}
Vec256<int32_t> angle() const {
return _mm256_set1_epi32(0);
}
Vec256<int32_t> real() const {
return *this;
}
Expand Down Expand Up @@ -467,9 +461,6 @@ class Vec256<int16_t> : public Vec256i {
Vec256<int16_t> abs() const {
return _mm256_abs_epi16(values);
}
Vec256<int16_t> angle() const {
return _mm256_set1_epi16(0);
}
Vec256<int16_t> real() const {
return *this;
}
Expand Down Expand Up @@ -719,9 +710,6 @@ class Vec256<int8_t> : public Vec256i {
Vec256<int8_t> abs() const {
return _mm256_abs_epi8(values);
}
Vec256<int8_t> angle() const {
return _mm256_set1_epi8(0);
}
Vec256<int8_t> real() const {
return *this;
}
Expand Down
18 changes: 14 additions & 4 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Tensor unary_op_impl_float(const Tensor& self, Stub& stub) {
// Note: This is done by running the operation as usual and then copying the
// operation's result to the expected result type.
template <typename Stub>
static inline Tensor& unary_op_impl_with_complex_to_float_out(Tensor& result, const Tensor& self, Stub& stub) {
static inline Tensor& unary_op_impl_with_complex_to_float_out(Tensor& result, const Tensor& self, Stub& stub, bool promotes_integer_to_float) {
if (self.is_complex() && !result.is_complex()) {
// Checks if the corresponding float type can be cast to the desired dtype
const auto float_type = c10::toValueType(self.scalar_type());
Expand All @@ -85,6 +85,10 @@ static inline Tensor& unary_op_impl_with_complex_to_float_out(Tensor& result, co
return result;
}

if (promotes_integer_to_float) {
return unary_op_impl_float_out(result, self, stub);
}

return unary_op_impl_out(result, self, stub);
}

Expand Down Expand Up @@ -173,7 +177,7 @@ Tensor& arctan_(Tensor& self) { return self.atan_(); }
// complex input. This makes sense mathematically since the absolute value
// and angle of a complex number has no imaginary part.
Tensor& abs_out(Tensor& result, const Tensor& self) {
return unary_op_impl_with_complex_to_float_out(result, self, abs_stub);
return unary_op_impl_with_complex_to_float_out(result, self, abs_stub, /*promotes_integer_to_float=*/false);
}
Tensor abs(const Tensor& self) {
return unary_op_impl_with_complex_to_float(self, at::abs_out);
Expand All @@ -195,10 +199,16 @@ Tensor& absolute_(Tensor& self) {
}

Tensor& angle_out(Tensor& result, const Tensor& self) {
return unary_op_impl_with_complex_to_float_out(result, self, angle_stub);
return unary_op_impl_with_complex_to_float_out(result, self, angle_stub, /*promotes_integer_to_float=*/true);
}
Tensor angle(const Tensor& self) {
return unary_op_impl_with_complex_to_float(self, at::angle_out);
if (self.is_complex()) {
const auto float_type = c10::toValueType(self.scalar_type());
Tensor result = at::empty({0}, self.options().dtype(float_type));
return at::angle_out(result, self);
}

return unary_op_impl_float(self, angle_stub);
}

Tensor real(const Tensor& self) {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ static void abs_kernel(TensorIterator& iter) {
}

static void angle_kernel(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "angle_cpu", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "angle_cpu", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return angle_impl(a); },
Expand Down
11 changes: 10 additions & 1 deletion aten/src/ATen/native/cpu/zmath.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,18 @@ inline double zabs <c10::complex<double>, double> (c10::complex<double> z) {
return std::abs(z);
}

// This overload corresponds to non-complex dtypes.
// The function is consistent with its NumPy equivalent
// for non-complex dtypes where `pi` is returned for
// negative real numbers and `0` is returned for 0 or positive
// real numbers.
// Note: `nan` is propagated.
template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
inline VALUE_TYPE angle_impl (SCALAR_TYPE z) {
return 0;
if (at::_isnan(z)) {
return z;
}
return z < 0 ? M_PI : 0;
}

template<>
Expand Down
8 changes: 6 additions & 2 deletions aten/src/ATen/native/cuda/UnaryComplexKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>

Expand All @@ -11,7 +12,10 @@ namespace at { namespace native {
// We manually overload angle because std::arg does not work with types other than c10::complex.
template<typename scalar_t>
__host__ __device__ static inline scalar_t angle_wrapper(scalar_t v) {
return 0;
if (at::_isnan(v)){
return v;
}
return v < 0 ? M_PI : 0;
}

template<typename T>
Expand All @@ -20,7 +24,7 @@ __host__ __device__ static inline c10::complex<T> angle_wrapper(c10::complex<T>
}

void angle_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "angle_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "angle_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return angle_wrapper(a);
});
Expand Down
1 change: 0 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6655,7 +6655,6 @@ def inner(self, device, dtype):
torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types_skip_rocm, _cpu_types, True,
[_wrap_maybe_warns("This overload of addmv_? is deprecated")]),
('atan2', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-2, 1e-5, 1e-5, _types, _types_no_half),
('angle', '', _small_3d, lambda t, d: [], 0, 0, 0, _types_no_half, [torch.bfloat16], False),
('fmod', 'value', _small_3d, lambda t, d: [3], 1e-3),
('fmod', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-3),
('chunk', '', _medium_2d, lambda t, d: [4], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
Expand Down
12 changes: 7 additions & 5 deletions test/test_unary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,7 @@ def _fn(t):
for alt, inplace in ((op.get_method(), False), (op.get_inplace(), True),
(torch.jit.script(_fn), False)):
if alt is None:
with self.assertRaises(RuntimeError):
alt(t.clone())
continue

if inplace and op.promotes_integers_to_float and dtype in integral_types() + (torch.bool,):
# Assert that RuntimeError is raised
Expand Down Expand Up @@ -426,9 +425,12 @@ def compare_out(op, input, out):
if out_dtype.is_floating_point and not dtype.is_complex:
compare_out(op, input, output)
elif out_dtype.is_floating_point and dtype.is_complex:
# Can't cast complex to float
with self.assertRaises(RuntimeError):
op(input, out=output)
if op.supports_complex_to_float:
compare_out(op, input, output)
else:
# Can't cast complex to float
with self.assertRaises(RuntimeError):
op(input, out=output)
elif out_dtype.is_complex:
compare_out(op, input, output)
else:
Expand Down
5 changes: 5 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,11 @@ def merge_dicts(*dicts):
Keyword args:
{out}
.. note:: Starting in PyTorch 1.8, angle returns pi for negative real numbers,
zero for non-negative real numbers, and propagates NaNs. Previously
the function would return zero for all real numbers and not propagate
floating-point NaNs.
Example::
>>> torch.angle(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))*180/3.14159
Expand Down
15 changes: 13 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def __init__(self,
handles_large_floats=True, # whether the op correctly handles large float values (like 1e20)
handles_extremals=True, # whether the op correctly handles extremal values (like inf)
handles_complex_extremals=True, # whether the op correct handles complex extremals (like inf -infj)
supports_complex_to_float=False, # op supports casting from complex input to real output safely eg. angle
sample_inputs_func=sample_inputs_unary,
**kwargs):
super(UnaryUfuncInfo, self).__init__(name,
Expand All @@ -267,6 +268,7 @@ def __init__(self,
self.handles_large_floats = handles_large_floats
self.handles_extremals = handles_extremals
self.handles_complex_extremals = handles_complex_extremals
self.supports_complex_to_float = supports_complex_to_float

# Epsilon to ensure grad and gradgrad checks don't test values
# outside a function's domain.
Expand Down Expand Up @@ -1011,6 +1013,17 @@ def sample_inputs_pinverse(op_info, device, dtype, requires_grad=False):
dtypes=[torch.bfloat16])),
promotes_integers_to_float=True,
handles_complex_extremals=False),
UnaryUfuncInfo('angle',
ref=np.angle,
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool),
dtypesIfROCM=all_types_and_complex_and(torch.bool),
decorators=(precisionOverride({torch.float16: 1e-2,
torch.bfloat16: 1e-2}),),
promotes_integers_to_float=True,
supports_complex_to_float=True,
test_inplace_grad=False),
OpInfo('linalg.solve',
aten_name='linalg_solve',
op=torch.linalg.solve,
Expand Down Expand Up @@ -1389,8 +1402,6 @@ def method_tests():
('complex', (S, S, S), ((S, S, S),), ''),
('abs', (S, S, S), NO_ARGS, '', (True,)),
('abs', (), NO_ARGS, 'scalar', (True,)),
('angle', (S, S, S), NO_ARGS, '', (True,)),
('angle', (), NO_ARGS, 'scalar', (True,)),
('clamp', (S, S, S), (0, 1), '', (True,)),
('clamp', (S, S, S), (None, 0.5), 'min', (True,)),
('clamp', (S, S, S), (0.5, None), 'max', (True,)),
Expand Down

0 comments on commit 461aafe

Please sign in to comment.