Skip to content

CPU-strided-complex support for compare and pointwise ops #28735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

dylanbespalko
Copy link
Contributor

@dylanbespalko dylanbespalko commented Oct 26, 2019

In-tree changes to pytorch to support complex numbers are being submitted here.
Out-of-tree support for complex numbers is here: pytorch-cpu-strided-complex extension

These changes optimize complex Vec256 math kernels so that are within 2X real number performance on average. Benchmarks are here

Changes so far:

  • Added complex support for eq, neq, max, and min ops.
    • max/min ops need to compare the absolute value for complex numbers (using zabs).
  • Added complex support for is_nonzero and where.
    • where op compares the absolute value for complex numbers (using zabs).
  • Added complex support for linear interp and and pointwise ops.
  • Added complex support for check_convert and Linspace/Logspace.
    • std::complex does not support ++operator.
    • All compilers from clang, g++, c++ on aarch64, x86 produce the same assembly code when using +=1' instead of ++`. example for loop
  • Added complex support for log, log2, log10.
  • Optimized Vec256 operators using various logarithmic identities.
    • asin(), acos(), atan() is optimized using a ln() identity.
    • sqrt() is optimized by splitting the computation into real and imag parts.
    • several _mm256_mul_pd are avoided by using _mm256_xor_pd ops instead.
  • Added complex support for pow.
    • exp is cast to std::complex<double>.
    • no special optimization is added when the exp is real because the std::pow() operator expects a std::complex number.

@dylanbespalko
Copy link
Contributor Author

@ezyang,

Please review this PR. Thanks.

