Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix for check failure in CollectiveGather op.
The fix is required only for Eager mode, graph mode already checked for input shape in shape inference pass.

PiperOrigin-RevId: 460801136
  • Loading branch information
ishark authored and tensorflower-gardener committed Jul 13, 2022
1 parent 7741dc5 commit c1f4918
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tensorflow/core/kernels/collective_ops.cc
Expand Up @@ -176,6 +176,10 @@ class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
auto output_shape = c->input(0).shape();
OP_REQUIRES_ASYNC(c, output_shape.dims() > 0,
errors::InvalidArgument("input should have rank > 0, ",
"recieved ", output_shape.dims()),
done);
output_shape.set_dim(
0, output_shape.dim_size(0) * col_params_->group.group_size);
col_params_->instance.shape = output_shape;
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/python/ops/collective_ops_test.py
Expand Up @@ -451,6 +451,20 @@ def testCollectiveGroupSizeMismatch(self):
])
context.ensure_initialized()

@test_util.run_v2_only
def testCollectiveGatherShapeCheckFailure(self):
with self.assertRaisesRegex(errors.InvalidArgumentError,
'input should have rank > 0'):
collective_ops.gen_collective_ops.CollectiveGather(
input=1,
group_size=1,
group_key=1,
instance_key=1,
shape=(3, 3, 3),
communication_hint='auto',
timeout_seconds=0,
name='')

@def_function.function
def run_all_reduce():
group_key = 10
Expand Down

0 comments on commit c1f4918

Please sign in to comment.