Skip to content

Commit

Permalink
standarize arg names in LayerNorm/InstanceNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Mar 8, 2020
1 parent 2ff9a5f commit 07e464d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 25 deletions.
2 changes: 2 additions & 0 deletions tensorpack/input_source/input_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ def _setup(self, input_signature):
self._spec = input_signature
if self._dataset is not None:
types = self._dataset.output_types
if len(types) == 1:
types = (types,)
spec_types = tuple(k.dtype for k in input_signature)
assert len(types) == len(spec_types), \
"Dataset and input signature have different length! {} != {}".format(
Expand Down
4 changes: 2 additions & 2 deletions tensorpack/models/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def internal_update_bn_ema(xn, batch_mean, batch_var,
'use_local_stat': 'training'
})
@disable_autograph()
def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
center=True, scale=True,
beta_initializer=tf.zeros_initializer(),
gamma_initializer=tf.ones_initializer(),
Expand Down Expand Up @@ -376,7 +376,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
'gamma_init': 'gamma_initializer',
'decay': 'momentum'
})
def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5,
def BatchRenorm(x, rmax, dmax, *, momentum=0.9, epsilon=1e-5,
center=True, scale=True, gamma_initializer=None,
data_format='channels_last'):
"""
Expand Down
68 changes: 45 additions & 23 deletions tensorpack/models/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,34 @@
from ..compat import tfv1 as tf # this should be avoided first in model code

from ..utils.argtools import get_data_format
from ..utils.develop import log_deprecated
from .common import VariableHolder, layer_register
from .tflayer import convert_to_tflayer_args

__all__ = ['LayerNorm', 'InstanceNorm']


@layer_register()
@convert_to_tflayer_args(
args_names=[],
name_mapping={
'use_bias': 'center',
'use_scale': 'scale',
'gamma_init': 'gamma_initializer',
})
def LayerNorm(
x, epsilon=1e-5,
use_bias=True, use_scale=True,
gamma_init=None, data_format='channels_last'):
x, epsilon=1e-5, *,
center=True, scale=True,
gamma_initializer=tf.ones_initializer(),
data_format='channels_last'):
"""
Layer Normalization layer, as described in the paper:
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
Args:
x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format.
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
center, scale (bool): whether to use the extra affine transformation or not.
"""
data_format = get_data_format(data_format, keras_mode=False)
shape = x.get_shape().as_list()
Expand All @@ -40,31 +50,36 @@ def LayerNorm(
if ndims == 2:
new_shape = [1, chan]

if use_bias:
if center:
beta = tf.get_variable('beta', [chan], initializer=tf.constant_initializer())
beta = tf.reshape(beta, new_shape)
else:
beta = tf.zeros([1] * ndims, name='beta')
if use_scale:
if gamma_init is None:
gamma_init = tf.constant_initializer(1.0)
gamma = tf.get_variable('gamma', [chan], initializer=gamma_init)
if scale:
gamma = tf.get_variable('gamma', [chan], initializer=gamma_initializer)
gamma = tf.reshape(gamma, new_shape)
else:
gamma = tf.ones([1] * ndims, name='gamma')

ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')

vh = ret.variables = VariableHolder()
if use_scale:
if scale:
vh.gamma = gamma
if use_bias:
if center:
vh.beta = beta
return ret


@layer_register()
def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format='channels_last'):
@convert_to_tflayer_args(
args_names=[],
name_mapping={
'gamma_init': 'gamma_initializer',
})
def InstanceNorm(x, epsilon=1e-5, *, center=True, scale=True,
gamma_initializer=tf.ones_initializer(),
data_format='channels_last', use_affine=None):
"""
Instance Normalization, as in the paper:
`Instance Normalization: The Missing Ingredient for Fast Stylization
Expand All @@ -73,12 +88,17 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
Args:
x (tf.Tensor): a 4D tensor.
epsilon (float): avoid divide-by-zero
use_affine (bool): whether to apply learnable affine transformation
center, scale (bool): whether to use the extra affine transformation or not.
use_affine: deprecated. Don't use.
"""
data_format = get_data_format(data_format, keras_mode=False)
shape = x.get_shape().as_list()
assert len(shape) == 4, "Input of InstanceNorm has to be 4D!"

if use_affine is not None:
log_deprecated("InstanceNorm(use_affine=)", "Use center= or scale= instead!", "2020-06-01")
center = scale = use_affine

if data_format == 'NHWC':
axis = [1, 2]
ch = shape[3]
Expand All @@ -91,19 +111,21 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=

mean, var = tf.nn.moments(x, axis, keep_dims=True)

if not use_affine:
return tf.divide(x - mean, tf.sqrt(var + epsilon), name='output')

beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
beta = tf.reshape(beta, new_shape)
if gamma_init is None:
gamma_init = tf.constant_initializer(1.0)
gamma = tf.get_variable('gamma', [ch], initializer=gamma_init)
gamma = tf.reshape(gamma, new_shape)
if center:
beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
beta = tf.reshape(beta, new_shape)
else:
beta = tf.zeros([1, 1, 1, 1], name='beta', dtype=x.dtype)
if scale:
gamma = tf.get_variable('gamma', [ch], initializer=gamma_initializer)
gamma = tf.reshape(gamma, new_shape)
else:
gamma = tf.ones([1, 1, 1, 1], name='gamma', dtype=x.dtype)
ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')

vh = ret.variables = VariableHolder()
if use_affine:
if scale:
vh.gamma = gamma
if center:
vh.beta = beta
return ret

0 comments on commit 07e464d

Please sign in to comment.