Skip to content

Commit

Permalink
Merge pull request #293 from sony/feature/20181107-pf-batch-norm-init…
Browse files Browse the repository at this point in the history
…ializers

 [python] Initialization option in batch normalization parametric function
  • Loading branch information
AkioHayakawa-sony committed Nov 7, 2018
2 parents 9294509 + fcf4067 commit 43ab231
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions python/src/nnabla/parametric_functions.py
Expand Up @@ -1312,7 +1312,8 @@ def depthwise_deconvolution(inp, kernel, pad=None, stride=None, dilation=None,
('var', 'Moving average of batch variance', '<see above>', False),
])
def batch_normalization(inp, axes=[1], decay_rate=0.9, eps=1e-5,
batch_stat=True, output_stat=False, fix_parameters=False):
batch_stat=True, output_stat=False, fix_parameters=False,
param_init=None):
"""
Batch normalization layer.
Expand Down Expand Up @@ -1341,6 +1342,12 @@ def batch_normalization(inp, axes=[1], decay_rate=0.9, eps=1e-5,
batch_stat (bool): Use mini-batch statistics rather than running ones.
output_stat (bool): Output batch mean and variance.
fix_parameters (bool): When set to `True`, the beta and gamma will not be updated.
param_init (dict):
Parameter initializers can be set with a dict. A key of the dict must
be ``'beta'``, ``'gamma'``, ``'mean'`` or ``'var'``.
A value of the dict must be an :obj:`~nnabla.initializer.Initializer`
or a :obj:`numpy.ndarray`.
E.g. ``{'beta': ConstantIntializer(0), 'gamma': np.ones(gamma_shape) * 2}``.
Returns:
:class:`~nnabla.Variable`: N-D array.
Expand All @@ -1359,14 +1366,21 @@ def batch_normalization(inp, axes=[1], decay_rate=0.9, eps=1e-5,
assert len(axes) == 1
shape_stat = [1 for _ in inp.shape]
shape_stat[axes[0]] = inp.shape[axes[0]]

if param_init is None:
param_init = {}
beta_init = param_init.get('beta', ConstantInitializer(0))
gamma_init = param_init.get('gamma', ConstantInitializer(1))
mean_init = param_init.get('mean', ConstantInitializer(0))
var_init = param_init.get('var', ConstantInitializer(0))
beta = get_parameter_or_create(
"beta", shape_stat, ConstantInitializer(0), True, not fix_parameters)
"beta", shape_stat, beta_init, True, not fix_parameters)
gamma = get_parameter_or_create(
"gamma", shape_stat, ConstantInitializer(1), True, not fix_parameters)
"gamma", shape_stat, gamma_init, True, not fix_parameters)
mean = get_parameter_or_create(
"mean", shape_stat, ConstantInitializer(0), False)
"mean", shape_stat, mean_init, False)
var = get_parameter_or_create(
"var", shape_stat, ConstantInitializer(0), False)
"var", shape_stat, var_init, False)
return F.batch_normalization(inp, beta, gamma, mean, var, axes,
decay_rate, eps, batch_stat, output_stat)

Expand Down

0 comments on commit 43ab231

Please sign in to comment.