Skip to content

Commit

Permalink
fix batch normalization (see #7)
Browse files Browse the repository at this point in the history
  • Loading branch information
aymericdamien committed Apr 24, 2016
1 parent d0cd9d8 commit dc56b7e
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions tflearn/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from __future__ import division, print_function, absolute_import

import tensorflow as tf
from tensorflow.python.training import moving_averages

import tflearn
from .. import utils
from .. import variables as vs


def batch_normalization(incoming, beta=0.0, gamma=1.0, epsilon=1e-5,
decay=0.999, trainable=True, restore=True,
decay=0.9, trainable=True, restore=True,
stddev=0.002, name="BatchNormalization"):
""" Batch Normalization.
Expand Down Expand Up @@ -55,32 +56,45 @@ def batch_normalization(incoming, beta=0.0, gamma=1.0, epsilon=1e-5,
tf.add_to_collection(tf.GraphKeys.EXCL_RESTORE_VARS, beta)
tf.add_to_collection(tf.GraphKeys.EXCL_RESTORE_VARS, gamma)

ema = tf.train.ExponentialMovingAverage(decay=decay)
axis = [i for i in range(input_ndim - 1)]
if len(axis) < 1: axis = [0]
batch_mean, batch_var = tf.nn.moments(incoming, axis,
name='moments')
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(
batch_var)

axis = list(range(input_ndim - 1))
moving_mean = vs.variable(scope + 'moving_mean',
input_shape[-1:],
initializer=tf.zeros_initializer,
trainable=False,
restore=restore)
moving_variance = vs.variable(scope + 'moving_variance',
input_shape[-1:],
initializer=tf.ones_initializer,
trainable=False,
restore=restore)

# Define a function to update mean and variance
def update_mean_var():
with tf.control_dependencies([ema_apply_op]):
return tf.identity(ema_mean), tf.identity(ema_var)

mean, variance = tf.nn.moments(incoming, axis)
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, decay)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay)
with tf.control_dependencies(
[update_moving_mean, update_moving_variance]):
return tf.identity(mean), tf.identity(variance)

# Retrieve variable managing training mode
is_training = tflearn.get_training_mode()
mean, var = tf.python.control_flow_ops.cond(
is_training, update_mean_var, lambda: (ema_mean, ema_var))
is_training, update_mean_var, lambda: (moving_mean, moving_variance))

try:
inference = tf.nn.batch_normalization(
incoming, mean, var, beta, gamma, epsilon)
inference.set_shape(input_shape)
# Fix for old Tensorflow
except Exception as e:
inference = tf.nn.batch_norm_with_global_normalization(
incoming, mean, var, beta, gamma, epsilon,
scale_after_normalization=True,
)
inference.set_shape(input_shape)

# Add attributes for easy access
inference.scope = scope
Expand Down

0 comments on commit dc56b7e

Please sign in to comment.