Skip to content

Commit

Permalink
Fixes #6365
Browse files Browse the repository at this point in the history
Added gradient to tf.mod

PiperOrigin-RevId: 170522376
  • Loading branch information
tensorflower-gardener committed Sep 29, 2017
1 parent c41cae3 commit eb25081
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
38 changes: 33 additions & 5 deletions tensorflow/python/ops/math_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ def _SegmentMinOrMaxGrad(op, grad, is_sorted):
num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype),
op.inputs[1])
else:
num_selected = math_ops.unsorted_segment_sum(math_ops.cast(is_selected, grad.dtype),
op.inputs[1], op.inputs[2])
num_selected = math_ops.unsorted_segment_sum(
math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2])

# Compute the gradient for each segment. The gradient for the ith segment is
# divided evenly among the selected elements in that segment.
Expand Down Expand Up @@ -315,7 +315,9 @@ def _SquareGrad(op, grad):
@ops.RegisterGradient("Sqrt")
def _SqrtGrad(op, grad):
y = op.outputs[0] # y = x^(1/2)
# pylint: disable=protected-access
return gen_math_ops._sqrt_grad(y, grad)
# pylint: enable=protected-access


@ops.RegisterGradient("SqrtGrad")
Expand All @@ -331,7 +333,9 @@ def _SqrtGradGrad(op, grad):
def _RsqrtGrad(op, grad):
"""Returns -0.5 * grad * conj(y)^3."""
y = op.outputs[0] # y = x^(-1/2)
# pylint: disable=protected-access
return gen_math_ops._rsqrt_grad(y, grad)
# pylint: enable=protected-access


@ops.RegisterGradient("RsqrtGrad")
Expand Down Expand Up @@ -499,7 +503,9 @@ def _IgammaGrad(op, grad):
x = op.inputs[1]
sa = array_ops.shape(a)
sx = array_ops.shape(x)
# pylint: disable=protected-access
unused_ra, rx = gen_array_ops._broadcast_gradient_args(sa, sx)
# pylint: enable=protected-access

# Perform operations in log space before summing, because Gamma(a)
# and Gamma'(a) can grow large.
Expand Down Expand Up @@ -552,7 +558,9 @@ def _ZetaGrad(op, grad):
# Broadcast gradients
sx = array_ops.shape(x)
sq = array_ops.shape(q)
# pylint: disable=protected-access
unused_rx, rq = gen_array_ops._broadcast_gradient_args(sx, sq)
# pylint: enable=protected-access
# Evaluate gradient
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
Expand All @@ -572,7 +580,9 @@ def _PolygammaGrad(op, grad):
# Broadcast gradients
sn = array_ops.shape(n)
sx = array_ops.shape(x)
# pylint: disable=protected-access
unused_rn, rx = gen_array_ops._broadcast_gradient_args(sn, sx)
# pylint: enable=protected-access
# Evaluate gradient
with ops.control_dependencies([grad]):
n = math_ops.conj(n)
Expand Down Expand Up @@ -700,7 +710,9 @@ def _AddGrad(op, grad):
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
# pylint: disable=protected-access
rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
# pylint: enable=protected-access
return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
array_ops.reshape(math_ops.reduce_sum(grad, ry), sy))

Expand All @@ -711,7 +723,9 @@ def _SubGrad(op, grad):
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
# pylint: disable=protected-access
rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
# pylint: enable=protected-access
return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy))

Expand All @@ -724,7 +738,9 @@ def _MulGrad(op, grad):
assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
sx = array_ops.shape(x)
sy = array_ops.shape(y)
# pylint: disable=protected-access
rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
# pylint: enable=protected-access
x = math_ops.conj(x)
y = math_ops.conj(y)
return (array_ops.reshape(math_ops.reduce_sum(grad * y, rx), sx),
Expand Down Expand Up @@ -756,9 +772,21 @@ def _FloorDivGrad(_, unused_grad):


@ops.RegisterGradient("FloorMod")
def _FloorModGrad(_, unused_grad):
"""The gradient for the FloorMod operator."""
return None, None
def _FloorModGrad(op, grad):
"""Returns grad * (1, -floor(x/y))."""
x = math_ops.conj(op.inputs[0])
y = math_ops.conj(op.inputs[1])

sx = array_ops.shape(x)
sy = array_ops.shape(y)
# pylint: disable=protected-access
rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
# pylint: enable=protected-access
floor_xy = math_ops.floor_div(x, y)
gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
gy = array_ops.reshape(
math_ops.reduce_sum(grad * math_ops.negative(floor_xy), ry), sy)
return gx, gy


@ops.RegisterGradient("TruncateDiv")
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/python/ops/math_grad_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,19 @@ def testSegmentMaxGradientWithTies(self):
self.assertLess(error, 1e-4)


class FloorModGradientTest(test.TestCase):

def testFloorModGradient(self):
# Making sure the input is not near the discontinuity point where
# x/y == floor(x/y)
ns = constant_op.constant([17.], dtype=dtypes.float32)
inputs = constant_op.constant([131.], dtype=dtypes.float32)
floor_mod = math_ops.floormod(inputs, ns)
with self.test_session():
error = gradient_checker.compute_gradient_error(inputs, [1],
floor_mod, [1])
self.assertLess(error, 1e-4)


if __name__ == "__main__":
test.main()

0 comments on commit eb25081

Please sign in to comment.