-
Notifications
You must be signed in to change notification settings - Fork 25k
CPU-strided-complex support for compare and pointwise ops #28735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CPU-strided-complex support for compare and pointwise ops #28735
Conversation
Please review this PR. Thanks. |
@@ -42,7 +45,7 @@ struct Reduction { | |||
index_t result_index = 0; | |||
for (int64_t k = 0; k < n; k++) { | |||
scalar_t value = data[k]; | |||
bool cmp = greater ? (result > value) : (result < value); | |||
bool cmp = greater ? (zabs_(result) > zabs_(value)) : (zabs_(result) < zabs_(value)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
abs-based max/min; makes sense, I suppose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is that what numpy does? Why is there no comparison to what NumPy does in any of these?
I don't think it is, e.g:
>>> np.max(np.array([0, -5, 2, 3]).astype(np.complex64))
(3+0j)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @gchanan,
The test cases are found here: pytorch-cpu-strided-complex extension.
The 'test/test_torch.py' compares the pytorch to numpy for each math kernel
import numpy as np
import torch as th
device = th.device('cpu')
def t2n(t):
return t.detach().numpy()
def test_tensor_compare(self):
a = th.tensor([[1+1j, 1-1j], [0.51, 0.51+0.51j]], dtype=th.complex128, device=device)
assert_allclose(th.max(a, -2)[0], np.max(t2n(a), axis=-2))
This test shows that th.max()
is the same as np.max()
.
I can't generate random numbers with torch.rand()
because it's still implemented in the legacy TH folder. If that kernel was ported over I could simply modify the internal test scripts.
FYI, here is a list of legacy kernels that are causing me the most pain:
print(a) # _th_masked_select_bool (you can get print to work by removing formatting code)
th.zeros_like(a) # th_zero
th.randint_like(c8) # _th_random_
th.rand((4, 4), dtype=th.complex128, device=device) # th_uniform
th.randn((4, 4), dtype=th.complex128, device=device) # th_normal
th.mv(m2, v2) # _th_mv not implemented
th.mm(m2, m2) # _th_mm not implemented
I'm not sure any of the indexing functions work. Eventually, I can work on porting kernels over from TH, but I'm currently under pressure to demonstrate CPU, GPU, FPGA interoperability.
@@ -91,6 +91,62 @@ void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) { | |||
); | |||
} | |||
}); | |||
} else if (isComplexType(iter.dtype())) { | |||
const auto exp = exp_scalar.to<std::complex<double>>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The copy paste here is irritating. What's the delta between tis and above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I know.
- The most visible difference is
const auto exp = exp_scalar.to<double>(); //real
const auto exp = exp_scalar.to<std::complex<double>>(); // complex
Is exp cast to highest precision to maintain some numerical accuracy?
std::pow()
was not very generous on various CI machines when base and exp were two different data types. I had to cast exp back to scalar_t when calling the generic pow.
[=](Vec base) -> Vec { return base.pow(exp); //real
[=](Vec base) -> Vec { return base.pow(scalar_t(exp)); // complex
- For complex numbers, the Vec256 optimizations of
sqrt()
,rsqrt()
are tied to the performance ofstd::log()
, not an AVX math function. I found a way to implementlog()
using AVX functions, but it was so complicated that it was slower thanstd::log()
. Two situations could play out in the future:- There could be a need to add different accelerations to the pow kernel, because you pay a huge penalty for calling std::pow for complex numbers.
- Some features in AVX512 could dramatically speed up the complex implementation of
log()
, which would cause things to optimized differently.
I didn't add any new accelleration cases to std::pow right now, but I could add them if you want. Eg. I can multiply base 10 times and it would still be faster than calling std::pow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is exp cast to highest precision to maintain some numerical accuracy?
In this case I think it is just convenience. Scalar stores the exponent internally as double and you have to extract it at that precision to make sure you don't lose any precision. But it's interesting that you need to downcast for the complex overloads.
Maybe we can add the comment here as a comment to the code in question.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I added the comment. Note the spec does say that it supports mixed complex precision.
template< class T, class U >
complex</*Promoted*/> pow( const complex<T>& x, const complex<U>& y);
However when base is std::complex and exp is std::complex the code does not compile and it says there is no suitable overload of std::pow()???
scalar_t weight_val = weight.to<scalar_t>(); | ||
at::native::cpu_kernel( | ||
iter, | ||
[weight_val](scalar_t self_val, scalar_t end_val) { | ||
return (weight_val < 0.5) | ||
return (zabs<scalar_t, value_t>(weight_val) < 0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
abs-based lerp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only cursory review of the vectorization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: In-tree changes to pytorch to support complex numbers are being submitted here. Out-of-tree support for complex numbers is here: [pytorch-cpu-strided-complex extension](https://gitlab.com/pytorch-complex/pytorch-cpu-strided-complex) These changes optimize complex Vec256 math kernels so that are within 2X real number performance on average. [Benchmarks are here](https://docs.google.com/spreadsheets/d/17pObcrSTpV4BOOX9FYf1vIX3QUlEgQhLvL1IBEyJyzs/edit#gid=0) Changes so far: - [x] Added complex support for eq, neq, max, and min ops. - max/min ops need to compare the absolute value for complex numbers (using zabs). - [x] Added complex support for is_nonzero and where. - where op compares the absolute value for complex numbers (using zabs). - [x] Added complex support for linear interp and and pointwise ops. - [x] Added complex support for check_convert and Linspace/Logspace. - std::complex does not support ++operator. - All compilers from clang, g++, c++ on aarch64, x86 produce the same assembly code when using `+=1' instead of `++`. [example for loop](https://godbolt.org/z/O6NW_p) - [x] Added complex support for log, log2, log10. - [x] Optimized Vec256 operators using various logarithmic identities. - `asin()`, `acos()`, `atan()` is optimized using a `ln()` identity. - `sqrt()` is optimized by splitting the computation into real and imag parts. - several `_mm256_mul_pd` are avoided by using `_mm256_xor_pd` ops instead. - [x] Added complex support for pow. - exp is cast to `std::complex<double>`. - no special optimization is added when the `exp` is real because the `std::pow()` operator expects a std::complex number. Pull Request resolved: pytorch/pytorch#28735 Differential Revision: D18170691 Pulled By: ezyang fbshipit-source-id: 6f167398e112cdeab02fcfde8b543cb6629c865a
In-tree changes to pytorch to support complex numbers are being submitted here.
Out-of-tree support for complex numbers is here: pytorch-cpu-strided-complex extension
These changes optimize complex Vec256 math kernels so that are within 2X real number performance on average. Benchmarks are here
Changes so far:
+=1' instead of
++`. example for loopasin()
,acos()
,atan()
is optimized using aln()
identity.sqrt()
is optimized by splitting the computation into real and imag parts._mm256_mul_pd
are avoided by using_mm256_xor_pd
ops instead.std::complex<double>
.exp
is real because thestd::pow()
operator expects a std::complex number.