Skip to content

Commit

Permalink
Merge branch 'feature/20180510-fix-sample-data' into 'feature/2018040…
Browse files Browse the repository at this point in the history
…1-file-format-converter'

Feature/20180510 fix sample data

See merge request nnabla/nnabla!165
  • Loading branch information
YukioOobuchi committed May 28, 2018
2 parents 1263aa9 + 2d68631 commit 1e14865
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
3 changes: 2 additions & 1 deletion python/src/nnabla/utils/converter/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,9 @@ Neural Network Layer:
x0:
doc: Indices with shape :math:`(I_0, ..., I_N)`
template: TI
x1:
w:
doc: Weights with shape :math:`(W_0, ..., W_M)`
parameter: true
outputs:
y:
doc: Output with shape :math:`(I_0, ..., I_N, W_1, ..., W_M)`
Expand Down
4 changes: 2 additions & 2 deletions python/test/function/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def test_embed_forward_backward(seed, shape_x, shape_w, ctx, func_name):
from nbla_test_utils import cap_ignore_region, function_tester
rng = np.random.RandomState(seed)
n_class = shape_w[0]
x = np.random.randint(0, n_class - 1, shape_x)
w = np.random.randn(*shape_w)
x = np.random.randint(0, n_class - 1, shape_x).astype(np.int32)
w = np.random.randn(*shape_w).astype(np.float32)
inputs = [x, w]
function_tester(rng, F.embed, lambda x, w: w[x], inputs,
ctx=ctx, func_name=func_name, atol_b=1e-2,
Expand Down
13 changes: 8 additions & 5 deletions python/test/function/test_logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_logical_scalar_forward_backward(val, seed, fname, ctx, func_name):
func = getattr(F, fname)
ref_func = getattr(np, fname.replace('_scalar', ''))
rng = np.random.RandomState(seed)
inputs = [rng.randint(0, 2, size=(2, 3, 4))]
inputs = [rng.randint(0, 2, size=(2, 3, 4)).astype(np.float32)]
function_tester(rng, func, ref_func, inputs, [val],
ctx=ctx, backward=[False], func_name=func_name)

Expand All @@ -42,6 +42,7 @@ def test_logical_scalar_forward_backward(val, seed, fname, ctx, func_name):
'greater': '>',
'greater_equal': '>=',
'less': '<',
'less_equal': '<=',
'equal': '==',
'not_equal': '!='}

Expand All @@ -51,6 +52,7 @@ def test_logical_scalar_forward_backward(val, seed, fname, ctx, func_name):
list_ctx_and_func_name(['greater_scalar',
'greater_equal_scalar',
'less_scalar',
'less_equal_scalar',
'equal_scalar',
'not_equal_scalar']))
@pytest.mark.parametrize("val", [-0.5, 0., 1.])
Expand All @@ -59,7 +61,7 @@ def test_logical_scalar_compare_forward_backward(val, seed, fname, ctx, func_nam
func = getattr(F, fname)
ref_func = eval('lambda x, y: x {} y'.format(opstr))
rng = np.random.RandomState(seed)
inputs = [rng.randint(0, 2, size=(2, 3, 4)) for _ in range(1)]
inputs = [rng.randint(0, 2, size=(2, 3, 4)).astype(np.float32) for _ in range(1)]
inputs[0][..., :2] = val
function_tester(rng, func, ref_func, inputs, [val],
ctx=ctx, backward=[False, False], func_name=func_name)
Expand All @@ -76,7 +78,7 @@ def test_logical_binary_forward_backward(seed, fname, ctx, func_name):
func = getattr(F, fname)
ref_func = getattr(np, fname)
rng = np.random.RandomState(seed)
inputs = [rng.randint(0, 2, size=(2, 3, 4)) for _ in range(2)]
inputs = [rng.randint(0, 2, size=(2, 3, 4)).astype(np.float32) for _ in range(2)]
function_tester(rng, func, ref_func, inputs,
ctx=ctx, backward=[False, False], func_name=func_name)

Expand All @@ -86,14 +88,15 @@ def test_logical_binary_forward_backward(seed, fname, ctx, func_name):
list_ctx_and_func_name(['greater',
'greater_equal',
'less',
'less_equal',
'equal',
'not_equal']))
def test_logical_binary_compare_forward_backward(seed, fname, ctx, func_name):
func = getattr(F, fname)
opstr = opstrs[fname]
ref_func = eval('lambda x, y: x {} y'.format(opstr))
rng = np.random.RandomState(seed)
inputs = [rng.randint(0, 2, size=(2, 3, 4)) for _ in range(2)]
inputs = [rng.randint(0, 2, size=(2, 3, 4)).astype(np.float32) for _ in range(2)]
inputs[0][..., :2] = inputs[1][..., :2]
function_tester(rng, func, ref_func, inputs,
ctx=ctx, backward=[False, False], func_name=func_name)
Expand All @@ -108,6 +111,6 @@ def test_logical_not_forward_backward(seed, fname, ctx, func_name):
func = getattr(F, fname)
ref_func = getattr(np, fname)
rng = np.random.RandomState(seed)
inputs = [rng.randint(0, 2, size=(2, 3, 4))]
inputs = [rng.randint(0, 2, size=(2, 3, 4)).astype(np.float32)]
function_tester(rng, func, ref_func, inputs,
ctx=ctx, backward=[False], func_name=func_name)

0 comments on commit 1e14865

Please sign in to comment.