Skip to content

Commit

Permalink
Merge pull request #427 from sony/feature/20190328-add-normalizations
Browse files Browse the repository at this point in the history
Add normalizations
  • Loading branch information
TakuyaNarihira committed May 16, 2019
2 parents b8bbebe + 1e4c408 commit 44e2667
Show file tree
Hide file tree
Showing 10 changed files with 1,254 additions and 166 deletions.
4 changes: 4 additions & 0 deletions doc/python/api/function.rst
Expand Up @@ -85,6 +85,10 @@ Normalization
.. autofunction:: clip_grad_by_value
.. autofunction:: clip_by_norm
.. autofunction:: clip_grad_by_norm
.. autofunction:: layer_normalization
.. autofunction:: instance_normalization
.. autofunction:: group_normalization
.. autofunction:: weight_standardization


Reduction
Expand Down
3 changes: 3 additions & 0 deletions doc/python/api/parametric_function.rst
Expand Up @@ -73,6 +73,9 @@ Here is the list of parametric functions.
.. autofunction:: batch_normalization
.. autofunction:: sync_batch_normalization
.. autofunction:: mean_subtraction
.. autofunction:: layer_normalization
.. autofunction:: instance_normalization
.. autofunction:: group_normalization

.. autofunction:: rnn
.. autofunction:: lstm
Expand Down
166 changes: 2 additions & 164 deletions python/src/nnabla/functions.py
Expand Up @@ -14,9 +14,10 @@

from __future__ import absolute_import
from .function_bases import *
from six.moves import reduce as rd

import nnabla as nn
import numpy as np
from .normalization_functions import *


def sum(x, axis=None, keepdims=False):
Expand Down Expand Up @@ -267,169 +268,6 @@ def slice(ctx, x, start=None, stop=None, step=None, n_outputs=-1, outputs=None):
return slice_base(x, start, stop, step, n_outputs, outputs)


def batch_normalization(x, beta, gamma, mean, variance, axes=[1], decay_rate=0.9, eps=1e-05, batch_stat=True, output_stat=False, n_outputs=None):
r"""
Batch normalization.
.. math::
\begin{eqnarray}
\mu &=& \frac{1}{M} \sum x_i \\
\sigma^2 &=& \frac{1}{M} \sum \left(x_i - \mu\right)^2 \\
\hat{x}_i &=& \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\
y_i &=& \hat{x}_i \gamma + \beta.
\end{eqnarray}
At testing time, the mean and variance values used are those that were computed during training by moving average.
References:
* `Ioffe and Szegedy, Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.
<https://arxiv.org/abs/1502.03167>`_
Args:
x(~nnabla.Variable): N-D array of input.
beta(~nnabla.Variable): N-D array of beta which is learned.
gamma(~nnabla.Variable): N-D array of gamma which is learned.
mean(~nnabla.Variable): N-D array of running mean (modified during forward execution).
variance(~nnabla.Variable): N-D array of running variance (modified during forward execution).
axes(repeated int64): Axes mean and variance are taken.
decay_rate(float): Decay rate of running mean and variance.
eps(float): Tiny value to avoid zero division by std.
batch_stat(bool): Use mini-batch statistics rather than running ones.
output_stat(bool): It true, the batch statistics of mean and variance,
will be returned as Variables. They are also differentiable.
Returns:
Returns batch normalization output as :obj:`~nnabla.Variable`.
If ``output_stat=True``, it also returns the mean and variance
of the mini-batch
* :obj:`~nnabla.Variable`: Output of the batch normalization
* :obj:`~nnabla.Variable`: Mean (if ``output_stat=True`)
* :obj:`~nnabla.Variable`: Variance (if ``output_stat=True`)
See Also:
``nnabla.function_bases.batch_normalization``.
"""
from .function_bases import batch_normalization as batch_normalization_base
n_outputs = 3 if output_stat else 1
assert batch_stat or (not output_stat)
if batch_stat and (mean.parent or variance.parent) is not None:
raise ValueError(
"if batch_stat is True, mean and variable must not have a parent function")

if len(axes) == 1:
return batch_normalization_base(x, beta, gamma, mean, variance,
axes=axes,
decay_rate=decay_rate,
eps=eps,
batch_stat=batch_stat,
n_outputs=n_outputs)

def transpose_and_reshape(x, axes):
transposed = transpose(x, transpose_axes)
return reshape(transposed, [rd(lambda x, y: x * y, transposed.shape[:len(axes)])] + list(
transposed.shape[len(axes):])), transposed.shape

def inverse_transpose_and_reshape(x, axes, variable_shape):
un_reshaped = reshape(
x, list(variable_shape[:len(axes)] + variable_shape[len(axes):]))
return transpose(un_reshaped, inv_transpose_axes)

def get_tranpose_args(ndim, axes):
transpose_axes = [i for i in list(
axes)] + [i for i in range(ndim) if i not in list(axes)]
inv_transpose_axes = np.argsort(transpose_axes).tolist()
return transpose_axes, inv_transpose_axes

transpose_axes, inv_transpose_axes = get_tranpose_args(len(x.shape), axes)
inp, transposed_inp_shape = transpose_and_reshape(x, axes)
beta, transposed_beta_shape = transpose_and_reshape(beta, axes)
gamma, transposed_gamma_shape = transpose_and_reshape(gamma, axes)
mean, transposed_mean_shape = transpose_and_reshape(mean, axes)
variance, transposed_variance_shape = transpose_and_reshape(variance, axes)

if n_outputs == 1:
out = batch_normalization_base(inp, beta, gamma, mean, variance,
axes=[0],
decay_rate=decay_rate,
eps=eps,
batch_stat=batch_stat,
n_outputs=n_outputs)
return inverse_transpose_and_reshape(out, axes, transposed_inp_shape)
out, mean, variance = batch_normalization_base(inp, beta, gamma, mean, variance,
axes=[0],
decay_rate=decay_rate,
eps=eps,
batch_stat=batch_stat,
n_outputs=n_outputs)
out = inverse_transpose_and_reshape(out, axes, transposed_inp_shape)
mean = inverse_transpose_and_reshape(mean, axes, transposed_mean_shape)
variance = inverse_transpose_and_reshape(
variance, axes, transposed_variance_shape)
return out, mean, variance


def sync_batch_normalization(x, beta, gamma, mean, variance, comm, group="world", axes=[1], decay_rate=0.9, eps=1e-05, batch_stat=True, output_stat=False, n_outputs=None):
r"""
Synchronized batch normalization.
For some tasks (e.g., semantic segmentation), batch size will be too small and BatchNormalization layer might not work well.
SyncBatchNorlization layer solves these problems by synchronizing batch stats (mean and var) between multiple processes.
.. math::
\begin{eqnarray}
\mu &=& \frac{1}{M} \sum x_i \\
\sigma^2 &=& \frac{1}{M} \left(\sum x_i - \mu\right)^2 \\
\hat{x}_i &=& \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\
y_i &=& \hat{x}_i \gamma + \beta.
\end{eqnarray}
References:
* Implementing Synchronized Multi-GPU Batch Normalization https://hangzhang.org/PyTorch-Encoding/notes/syncbn.html
Args:
x(~nnabla.Variable): N-D array of input.
beta(~nnabla.Variable): N-D array of beta which is learned.
gamma(~nnabla.Variable): N-D array of gamma which is learned.
mean(~nnabla.Variable): N-D array of running mean (modified during forward execution).
variance(~nnabla.Variable): N-D array of running variance (modified during forward execution).
comm(~nnabla.communicators.Communicator): The communicator
group(string): The name of the communicator group
axes(repeated int64): Axes mean and variance are taken.
decay_rate(float): Decay rate of running mean and variance.
eps(float): Tiny value to avoid zero division by std.
batch_stat(bool): Use mini-batch statistics rather than running ones.
output_stat(bool): It true, the batch statistics of mean and variance,
will be returned as Variables. They are also differentiable.
Returns:
Returns batch normalization output as :obj:`~nnabla.Variable`.
If ``output_stat=True``, it also returns the mean and variance
of the mini-batch
* :obj:`~nnabla.Variable`: Output of the batch normalization
* :obj:`~nnabla.Variable`: Mean (if ``output_stat=True`)
* :obj:`~nnabla.Variable`: Variance (if ``output_stat=True`)
See Also:
``nnabla.function_bases.batch_normalization``.
"""
from .function_bases import sync_batch_normalization as batch_normalization_base
n_outputs = 3 if output_stat else 1
return batch_normalization_base(x, beta, gamma, mean, variance,
comm, group=group,
axes=axes,
decay_rate=decay_rate,
eps=eps,
batch_stat=batch_stat,
n_outputs=n_outputs)


def mean_subtraction(x, mean, t, base_axis=1, update_running_mean=True):
r"""
It subtracts the mean of the elements of the input array,
Expand Down

0 comments on commit 44e2667

Please sign in to comment.