Skip to content
Permalink
Browse files

Merge pull request #418 from sony/feature/20190410-synced-batch-norma…

…lization

Synchronized Batch Normalization
  • Loading branch information...
TE-TakuyaNarihira committed May 8, 2019
2 parents 2e7b127 + abc21e1 commit 2caae81e29f884a782a376c7333045c39680eb0f
@@ -100,6 +100,9 @@ Sinc:
BatchNormalization:
float: [float]
half: [Half]
SyncBatchNormalization:
float: [float]
half: [Half]
MeanSubtraction:
float: [float]
half: [Half]
@@ -1200,6 +1200,67 @@ Normalization:
function_ids:
iIffB: 22
c_runtime: support
SyncBatchNormalization:
snake_name: sync_batch_normalization
doc: |2
Synchronized Batch Normalization:
For some tasks (e.g., semantic segmentation), batch size will be too small and BatchNormalization layer might not work well.
SyncBatchNorlization layer solves these problems by synchronizing batch stats (mean and var) between multiple processes.
.. math::
\begin{eqnarray}
\mu &=& \frac{1}{M} \sum x_i \\
\sigma^2 &=& \frac{1}{M} \left(\sum x_i - \mu\right)^2 \\
\hat{x}_i &=& \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\
y_i &=& \hat{x}_i \gamma + \beta.
\end{eqnarray}
References:
* Implementing Synchronized Multi-GPU Batch Normalization https://hangzhang.org/PyTorch-Encoding/notes/syncbn.html
inputs:
x:
doc: N-D array of input.
beta:
doc: N-D array of beta which is learned.
gamma:
doc: N-D array of gamma which is learned.
mean:
doc: N-D array of running mean (modified during forward execution).
variance:
doc: N-D array of running variance (modified during forward execution).
arguments:
comm:
doc: The communicator
type: Communicator
group:
doc: The name of the communicator group
type: string
default: world
axes:
doc: Axes mean and variance are taken.
type: repeated int64
default: (1,)
decay_rate:
doc: Decay rate of running mean and variance.
type: float
default: '0.9'
eps:
doc: Tiny value to avoid zero division by std.
type: float
default: 1e-05
batch_stat:
doc: Use mini-batch statistics rather than running ones.
type: bool
default: 'True'
outputs:
y:
doc: N-D array
c_runtime: not support
function_ids:
CiiIffB: 263
MeanSubtraction:
snake_name: mean_subtraction
doc: |2
@@ -36,6 +36,8 @@ def type_to_pack_format(typestring):
fmt = 'iI'
elif typestring == 'string':
fmt = 'i'
elif typestring == 'Communicator':
fmt = 'C'
return fmt

def generate_cpp_utils(function_info):
@@ -13,11 +13,12 @@
# limitations under the License.

type_from_proto = {
'Shape': {'cpp': 'const vector<int> &', 'cpp_var': 'const vector<int>', 'pyx': 'const vector[int]&'},
'int64': {'cpp': 'int', 'cpp_var': 'int', 'pyx': 'int'},
'bool': {'cpp': 'bool', 'cpp_var': 'bool', 'pyx': 'cpp_bool'},
'float': {'cpp': 'float', 'cpp_var': 'float', 'pyx': 'float'},
'double': {'cpp': 'double', 'cpp_var': 'double', 'pyx': 'double'},
'repeated int64': {'cpp': 'const vector<int> &', 'cpp_var': 'const vector<int>', 'pyx': 'const vector[int]&'},
'string': {'cpp': 'const string &', 'cpp_var': 'const string', 'pyx': 'const string&'}
'Shape': {'cpp': 'const vector<int> &', 'cpp_var': 'const vector<int>', 'pyx': 'const vector[int]&', 'pxd': 'const vector[int]&'},
'int64': {'cpp': 'int', 'cpp_var': 'int', 'pyx': 'int', 'pxd': 'int'},
'bool': {'cpp': 'bool', 'cpp_var': 'bool', 'pyx': 'cpp_bool', 'pxd': 'cpp_bool'},
'float': {'cpp': 'float', 'cpp_var': 'float', 'pyx': 'float', 'pxd': 'float'},
'double': {'cpp': 'double', 'cpp_var': 'double', 'pyx': 'double', 'pxd': 'double'},
'repeated int64': {'cpp': 'const vector<int> &', 'cpp_var': 'const vector<int>', 'pyx': 'const vector[int]&', 'pxd': 'const vector[int]&'},
'string': {'cpp': 'const string &', 'cpp_var': 'const string', 'pyx': 'const string&', 'pxd': 'const string&'},
'Communicator': {'cpp': 'const shared_ptr<Communicator> &', 'cpp_var': 'shared_ptr<const Communicator>', 'pyx': 'Communicator', 'pxd': 'shared_ptr[CCommunicator]&'}
}
@@ -79,6 +79,7 @@ Normalization
-------------

