Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.cond not working with depedencies #2062

Closed
alex04309 opened this issue Apr 22, 2016 · 2 comments
Closed

tf.cond not working with depedencies #2062

alex04309 opened this issue Apr 22, 2016 · 2 comments
Assignees

Comments

@alex04309
Copy link

alex04309 commented Apr 22, 2016

tf.cond seems to have a bug if one of the condition have a dependency. (Dependencies are run, whatever tf.cond arg is True or False).

To illustrate:

import tensorflow as tf

a = tf.Variable(0)
incr = a.count_up_to(1)

def todo_if_true():
  with tf.control_dependencies([incr]):
    return tf.identity(a)
def todo_if_false():
  return tf.identity(a)

g = tf.cond(tf.constant(False), todo_if_true, todo_if_false)
init = tf.initialize_all_variables()

with tf.Session() as sess:
  sess.run(init)
  print(sess.run(g))

Output:

1 #But should be 0
@prb12
Copy link
Member

prb12 commented Apr 22, 2016

@yuanbyu - any ideas?

@yuanbyu
Copy link
Contributor

yuanbyu commented Apr 22, 2016

You need to move the count_up_to op inside the conditional branch you want it to be executed.

a = tf.Variable(0)

def todo_if_true():
  incr = a.count_up_to(1)
  with tf.control_dependencies([incr]):
    return tf.identity(a)
def todo_if_false():
  return tf.identity(a)

g = tf.cond(tf.constant(False), todo_if_true, todo_if_false)

I have added the following paragraph to the doc:

Note that the conditional execution applies only to the operations defined in
fn1 and fn2. Consider the following simple program:

  z = tf.mul(a, b)
  result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))

If x < y, the tf.add operation will be executed and tf.square
operation will not be executed. Since z is needed for at least one
branch of the cond, the tf.mul operation is always executed, unconditionally.
Although this behavior is consistent with the dataflow model of TensorFlow,
it has occasionally surprised some users who expected a lazier semantics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants