Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tf.contrib.layers.batch_norm() docs #4361

Closed
b3nk4n opened this issue Sep 13, 2016 · 16 comments
Closed

Update tf.contrib.layers.batch_norm() docs #4361

b3nk4n opened this issue Sep 13, 2016 · 16 comments
Labels
stat:awaiting response Status - Awaiting response from author type:bug Bug

Comments

@b3nk4n
Copy link

b3nk4n commented Sep 13, 2016

Tensorflow version that I use : 0.10 (pip package)


I took heavy use of tf.contrib.layers.batch_norm() the last weeks.

After facing some problems on how to use it correctly, I figured out that there are many devs out there who are confused as well, such as here:

I would suggest to do following improvements to make it more clear:

1) Update example in doc-string:

The example tells in case we use update_collections on its defaults, we have to include this:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
    updates = tf.group(update_ops)
    total_loss = control_flow_ops.with_dependencies([updates], total_loss)

But this is actually not working or deprecated, as it throws errors. Instead, we have to do some tiny changes. I would suggest to update the docs as follows:

from tensorflow.python import control_flow_ops

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
    updates = tf.tuple(update_ops)
    total_loss = control_flow_ops.with_dependencies(updates, total_loss)

As a side question, why do we apply it to the total_loss, and not to the train_op directly, as described in the doc-string text. Added a dependency to total_loss works, but grouping it with the train_op would make the example more clear in my opinion, because we do batch-statistic updates only during training.

2) UPDATE_OPS in combination with reuse varscope:

This is related to the question above. Let's say we have a model with which reuses an convolutional encoder (and also its batch-norm-layers) several times. Even when we reuse these layers, the update operation for the batch-statistics is added to UPDATE_OPS nevertheless. Personally, I'm not sure if this is a bug, or if this is really what should be done?
Or is it required to filter the update-ops after collecting them with update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS), so that each one is executed just once?

To sum this up: Am I wrong that lines 213-215 should not be executed when reuse=True? So changing it to:

if not reuse:
    # Collect the updates to be computed later.
    ops.add_to_collections(updates_collections, update_moving_mean)
    ops.add_to_collections(updates_collections, update_moving_variance)

In my case, I'm using a Conv-LSTM-Conv_tp architecture, where I reuse the Conv/Conv_tp for each timestep. When I increase the number of timesteps in the LSTM, the number of update-ops increases in proportionally, while the number of model-parameters stays constant because they get reused. Currently, I'm getting 420 update-ops when calling tf.get_collection(tf.GraphKeys.UPDATE_OPS). As the performance feels super slow when I use batch-norm, I guess this high number of update-ops cannot be right.

3) Handling of is_training parameter:

I have seen a lot of examples people doing something like this in their code to handle the is_training parameter:

def batch_norm_layer(x,train_phase,scope_bn):
    bn_train = batch_norm(x, decay=0.999, center=True, scale=True,
    updates_collections=None,
    is_training=True)
    bn_inference = batch_norm(x, decay=0.999, center=True, scale=True,
    updates_collections=None,
    is_training=False)
    bn = tf.cond(train_phase, lambda: bn_train, lambda: bn_inference)
    return bn

As far as I know, this was really required in the past, because is_training was just a Boolean. But since the param can be a Bool-Tensor as well, this is not required anymore. Since many devs are still ding this workaound, added a comment to the doc-string that this is not required anymore could be helpful.

4) Usage on Multi-GPU configuration

a) When I optimize my code for multi-GPU systems (as in the CIFAR10 example) the number of update-ops increases with the factor of num_gpus (might be related to 2) ).

b) When I use tf.contrib.batch_norm() within a multi-GPU system, I get an error like this:

InvalidArgumentError: Cannot assign a device to node 'tower_1/inference/ConvStack/x_bn_9/moments/sufficient_statistics/SparseToDense': 
Could not satisfy explicit device specification '/device:GPU:1' because no supported kernel 
for GPU devices is available.
...

Hence, to we have to wrap evey batch_norm() call with tf.device("/cpu:0")? I guess this might have bad impacts on performance, right?

Thanks!

PS: Sorry in case this question would fits better to StackOverflow. As it is a combination of suggested improvements and questions. Just let me know...

@argman
Copy link

argman commented Sep 15, 2016

Agree, I believe there is bug in batch_norm.

@b3nk4n
Copy link
Author

b3nk4n commented Sep 15, 2016

With bug in batch_norm, which point's of my list do you actually mean? And could you propose any workaround?

@argman
Copy link

argman commented Sep 17, 2016

Dont know why, I cannot do multi-gpu training when batch_norm moving_avg is applied, but when I update my tf to master version and update my cuda,cudnn, the problem go away.

@jmchen-g
Copy link
Contributor

@shlens Could you take a look at this? Thanks.

@jmchen-g jmchen-g added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Sep 23, 2016
@shlens
Copy link
Contributor

shlens commented Sep 23, 2016

@bsautermeister would you have a suggested edit on the docstring that would make the layer more clear?

@argman@, it sounds like your error is fixed, correct?

@argman
Copy link

argman commented Sep 24, 2016

@shlens , yes, I just update tf to the newest

@dasabir
Copy link

dasabir commented Oct 20, 2016

Is reuse=True working? Whenever I'm trying 'reuse=True' I get errors like - "Variable norm0/beta does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?" I'm following the docstring and providing the 'scope' too. As far as I understand, when a variable is to be created using tf.get_variable() and reused, first, it has to be created and then its reuse is to be enabled by using - tf.get_variable_scope().reuse_variables().
Without "reuse=True" in 'tf.contrib.layers.batch_norm()', I think, the right moving mean and variances will not be restored.
I'm using twnsorflow version 0.11