.. autofunction:: batch_normalization
.. autofunction:: sync_batch_normalization
.. autofunction:: mean_subtraction
.. autofunction:: clip_by_value
.. autofunction:: clip_grad_by_value
@@ -71,6 +71,7 @@ Here is the list of parametric functions.
.. autofunction:: deconvolution
.. autofunction:: depthwise_deconvolution
.. autofunction:: batch_normalization
.. autofunction:: sync_batch_normalization
.. autofunction:: mean_subtraction

.. autofunction:: rnn
@@ -0,0 +1,106 @@
// Copyright (c) 2017 Sony Corporation. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/** Batch Normalization
*/
#ifndef __NBLA_FUNCTION_SYNC_BATCHNORM_HPP__
#define __NBLA_FUNCTION_SYNC_BATCHNORM_HPP__

#include <nbla/communicator.hpp>
#include <nbla/cpu.hpp>
#include <nbla/function.hpp>
#include <nbla/function/batch_normalization.hpp>
#include <nbla/function_registry.hpp>

#include <vector>

using std::vector;

namespace nbla {

NBLA_REGISTER_FUNCTION_HEADER(SyncBatchNormalization,
const std::shared_ptr<Communicator> &,
const std::string &, const vector<int> &, float,
float, bool);

/** Batch normalization with sync between other processes at training time
defined as
@f[
\begin{array}{lcl}
\mu &=& \frac{1}{M} \sum x_i\\
\sigma^2 &=& \frac{1}{M} \left(\sum x_i - \mu\right)^2\\
\hat{x}_i &=& \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\
y_i &=& \hat{x}_i \gamma + \beta.
\end{array}
@f]
Inputs:
- N-D array of input.
- N-D array of beta which is learned.
- N-D array of gamma which is learned.
- N-D array of running mean (modified during forward execution).
- N-D array of running variance (modified during forward execution).
Outputs (1 or 3):
- N-D array.
- (Optional) N-D array of batch mean.
- (Optional) N-D array of batch variance.
@tparam T Data type for computation.
@param comm The communicator
@param group The name of the communicator group
@param axes Axes mean and variance are taken.
@param decay_rate Decay rate of running mean and variance.
@param eps Tiny value to avoid zero division by std.
@sa Implementing Synchronized Multi-GPU Batch Normalization
https://hangzhang.org/PyTorch-Encoding/notes/syncbn.html
\ingroup FunctionImplGrp
*/
template <typename T>
class SyncBatchNormalization : public BatchNormalization<T> {
protected:
std::shared_ptr<Communicator> comm_;
std::string group_;
size_t num_processes_;

public:
SyncBatchNormalization(const Context &ctx,
const std::shared_ptr<Communicator> &comm,
const std::string &group, const vector<int> axes,
float decay_rate, float eps, bool batch_stat)
: BatchNormalization<T>(ctx, axes, decay_rate, eps, batch_stat),
comm_(comm), group_(group) {}
virtual ~SyncBatchNormalization() {}
virtual shared_ptr<Function> copy() const override {
return create_SyncBatchNormalization(this->ctx_, this->comm_, this->group_,
this->axes_, this->decay_rate_,
this->eps_, this->batch_stat_);
}
virtual string name() override { return "SyncBatchNormalization"; }

protected:
NBLA_API virtual void setup_impl(const Variables &inputs,
const Variables &outputs) override;
NBLA_API virtual void forward_impl_batch(const Variables &inputs,
const Variables &outputs) override;
NBLA_API virtual void backward_impl_batch(const Variables &inputs,
const Variables &outputs,
const vector<bool> &propagate_down,
const vector<bool> &accum) override;
};
}
#endif
@@ -26,6 +26,7 @@
#include <nbla/computation_graph/variable.hpp>
#include <nbla/context.hpp>
#include <nbla/global_context.hpp>
#include <nbla/communicator.hpp>