@@ -42,7 +45,7 @@ struct Reduction {
index_t result_index = 0;
for (int64_t k = 0; k < n; k++) {
scalar_t value = data[k];
bool cmp = greater ? (result > value) : (result < value);
bool cmp = greater ? (zabs_(result) > zabs_(value)) : (zabs_(result) < zabs_(value));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

abs-based max/min; makes sense, I suppose

Copy link
Contributor

@gchanan gchanan Oct 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that what numpy does? Why is there no comparison to what NumPy does in any of these?

I don't think it is, e.g:

>>> np.max(np.array([0, -5, 2, 3]).astype(np.complex64))
(3+0j)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gchanan,

The test cases are found here: pytorch-cpu-strided-complex extension.

The 'test/test_torch.py' compares the pytorch to numpy for each math kernel

import numpy as np
import torch as th

device = th.device('cpu')

def t2n(t):
    return t.detach().numpy()

def test_tensor_compare(self):
     a = th.tensor([[1+1j, 1-1j], [0.51, 0.51+0.51j]], dtype=th.complex128, device=device)
     assert_allclose(th.max(a, -2)[0], np.max(t2n(a), axis=-2))

This test shows that th.max() is the same as np.max().

I can't generate random numbers with torch.rand() because it's still implemented in the legacy TH folder. If that kernel was ported over I could simply modify the internal test scripts.

FYI, here is a list of legacy kernels that are causing me the most pain:

    print(a)  # _th_masked_select_bool (you can get print to work by removing formatting code)
    th.zeros_like(a) # th_zero
    th.randint_like(c8)  # _th_random_
    th.rand((4, 4), dtype=th.complex128, device=device)  # th_uniform
    th.randn((4, 4), dtype=th.complex128, device=device)  # th_normal

    th.mv(m2, v2)  # _th_mv not implemented
    th.mm(m2, m2)  # _th_mm not implemented

I'm not sure any of the indexing functions work. Eventually, I can work on porting kernels over from TH, but I'm currently under pressure to demonstrate CPU, GPU, FPGA interoperability.

@@ -91,6 +91,62 @@ void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) {
);
}
});
} else if (isComplexType(iter.dtype())) {
const auto exp = exp_scalar.to<std::complex<double>>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copy paste here is irritating. What's the delta between tis and above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I know.

  1. The most visible difference is
const auto exp = exp_scalar.to<double>();   //real
const auto exp = exp_scalar.to<std::complex<double>>(); // complex

Is exp cast to highest precision to maintain some numerical accuracy?

  1. std::pow() was not very generous on various CI machines when base and exp were two different data types. I had to cast exp back to scalar_t when calling the generic pow.
[=](Vec base) -> Vec { return base.pow(exp);    //real
[=](Vec base) -> Vec { return base.pow(scalar_t(exp));  // complex
  1. For complex numbers, the Vec256 optimizations of sqrt(), rsqrt() are tied to the performance of std::log(), not an AVX math function. I found a way to implement log() using AVX functions, but it was so complicated that it was slower than std::log(). Two situations could play out in the future:
    • There could be a need to add different accelerations to the pow kernel, because you pay a huge penalty for calling std::pow for complex numbers.
    • Some features in AVX512 could dramatically speed up the complex implementation of log(), which would cause things to optimized differently.

I didn't add any new accelleration cases to std::pow right now, but I could add them if you want. Eg. I can multiply base 10 times and it would still be faster than calling std::pow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is exp cast to highest precision to maintain some numerical accuracy?

In this case I think it is just convenience. Scalar stores the exponent internally as double and you have to extract it at that precision to make sure you don't lose any precision. But it's interesting that you need to downcast for the complex overloads.

Maybe we can add the comment here as a comment to the code in question.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I added the comment. Note the spec does say that it supports mixed complex precision.

template< class T, class U > 
complex</*Promoted*/> pow( const complex<T>& x, const complex<U>& y);

However when base is std::complex and exp is std::complex the code does not compile and it says there is no suitable overload of std::pow()???

scalar_t weight_val = weight.to<scalar_t>();
at::native::cpu_kernel(
iter,
[weight_val](scalar_t self_val, scalar_t end_val) {
return (weight_val < 0.5)
return (zabs<scalar_t, value_t>(weight_val) < 0.5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

abs-based lerp

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only cursory review of the vectorization

@ezyang ezyang requested a review from cpuhrsch October 28, 2019 14:06
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in d8c368b.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 29, 2019
Summary:
In-tree changes to pytorch to support complex numbers are being submitted here.
Out-of-tree support for complex numbers is here: [pytorch-cpu-strided-complex extension](https://gitlab.com/pytorch-complex/pytorch-cpu-strided-complex)

These changes optimize complex Vec256 math kernels so that are within 2X real number performance on average.  [Benchmarks are here](https://docs.google.com/spreadsheets/d/17pObcrSTpV4BOOX9FYf1vIX3QUlEgQhLvL1IBEyJyzs/edit#gid=0)

Changes so far:

- [x]  Added complex support for eq, neq, max, and min ops.
   - max/min ops need to compare the absolute value for complex numbers (using zabs).
- [x] Added complex support for is_nonzero and where.
   - where op compares the absolute value for complex numbers (using zabs).
- [x] Added complex support for linear interp and and pointwise ops.
- [x] Added complex support for check_convert and Linspace/Logspace.
   - std::complex does not support ++operator.
   - All compilers from clang, g++, c++ on aarch64, x86 produce the same assembly code when using `+=1' instead of `++`. [example for loop](https://godbolt.org/z/O6NW_p)
- [x] Added complex support for log, log2, log10.
- [x] Optimized Vec256 operators using various logarithmic identities.
  - `asin()`, `acos()`, `atan()` is optimized using a `ln()` identity.
  - `sqrt()` is optimized by splitting the computation into real and imag parts.
  - several `_mm256_mul_pd` are avoided by using `_mm256_xor_pd` ops instead.
- [x] Added complex support for pow.
  - exp is cast to `std::complex<double>`.
  - no special optimization is added when the `exp` is real because the `std::pow()` operator expects a std::complex number.
Pull Request resolved: pytorch/pytorch#28735

Differential Revision: D18170691

Pulled By: ezyang

fbshipit-source-id: 6f167398e112cdeab02fcfde8b543cb6629c865a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants