Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NotFoundError when using an optimizer on complex variables #44834

Closed
davidho95 opened this issue Nov 13, 2020 · 13 comments
Closed

NotFoundError when using an optimizer on complex variables #44834

davidho95 opened this issue Nov 13, 2020 · 13 comments
Assignees
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.2 Issues related to TF 2.2 type:bug Bug

Comments

@davidho95
Copy link

davidho95 commented Nov 13, 2020

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes (minimal working example provided)
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux-3.10.0-1127.19.1.el7.x86_64-x86_64-with-glibc2.10
  • TensorFlow installed from (source or binary): Binary (installed using conda)
  • TensorFlow version (use command below): 2.2.0
  • Python version: 3.8.5
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: 10.1/7.6.5
  • GPU model and memory: Quadro P1000 96GB

Describe the current behavior
When I attempt to optimise a loss function in complex variables, I get a NotFoundError when using the apply_gradients function. The error persists for all optimisers that I have tried (SGD is shown in the example). If I replace the complex variable with a float there are no issues.

Describe the expected behavior
Apply_gradients should carry out an SGD step.

Standalone code to reproduce the issue

import tensorflow as tf
import numpy as np

print(tf.config.list_physical_devices('GPU'))

# Initialise a complex matrix
mat = tf.random.uniform([1000, 1000], dtype=tf.float64)
mat = tf.complex(mat, mat)

var = tf.Variable(mat, trainable=True)

# Return the squared norm of this matrix as the loss function
def lossFn():
    return tf.math.abs(tf.linalg.trace(var @ tf.linalg.adjoint(var)))

# SGD optimizer
opt = tf.keras.optimizers.SGD(learning_rate=0.01)

numSteps=0
while numSteps < 100:
    with tf.GradientTape() as tape:
        loss = lossFn()
    grads = tape.gradient(loss, [var])

    # This is the step that fails
    opt.apply_gradients(zip(grads, [var]))
    numSteps += 1
    print(loss.numpy())

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

