Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic Element-wise Complex Number Calculations Not Available On GPU #3624

Closed
sbrodeur opened this issue Aug 3, 2016 · 31 comments
Closed

Basic Element-wise Complex Number Calculations Not Available On GPU #3624

sbrodeur opened this issue Aug 3, 2016 · 31 comments
Assignees
Labels
stat:contribution welcome Status - Contributions welcome type:bug Bug type:feature Feature requests

Comments

@sbrodeur
Copy link

sbrodeur commented Aug 3, 2016

Basic element-wise addition, subtraction, multiplication or division for any Tensor of type tf.complex64 is not implemented on GPU.

Environment info

Operating System: Centos 7, 3.10.0-327.22.2.el7.x86_64

Installed version of CUDA and cuDNN: CUDA 7.5 and cuDNN 7.0-v4
-rw-r--r--. 1 root root 189170 Jul 22 16:14 /usr/local/cuda-7.5/lib/libcudadevrt.a
lrwxrwxrwx. 1 root root 16 Jul 22 16:14 /usr/local/cuda-7.5/lib/libcudart.so -> libcudart.so.7.5
lrwxrwxrwx. 1 root root 19 Jul 22 16:14 /usr/local/cuda-7.5/lib/libcudart.so.7.5 -> libcudart.so.7.5.18
-rwxr-xr-x. 1 root root 311596 Jul 22 16:14 /usr/local/cuda-7.5/lib/libcudart.so.7.5.18
-rw-r--r--. 1 root root 558020 Jul 22 16:14 /usr/local/cuda-7.5/lib/libcudart_static.a

Tensorflow installed from source:

  1. Commit hash 00700f0
  2. Bazel information:
    Build label: 0.3.0-2016-07-22 (@ca36b06)
    Build target: bazel-out/local-fastbuild/bin/src/main/java/com/google/devtools/build/lib/bazel/BazelServer_deploy.jar
    Build time: Fri Jul 22 19:23:10 2016 (1469215390)
    Build timestamp: 1469215390
    Build timestamp as int: 1469215390

Steps to reproduce

  1. Add, subtract, multiply or divide any Tensor of type tf.complex64. A code example is shown here for element-wise addition:
import tensorflow as tf

if __name__ == '__main__':

    with tf.device('/gpu:0'):
        N = 100
        a = tf.complex(tf.random_normal((N,)), tf.random_normal((N,)))
        b = tf.complex(tf.random_normal((N,)), tf.random_normal((N,)))
        c = a + b

        with tf.Session() as sess:
            c = sess.run(c)

The code returns the following output if run on GPU (works well on CPU):

