Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unsafe_div op for division by zero #19105

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 7 additions & 8 deletions tensorflow/cc/gradients/math_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -441,21 +441,20 @@ Status RealDivGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);

Status UnsafeDivGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
Status DivNoNanGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
auto x_1 = ConjugateHelper(scope, op.input(0));
auto x_2 = ConjugateHelper(scope, op.input(1));
// y = x_1 / x_2
// dy/dx_1 = 1/x_2
// dy/dx_2 = -x_1/x_2^2
auto gx_1 = UnsafeDiv(scope, grad_inputs[0], x_2);
auto gx_2 =
Mul(scope, grad_inputs[0],
UnsafeDiv(scope, UnsafeDiv(scope, Neg(scope, x_1), x_2), x_2));
auto gx_1 = DivNoNan(scope, grad_inputs[0], x_2);
auto gx_2 = Mul(scope, grad_inputs[0],
DivNoNan(scope, DivNoNan(scope, Neg(scope, x_1), x_2), x_2));
return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
}
REGISTER_GRADIENT_OP("UnsafeDiv", UnsafeDivGrad);
REGISTER_GRADIENT_OP("DivNoNan", DivNoNanGrad);

Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/cc/gradients/math_grad_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ using ops::AddN;
using ops::BatchMatMul;
using ops::Const;
using ops::Div;
using ops::DivNoNan;
using ops::MatMul;
using ops::Max;
using ops::Maximum;
Expand All @@ -48,7 +49,6 @@ using ops::SegmentSum;
using ops::SquaredDifference;
using ops::Sub;
using ops::Sum;
using ops::UnsafeDiv;

// TODO(andydavis) Test gradient function against numeric gradients output.
// TODO(andydavis) As more gradients are added move common test functions
Expand Down Expand Up @@ -854,21 +854,21 @@ TEST_F(NaryGradTest, RealDiv) {
RunTest({x}, {x_shape}, {y}, {x_shape});
}