namespace nbla {
namespace functions {
@@ -142,4 +143,4 @@ NBLA_API nbla::CgVariablePtr operator+(const float &a, const nbla::CgVariablePtr
NBLA_API nbla::CgVariablePtr operator*(const float &a, const nbla::CgVariablePtr &b);
NBLA_API nbla::CgVariablePtr operator-(const float &a, const nbla::CgVariablePtr &b);
NBLA_API nbla::CgVariablePtr operator/(const float &a, const nbla::CgVariablePtr &b);
#endif
#endif
@@ -27,6 +27,9 @@ from libcpp cimport bool as cpp_bool
cimport _variable
from _variable cimport CVariable, CContext, dtypes, VariablePtr

cimport communicator
from communicator cimport CCommunicator

ctypedef vector[CVariable*] Variables

cdef extern from "nbla/function.hpp" namespace "nbla":
@@ -83,7 +86,7 @@ from utils.type_conv import type_from_proto
%for name, func in function_info.items():
cdef extern from "nbla/function/${func['snake_name']}.hpp" namespace "nbla":
shared_ptr[CFunction] create_${name}(
const CContext &${''.join([', {} {}'.format(type_from_proto[v['type']]['pyx'], k) for k, v in func.get('arguments', {}).items()])}) except +
const CContext &${''.join([', {} {}'.format(type_from_proto[v['type']]['pxd'], k) for k, v in func.get('arguments', {}).items()])}) except +
%endfor

cdef class Function:
@@ -34,6 +34,7 @@ from _nd_array cimport NdArray
from _nd_array import NdArray
from _imperative cimport *
from _computation_graph cimport connect
from communicator cimport Communicator
# Numpy
import numpy as np
cimport numpy as np
@@ -559,12 +560,24 @@ from utils.type_conv import type_from_proto
def ${name}(CContext ctx${''.join([', {} {}'.format(type_from_proto[v['type']]['pyx'], k) for k, v in func.get('arguments', {}).items()])}):
info = Info()
info.args = {}
%for arg in func.get('arguments', {}).keys():
%for arg, v in func.get('arguments', {}).items():
%if v['type'] == 'Communicator':
info.args['${arg}'] = '<Communicator>'
%else:
info.args['${arg}'] = ${arg}
%endif
%endfor
info.type_name = '${name}'
info.tags = {}
f = Function.create(create_${name}(ctx${''.join([', %s' % k for k in func.get('arguments', {}).keys()])}), info)
f = Function.create(create_${name}(ctx
%for k, v in func.get('arguments', {}).items():
%if v['type'] == 'Communicator':
, ${k}.communicator
%else:
, ${k}
%endif
%endfor
), info)
return f
%endfor

@@ -380,6 +380,64 @@ def get_tranpose_args(ndim, axes):
return out, mean, variance


def sync_batch_normalization(x, beta, gamma, mean, variance, comm, group="world", axes=[1], decay_rate=0.9, eps=1e-05, batch_stat=True, output_stat=False, n_outputs=None):
r"""
Synchronized batch normalization.
For some tasks (e.g., semantic segmentation), batch size will be too small and BatchNormalization layer might not work well.
SyncBatchNorlization layer solves these problems by synchronizing batch stats (mean and var) between multiple processes.
.. math::
\begin{eqnarray}
\mu &=& \frac{1}{M} \sum x_i \\
\sigma^2 &=& \frac{1}{M} \left(\sum x_i - \mu\right)^2 \\
\hat{x}_i &=& \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\
y_i &=& \hat{x}_i \gamma + \beta.
\end{eqnarray}
References:
* Implementing Synchronized Multi-GPU Batch Normalization https://hangzhang.org/PyTorch-Encoding/notes/syncbn.html
Args:
x(~nnabla.Variable): N-D array of input.
beta(~nnabla.Variable): N-D array of beta which is learned.
gamma(~nnabla.Variable): N-D array of gamma which is learned.
mean(~nnabla.Variable): N-D array of running mean (modified during forward execution).
variance(~nnabla.Variable): N-D array of running variance (modified during forward execution).
comm(~nnabla.communicators.Communicator): The communicator
group(string): The name of the communicator group
axes(repeated int64): Axes mean and variance are taken.
decay_rate(float): Decay rate of running mean and variance.
eps(float): Tiny value to avoid zero division by std.
batch_stat(bool): Use mini-batch statistics rather than running ones.
output_stat(bool): It true, the batch statistics of mean and variance,
will be returned as Variables. They are also differentiable.
Returns:
Returns batch normalization output as :obj:`~nnabla.Variable`.
If ``output_stat=True``, it also returns the mean and variance
of the mini-batch
* :obj:`~nnabla.Variable`: Output of the batch normalization
* :obj:`~nnabla.Variable`: Mean (if ``output_stat=True`)
* :obj:`~nnabla.Variable`: Variance (if ``output_stat=True`)
See Also:
``nnabla.function_bases.batch_normalization``.
"""
from .function_bases import sync_batch_normalization as batch_normalization_base
n_outputs = 3 if output_stat else 1
return batch_normalization_base(x, beta, gamma, mean, variance,
comm, group=group,
axes=axes,
decay_rate=decay_rate,
eps=eps,
batch_stat=batch_stat,
n_outputs=n_outputs)


def mean_subtraction(x, mean, t, base_axis=1, update_running_mean=True):
r"""
It subtracts the mean of the elements of the input array,

0 comments on commit 2caae81

Please sign in to comment.
You can’t perform that action at this time.