I tensorflow/stream_executor/dso_loader.cc:108] successfully opened CUDA library libcublas.so.7.5 locally
I tensorflow/stream_executor/dso_loader.cc:108] successfully opened CUDA library libcudnn.so.4.0.7 locally
I tensorflow/stream_executor/dso_loader.cc:108] successfully opened CUDA library libcufft.so.7.5 locally
I tensorflow/stream_executor/dso_loader.cc:108] successfully opened CUDA library libcuda.so.1 locally
I tensorflow/stream_executor/dso_loader.cc:108] successfully opened CUDA library libcurand.so.7.5 locally
I tensorflow/core/common_runtime/gpu/gpu_init.cc:102] Found device 0 with properties:
name: Tesla K40c
major: 3 minor: 5 memoryClockRate (GHz) 0.745
pciBusID 0000:02:00.0
Total memory: 12.00GiB
Free memory: 11.90GiB
W tensorflow/stream_executor/cuda/cuda_driver.cc:572] creating context when one is currently active; existing: 0x5168890
I tensorflow/core/common_runtime/gpu/gpu_init.cc:102] Found device 1 with properties:
name: GeForce GT 610
major: 2 minor: 1 memoryClockRate (GHz) 1.62
pciBusID 0000:01:00.0
Total memory: 1023.19MiB
Free memory: 396.98MiB
I tensorflow/core/common_runtime/gpu/gpu_init.cc:59] cannot enable peer access from device ordinal 0 to device ordinal 1
I tensorflow/core/common_runtime/gpu/gpu_init.cc:59] cannot enable peer access from device ordinal 1 to device ordinal 0
I tensorflow/core/common_runtime/gpu/gpu_init.cc:126] DMA: 0 1
I tensorflow/core/common_runtime/gpu/gpu_init.cc:136] 0: Y N
I tensorflow/core/common_runtime/gpu/gpu_init.cc:136] 1: N Y
I tensorflow/core/common_runtime/gpu/gpu_device.cc:839] Creating TensorFlow device (/gpu:0) -> (device: 0, name: Tesla K40c, pci bus id: 0000:02:00.0)
I tensorflow/core/common_runtime/gpu/gpu_device.cc:814] Ignoring gpu device (device: 1, name: GeForce GT 610, pci bus id: 0000:01:00.0) with Cuda compute capability 2.1. The minimum required Cuda capability is 3.5.
E tensorflow/core/client/tensor_c_api.cc:485] Cannot assign a device to node 'add': Could not satisfy explicit device specification '/device:GPU:0' because no supported kernel for GPU devices is available.
[[Node: add = Add[T=DT_COMPLEX64, _device="/device:GPU:0"](Complex, Complex_1)]]
Traceback (most recent call last):
File "test_div_gpu_prob.py", line 12, in
c = sess.run(c)
File "/usr/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 382, in run
run_metadata_ptr)
File "/usr/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 655, in _run
feed_dict_string, options, run_metadata)
File "/usr/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 723, in _do_run
target_list, options, run_metadata)
File "/usr/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 743, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors.InvalidArgumentError: Cannot assign a device to node 'add': Could not satisfy explicit device specification '/device:GPU:0' because no supported kernel for GPU devices is available.
[[Node: add = Add[T=DT_COMPLEX64, _device="/device:GPU:0"](Complex, Complex_1)]]
Caused by op u'add', defined at:
File "test_div_gpu_prob.py", line 9, in
c = a + b
File "/usr/lib/python2.7/site-packages/tensorflow/python/ops/math_ops.py", line 755, in binary_op_wrapper
return func(x, y, name=name)
File "/usr/lib/python2.7/site-packages/tensorflow/python/ops/gen_math_ops.py", line 70, in add
result = _op_def_lib.apply_op("Add", x=x, y=y, name=name)
File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 703, in apply_op
op_def=op_def)
File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2310, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1232, in init
self._traceback = _extract_stack()

What have you tried?

  1. Implementation using builtin Tensorflow functions works, if the real and imaginary parts are separated. See the code below:
import numpy as np
import tensorflow as tf

def complex_add(x, y):
    xr, xi = tf.real(x), tf.imag(x)
    yr, yi = tf.real(y), tf.imag(y)
    return tf.complex(xr + yr, xi + yi)

def complex_sub(x, y):
    xr, xi = tf.real(x), tf.imag(x)
    yr, yi = tf.real(y), tf.imag(y)
    return tf.complex(xr - yr, xi - yi)

def complex_mul(x, y):
    xr, xi = tf.real(x), tf.imag(x)
    yr, yi = tf.real(y), tf.imag(y)
    return tf.complex(xr*yr - xi*yi, xr*yi + xi*yr)

def complex_div(x, y):
    xr, xi = tf.real(x), tf.imag(x)
    yr, yi = tf.real(y), tf.imag(y)
    d = tf.square(yr) + tf.square(yi)
    return tf.complex((xr*yr+xi*yi)/d, (xi*yr-xr*yi)/d)

if __name__ == '__main__':

    with tf.device('/gpu:0'):
        N = 100
        a = tf.complex(tf.random_normal((N,)), tf.random_normal((N,)))
        b = tf.complex(tf.random_normal((N,)), tf.random_normal((N,)))

        with tf.Session() as sess:

            a_, b_, c = sess.run([a,b,complex_add(a,b)])
            assert np.allclose(c, a_ + b_)

            a_, b_, c = sess.run([a,b,complex_sub(a,b)])
            assert np.allclose(c, a_ - b_)

            a_, b_, c = sess.run([a,b,complex_mul(a,b)])
            assert np.allclose(c, a_ * b_)

            a_, b_, c = sess.run([a,b,complex_div(a,b)])
            assert np.allclose(c, a_ / b_)

