Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect behavior from tf.layers.batch_normalization() when training=0 #10118

Closed
somewacko opened this issue May 22, 2017 · 8 comments
Closed

Comments

@somewacko
Copy link

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 16
  • TensorFlow installed from (source or binary): pip
  • TensorFlow version (use command below): ('v1.1.0-rc0-61-g1ec6ed5', '1.1.0')
  • Bazel version (if compiling from source): N/A
  • CUDA/cuDNN version: 8/5
  • GPU model and memory: Nvidia Titan X
  • Exact command to reproduce: gist

Describe the problem

I've noticed that tf.layers.batch_normalization doesn't seem to give reasonable results when training=0 (i.e. use distribution statistics instead of just the batch), especially if you apply BN before activations (e.g. ResNet-like architectures).

Using the Gist above, if you try to fit a model to noise with SGD (lr=0.01) using repeated applications of dense matrix multiplication -> batch normalization -> ReLU activations, you get this loss for the same inputs over time: (blue: training=1, green: training=0)

image

Using an Adam (lr=0.001) optimizer instead gets even weirder results:

image

However, if I use my own implementation of batch norm (included in gist) I get reasonable results, with the loss for each being similar to each other: (Adam has similar behavior)

image

(Interestingly this doesn't seem to be as much of a problem if you have ReLU before BN, I haven't thought too deeply about why.)

Am I seeing things and just have some misunderstanding about what that function is doing, or is this actually a bug?

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented May 23, 2017

The document of tf.layers.batch_normalization said this:

Note: the operations which update the moving_mean and moving_variance variables will not be added as dependencies of your training operation and so must be run separately. For example:
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
sess.run([train_op, extra_update_ops], ...)

But you didn't use it in your code.

@somewacko
Copy link
Author

Where did you find that? The API reference says nothing about needing to run extra operations or adding extra operations to UPDATE_OPS.

The entire description:

Defined in tensorflow/python/layers/normalization.py.

Functional interface for the batch normalization layer.

Reference: http://arxiv.org/abs/1502.03167

"Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift"

Sergey Ioffe, Christian Szegedy

@ppwwyyxx
Copy link
Contributor

Sorry I was reading an old version of the doc. The way to run the update may have changed.

@ppwwyyxx
Copy link
Contributor

No. I was actually reading a newer version of doc. See here https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/layers/normalization.py#L338

@somewacko
Copy link
Author

Oh wow, surprised that's not on the website... I'll close this then, since it looks like the change to the doc has already been made.

@3rd3
Copy link

3rd3 commented Jun 16, 2017

Just for the sake of completeness, here is the recommended code from the docs:

      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss)

@akssieg
Copy link

akssieg commented Oct 25, 2017

I am using
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
then..
sess.run(train_step,update_ops,feed_dict={x: batch[0], y_: batch[1],istrain:True})

but it is showing error as
TypeError: run() got multiple values for argument 'feed_dict'.

If I don't include "update_ops" in the sess.run() then it works fine but i need to create a dependency.
Could anyone tell that what's wrong:

@somewacko
Copy link
Author

@akssieg 1. You should make update_ops a dependency as mentioned above, and 2. you're not calling sess.run() correctly – all of your fetches/ops should be a single list/dict, not multiple arguments.

Try using:

sess.run([train_step, update_ops], feed_dict={...})

(notice the square brackets)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants