Skip to content

Commit

Permalink
Refactor BN multiple axes and test in PF
Browse files Browse the repository at this point in the history
  • Loading branch information
TakuyaNarihira committed May 22, 2019
1 parent bee31be commit db91a15
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 46 deletions.
46 changes: 15 additions & 31 deletions python/src/nnabla/normalization_functions.py
Expand Up @@ -73,8 +73,7 @@ def inv(self, y):
return transpose(transposed, self.inv_transpose_axes)


def batch_normalization(x, beta, gamma, mean, variance, axes=[1], decay_rate=0.9, eps=1e-05, batch_stat=True,
output_stat=False, n_outputs=None):
def batch_normalization(x, beta, gamma, mean, variance, axes=[1], decay_rate=0.9, eps=1e-05, batch_stat=True, output_stat=False, n_outputs=None):
r"""
Batch normalization.
Expand Down Expand Up @@ -135,47 +134,32 @@ def batch_normalization(x, beta, gamma, mean, variance, axes=[1], decay_rate=0.9
batch_stat=batch_stat,
n_outputs=n_outputs)

def transpose_and_reshape(x, axes):
transposed = transpose(x, transpose_axes)
return reshape(transposed, [rd(lambda x, y: x * y, transposed.shape[:len(axes)])] + list(
transposed.shape[len(axes):])), transposed.shape

def inverse_transpose_and_reshape(x, axes, variable_shape):
un_reshaped = reshape(
x, list(variable_shape[:len(axes)] + variable_shape[len(axes):]))
return transpose(un_reshaped, inv_transpose_axes)

def get_tranpose_args(ndim, axes):
transpose_axes = [i for i in list(
axes)] + [i for i in range(ndim) if i not in list(axes)]
inv_transpose_axes = np.argsort(transpose_axes).tolist()
return transpose_axes, inv_transpose_axes

transpose_axes, inv_transpose_axes = get_tranpose_args(len(x.shape), axes)
inp, transposed_inp_shape = transpose_and_reshape(x, axes)
beta, transposed_beta_shape = transpose_and_reshape(beta, axes)
gamma, transposed_gamma_shape = transpose_and_reshape(gamma, axes)
mean, transposed_mean_shape = transpose_and_reshape(mean, axes)
variance, transposed_variance_shape = transpose_and_reshape(variance, axes)
in_adapter = BatchNormalizationInOutAdapter(x.ndim, axes)
param_adapter = BatchNormalizationInOutAdapter(x.ndim, axes)
inp = in_adapter(x)
beta = param_adapter(beta)
gamma = param_adapter(gamma)
mean = param_adapter(mean)
variance = param_adapter(variance)
axis = x.ndim - len(axes)

if n_outputs == 1:
out = batch_normalization_base(inp, beta, gamma, mean, variance,
axes=[0],
axes=[axis],
decay_rate=decay_rate,
eps=eps,
batch_stat=batch_stat,
n_outputs=n_outputs)
return inverse_transpose_and_reshape(out, axes, transposed_inp_shape)
return in_adapter.inv(out)
out, mean, variance = batch_normalization_base(inp, beta, gamma, mean, variance,
axes=[0],
axes=[axis],
decay_rate=decay_rate,
eps=eps,
batch_stat=batch_stat,
n_outputs=n_outputs)
out = inverse_transpose_and_reshape(out, axes, transposed_inp_shape)
mean = inverse_transpose_and_reshape(mean, axes, transposed_mean_shape)
variance = inverse_transpose_and_reshape(
variance, axes, transposed_variance_shape)
out = in_adapter.inv(out)
mean = param_adapter.inv(mean)
variance = param_adapter.inv(variance)
return out, mean, variance


Expand Down
34 changes: 19 additions & 15 deletions python/test/test_parametric_functions.py
Expand Up @@ -187,23 +187,28 @@ def test_pf_convolution_execution(g_rng, inshape, outmaps, kernel, pad, stride,
assert np.allclose(b_init, b.d)


@pytest.mark.parametrize("inshape, decay_rate, eps", [
((1, 2, 1, 4), 0.9, 1e-5),
((8, 8), 0.99, 1e-3),
def _get_bn_parameter_shape(inshape, axes):
'''
Helper function which gets parameter shape of Batch Normalization.
'''
return tuple(size if i in axes else 1 for (i, size) in enumerate(inshape))


@pytest.mark.parametrize("inshape, decay_rate, eps, axes", [
((1, 2, 1, 4), 0.9, 1e-5, [3]),
((8, 8), 0.99, 1e-3, [1]),
])
@pytest.mark.parametrize('batch_stat, output_stat', [(False, False), (True, False), (True, True)])
@pytest.mark.parametrize('param_init', [None, True])
@pytest.mark.parametrize("fix_parameters", [False, True])
@pytest.mark.parametrize("rng", [None, True])
def test_pf_batch_normalization_execution(g_rng, inshape, decay_rate, eps, batch_stat, output_stat, param_init, fix_parameters, rng):

axis = 1 # Assume axes=[1]
p_shape = [1] * len(inshape)
p_shape[axis] = inshape[axis]
p_shape = tuple(p_shape)
def test_pf_batch_normalization_execution(
g_rng, inshape, axes, decay_rate, eps, batch_stat, output_stat,
param_init, fix_parameters, rng):

p_shape = _get_bn_parameter_shape(inshape, axes)
if param_init:
beta_init = np.ones(p_shape) * 1
beta_init = np.ones(p_shape)
gamma_init = np.ones(p_shape) * 2
mean_init = np.ones(p_shape) * 0.5
var_init = np.ones(p_shape) * 1.5
Expand All @@ -217,6 +222,7 @@ def test_pf_batch_normalization_execution(g_rng, inshape, decay_rate, eps, batch
x = nn.Variable.from_numpy_array(g_rng.randn(*inshape))

kw = {}
insert_if_not_default(kw, 'axes', axes, [1])
insert_if_not_default(kw, 'decay_rate', decay_rate, 0.9)
insert_if_not_default(kw, 'eps', eps, 1e-5)
insert_if_not_default(kw, 'batch_stat', batch_stat, True)
Expand Down Expand Up @@ -258,6 +264,7 @@ def test_pf_batch_normalization_execution(g_rng, inshape, decay_rate, eps, batch
args = h.parent.info.args
assert np.isclose(args['decay_rate'], decay_rate)
assert np.isclose(args['eps'], eps)
assert args['axes'] == axes
assert args['batch_stat'] == batch_stat

# Check created parameters
Expand Down Expand Up @@ -298,13 +305,10 @@ def test_pf_fused_batch_normalization_execution(
g_rng, inshape, axes, decay_rate, eps, batch_stat, nonlinearity,
output_stat, param_init, fix_parameters, with_z, rng):

p_shape = [1 for _ in inshape]
for i in range(len(axes)):
p_shape[axes[i]] = inshape[axes[i]]
p_shape = tuple(p_shape)
p_shape = _get_bn_parameter_shape(inshape, axes)

if param_init:
beta_init = np.ones(p_shape) * 1
beta_init = np.ones(p_shape)
gamma_init = np.ones(p_shape) * 2
mean_init = np.ones(p_shape) * 0.5
var_init = np.ones(p_shape) * 1.5
Expand Down

0 comments on commit db91a15

Please sign in to comment.