Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add uint16 support for py_func #18659

Merged
merged 9 commits into from
Apr 19, 2018
32 changes: 32 additions & 0 deletions tensorflow/python/kernel_tests/py_func_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,38 @@ class PyFuncTest(test.TestCase):
"""Encapsulates tests for py_func and eager_py_func."""

# ----- Tests for py_func -----
def testRealDataTypes(self):
def sum_func(x, y):
return x + y
for dtype in [dtypes.float16, dtypes.float32, dtypes.float64,
dtypes.uint8, dtypes.int8, dtypes.uint16, dtypes.int16,
dtypes.int32, dtypes.int64]:
with self.test_session():
x = constant_op.constant(1, dtype=dtype)
y = constant_op.constant(2, dtype=dtype)
z = self.evaluate(script_ops.py_func(sum_func, [x, y], dtype))
self.assertEqual(z, 3)

def testComplexDataTypes(self):
def sub_func(x, y):
return x - y
for dtype in [dtypes.complex64, dtypes.complex128]:
with self.test_session():
x = constant_op.constant(1 + 1j, dtype=dtype)
y = constant_op.constant(2 - 2j, dtype=dtype)
z = self.evaluate(script_ops.py_func(sub_func, [x, y], dtype))
self.assertEqual(z, -1 + 3j)

def testBoolDataTypes(self):
def and_func(x, y):
return x and y
dtype = dtypes.bool
with self.test_session():
x = constant_op.constant(True, dtype=dtype)
y = constant_op.constant(False, dtype=dtype)
z = self.evaluate(script_ops.py_func(and_func, [x, y], dtype))
self.assertEqual(z, False)

def testSingleType(self):
with self.test_session():
x = constant_op.constant(1.0, dtypes.float32)
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/python/lib/core/py_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ Status NumericNpDTypeToTfDType(const int np, DataType* tf) {
case NPY_INT8:
*tf = DT_INT8;
break;
case NPY_UINT16:
*tf = DT_UINT16;
break;
case NPY_INT16:
*tf = DT_INT16;
break;
Expand Down