-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[numpy] torch.angle
: promote integer inputs to float
#49163
Changes from 11 commits
a3321b8
4d0fac4
4132b4b
f794e33
aa66092
bc3d3a3
210e492
399c0e9
dd2e91d
ad5b880
8e69c74
75b90f6
76f7b73
d2d9bcb
96b752d
da1ce2b
6b2fdc9
f22b7ee
df0cb2a
9266941
d95ff19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -66,7 +66,7 @@ Tensor unary_op_impl_float(const Tensor& self, Stub& stub) { | |||||||||||
// Note: This is done by running the operation as usual and then copying the | ||||||||||||
// operation's result to the expected result type. | ||||||||||||
template <typename Stub> | ||||||||||||
static inline Tensor& unary_op_impl_with_complex_to_float_out(Tensor& result, const Tensor& self, Stub& stub) { | ||||||||||||
static inline Tensor& unary_op_impl_with_complex_to_float_out(Tensor& result, const Tensor& self, Stub& stub, bool promotes_integer_to_float) { | ||||||||||||
if (self.is_complex() && !result.is_complex()) { | ||||||||||||
// Checks if the corresponding float type can be cast to the desired dtype | ||||||||||||
const auto float_type = c10::toValueType(self.scalar_type()); | ||||||||||||
|
@@ -85,6 +85,10 @@ static inline Tensor& unary_op_impl_with_complex_to_float_out(Tensor& result, co | |||||||||||
return result; | ||||||||||||
} | ||||||||||||
|
||||||||||||
if (promotes_integer_to_float) { | ||||||||||||
return unary_op_impl_float_out(result, self, stub); | ||||||||||||
} | ||||||||||||
|
||||||||||||
return unary_op_impl_out(result, self, stub); | ||||||||||||
} | ||||||||||||
|
||||||||||||
|
@@ -173,7 +177,7 @@ Tensor& arctan_(Tensor& self) { return self.atan_(); } | |||||||||||
// complex input. This makes sense mathematically since the absolute value | ||||||||||||
// and angle of a complex number has no imaginary part. | ||||||||||||
Tensor& abs_out(Tensor& result, const Tensor& self) { | ||||||||||||
return unary_op_impl_with_complex_to_float_out(result, self, abs_stub); | ||||||||||||
return unary_op_impl_with_complex_to_float_out(result, self, abs_stub, /*promotes_integer_to_float=*/false); | ||||||||||||
} | ||||||||||||
Tensor abs(const Tensor& self) { | ||||||||||||
return unary_op_impl_with_complex_to_float(self, at::abs_out); | ||||||||||||
|
@@ -195,10 +199,16 @@ Tensor& absolute_(Tensor& self) { | |||||||||||
} | ||||||||||||
|
||||||||||||
Tensor& angle_out(Tensor& result, const Tensor& self) { | ||||||||||||
return unary_op_impl_with_complex_to_float_out(result, self, angle_stub); | ||||||||||||
return unary_op_impl_with_complex_to_float_out(result, self, angle_stub, /*promotes_integer_to_float=*/true); | ||||||||||||
} | ||||||||||||
Tensor angle(const Tensor& self) { | ||||||||||||
return unary_op_impl_with_complex_to_float(self, at::angle_out); | ||||||||||||
if (self.is_complex()) { | ||||||||||||
const auto float_type = c10::toValueType(self.scalar_type()); | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't this run afoul of the safe casting logic in TensorIterator? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, sorry, it doesn't because it doesn't call tensor iterator, it calls angle_out. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is actually similar to what was already there. pytorch/aten/src/ATen/native/UnaryOps.cpp Lines 105 to 109 in f5b68e7
|
||||||||||||
Tensor result = at::empty({0}, self.options().dtype(float_type)); | ||||||||||||
return at::angle_out(result, self); | ||||||||||||
} | ||||||||||||
|
||||||||||||
return unary_op_impl_float(self, angle_stub); | ||||||||||||
} | ||||||||||||
|
||||||||||||
Tensor real(const Tensor& self) { | ||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,7 +167,7 @@ static void abs_kernel(TensorIterator& iter) { | |
} | ||
|
||
static void angle_kernel(TensorIterator& iter) { | ||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "angle_cpu", [&]() { | ||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "angle_cpu", [&]() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not dispatched on common_dtype but dtype? |
||
cpu_kernel_vec( | ||
iter, | ||
[=](scalar_t a) -> scalar_t { return angle_impl(a); }, | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -35,7 +35,10 @@ inline double zabs <c10::complex<double>, double> (c10::complex<double> z) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comment explaining this function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the name pytorch/aten/src/ATen/native/cpu/zmath.h Lines 114 to 163 in 76f7b73
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we're stuck with "_impl", then. If the name is consistent then this PR doesn't need to bother updating it. But a comment explaining it would still be good. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have added the following comment. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
inline VALUE_TYPE angle_impl (SCALAR_TYPE z) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (at::_isnan(z)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return z; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return z < 0 ? M_PI : 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
template<> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
#include <ATen/native/cuda/Loops.cuh> | ||
#include <ATen/Context.h> | ||
#include <ATen/Dispatch.h> | ||
#include <ATen/NumericUtils.h> | ||
#include <ATen/native/DispatchStub.h> | ||
#include <ATen/native/TensorIterator.h> | ||
|
||
|
@@ -11,7 +12,10 @@ namespace at { namespace native { | |
// We manually overload angle because std::arg does not work with types other than c10::complex. | ||
template<typename scalar_t> | ||
__host__ __device__ static inline scalar_t angle_wrapper(scalar_t v) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The name "angle_impl" in zmath.h seems odd to me, is there a better name? Maybe one that would fit a descriptive comment better? |
||
return 0; | ||
if (at::_isnan(v)){ | ||
return v; | ||
} | ||
return v < 0 ? M_PI : 0; | ||
} | ||
|
||
template<typename T> | ||
|
@@ -20,7 +24,7 @@ __host__ __device__ static inline c10::complex<T> angle_wrapper(c10::complex<T> | |
} | ||
|
||
void angle_kernel_cuda(TensorIterator& iter) { | ||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "angle_cuda", [&]() { | ||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "angle_cuda", [&]() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note the dispatch is on the common_dtype here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kshitij12345 @mruberry why did we change it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reference: |
||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { | ||
return angle_wrapper(a); | ||
}); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -625,6 +625,10 @@ def merge_dicts(*dicts): | |
Keyword args: | ||
{out} | ||
|
||
.. note:: From version 1.8 onwards, the angle function returns `PI` for negative real numbers, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Starting in PyTorch 1.8, angle returns pi for negative real numbers, zero for non-negative real numbers, and propagates NaNs. Previously the function would return zero for all real numbers and not propagate floating-point NaNs." |
||
and `0` for zero and postive real numbers. Prior to version 1.8, the function would | ||
return `0` for all real numbers and `NaN`. | ||
|
||
Example:: | ||
|
||
>>> torch.angle(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))*180/3.14159 | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -250,6 +250,7 @@ def __init__(self, | |||
handles_large_floats=True, # whether the op correctly handles large float values (like 1e20) | ||||
handles_extremals=True, # whether the op correctly handles extremal values (like inf) | ||||
handles_complex_extremals=True, # whether the op correct handles complex extremals (like inf -infj) | ||||
supports_complex_to_float=False, # op supports casting from complex input to real output safely eg. angle | ||||
sample_inputs_func=sample_inputs_unary, | ||||
**kwargs): | ||||
super(UnaryUfuncInfo, self).__init__(name, | ||||
|
@@ -264,6 +265,7 @@ def __init__(self, | |||
self.handles_large_floats = handles_large_floats | ||||
self.handles_extremals = handles_extremals | ||||
self.handles_complex_extremals = handles_complex_extremals | ||||
self.supports_complex_to_float = supports_complex_to_float | ||||
|
||||
# Epsilon to ensure grad and gradgrad checks don't test values | ||||
# outside a function's domain. | ||||
|
@@ -808,6 +810,21 @@ def sample_inputs(self, device, dtype, requires_grad=False): | |||
promotes_integers_to_float=True, | ||||
handles_complex_extremals=False, | ||||
test_complex_grad=False), | ||||
UnaryUfuncInfo('angle', | ||||
ref=np.angle, | ||||
dtypes=all_types_and_complex_and(torch.bool), | ||||
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), | ||||
dtypesIfCUDA=all_types_and_complex_and(torch.bool), | ||||
dtypesIfROCM=all_types_and_complex_and(torch.bool), | ||||
decorators=(precisionOverride({torch.float16: 1e-2, | ||||
torch.bfloat16: 1e-2}),), | ||||
skips=( | ||||
# RuntimeError: "isfinite" not implemented for 'BFloat16' | ||||
kshitij12345 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
SkipInfo('TestCommon', 'test_variant_consistency_jit', | ||||
dtypes=[torch.bfloat16]),), | ||||
promotes_integers_to_float=True, | ||||
supports_complex_to_float=True, | ||||
test_inplace_grad=False), | ||||
] | ||||
|
||||
if TEST_SCIPY: | ||||
|
@@ -1151,8 +1168,6 @@ def method_tests(): | |||
('complex', (S, S, S), ((S, S, S),), ''), | ||||
('abs', (S, S, S), NO_ARGS, '', (True,)), | ||||
('abs', (), NO_ARGS, 'scalar', (True,)), | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reviewer's note: There is another test for angle but it also tests abs: pytorch/test/test_unary_ufuncs.py Line 538 in 65876d3
|
||||
('angle', (S, S, S), NO_ARGS, '', (True,)), | ||||
('angle', (), NO_ARGS, 'scalar', (True,)), | ||||
('clamp', (S, S, S), (0, 1), '', (True,)), | ||||
('clamp', (S, S, S), (None, 0.5), 'min', (True,)), | ||||
('clamp', (S, S, S), (0.5, None), 'max', (True,)), | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I am not very well versed with the AVX instructions ... what (nan related) information do we get by comparing values to itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since AVX itself does not have a special function for masking
nan
s.We use the property
nan != nan
to find thenan
values.