Skip to content

Commit

Permalink
fix data type problem
Browse files Browse the repository at this point in the history
  • Loading branch information
TE-WoodyLi committed May 10, 2018
1 parent cb3b8f0 commit 6ca02d1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions python/test/function/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def test_embed_forward_backward(seed, shape_x, shape_w, ctx, func_name):
from nbla_test_utils import cap_ignore_region, function_tester
rng = np.random.RandomState(seed)
n_class = shape_w[0]
x = np.random.randint(0, n_class - 1, shape_x)
w = np.random.randn(*shape_w)
x = np.random.randint(0, n_class - 1, shape_x).astype(np.int32)
w = np.random.randn(*shape_w).astype(np.float32)
inputs = [x, w]
function_tester(rng, F.embed, lambda x, w: w[x], inputs,
ctx=ctx, func_name=func_name, atol_b=1e-2,
Expand Down
2 changes: 1 addition & 1 deletion python/test/function/test_logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_logical_binary_forward_backward(seed, fname, ctx, func_name):
func = getattr(F, fname)
ref_func = getattr(np, fname)
rng = np.random.RandomState(seed)
inputs = [rng.randint(0, 2, size=(2, 3, 4)) for _ in range(2)]
inputs = [rng.randint(0, 2, size=(2, 3, 4)).astype(np.float32) for _ in range(2)]
function_tester(rng, func, ref_func, inputs,
ctx=ctx, backward=[False, False], func_name=func_name)

Expand Down

0 comments on commit 6ca02d1

Please sign in to comment.