-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why accuracy of CNN with BatchNormLayer will change slightly after restoring? #57
Comments
please follow the lastest implementatiom ---previous answer --- class BatchNormLayer5(Layer): #
"""
The :class:`BatchNormLayer` class is a normalization layer, see ``tf.nn.batch_normalization`` and ``tf.nn.moments``.
Batch normalization on fully-connected or convolutional maps.
Parameters
-----------
layer : a :class:`Layer` instance
The `Layer` class feeding into this layer.
decay : float
A decay factor for ExponentialMovingAverage.
epsilon : float
A small float number to avoid dividing by 0.
act : activation function.
is_train : boolean
Whether train or inference.
beta_init : beta initializer
The initializer for initializing beta
gamma_init : gamma initializer
The initializer for initializing gamma
name : a string or None
An optional name to attach to this layer.
References
----------
- `Source <https://github.com/ry/tensorflow-resnet/blob/master/resnet.py>`_
- `stackoverflow <http://stackoverflow.com/questions/38312668/how-does-one-do-inference-with-batch-normalization-with-tensor-flow>`_
"""
def __init__(
self,
layer = None,
decay = 0.999,
epsilon = 0.00001,
act = tf.identity,
is_train = False,
beta_init = tf.zeros_initializer,
# gamma_init = tf.ones_initializer,
gamma_init = tf.random_normal_initializer(mean=1.0, stddev=0.002),
name ='batchnorm_layer',
):
Layer.__init__(self, name=name)
self.inputs = layer.outputs
print(" tensorlayer:Instantiate BatchNormLayer %s: decay: %f, epsilon: %f, act: %s, is_train: %s" %
(self.name, decay, epsilon, act.__name__, is_train))
x_shape = self.inputs.get_shape()
params_shape = x_shape[-1:]
from tensorflow.python.training import moving_averages
from tensorflow.python.ops import control_flow_ops
with tf.variable_scope(name) as vs:
axis = list(range(len(x_shape) - 1))
## 1. beta, gamma
beta = tf.get_variable('beta', shape=params_shape,
initializer=beta_init,
trainable=is_train)#, restore=restore)
gamma = tf.get_variable('gamma', shape=params_shape,
initializer=gamma_init, trainable=is_train,
)#restore=restore)
## 2. moving variables during training (not update by gradient!)
moving_mean = tf.get_variable('moving_mean',
params_shape,
initializer=tf.zeros_initializer,
trainable=False,)# restore=restore)
moving_variance = tf.get_variable('moving_variance',
params_shape,
initializer=tf.constant_initializer(1.),
trainable=False,)# restore=restore)
batch_mean, batch_var = tf.nn.moments(self.inputs, axis)
## 3.
# These ops will only be preformed when training.
def mean_var_with_update():
try: # TF12
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False) # if zero_debias=True, has bias
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, batch_var, decay, zero_debias=False) # if zero_debias=True, has bias
# print("TF12 moving")
except Exception as e: # TF11
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, batch_var, decay)
# print("TF11 moving")
with tf.control_dependencies([update_moving_mean, update_moving_variance]):
# return tf.identity(update_moving_mean), tf.identity(update_moving_variance)
return tf.identity(batch_mean), tf.identity(batch_var)
if is_train:
mean, var = mean_var_with_update()
else:
mean, var = (batch_mean, batch_var) # hao
normed = tf.nn.batch_normalization(
x=self.inputs,
mean=mean,
variance=var,
offset=beta,
scale=gamma,
variance_epsilon=epsilon,
name="tf_bn"
)
self.outputs = act( normed )
variables = [beta, gamma]
self.all_layers = list(layer.all_layers)
self.all_params = list(layer.all_params)
self.all_drop = dict(layer.all_drop)
self.all_layers.extend( [self.outputs] )
self.all_params.extend( variables ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi everyone, I found a interesting thing but I don't know the reason.
When I restore a CNN network with
BatchNormLayer
from npz file, the accuracy is slightly different, my code as attached. Hope someone can help me, thanks in advance.is_test_only = False
: (note: I setn_epoch=1
and use only a small part of trainig data for fast debugging.)is_test_only = True
:variables = tf.GraphKeys.GLOBAL_VARIABLES
inBatchNormLayer
(layers.py line 1825), but I found it will get 8 parameters ... are you sure it is correct? I tried the following setting, but the accuracy still has slightly difference ...@boscotsang as I discuss with you in pull/42, the testing and training cost are all drop normally, but I really don't understand why the accuracies are different after restoring and what variables should be included in the
BatchNormLayer
.My code
environment: TensorFlow 0.12 and TensorLayer 1.3.0
The text was updated successfully, but these errors were encountered: