Skip to content

Commit

Permalink
fix data type problem for all logic functions
Browse files Browse the repository at this point in the history
  • Loading branch information
TE-WoodyLi committed May 10, 2018
1 parent dcb2c00 commit 2d68631
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions python/test/function/test_logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_logical_scalar_forward_backward(val, seed, fname, ctx, func_name):
func = getattr(F, fname)
ref_func = getattr(np, fname.replace('_scalar', ''))
rng = np.random.RandomState(seed)
inputs = [rng.randint(0, 2, size=(2, 3, 4))]
inputs = [rng.randint(0, 2, size=(2, 3, 4)).astype(np.float32)]
function_tester(rng, func, ref_func, inputs, [val],
ctx=ctx, backward=[False], func_name=func_name)

Expand All @@ -42,6 +42,7 @@ def test_logical_scalar_forward_backward(val, seed, fname, ctx, func_name):
'greater': '>',
'greater_equal': '>=',
'less': '<',
'less_equal': '<=',
'equal': '==',
'not_equal': '!='}

Expand All @@ -51,6 +52,7 @@ def test_logical_scalar_forward_backward(val, seed, fname, ctx, func_name):
list_ctx_and_func_name(['greater_scalar',
'greater_equal_scalar',
'less_scalar',
'less_equal_scalar',
'equal_scalar',
'not_equal_scalar']))
@pytest.mark.parametrize("val", [-0.5, 0., 1.])
Expand All @@ -59,7 +61,7 @@ def test_logical_scalar_compare_forward_backward(val, seed, fname, ctx, func_nam
func = getattr(F, fname)
ref_func = eval('lambda x, y: x {} y'.format(opstr))
rng = np.random.RandomState(seed)
inputs = [rng.randint(0, 2, size=(2, 3, 4)) for _ in range(1)]
inputs = [rng.randint(0, 2, size=(2, 3, 4)).astype(np.float32) for _ in range(1)]
inputs[0][..., :2] = val
function_tester(rng, func, ref_func, inputs, [val],
ctx=ctx, backward=[False, False], func_name=func_name)
Expand All @@ -86,14 +88,15 @@ def test_logical_binary_forward_backward(seed, fname, ctx, func_name):
list_ctx_and_func_name(['greater',
'greater_equal',
'less',
'less_equal',
'equal',
'not_equal']))
def test_logical_binary_compare_forward_backward(seed, fname, ctx, func_name):
func = getattr(F, fname)
opstr = opstrs[fname]
ref_func = eval('lambda x, y: x {} y'.format(opstr))
rng = np.random.RandomState(seed)
inputs = [rng.randint(0, 2, size=(2, 3, 4)) for _ in range(2)]
inputs = [rng.randint(0, 2, size=(2, 3, 4)).astype(np.float32) for _ in range(2)]
inputs[0][..., :2] = inputs[1][..., :2]
function_tester(rng, func, ref_func, inputs,
ctx=ctx, backward=[False, False], func_name=func_name)
Expand All @@ -108,6 +111,6 @@ def test_logical_not_forward_backward(seed, fname, ctx, func_name):
func = getattr(F, fname)
ref_func = getattr(np, fname)
rng = np.random.RandomState(seed)
inputs = [rng.randint(0, 2, size=(2, 3, 4))]
inputs = [rng.randint(0, 2, size=(2, 3, 4)).astype(np.float32)]
function_tester(rng, func, ref_func, inputs,
ctx=ctx, backward=[False], func_name=func_name)

0 comments on commit 2d68631

Please sign in to comment.