It would be nice to have such functions transparent with the built-in CPU implementations.

@sbrodeur
Copy link
Author

sbrodeur commented Aug 3, 2016

Note: implementations using built-in Tensorflow functions as show above doesn't solve gradient issues caused by the handling of complex numbers:

import tensorflow as tf

def complex_mul(x, y):
    xr, xi = tf.real(x), tf.imag(x)
    yr, yi = tf.real(y), tf.imag(y)
    return tf.complex(xr*yr - xi*yi, xr*yi + xi*yr)

if __name__ == '__main__':

    with tf.device('/gpu:0'):
        N = 100
        a = tf.complex(tf.random_normal((N,)), tf.random_normal((N,)))
        b = tf.complex(tf.random_normal((N,)), tf.random_normal((N,)))
        c = complex_mul(a, b)

        grad = tf.gradients([c], [a])

        with tf.Session() as sess:
            grad = sess.run(grad)

This code will fail with the following error:

E tensorflow/core/client/tensor_c_api.cc:485] Cannot assign a device to node 'gradients/Shape': Could not satisfy explicit device specification '/device:GPU:0' because no supported kernel for GPU devices is available.
[[Node: gradients/Shape = ShapeT=DT_COMPLEX64, _device="/device:GPU:0"]]
Traceback (most recent call last):
File "test_div_gpu_grad_prob.py", line 19, in
grad = sess.run(grad)
File "/usr/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 382, in run
run_metadata_ptr)
File "/usr/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 655, in _run
feed_dict_string, options, run_metadata)
File "/usr/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 723, in _do_run
target_list, options, run_metadata)
File "/usr/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 743, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors.InvalidArgumentError: Cannot assign a device to node 'gradients/Shape': Could not satisfy explicit device specification '/device:GPU:0' because no supported kernel for GPU devices is available.
[[Node: gradients/Shape = ShapeT=DT_COMPLEX64, _device="/device:GPU:0"]]
Caused by op u'gradients/Shape', defined at:
File "test_div_gpu_grad_prob.py", line 16, in
grad = tf.gradients([c], [a])
File "/usr/lib/python2.7/site-packages/tensorflow/python/ops/gradients.py", line 367, in gradients
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)
File "/usr/lib/python2.7/site-packages/tensorflow/python/ops/gradients.py", line 230, in _DefaultGradYs
array_ops.shape(y),
File "/usr/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 131, in shape
return gen_array_ops.shape(input, name=name)
File "/usr/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 1922, in shape
result = _op_def_lib.apply_op("Shape", input=input, name=name)
File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 703, in apply_op
op_def=op_def)
File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2310, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1232, in init
self._traceback = _extract_stack()

@poxvoculi
Copy link
Contributor

It seems that support for complex64 types is piecemeal, by op-type and device-type. Bringing in @martinwicke for a comment on the policy.

@poxvoculi poxvoculi added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Aug 3, 2016
@martinwicke martinwicke added bug stat:contribution welcome Status - Contributions welcome and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Aug 3, 2016
@martinwicke
Copy link
Member

Yes, this is a good feature request bordering on a bug. Please check the Op registrations of the affected ops, and you'll probably find that the templates of many of them are not specialized for complex data types. It is a relatively simple thing to fix, and I'd love PRs that do it.

@ibab
Copy link
Contributor

ibab commented Aug 8, 2016

@sbrodeur: Are you currently working on this?
If not, I could go ahead and attempt a fix.

@sbrodeur
Copy link
Author

sbrodeur commented Aug 8, 2016

