Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interaction between tf.map_fn and tf.gradients #7643

Closed
daniellevy opened this issue Feb 18, 2017 · 5 comments
Closed

Interaction between tf.map_fn and tf.gradients #7643

daniellevy opened this issue Feb 18, 2017 · 5 comments
Labels
stat:awaiting response Status - Awaiting response from author type:bug Bug

Comments

@daniellevy
Copy link

Hi,

I am using Tensorflow v0.11 and I have tried on Mac OS X and Centos 6

I am running into an error when running the following code:

W = tf.get_variable('W', (5, 3))

x = tf.placeholder(tf.float32, shape=(None, 5))

h = tf.matmul(x, W)

grads = tf.map_fn(lambda x: tf.gradients(x, W)[0], h)

I basically want to have the following but without a fixed batch size:
grads = [tf.gradients(h[t], W)[0] for t in range(batch_size)]

My error is:

Invalid argument: TensorArray map/TensorArray_1@map/while/gradients: Could not write to TensorArray index 3 because it has already been read.
[...]
tensorflow.python.framework.errors.InvalidArgumentError: TensorArray map/TensorArray_1@map/while/gradients: Could not write to TensorArray index 3 because it has already been read.
	 [[Node: map/while/gradients/map/while/TensorArrayRead_grad/TensorArrayWrite = TensorArrayWrite[T=DT_FLOAT, _class=["loc:@map/TensorArray"], _device="/job:localhost/replica:0/task:0/cpu:0"](map/while/gradients/map/while/TensorArrayRead_grad/TensorArrayGrad/TensorArrayGrad, map/while/Identity, map/while/gradients/Fill, map/while/gradients/map/while/TensorArrayRead_grad/TensorArrayGrad/gradient_flow)]]

I have tried the following workaround using scan instead of map_fn with a zero initializer but to no avail:

initializer = np.zeros((5, 3)).astype('float32')
grads = tf.scan(
	lambda a, x: tf.gradients(x, W)[0],
	h,
	initializer)

Is this a know issue?

@aselle
Copy link
Contributor

aselle commented Feb 18, 2017

@yuanbyu, do you have any ideas? @daniellevy , could you try 1.0, please and see if this is still a problem.

@aselle aselle added stat:awaiting response Status - Awaiting response from author type:bug Bug labels Feb 18, 2017
@daniellevy
Copy link
Author

Yes the problem still occurs with v1.0 with the same error message.

@aselle aselle removed the stat:awaiting response Status - Awaiting response from author label Feb 19, 2017
@aselle
Copy link
Contributor

aselle commented Feb 19, 2017

It looks like this is a known issue and this issue is probably a duplicate of #3972. There @yuanbyu suggests using tf.while_loop instead of tf.map_fn. Give that a try and let us know if that works. Thanks!

@aselle aselle added stat:awaiting tensorflower Status - Awaiting response from tensorflower stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Feb 19, 2017
@dillonalaird
Copy link

I think this works

max_seq_len = 10
x = tf.placeholder(tf.float32, [None, 5])
W = tf.get_variable("W", [5, 3])
h = tf.matmul(x, W)

def body(old_g, t):
    g = tf.gradients([h[t]], [W])[0]
    new_g = tuple(tf.select(tf.equal(ti, t), g, old_g[ti]) for ti in range(len(old_g)))
    return new_g, t + 1

def cond(_, t):
    return tf.less(t, tf.shape(h)[0])

grads = tf.while_loop(cond, body, [(tf.zeros_like(W),)*max_seq_len, tf.constant(0)])

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    grads_out = sess.run(grads, feed_dict={x: np.random.randn(2*5).reshape(2, 5)})

You basically just build a giant empty tuple and fill it in as you go. dynmaic_rnn does something similar to this.

@aselle
Copy link
Contributor

aselle commented Mar 3, 2017

Automatically closing due to lack of recent activity. Please update the issue when new information becomes available, and we will reopen the issue. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting response Status - Awaiting response from author type:bug Bug
Projects
None yet
Development

No branches or pull requests

3 participants