Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating ExponentialMovingAverage based on a condition #1724

Closed
bayerj opened this issue Mar 31, 2016 · 7 comments
Closed

Updating ExponentialMovingAverage based on a condition #1724

bayerj opened this issue Mar 31, 2016 · 7 comments
Assignees

Comments

@bayerj
Copy link

bayerj commented Mar 31, 2016

Using tensorflow 0.7.1.

I have followed issue #804 to make use of batch normalization with tensorflow. However, I struggle to get the snippet to work as expected and have nailed down the problem to something more concise.

Let us assume we want to perform a moving average over incoming scalars; sometimes we want to update the statistics of that moving average, sometimes we don't. We will model that with a placeholder do_update, which we can set to True or False in the feed_dict passed to sess.run(...). Basically, this is the code from #804 but greatly simplified.

import tensorflow as tf
import numpy as np

inpt = tf.Variable(np.array([1.]))
do_update = tf.placeholder(tf.bool)

ema = tf.train.ExponentialMovingAverage(.9)
ema_assign = ema.apply([inpt])


def update():
    with tf.control_dependencies([ema_assign]):
        return tf.identity(ema.average(inpt))      # note the identity.

def no_update():
    return ema.average(inpt)

run = tf.python.control_flow_ops.cond(do_update, update, no_update)

However, when I execute run updating will happen. It does not matter what the value of do_update is.

print run.eval({inpt: np.array([2.]), do_update: True})
print run.eval({inpt: np.array([2.]), do_update: True})
print run.eval({inpt: np.array([2.]), do_update: True})

# prints:
# [ 1.10000002]
# [ 1.19000004]
# [ 1.27100006]

print run.eval({inpt: np.array([2.]), do_update: False})
print run.eval({inpt: np.array([2.]), do_update: False})
print run.eval({inpt: np.array([2.]), do_update: False})

# prints:
# [ 1.34390007]
# [ 1.40951008]
# [ 1.46855908]

Curiously, if I remove the tf.identity above in the definition of update, neither of them performs an update after starting a new session.

print run.eval({inpt: np.array([2.]), do_update: True})
print run.eval({inpt: np.array([2.]), do_update: True})
print run.eval({inpt: np.array([2.]), do_update: True})

# prints:
# [ 1.]
# [ 1.]
# [ 1.]

print run.eval({inpt: np.array([2.]), do_update: False})
print run.eval({inpt: np.array([2.]), do_update: False})
print run.eval({inpt: np.array([2.]), do_update: False})

# prints:
# [ 1.]
# [ 1.]
# [ 1.]

This seems as uninteded behaviour to me, but maybe I am missing something.

@rdipietro
Copy link
Contributor

Reproduced this. I'm compiling from source, commit b4b276e.

Even more confusing to me:

import tensorflow as tf
import numpy as np

x = tf.Variable(1.)
update_x = tf.assign_add(x, 1.0)

do_update = tf.placeholder(tf.bool)

ema = tf.train.ExponentialMovingAverage(.9)
ema_assign = ema.apply([x])

avg_without_update = ema.average(x)

with tf.control_dependencies([ema_assign]):
    avg_with_update = tf.identity(avg_without_update)

avg = tf.python.control_flow_ops.cond(do_update, lambda: avg_with_update, lambda: avg_without_update)

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print(sess.run([update_x, avg], {do_update: False}))
    print(sess.run([update_x, avg], {do_update: False}))
    print(sess.run([update_x, avg], {do_update: False}))
    print(sess.run([update_x, avg], {do_update: False}))
    print(sess.run([update_x, avg], {do_update: False}))

results in

[2.0, 1.0]
[3.0, 1.1]
[4.0, 1.3900001]
[5.0, 1.7510002]
[6.0, 2.0759003]

@mikowals
Copy link
Contributor

mikowals commented Apr 1, 2016

I think the dependencies set in a tf.cond() function are always getting executed. For example:

dummy = tf.Print(1, ['dummy'])

def true_fn():
  with tf.control_dependencies([dummy]):
    return tf.Print(True,['in true_fn'])

def false_fn():
  return tf.Print(False,['in false_fn'])

run = tf.python.control_flow_ops.cond(tf.constant(False), true_fn, false_fn)

with tf.Session() as sess:
  print run.eval()

will output:

I tensorflow/core/kernels/logging_ops.cc:79] [in false_fn]
I tensorflow/core/kernels/logging_ops.cc:79] [dummy]
False

@rdipietro
Copy link
Contributor

Interesting. This might imply that the sequence_length complications in python.ops.rnn._rnn_step are actually saving no computation time.

@aymericdamien
Copy link
Contributor

Is there any update regarding this issue?

@mufan-li
Copy link

mufan-li commented Apr 24, 2016

@mikowals is correct, tf.cond() seems to be executing both fn1 and fn2 regardless of the condition. I have figured out a work around without using tf.train.ExponentialMovingAverage

import tensorflow as tf
import numpy as np

inpt = tf.Variable(np.array([1.0]))
do_update = tf.placeholder(tf.bool)

prev_ema = tf.Variable(np.array([0.0]))
decay = 0.9

new_ema = (1-decay)*inpt + decay*prev_ema

ema = tf.cond(do_update, lambda: new_ema, lambda: prev_ema)

assign_op = prev_ema.assign(ema)

with tf.control_dependencies([assign_op]):
    cur_ema = tf.identity(prev_ema)

Notice that now update becomes an assign_op that executes each time, however with a different value. Now I can execute the following with desired behavior

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    sess.run(cur_ema, {inpt:np.array([2.]), do_update: True})
    print prev_ema.eval()
    sess.run(cur_ema, {inpt:np.array([2.]), do_update: True})
    print prev_ema.eval()
    sess.run(cur_ema, {inpt:np.array([2.]), do_update: True})
    print prev_ema.eval()

    sess.run(cur_ema, {inpt:np.array([2.]), do_update: False})
    print prev_ema.eval()
    sess.run(cur_ema, {inpt:np.array([2.]), do_update: False})
    print prev_ema.eval()
    sess.run(cur_ema, {inpt:np.array([2.]), do_update: False})
    print prev_ema.eval()

    sess.run(cur_ema, {inpt:np.array([2.]), do_update: True})
    print prev_ema.eval()
    sess.run(cur_ema, {inpt:np.array([2.]), do_update: True})
    print prev_ema.eval()

Which results in

[ 0.2]
[ 0.38]
[ 0.542]
[ 0.542]
[ 0.542]
[ 0.542]
[ 0.6878]
[ 0.81902]

@mikowals
Copy link
Contributor

There was an explanation / workaround posted to a issue #2062 a couple of days ago. The dependency node needs to be moved inside the conditional function.

So an EMA update from this issues works as expected if the condition is done like this:

import tensorflow as tf
import numpy as np

inpt = tf.Variable(np.array([1.]))
do_update = tf.placeholder(tf.bool)

ema = tf.train.ExponentialMovingAverage(.9)

def update():
  ema_assign = ema.apply([inpt])
  with tf.control_dependencies([ema_assign]):
      return tf.identity(ema.average(inpt))      # note the identity.

def no_update():
  return ema.average(inpt)

run = tf.python.control_flow_ops.cond(do_update, update, no_update)

@yuanbyu
Copy link
Contributor

yuanbyu commented Apr 25, 2016

Thanks for adding a reference to issue #2062! I meant but forgot to do that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants