Skip to content

Commit

Permalink
Cast int8 to bool for lax.not in jax2tf.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomhennigan committed Jun 22, 2020
1 parent 8f4ba7e commit 98e28c5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -495,7 +495,6 @@ def _shift_right_logical(x, y):
tf_impl[lax.shift_right_logical_p] = _shift_right_logical

tf_impl[lax.shift_left_p] = tf.bitwise.left_shift
tf_impl[lax.not_p] = tf.bitwise.invert


def bool_to_int8(f, argnums):
Expand All @@ -514,6 +513,7 @@ def wrapper(*args, **kwargs):
tf_impl[lax.or_p] = bool_to_int8(tf.bitwise.bitwise_or, argnums=(0, 1))
tf_impl[lax.and_p] = bool_to_int8(tf.bitwise.bitwise_and, argnums=(0, 1))
tf_impl[lax.xor_p] = bool_to_int8(tf.bitwise.bitwise_xor, argnums=(0, 1))
tf_impl[lax.not_p] = bool_to_int8(tf.bitwise.invert, argnums=(0,))

tf_impl[lax.eq_p] = wrap_binary_op(tf.math.equal)
tf_impl[lax.ne_p] = wrap_binary_op(tf.math.not_equal)
Expand Down
14 changes: 14 additions & 0 deletions jax/experimental/jax2tf/tests/primitives_test.py
Expand Up @@ -223,6 +223,20 @@ def test_binary_logical_elementwise(self, f_jax):
self.assertAllClose(r_jax[np.isfinite(r_jax)],
r_tf[np.isfinite(r_tf)], atol=1e-4)

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{f_jax.__name__}",
f_jax=f_jax)
for f_jax in LAX_LOGICAL_ELEMENTWISE_BINARY))
def test_binary_logical_elementwise_bool(self, f_jax):
if f_jax == lax.shift_left:
self.skipTest("Shift of bool not supported")
a = np.array([0, 0, 1, 1, 0, 0, 1, 1], dtype=np.bool_)
b = np.array([0, 1, 0, 1, 0, 1, 0, 1], dtype=np.bool_)
f_tf = tf.function(jax2tf.convert(f_jax))
r_jax = f_jax(a, b)
r_tf = f_tf(a, b)
self.assertAllClose(r_jax, r_tf)

# TODO(necula): combine tests that are identical except for the harness
# wait until we get more experience with using harnesses.
@primitive_harness.parameterized(primitive_harness.lax_shift_left)
Expand Down

0 comments on commit 98e28c5

Please sign in to comment.