Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added operation tf.math.reciprocal_no_nan() #30888

Merged
merged 9 commits into from Jul 29, 2019
30 changes: 30 additions & 0 deletions tensorflow/python/ops/math_ops.py
Expand Up @@ -4003,3 +4003,33 @@ def polyval(coeffs, x, name=None):
for c in coeffs[1:]:
p = c + p * x
return p

@tf_export("math.reciprocal_no_nan")
def reciprocal_no_nan(x, name=None):
"""Performs a safe reciprocal operation, element wise.
If a particular element is zero, the reciprocal for that element is
also set to zero.

For example:
```python
x = tf.constant([2.0, 0.5, 0, 1], dtype=tf.float32)
tf.math.reciprocal_no_nan(x) # [ 0.5, 2, 0.0, 1.0 ]
```

Args:
x: A `Tensor` of type `float16`, `float32`, `float64`
`complex64` or `complex128`.
name: A name for the operation (optional).

Returns:
A `Tensor` of same shape and type as `x`.

Raises:
TypeError: x must be of a valid dtype.

"""

with ops.name_scope(name, "reciprocal_no_nan", [x]) as scope:
x = ops.convert_to_tensor(x, name="x")
one = constant_op.constant(1, dtype=x.dtype.base_dtype, name="one")
return gen_math_ops.div_no_nan(one, x, name=scope)
29 changes: 29 additions & 0 deletions tensorflow/python/ops/math_ops_test.py
Expand Up @@ -699,5 +699,34 @@ def testErrorReceivedIfDtypeMismatchFromOp(self):
a = array_ops.ones([1], dtype=dtypes.int32) + 1.0
self.evaluate(a)


class ReciprocalNoNanTest(test_util.TensorFlowTestCase):

allowed_dtypes = [dtypes.float16, dtypes.float32, dtypes.float64,
dtypes.complex64, dtypes.complex128]

@test_util.run_in_graph_and_eager_modes
def testBasic(self):
for dtype in self.allowed_dtypes:
x = constant_op.constant([1.0, 2.0, 0.0, 4.0], dtype=dtype)

y = math_ops.reciprocal_no_nan(x)

target = constant_op.constant([1.0, 0.5, 0.0, 0.25], dtype=dtype)

self.assertAllEqual(y, target)
self.assertEqual(y.dtype.base_dtype, target.dtype.base_dtype)

@test_util.run_in_graph_and_eager_modes
def testInverse(self):
for dtype in self.allowed_dtypes:
x = np.random.choice([0, 1, 2, 4, 5], size=(5, 5, 5))
x = constant_op.constant(x, dtype=dtype)

y = math_ops.reciprocal_no_nan(math_ops.reciprocal_no_nan(x))

self.assertAllClose(y, x)
self.assertEqual(y.dtype.base_dtype, x.dtype.base_dtype)

if __name__ == "__main__":
googletest.main()
4 changes: 4 additions & 0 deletions tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
Expand Up @@ -308,6 +308,10 @@ tf_module {
name: "reciprocal"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "reciprocal_no_nan"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "reduce_all"
argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
Expand Up @@ -308,6 +308,10 @@ tf_module {
name: "reciprocal"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "reciprocal_no_nan"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "reduce_all"
argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
Expand Down