Skip to content

Commit

Permalink
Merge pull request #155 from sony/hotfix/20180606-mean-subtraction
Browse files Browse the repository at this point in the history
change a variable name of mean_subtraction, "rmean" to "mean"
  • Loading branch information
TE-YoshiyukiKobayashi committed Jun 7, 2018
2 parents 08c3c0c + 06670eb commit 97669f9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions python/src/nnabla/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def batch_normalization(x, beta, gamma, mean, variance, axes=[1], decay_rate=0.9
n_outputs=n_outputs)


def mean_subtraction(x, rmean, t, base_axis=1, update_running_mean=True):
def mean_subtraction(x, mean, t, base_axis=1, update_running_mean=True):
r"""
It subtracts the mean of the elements of the input array,
and normalizes it to :math:`0`. Preprocessing arrays with this function has the effect of improving accuracy
Expand All @@ -244,7 +244,7 @@ def mean_subtraction(x, rmean, t, base_axis=1, update_running_mean=True):
Args:
x(~nnabla.Variable): N-D array of input.
rmean(~nnabla.Variable): N-D array of running mean (modified during forward execution).
mean(~nnabla.Variable): N-D array of running mean (modified during forward execution).
t(~nnabla.Variable): Scalar of num of iteration of running mean (modified during forward execution).
base_axis(int): Base axis of Mean Subtraction operation. Dimensions up to base_axis is treated as sample dimension.
[default=``1``]
Expand All @@ -259,7 +259,7 @@ def mean_subtraction(x, rmean, t, base_axis=1, update_running_mean=True):
"""
from .function_bases import mean_subtraction as mean_subtraction_base
return mean_subtraction_base(x, rmean, t,
return mean_subtraction_base(x, mean, t,
base_axis=base_axis,
update_running_mean=update_running_mean)

Expand Down
6 changes: 3 additions & 3 deletions python/src/nnabla/parametric_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,11 +1297,11 @@ def mean_subtraction(inp, base_axis=1, update_running_mean=True, fix_parameters=
"""
assert len(inp.shape) >= base_axis
shape = inp.shape[base_axis:]
rmean = get_parameter_or_create(
"rmean", shape, ConstantInitializer(0), False)
mean = get_parameter_or_create(
"mean", shape, ConstantInitializer(0), False)
t = get_parameter_or_create(
"t", (1, ), ConstantInitializer(0), False)
return F.mean_subtraction(inp, rmean, t, base_axis=base_axis, update_running_mean=update_running_mean)
return F.mean_subtraction(inp, mean, t, base_axis=base_axis, update_running_mean=update_running_mean)


@parametric_function_api("embed")
Expand Down

0 comments on commit 97669f9

Please sign in to comment.