Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eager scatter_nd forward works with incorrect code #18648

Closed
mbosnjak opened this issue Apr 18, 2018 · 3 comments
Closed

eager scatter_nd forward works with incorrect code #18648

mbosnjak opened this issue Apr 18, 2018 · 3 comments
Assignees

Comments

@mbosnjak
Copy link

mbosnjak commented Apr 18, 2018

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
    code, see below
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
    OSX 10.12.6
  • TensorFlow installed from (source or binary):
    binary
  • TensorFlow version (use command below):
    v1.7.0-3-g024aecf414 1.7.0
  • Python version:
    3.6.1
  • Bazel version (if compiling from source):
    N/A
  • GCC/Compiler version (if compiling from source):
    N/A
  • CUDA/cuDNN version:
    N/A
  • GPU model and memory:
    N/A
  • Exact command to reproduce:
    see code below

Source code / logs

import tensorflow as tf
import tensorflow.contrib.eager as tfe

tfe.enable_eager_execution()

x_var = tfe.Variable(tf.random_uniform([10]))
y_var = tfe.Variable(tf.random_uniform([10]))

with tfe.GradientTape() as tape:
    dot = tf.scatter_nd(indices=[0],  # this should be indices=[[0]]
                        updates=[tf.einsum('i,i->', x_var, y_var)],
                        shape=[1])
print(dot)
gradient = tape.gradient(dot, [x_var, y_var])
print(gradient)

Error traceback:

Traceback (most recent call last):
  File "/Users/m/workspace/bug/experimental/bug.py", line 31, in <module>
    gradient = tape.gradient(dot, [x_var, y_var])
  File "/Users/m/workspace/bug/venv/lib/python3.6/site-packages/tensorflow/python/eager/backprop.py", line 764, in gradient
    output_gradients=output_gradients)
  File "/Users/m/workspace/bug/venv/lib/python3.6/site-packages/tensorflow/python/eager/imperative_grad.py", line 65, in imperative_grad
    tape._tape, vspace, target, sources, output_gradients, status)  # pylint: disable=protected-access
  File "/Users/m/workspace/bug/venv/lib/python3.6/site-packages/tensorflow/python/eager/backprop.py", line 141, in grad_fn
    op_inputs, op_outputs, orig_outputs)
  File "/Users/m/workspace/bug/venv/lib/python3.6/site-packages/tensorflow/python/eager/backprop.py", line 109, in _magic_gradient_function
    return grad_fn(mock_op, *out_grads)
  File "/Users/m/workspace/bug/venv/lib/python3.6/site-packages/tensorflow/python/ops/array_grad.py", line 39, in _PackGrad
    return array_ops.unstack(grad, num=op.get_attr("N"), axis=op.get_attr("axis"))
  File "/Users/m/workspace/bug/venv/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 1084, in unstack
    return gen_array_ops.unpack(value, num=num, axis=axis, name=name)
  File "/Users/m/workspace/bug/venv/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 8741, in unpack
    _six.raise_from(_core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: axis = 0 not in [0, 0) [Op:Unpack] name: unstack

Describe the problem

This is erroneous code which runs half-way. It will still calculate the forward pass, but fail on the backward pass. This does not happen with static graph tensorflow, where you get a correct error that tensor shapes are not matching.

@mbosnjak mbosnjak changed the title eager works half-way with uncorrect code in scatter_nd eager scatter_nd forward works with incorrect code Apr 18, 2018
@tensorflowbutler tensorflowbutler added the stat:awaiting response Status - Awaiting response from author label Apr 18, 2018
@tensorflowbutler
Copy link
Member

Thank you for your post. We noticed you have not filled out the following field in the issue template. Could you update them if they are relevant in your case, or leave them as N/A? Thanks.
Bazel version

@skye skye assigned alextp and unassigned skye Apr 18, 2018
@tensorflowbutler
Copy link
Member

It has been 14 days with no activity and the awaiting response label was assigned. Is this still an issue?

@mbosnjak
Copy link
Author

mbosnjak commented May 4, 2018

Checked tf1.8, it's still there.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label May 5, 2018
@yifeif yifeif closed this as completed in 0775f68 May 15, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants