Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add gradient for broadcast_to #22083

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 44 additions & 0 deletions tensorflow/python/kernel_tests/broadcast_to_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.platform import test as test_lib


Expand Down Expand Up @@ -81,5 +83,47 @@ def testBroadcastToShapeTypeAndInference(self):
# check shape inference when shape input is constant
self.assertAllEqual(shape, v_np.shape)

def testGradientForScalar(self):
# TODO(alextp): There is a bug with broadcast_to on GPU from scalars,
# hence we make this test cpu-only.
with ops.device("cpu:0"):
x = constant_op.constant(1, dtype=dtypes.float32)
v = array_ops.broadcast_to(x, [2, 4, 3])
out = 2 * v
with self.test_session():
err = gradient_checker.compute_gradient_error(x, x.get_shape(),
out, out.get_shape())
self.assertLess(err, 1e-4)

def testGradientWithSameRank(self):
x = constant_op.constant(np.reshape(np.arange(6), (2, 1, 3)),
dtype=dtypes.float32)
v = array_ops.broadcast_to(x, [2, 5, 3])
out = 2 * v
with self.test_session():
err = gradient_checker.compute_gradient_error(x, x.get_shape(),
out, out.get_shape())
self.assertLess(err, 1e-4)

def testGradientWithIncreasingRank(self):
x = constant_op.constant([[1], [2]],
dtype=dtypes.float32)
v = array_ops.broadcast_to(x, [5, 2, 3])
out = 2 * v
with self.test_session():
err = gradient_checker.compute_gradient_error(x, x.get_shape(),
out, out.get_shape())
self.assertLess(err, 1e-4)

def testGradientWithBroadcastAllDimensions(self):
x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32)
v = array_ops.broadcast_to(x, [5, 4, 6])
out = 2 * v
with self.test_session():
err = gradient_checker.compute_gradient_error(x, x.get_shape(),
out, out.get_shape())
self.assertLess(err, 1e-4)


if __name__ == "__main__":
test_lib.main()
19 changes: 19 additions & 0 deletions tensorflow/python/ops/array_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,3 +805,22 @@ def _ScatterNdNonAliasingAddGrad(op, grad):
indices = op.inputs[1]
updates_grad = array_ops.gather_nd(grad, indices)
return [grad, None, updates_grad]


@ops.RegisterGradient("BroadcastTo")
def _BroadcastToGrad(op, grad):
input_value = op.inputs[0]
broadcast_shape = op.inputs[1]
# Assign ids for each position in input_value.
input_value_shape = array_ops.shape(input_value)
input_value_size = array_ops.size(input_value)
ids = array_ops.reshape(math_ops.range(input_value_size), input_value_shape)
broadcast_ids = array_ops.broadcast_to(ids, broadcast_shape)
# Group by ids and sum its gradients.
grad_flatten = array_ops.reshape(grad, [-1])
broadcast_ids_flatten = array_ops.reshape(broadcast_ids, [-1])
updates_grad_flatten = math_ops.unsorted_segment_sum(grad_flatten,
broadcast_ids_flatten,
input_value_size)
updates_grad = array_ops.reshape(updates_grad_flatten, input_value_shape)
return [updates_grad, None]