Traceback (most recent call last):
File "gpuTest.py", line 25, in
opt.apply_gradients(zip(grads, [var]))
File "/rds/general/user/dlh16/home/anaconda3/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 504, in apply_gradients
return distribute_ctx.get_replica_context().merge_call(
File "/rds/general/user/dlh16/home/anaconda3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2420, in merge_call
return self._merge_call(merge_fn, args, kwargs)
File "/rds/general/user/dlh16/home/anaconda3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2427, in _merge_call
return merge_fn(self._strategy, *args, **kwargs)
File "/rds/general/user/dlh16/home/anaconda3/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 282, in wrapper
return func(*args, **kwargs)
File "/rds/general/user/dlh16/home/anaconda3/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 591, in _distributed_apply
update_ops.extend(distribution.extended.update(
File "/rds/general/user/dlh16/home/anaconda3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2013, in update
return self._update(var, fn, args, kwargs, group)
File "/rds/general/user/dlh16/home/anaconda3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2659, in _update
return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
File "/rds/general/user/dlh16/home/anaconda3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2665, in _update_non_slot
result = fn(*args, **kwargs)
File "/rds/general/user/dlh16/home/anaconda3/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 282, in wrapper
return func(*args, **kwargs)
File "/rds/general/user/dlh16/home/anaconda3/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 567, in apply_grad_to_update_var
update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
File "/rds/general/user/dlh16/home/anaconda3/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/gradient_descent.py", line 143, in _resource_apply_dense
return training_ops.resource_apply_gradient_descent(
File "/rds/general/user/dlh16/home/anaconda3/lib/python3.8/site-packages/tensorflow/python/training/gen_training_ops.py", line 1908, in resource_apply_gradient_descent
_ops.raise_from_not_ok_status(e, name)
File "/rds/general/user/dlh16/home/anaconda3/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 6653, in raise_from_not_ok_status
six.raise_from(core._status_to_exception(e.code, message), None)
File "", line 3, in raise_from
tensorflow.python.framework.errors_impl.NotFoundError: No registered 'ResourceApplyGradientDescent' OpKernel for 'GPU' devices compatible with node {{node ResourceApplyGradientDescent}}
(OpKernel was found, but attributes didn't match) Requested Attributes: T=DT_COMPLEX128, use_locking=true
. Registered: device='XLA_GPU'; T in [DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16, DT_COMPLEX128, DT_HALF]
device='XLA_CPU'; T in [DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16, DT_COMPLEX128, DT_HALF]
device='XLA_CPU_JIT'; T in [DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16, DT_COMPLEX128, DT_HALF]
device='GPU'; T in [DT_DOUBLE]
device='GPU'; T in [DT_FLOAT]
device='GPU'; T in [DT_HALF]
device='CPU'; T in [DT_COMPLEX128]
device='CPU'; T in [DT_COMPLEX64]
device='CPU'; T in [DT_DOUBLE]
device='CPU'; T in [DT_FLOAT]
device='CPU'; T in [DT_BFLOAT16]
device='CPU'; T in [DT_HALF]
device='XLA_GPU_JIT'; T in [DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16, DT_COMPLEX128, DT_HALF]
[Op:ResourceApplyGradientDescent]

@Saduf2019
Copy link
Contributor

@davidho95
I ran the code on tf nightly and do not face any error please find the gist here.

@Saduf2019 Saduf2019 added the stat:awaiting response Status - Awaiting response from author label Nov 13, 2020
@davidho95
Copy link
Author

Hi Saduf,

Thank you for your reply.

When I run the notebook you link, the line print(tf.config.list_physical_devices('GPU')) outputs an empty list, and the calculation is much slower than expected on a GPU. I think the process is running on the CPU rather than the GPU; is there any way to test this?

@Saduf2019
Copy link
Contributor

@davidho95
I ran the code on gpu and its much faster but with the fist output as blank, please find the gist here.

@davidho95
Copy link
Author

When I try to run the notebook the log files show:

2020-11-13 17:18:05.423840: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib64-nvidia

and other errors indicating no GPU is being used

@Saduf2019
Copy link
Contributor

@davidho95
Please click on "runtime">"change runtime">"gpu" option and re run.

@davidho95
Copy link
Author

This still occurs even with the "runtime">"change runtime">"gpu" set

@davidho95
Copy link
Author

davidho95 commented Nov 14, 2020

Apparently this is an issue that others have experienced with tf-nightly builds on colab

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Nov 16, 2020
@Saduf2019 Saduf2019 added TF 2.2 Issues related to TF 2.2 comp:gpu GPU related issues comp:keras Keras related issues labels Nov 17, 2020
@Saduf2019 Saduf2019 assigned ymodak and unassigned Saduf2019 Nov 17, 2020
@ymodak ymodak added comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed comp:gpu GPU related issues comp:keras Keras related issues labels Nov 17, 2020
@rohan100jain
Copy link
Member

This issue here is that we don't have good complex128 support on GPU for these kernels https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/training_ops.cc#L838

This depends on some support in Eigen - https://gitlab.com/libeigen/eigen/-/issues/1905 which hasn't been prioritized for some time and I'm not sure how quickly we'll be able to get to it.

Is this blocking your work?

@davidho95
Copy link
Author

This issue here is that we don't have good complex128 support on GPU for these kernels https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/training_ops.cc#L838

This depends on some support in Eigen - https://gitlab.com/libeigen/eigen/-/issues/1905 which hasn't been prioritized for some time and I'm not sure how quickly we'll be able to get to it.

Is this blocking your work?

Hi Rohan,

Thank you for this, it's good to know it is not a problem with my build. Initially I thought this was prohibitive to my work, but actually I have realised the nature of the calculations I am doing don't benefit from GPU optimisation (they involve batch multiplication of large numbers of small matrices). So my current work no longer requires GPU support.

In the future calculations with complex numbers on GPUs may be very useful though: I work in theoretical physics where complex numbers are ubiquitous. I think TensorFlow's tools are very well suited to calculations in my field and if this support were available I think there is a significant potential user base.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Nov 21, 2020
@ymodak ymodak added type:bug Bug stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed type:bug Bug labels Nov 24, 2020
@shandilya1998
Copy link

Hey! I am working on complex valued neural networks that require complex valued backpropagation. Is there any chance this will be released anytime soon?

@xusky69
Copy link

xusky69 commented Mar 14, 2021

In case this is useful to anyone, I found a workaround for this issue by definining the real and complex part of each weight separately and joining them as a complex tensor using tf.complex() during forprop.

Howerver, I'm not entirely sure how this affects the computation of momenta of optimizers such as Adam. It worked just as expected on my experiments tho.

@sachinprasadhs
Copy link
Contributor

Was able to run the code without any error in Tensorflow GPU 2.5, please find the gist here.
Also, closing the issue since it is resolved in the latest version, feel free to reopen the issue. Thanks!

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.2 Issues related to TF 2.2 type:bug Bug
Projects
None yet
Development

No branches or pull requests

8 participants