TEST_F(NaryGradTest, UnsafeDiv) {
TEST_F(NaryGradTest, DivNoNan) {
{
TensorShape x_shape({3, 2, 5});
const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
// Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
// division errors in the numeric estimator used by the gradient checker.
const auto y = UnsafeDiv(
const auto y = DivNoNan(
scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
RunTest({x}, {x_shape}, {y}, {x_shape});
}
{
// Return 0 gradient (rather than NaN) for division by zero.
const auto x = Placeholder(scope_, DT_FLOAT);
const auto zero = Const<float>(scope_, 0.0);
const auto y = UnsafeDiv(scope_, x, zero);
const auto y = DivNoNan(scope_, x, zero);

std::vector<Output> grad_outputs;
TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs));
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
op {
graph_op_name: "DivNoNan"
summary: "Returns 0 if the denominator is zero."
description: <<END

*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
END
}
5 changes: 0 additions & 5 deletions tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt

This file was deleted.

4 changes: 4 additions & 0 deletions tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
op {
graph_op_name: "DivNoNan"
visibility: HIDDEN
}
4 changes: 0 additions & 4 deletions tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt

This file was deleted.

3 changes: 1 addition & 2 deletions tensorflow/core/kernels/cwise_op_div.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ REGISTER5(BinaryOp, CPU, "TruncateDiv", functor::safe_div, uint8, uint16, int16,
int32, int64);
REGISTER6(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double,
bfloat16, complex64, complex128);
REGISTER5(BinaryOp, CPU, "UnsafeDiv", functor::unsafe_div, float, double, int16,
int32, int64);
REGISTER2(BinaryOp, CPU, "DivNoNan", functor::div_no_nan, float, double);

#if GOOGLE_CUDA
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/core/kernels/cwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> {
};

template <typename T>
struct unsafe_div_op {
EIGEN_EMPTY_STRUCT_CTOR(unsafe_div_op)
struct div_no_nan_op {
EIGEN_EMPTY_STRUCT_CTOR(div_no_nan_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a,
const T& b) const {
if (b != 0) {
Expand All @@ -167,7 +167,7 @@ struct unsafe_div_op {
};

template <typename T>
struct functor_traits<unsafe_div_op<T>> {
struct functor_traits<div_no_nan_op<T>> {
enum {
Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::AddCost,
PacketAccess = false,
Expand Down Expand Up @@ -742,7 +742,7 @@ struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op<
};

template <typename T>
struct unsafe_div : base<T, Eigen::internal::unsafe_div_op<T>> {};
struct div_no_nan : base<T, Eigen::internal::div_no_nan_op<T>> {};

template <typename T>
struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {};
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/core/ops/math_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,18 +495,18 @@ Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("RealDiv", RealDivGrad);

Status UnsafeDivGrad(const AttrSlice& attrs, FunctionDef* g) {
Status DivNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForBinaryCwise(g, {
{{"gx"}, "UnsafeDiv", {"dz", "y"}},
{{"gx"}, "DivNoNan", {"dz", "y"}},
{{"nx"}, "Neg", {"x"}, {}, {"dz"}},
{{"y2"}, "Square", {"y"}, {}, {"dz"}},
{{"nx_y2"}, "UnsafeDiv", {"nx", "y2"}},
{{"nx_y2"}, "DivNoNan", {"nx", "y2"}},
{{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
});
// clang-format on
}
REGISTER_OP_GRADIENT("UnsafeDiv", UnsafeDivGrad);
REGISTER_OP_GRADIENT("DivNoNan", DivNoNanGrad);

Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/core/ops/math_grad_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -753,14 +753,14 @@ TEST_F(MathGradTest, Div) {
}
}

TEST_F(MathGradTest, UnsafeDiv) {
TEST_F(MathGradTest, DivNoNan) {
auto x = test::AsTensor<float>(
{0.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 0.f}, TensorShape({3, 3}));
auto y = test::AsTensor<float>({-10.f, 0.f, 10.f}, TensorShape({3, 1}));
Tensor dx;
Tensor dy;
{
SymGrad("UnsafeDiv", x, y, &dx, &dy);
SymGrad("DivNoNan", x, y, &dx, &dy);
{
auto g = [](float x, float y) {
if (y == 0.f) {
Expand Down Expand Up @@ -792,7 +792,7 @@ TEST_F(MathGradTest, UnsafeDiv) {
}
}
{ // Swap x and y.
SymGrad("UnsafeDiv", y, x, &dy, &dx);
SymGrad("DivNoNan", y, x, &dy, &dx);
{
auto g = [](float x, float y) {
if (y == 0.f) {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/ops/math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ Returns x * y element-wise.
REGISTER_OP("Div").BINARY_MORE().SetShapeFn(
shape_inference::BroadcastBinaryOpShapeFn);

REGISTER_OP("UnsafeDiv")
REGISTER_OP("DivNoNan")
.BINARY_MORE()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forget to replace BINARY_MORE by float32, float64 only. Push the change commit later.

.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/ops/math_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
"Mod", "Mul",
"NotEqual", "Pow",
"Sub", "SquaredDifference",
"UnsafeDiv"}) {
"DivNoNan"}) {
ShapeInferenceTestOp op(op_name);
INFER_OK(op, "?;?", "?");
INFER_OK(op, "[1,2];?", "?");
Expand Down
10 changes: 5 additions & 5 deletions tensorflow/python/ops/math_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,9 +972,9 @@ def _RealDivGrad(op, grad):
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy))


@ops.RegisterGradient("UnsafeDiv")
def _UnsafeDivGrad(op, grad):
"""UnsafeDiv op gradient."""
@ops.RegisterGradient("DivNoNan")
def _DivNoNanGrad(op, grad):
"""DivNoNan op gradient."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
Expand All @@ -983,10 +983,10 @@ def _UnsafeDivGrad(op, grad):
x = math_ops.conj(x)
y = math_ops.conj(y)
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.unsafe_div(grad, y), rx), sx),
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
grad * math_ops.unsafe_div(math_ops.unsafe_div(-x, y), y),
grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y),
ry), sy))


Expand Down
15 changes: 9 additions & 6 deletions tensorflow/python/ops/math_grad_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,12 @@ def testFloorModGradient(self):
self.assertLess(error, 1e-4)


class UnsafeDivGradientTest(test.TestCase):
class DivNoNanGradientTest(test.TestCase):

def testBasicGradient(self):
inputs = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32)
outputs = math_ops.unsafe_div(inputs, 1 + math_ops.abs(inputs))
inputs = constant_op.constant(np.arange(-3, 3),
dtype=dtypes.float32)
outputs = math_ops.div_no_nan(inputs, 1 + math_ops.abs(inputs))
with self.test_session():
error = gradient_checker.compute_gradient_error(
inputs,
Expand All @@ -244,9 +245,11 @@ def testBasicGradient(self):
self.assertLess(error, 1e-4)

def testGradientWithDenominatorIsZero(self):
x = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32)
y = array_ops.zeros_like(x, dtype=dtypes.float32)
outputs = math_ops.unsafe_div(x, y)
x = constant_op.constant(np.arange(-3, 3),
dtype=dtypes.float32)
y = array_ops.zeros_like(x,
dtype=dtypes.float32)
outputs = math_ops.div_no_nan(x, y)
with self.test_session():
dx, dy = gradients.gradients(outputs, [x, y])
self.assertAllClose(dx.eval(), np.zeros(x.shape.as_list()))
Expand Down
16 changes: 7 additions & 9 deletions tensorflow/python/ops/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,29 +1038,27 @@ def div(x, y, name=None):
return _div_python2(x, y, name)


def unsafe_div(x, y, name=None):
@tf_export("div_no_nan")
def div_no_nan(x, y, name=None):
"""Computes an unsafe divide which returns 0 if the y is zero.

Note that the function uses Python 3 division operator semantics.

Args:
x: A `Tensor`. Must be one of the following types:
`float32`, `float64`, `int16`, `int32`, `int64`.
x: A `Tensor`. Must be one of the following types: `float32`, `float64`.
y: A `Tensor` whose dtype is compatible with `x`.
name: A name for the operation (optional).
Returns:
The element-wise value of the x divided by y.
"""

with ops.name_scope(name, "unsafe_div", [x, y]) as name:
with ops.name_scope(name, "div_no_nan", [x, y]) as name:
x = ops.convert_to_tensor(x, name="x")
y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
x_dtype = x.dtype.base_dtype
y_dtype = y.dtype.base_dtype
if x_dtype != y_dtype:
raise TypeError(
"x and y must have the same dtype, got %r != %r" % (x_dtype, y_dtype))
return gen_math_ops.unsafe_div(x, y, name=name)
raise TypeError("x and y must have the same dtype, got %r != %r" %
(x_dtype, y_dtype))
return gen_math_ops.div_no_nan(x, y, name=name)


# TODO(aselle): This should be removed
Expand Down
17 changes: 9 additions & 8 deletions tensorflow/python/ops/math_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,18 +473,19 @@ def testConsistent(self):
self.assertAllEqual(tf_result, expanded_nums)


class UnsafeDivTest(test_util.TensorFlowTestCase):
class DivNoNanTest(test_util.TensorFlowTestCase):

def testBasic(self):
nums = np.arange(-10, 10, .25).reshape(80, 1)
divs = np.arange(-3, 3, .25).reshape(1, 24)
for dtype in [np.float32, np.float64]:
nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1)
divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24)

np_result = np.true_divide(nums, divs)
np_result[:, divs[0] == 0] = 0
np_result = np.true_divide(nums, divs)
np_result[:, divs[0] == 0] = 0

with self.test_session():
tf_result = math_ops.unsafe_div(nums, divs).eval()
self.assertAllEqual(tf_result, np_result)
with self.test_session():
tf_result = math_ops.div_no_nan(nums, divs).eval()
self.assertAllEqual(tf_result, np_result)


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/tools/api/golden/v2/tensorflow.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,10 @@ tf_module {
name: "div"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "div_no_nan"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "divide"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
Expand Down