Please inform me if this is not the right place to raise this issue. I got to it from #1122

@jfsantos
Copy link
Contributor

I have the same issue as @dasabir when trying to reuse a batch_norm layer within a variable scope.

@wookayin
Copy link
Contributor

wookayin commented Nov 1, 2016

For (2), I agree with @bsautermeister because as I believe adding dependences on train_op looks sound. For some reasons, one may compute loss value (i.e. forward-prop) for validation datapoints; but with dependences on loss batch-normalization statistics are also taken from validation set.

For (3), do we need to share the BN parameters for bn_train and bn_inference? (in the original code different BN variables like beta, gamma are present for those two)

 def batch_norm_layer(x, train_phase, scope_bn):
   bn_train = batch_norm(x, decay=0.999, center=True, scale=True,
-  updates_collections=None, is_training=True)
+  updates_collections=None, is_training=True, scope=scope_bn)
   bn_inference = batch_norm(x, decay=0.999, center=True, scale=True,
-  updates_collections=None, is_training=False)
+  updates_collections=None, is_training=False, scope=scope_bn, reuse=True)
   bn = tf.cond(train_phase, lambda: bn_train, lambda: bn_inference)
   return bn

NOTE: I simply ignored the invalid moving average/variance update in the code for simplicity.

@wjiangcmu
Copy link

wjiangcmu commented Nov 2, 2016

@dasabir and @jfsantos I had same issue. But by speficying the scope_name for batch_norm, the issue was fixed. Under a scope with reusable=True, tf.contrib.layers.batch_norm(x) will always create new norm_variables and make them reusable which gives you the error. One thing you can do it is to specify the norm_scope name like this tf.contrib.layers.batch_norm(x, scope="name"). When you reuse this norm layer, just do tf.contrib.layers.batch_norm(x, scope="name", reuse=True) or use tf.contrib.layers.batch_norm(x, scope="name") under a reusable scope. Hope this is helpful.

@RuiShu
Copy link

RuiShu commented Dec 27, 2016

I noticed that the docs haven't been updated yet. Would it be useful if the docs instead said:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(total_loss)

As for proper reuse across multiple data streams, it looks like a shareable version is still in the works.

As an aside, to the best of my understanding, the notion of a shareable BN layer should be treated with some care. Depending on the use-case, I think there should be an option to distinguish sharing of the moving averages from the sharing of the beta/gamma parameters as noted here.

@aselle aselle removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jan 27, 2017
@drpngx
Copy link
Contributor

drpngx commented Jan 27, 2017

Is this still a problem with tf.nn.batch_norm?

@drpngx drpngx added stat:awaiting response Status - Awaiting response from author type:bug Bug labels Jan 27, 2017
@aselle
Copy link
Contributor

aselle commented Feb 13, 2017

Closing due to lack of recent activity. Please update the issue if it persists and we will reopen.

@aselle aselle closed this as completed Feb 13, 2017
@jianlong-yuan
Copy link

When you use batch normalization across multi gpus, how to update variance?

@drewanye
Copy link

drewanye commented Apr 6, 2018

I solve the problem of reusing batch_normalization by specifying reuse=False when first creating bn(I use slim, but it's same to tf.layers.batch_normalization):

scope = tf.get_variable_scope()
bn1 = slim.batch_norm(input1, decay=0.9, reuse=False, scope=scope, is_training=is_training)
bn2 = slim.batch_norm(input2, decay=0.9, reuse=True, scope=scope, is_training=is_training)

You have to specify reuse=False at your first time to create parameters in batch normalization. Or you will get the error info:
Variable cnn/block1/conv1/beta does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?

@qingchenwuhou
Copy link

qingchenwuhou commented Jun 7, 2018

I obey @wjiangcmu 's advice, it works.
the code:
33 self.is_training = tf.placeholder(tf.bool, name='MODE')
// first use:
94 self.img_bn1 = tf.cond(self.is_training,
95 lambda: batch_norm(self.img_fc1, is_training=self.is_training, center=True, scale=True, activation_fn=None, decay=0.9, scope='discriminator/img_bn1', reuse = False),
96 lambda: batch_norm(self.img_fc1, is_training=self.is_training, center=True, scale=True, activation_fn=None, decay=0.9, scope='discriminator/img_bn1', reuse = True))

// add update_ops before second ruse, and filter out unrelated update_ops(unrelated moving mean and variance)
126 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
127 print('update_ops')
128 for key in update_ops:
129 print(key)
131 i2t_update_extra_ops = [elem for elem in update_ops if 'text_feature/attention' not in elem.name]

// second use:
132 self.img_neg_bn1 = batch_norm(self.img_neg_fc1, is_training=self.is_training, center=True, scale=True, activation_fn=None, decay=0.9, scope='discriminator/img_bn1', reuse = True)

// weight update and dependent extra_ops(moving mean and variance)
242 self.i2t_optimizer = tf.train.GradientDescentOptimizer(learning_rate )
243 i2t_update_grads = self.i2t_optimizer.minimize(self.i2t_loss)
244
245 i2t_train_ops = [i2t_update_grads] + i2t_update_extra_ops
246 self.i2t_updates = tf.group(*i2t_train_ops)

in addition, in order to update each batch_norm only once, according to @bsautermeister 's "UPDATE_OPS in combination with reuse varscope", I add the update_ops before the second use each batch_norm, and filter out unrelated update_ops.

Hope this will be helpful for others.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting response Status - Awaiting response from author type:bug Bug
Projects
None yet
Development

No branches or pull requests