Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Computing gradients within tf.while_loop #14101

Closed
pramodkaushik opened this issue Oct 30, 2017 · 11 comments
Closed

Computing gradients within tf.while_loop #14101

pramodkaushik opened this issue Oct 30, 2017 · 11 comments
Assignees
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower

Comments

@pramodkaushik
Copy link

I am posting this here because a similar question on stackoverflow is still unanswered, so I suspect it might be a bug.

Adding gradient ops within a tf.while_loop for computing gradients of loop variables w.r.t external variables results in an error.

Program reproducing the error:

import numpy as np
import tensorflow as tf
tf.reset_default_graph()
F = lambda x: tf.cumsum(x)
G = lambda x: x[-1]
H = lambda x: x
encoder_emb_inp = tf.placeholder(dtype=tf.float32, shape=[4])
encoder_outputs = F(encoder_emb_inp)
decoder_initial_state = G(encoder_outputs)
decoder_initial_output = H(decoder_initial_state)
def cond(time, unused_state, unused_output):
    return tf.less(time, 3)

def body(time, state, inputs):
    step = lambda s, i: (tf.multiply(s,s), tf.multiply(s,i))
    (next_state, next_output) = step(state, inputs)
    next_grads = tf.gradients(next_output, decoder_initial_state)
    tf.Print(next_grads, next_grads)
    return (time + 1, next_state, next_output)

initial_time = tf.constant(0, dtype=tf.int32)


final_time, final_state, final_outputs = tf.while_loop(cond, body, loop_vars = [initial_time, decoder_initial_state, decoder_initial_output])

Error message:

<ipython-input-126-5205901211cc> in body(time, state, inputs)
      6     (next_state, next_output) = step(state, inputs)
      7 #    next_grads = tf.gradients(next_output, state)
----> 8     next_grads = tf.gradients(next_output, decoder_initial_state)
      9     tf.Print(next_grads, next_grads)
     10     return (time + 1, next_state, next_output)

/home/pramodkm/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/gradients_impl.pyc in gradients(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients)
    591                 out_grads[i] = loop_state.ZerosLike(op, i)
    592               else:
--> 593                 out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
    594           with ops.name_scope(op.name + "_grad"):
    595             # pylint: disable=protected-access

/home/pramodkm/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.pyc in ZerosLikeOutsideLoop(op, index)
   1342     if op_ctxt:
   1343       # We are in a cond context. Use a switch to create zeros only when needed.
-> 1344       pred = op_ctxt.pred
   1345       branch = op_ctxt.branch
   1346       switch_val = switch(op.inputs[0], pred)[1 - branch]

AttributeError: 'WhileContext' object has no attribute 'pred'

I am using tf-nightly-gpu 1.5.0-dev20171026

Thanks!

@tensorflowbutler
Copy link
Member

It has been 14 days with no activity and this issue has an assignee.Please update the label and/or status accordingly.

@skye
Copy link
Member

skye commented Jan 2, 2018

No update at this time. I'll try to take a look soon but have a lot of other stuff on my plate currently.

@tensorflowbutler
Copy link
Member

Nagging Assignee: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@mholzel
Copy link
Contributor

mholzel commented Jan 23, 2018

It seems like this is the same as #9450.

@tensorflowbutler
Copy link
Member

Nagging Assignee: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

1 similar comment
@tensorflowbutler
Copy link
Member

Nagging Assignee: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@tensorflowbutler
Copy link
Member

Nagging Assignee @skye: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@skye skye added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Mar 23, 2018
@skye
Copy link
Member

skye commented Mar 31, 2018

Hey, I finally took a look at this. Unfortunately I believe this is impossible to compute without a major change to how we compute gradients :( I'll at least work on providing a better error message though.

@pramodkaushik
Copy link
Author

Thanks!

@tensorflowbutler
Copy link
Member

Nagging Assignee @skye: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

1 similar comment
@tensorflowbutler
Copy link
Member

Nagging Assignee @skye: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@skye skye closed this as completed May 1, 2018
@skye skye reopened this May 1, 2018
@yifeif yifeif closed this as completed in 482ed8e May 9, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower
Projects
None yet
Development

No branches or pull requests

4 participants