@ibab: I did not yet attempt a fix. I've looked a at little at Eigen:
"By default, Eigen currently supports standard floating-point types (float, double, std::complex, std::complex, long double), as well as all native integer types (e.g., int, unsigned int, short, etc.), and bool."
[https://eigen.tuxfamily.org/dox-devel/TopicCustomizingEigen.html]

Thus, for the simple calculations here, should I expect Eigen to provide compatible functors, e.g. :

template <typename T>
struct add : base<T, Eigen::internal::scalar_sum_op<T> > {
  static const bool use_bcast_optimization = true;
};

This code is in file cwise_ops.h

Does this means the fix is similar to #2263, i.e. just adding the complex64 type when we register the kernels?

@ibab
Copy link
Contributor

ibab commented Aug 8, 2016

Yes, you won't have to implement the operations themselves, you just need to enable them.
For example, you can look at the supported types for the addition op here:

REGISTER4(BinaryOp, GPU, "Add", functor::add, float, Eigen::half, double,

You would need to add complex64 and complex128 to the macro (and change it into REGISTER6).

You should make sure that the GPU tests are enabled for complex64 and complex128 for each op that has been extended, for example here:

def testComplex64Basic(self):
.

@sbrodeur
Copy link
Author

sbrodeur commented Aug 8, 2016

Thanks for the information @ibab! I will attempt a fix myself and send a PR soon!

@sbrodeur
Copy link
Author

sbrodeur commented Aug 12, 2016

So far, I can make it work with some operations (add, sub) by simply adding the complex data types when registering the kernels: e.g.

REGISTER4(BinaryOp, GPU, "Add", functor::add, float, Eigen::half, double,

DEFINE_BINARY6(add, Eigen::half, float, double, int64, complex64, complex128);

Compilation errors however occur for multiplication (and division), as seen below.
Searching the web, I found here that CUDA may not support std::complex because of STL incompatibilities:
https://forum.kde.org/viewtopic.php?f=74&t=123919

It seems to solve this problem, people have been using reimplementations of the std:complex type (e.g. from thrust, cuda_complex or cusp) so that it can be used in device code:
https://github.com/thrust/thrust/blob/2ef13096187b40a35a71451d09e49b14074b0859/thrust/complex.h
https://github.com/jtravs/cuda_complex/blob/master/cuda_complex.hpp
https://github.com/cusplibrary/cusplibrary/blob/master/cusp/complex.h

Would the Eigen library implementing something similar to what thrust uses solve the issue in Tensorflow?

Compilation output

INFO: From Compiling tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc:
nvcc warning : option '--relaxed-constexpr' has been deprecated and replaced by option '--expt-relaxed-constexpr'.
In file included from /usr/local/cuda-7.5/include/host_config.h:161:0,
from /usr/local/cuda-7.5/include/cuda_runtime.h:76,
from :0:
/usr/include/features.h:330:4: warning: #warning _FORTIFY_SOURCE requires compiling with optimization (-O) [-Wcpp]

warning _FORTIFY_SOURCE requires compiling with optimization (-O)

^

In file included from /usr/local/cuda-7.5/include/host_config.h:161:0,
from /usr/local/cuda-7.5/include/cuda_runtime.h:76,
from :0:
/usr/include/features.h:330:4: warning: #warning _FORTIFY_SOURCE requires compiling with optimization (-O) [-Wcpp]

warning _FORTIFY_SOURCE requires compiling with optimization (-O)

^

In file included from /usr/local/cuda-7.5/include/host_config.h:161:0,
from /usr/local/cuda-7.5/include/cuda_runtime.h:76,
from :0:
/usr/include/features.h:330:4: warning: #warning _FORTIFY_SOURCE requires compiling with optimization (-O) [-Wcpp]

warning _FORTIFY_SOURCE requires compiling with optimization (-O)

^

nvcc warning : option '--relaxed-constexpr' has been deprecated and replaced by option '--expt-relaxed-constexpr'.
external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h(136): error: a value of type "int" cannot be assigned to an entity of type "_ZNSt7complexIfE9_ComplexTE"

external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h(136): error: a value of type "int" cannot be assigned to an entity of type "_ZNSt7complexIfE9_ComplexTE"

external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h(136): error: a value of type "int" cannot be assigned to an entity of type "_ZNSt7complexIfE9_ComplexTE"

external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h(136): error: a value of type "int" cannot be assigned to an entity of type "_ZNSt7complexIfE9_ComplexTE"

external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h(136): error: a value of type "int" cannot be assigned to an entity of type "_ZNSt7complexIfE9_ComplexTE"

external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h(136): error: a value of type "int" cannot be assigned to an entity of type "_ZNSt7complexIfE9_ComplexTE"

external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h(136): error: a value of type "int" cannot be assigned to an entity of type "_ZNSt7complexIfE9_ComplexTE"

external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h(136): error: a value of type "int" cannot be assigned to an entity of type "_ZNSt7complexIfE9_ComplexTE"

external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h(136): error: a value of type "int" cannot be assigned to an entity of type "_ZNSt7complexIfE9_ComplexTE"

external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h(136): error: a value of type "int" cannot be assigned to an entity of type "_ZNSt7complexIfE9_ComplexTE"

external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h(136): error: a value of type "int" cannot be assigned to an entity of type "_ZNSt7complexIfE9_ComplexTE"

external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h(136): error: a value of type "int" cannot be assigned to an entity of type "_ZNSt7complexIfE9_ComplexTE"

12 errors detected in the compilation of "/tmp/tmpxft_0000430f_00000000-12_cwise_op_gpu_mul.cu.compute_35.cpp2.i".

@ibab
Copy link
Contributor

ibab commented Aug 13, 2016

Strange, your errors seem to be caused by the fact that Eigen is trying to assign values from an int Tensor into a complex Tensor.
I don't think that's supposed to happen usually.

I've tried enabling complex mul and div just now, and it successfully compiled and ran when I restricted the ops to run on the GPU.
I'm also using CUDA 7.5, not sure what else could be different between our setups.

@sbrodeur
Copy link
Author

Here is my configuration:

GPU: Tesla K40c
Operating System: CentOS Linux 7 (Core)
Kernel: Linux 3.10.0-327.22.2.el7.x86_64
Architecture: x86-64
C Compiler: gcc (GCC) 4.8.5 20150623 (Red Hat 4.8.5-4)
CUDA Compiler: Cuda compilation tools, release 7.5, V7.5.17

I will try with a more recent gcc (e.g. 4.9.2) to see if the compilation problem disappears.

@ibab
Copy link
Contributor

ibab commented Aug 13, 2016

I'm also using gcc 4.8.5 andCuda compilation tools, release 7.5, V7.5.17.
Not sure what could be causing this if these two are the same.
I've uploaded my changes to ibab@8c3baae, so you might want to compare these with yours.
If that doesn't help we should try rebasing to the same TensorFlow commit.

@sbrodeur
Copy link
Author

Sadly, I obtain the same errors if I clone and compile the fork ibab/tensorflow@8c3baae without any modifications.
@ibab - What is your Linux distribution?

@ibab
Copy link
Contributor

ibab commented Aug 13, 2016

I'm running Scientific Linux 6, which should be virtually identical to Red Hat 6.
I get my compiler toolchain from anaconda, though.
I'll try to make sure it's not something weird on my end.

@sbrodeur
Copy link
Author

On my side, I'll try a build on my laptop which runs the latest Debian 8 (Jessie). I don't have a Nvidia GPU but I should nevertheless be able to compile with CUDA.

@ibab
Copy link
Contributor

ibab commented Aug 13, 2016

Okay, I've rebuilt tensorflow after a bazel clean --expunge to make extra sure that I'm using the right Eigen version, as I've changed it around a few times previously, but it still built successfully.

Edit: Btw, do you also get compiler warnings about calling __host__ functions from device code as in the link you posted above?

@sbrodeur
Copy link
Author

sbrodeur commented Aug 14, 2016

I do not get compiler warnings about calling host functions from device code.

I just tried to build on my laptop (Debian 8, up-to-date) with configuration;
gcc (Debian 4.9.2-10) 4.9.2
Cuda compilation tools, release 7.5, V7.5.17

I obtained the same errors, so it does not seem related to gcc or distribution. I also tried to build with the latest eigen (3782cd1de9c4) on the Centos 7 machine, and that did not help either. I will try building with CUDA 8, after which I will be clueless about those compilation issues.

Edit: same errors with CUDA 8.

@ibab
Copy link
Contributor

ibab commented Aug 14, 2016

I've tried compiling with different compute capabilities, but it still compiled without errors.
The fact that you reproduced it on two different systems makes me think that it's a problem with my setup, though.
Unfortunately nvcc doesn't give us a lot of information in the error message (like which template instantiations we are dealing with) :(

@iportillo
Copy link

@ibab , @sbrodeur , thank you so much for working on this. I think it would really speed up one of my projects. Is there any new progress? Are you planning to include this on the next release of TensorFlow? What about basic math functions such as tf.exp(), tf.complex_abs()?

Thanks again!

@sbrodeur
Copy link
Author

sbrodeur commented Aug 24, 2016

@iportillo - I will give it another try today. It would also significantly accelerate my experiments, since everything could run on the GPU. I'll try to see if it would be easy to use CUDABlas directly (rather than Eigen) for the basic math functions on complex numbers.

tf.complex_abs is easy to implement on GPU right now:

def complex_abs(x):
    return tf.sqrt(tf.square(tf.real(x)) + tf.square(tf.imag(x)))

By tf.exp(), do you mean converting from the Cartesian to the complex exponential form (angle and norm)? To calculate the angle, this means implementing the atan2 function (for complex x + iy):

def atan2(y, x):
    angle = tf.select(tf.greater(x,0.0), tf.atan(y/x) + np.pi, tf.zeros_like(x))
    angle = tf.select(tf.logical_and(tf.less(x,0.0),  tf.greater_equal(y,0.0)), tf.atan(y/x) + np.pi, angle)
    angle = tf.select(tf.logical_and(tf.less(x,0.0),  tf.less(y,0.0)), tf.atan(y/x) - np.pi, angle)
    angle = tf.select(tf.logical_and(tf.equal(x,0.0), tf.greater(y,0.0)), 0.5*np.pi * tf.ones_like(x), angle)
    angle = tf.select(tf.logical_and(tf.equal(x,0.0), tf.less(y,0.0)), -0.5*np.pi * tf.ones_like(x), angle)
    angle = tf.select(tf.logical_and(tf.equal(x,0.0), tf.equal(y,0.0)), np.nan * tf.zeros_like(x), angle)
    return angle

def complex_arg(x):
    return atan2(tf.imag(x), tf.real(x))

It's not optimized but works well on GPU.

@ibab
Copy link
Contributor

ibab commented Aug 24, 2016

@benoitsteiner: We're having some problems with implementing the product and div ops for std::complex using the Eigen Tensor library.
Do you see why we would get an error like the following when enabling them in TensorFlow?

external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h(136): error: a value of type "int" cannot be assigned to an entity of type "_ZNSt7complexIfE9_ComplexTE"
$ c++filt _ZNSt7complexIfE9_ComplexTE
std::complex<float>::_ComplexT

Maybe we would need to switch to something like thrust::complex as @sbrodeur suggested?

@sbrodeur
Copy link
Author

sbrodeur commented Aug 24, 2016

I made some progress! I can make multiplication and division ops work for complex numbers if I specialized the templates in https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/cwise_ops.h#L432

template <typename T>
struct mul : base<T, Eigen::internal::scalar_product_op<T> > {};

template <typename T>
struct multiply_complex {
  typedef std::complex<T> result_type;
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(std::complex<T> a,
                                                               std::complex<T> b) const {
    return std::complex<T>(a.real()*b.real() - a.imag()*b.imag(),
                           a.real()*b.imag() + a.imag()*b.real());
  }
};

template <>
struct mul<std::complex<float> > : base<std::complex<float>, multiply_complex<float> > {};

template <>
struct mul<std::complex<double> > : base<std::complex<double>, multiply_complex<double> > {};

It seems more like a hack, but it doesn't involve changes in Eigen for now.

Not sure what is wrong with nvcc using scalar_product_op in Eigen for complex numbers:
https://github.com/RLovelett/eigen/blob/master/Eigen/src/Core/functors/BinaryFunctors.h#L76

However, it seems tightly related to using built-in * and / operators for std:complex types.
For instance, this fails with the same errors as in the previous posts:

template <typename T>
struct mul : base<T, Eigen::internal::scalar_product_op<T> > {};

template <typename T>
struct multiply_complex {
  typedef std::complex<T> result_type;
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(std::complex<T> a,
                                                               std::complex<T> b) const {
    return a*b;
  }
};

template <>
struct mul<std::complex<float> > : base<std::complex<float>, multiply_complex<float> > {};

template <>
struct mul<std::complex<double> > : base<std::complex<double>, multiply_complex<double> > {};

@sbrodeur
Copy link
Author

I can confirm that with the above trick, I can make work a lot of very useful functions for complex numbers on GPU (e.g. square, neg, div, mul, abs) This brings support for complex gradient computation on GPU:

import tensorflow as tf

if __name__ == '__main__':

    with tf.device('/gpu:0'):
        N = 100
        a = tf.complex(tf.random_normal((N,)), tf.random_normal((N,)))
        b = tf.complex(tf.random_normal((N,)), tf.random_normal((N,)))
        c = tf.complex(tf.random_normal((N,)), tf.random_normal((N,)))

        d = c * tf.neg(tf.square(a + b))

        grad = tf.gradients([d], [a])

        with tf.Session() as sess:
            grad = sess.run(grad)

Should I make a PR or should we investigate further the handling of std::complex by nvcc?

@ibab
Copy link
Contributor

ibab commented Aug 29, 2016

I'm not a maintainer, but I think a PR would definitely be a good idea 👍
Maybe you can split it into two PRs, one that requires the extra scalar_prod_op specialization, and one that doesn't?
In the long term, it would probably be best to get the specialization into Eigen itself, or to find another fix (like switching to thrust::complex).

@martinwicke martinwicke removed their assignment Sep 7, 2016
@woodshop
Copy link

I've been privately writing GPU-based complex-valued ops for TF and decided to make my repository public. I think that more general support for computation of complex numbers on the GPU will be valuable to the community. However since my repository is in the early stages and isn't well tested, I think I'd like to develop it as a separate project and then port it as a TF pull request when it's more mature. Feel free to make contributions and/or suggestions.

https://github.com/woodshop/complex_tf

@benoitsteiner
Copy link
Contributor

In C++14, std::complex methods are marked as constexpr. This will ensure that they can be used inside cuda kernels even though they're not marked as __device__ functions provided that we compile with the --relaxed-constexpr flag (which TensorFlow has been doing for some time now).

Unfortunately nvcc doesn't yet support c++14, but we can ask nvidia to start adding partial support for it starting with complex numbers.

@benoitsteiner benoitsteiner self-assigned this Sep 14, 2016
@rryan
Copy link
Member

rryan commented Sep 15, 2016

@iportillo ComplexAbs (and a few others) added here: f216420
and the corresponding Eigen change:
https://bitbucket.org/eigen/eigen/commits/6d4cd6e5cdd9c750b10cc4c6a374e4c513b267ed

@rryan
Copy link
Member

rryan commented Sep 28, 2016

After adding a workaround to Eigen:
https://bitbucket.org/eigen/eigen/commits/27f6140fa81c9fe83167d87e7aeb23031b42f344

We were able to enable addition, subtraction, division, and multiplication kernels for complex types on GPU: 93f15d4

@benoitsteiner
Copy link
Contributor

@sbrodeur Does TensorFlow now support all the operations you need on complex, or are there additional improvements we need to make ?

@sbrodeur
Copy link
Author

@benoitsteiner Tensorflow now supports everything I need for handling complex numbers.

@benoitsteiner
Copy link
Contributor

Thanks, closing the issue.

@wzm2256
Copy link

wzm2256 commented Jul 14, 2017

Is it possible to calculate a complex number divide a float number without type cast?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:contribution welcome Status - Contributions welcome type:bug Bug type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

10 participants