Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change default initial weight value of PReLU #474

Merged
merged 3 commits into from Jun 25, 2019
Merged
Changes from 1 commit
Commits
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.

Always

Just for now

Prev

[test] Add PReLU parametric function test

  • Loading branch information...
TE-TakuyaNarihira committed Jun 21, 2019
commit b49d65fbf0a246aaa1b1c6105751fe17219a54cf
@@ -29,7 +29,7 @@ def g_rng(request):

def process_param_init(p_init, shape, rng):
if p_init is True:
p_init = rng.randn(*shape)
p_init = np.asarray(rng.randn(*shape))
return p_init


@@ -1030,4 +1030,47 @@ def test_pf_lstm_execution(g_rng, inshape, w0_init, w_init, b_init, num_layers,
assert np.allclose(b_init, b.d)


@pytest.mark.parametrize("inshape", [(8, 2, 2, 2), (16, 1, 8)])
@pytest.mark.parametrize("base_axis", [1, 2])
@pytest.mark.parametrize("shared", [False, True])
@pytest.mark.parametrize("slope_init", [None, I.NormalInitializer(), True])
@pytest.mark.parametrize("fix_parameters", [False, True])
def test_pf_prelu_execution(g_rng, inshape, base_axis, shared, slope_init, fix_parameters):

slope_shape = tuple() if shared else (inshape[base_axis],)
slope_init = process_param_init(slope_init, slope_shape, g_rng)

kw = {}
insert_if_not_none(kw, 'slope_init', slope_init)
insert_if_not_default(kw, 'base_axis', base_axis, 1)
insert_if_not_default(kw, 'shared', shared, True)
insert_if_not_default(kw, 'fix_parameters', fix_parameters, False)

x = nn.Variable.from_numpy_array(g_rng.randn(*inshape))

# Check execution
y = PF.prelu(x, **kw)
y.forward()
y.backward()

# Check values
# TODO

# Check args
assert y.parent.info.type_name == 'PReLU'
args = y.parent.info.args
assert args['base_axis'] == base_axis

# Check created parameters
assert y.parent.inputs[0] == x
assert len(y.parent.inputs) == 2
assert len(nn.get_parameters()) == 1
slope = nn.get_parameters()['prelu/slope']
assert slope.shape == slope_shape
assert slope.need_grad
assert y.parent.inputs[1].need_grad == (not fix_parameters)
if isinstance(slope_init, np.ndarray):
assert np.allclose(slope_init, slope.d)


# TODO: Test all parametric functions.
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.