Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix undefined behavior in CollectiveReduceV2 and others
We should not call done after it's moved.

PiperOrigin-RevId: 400838185
Change-Id: Ifc979740054b8f8c6f4d50acc89472fe60c4fdb1
  • Loading branch information
crccw authored and tensorflower-gardener committed Oct 4, 2021
1 parent bd1a138 commit ca38dab
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 11 deletions.
25 changes: 14 additions & 11 deletions tensorflow/core/kernels/collective_ops.cc
Expand Up @@ -494,15 +494,17 @@ class CollectiveOpV2Kernel : public AsyncOpKernel {
const Tensor& group_size, const Tensor& group_key,
const Tensor& instance_key) {
if (group_size.dims() > 0) {
return errors::Internal("Unexpected dimensions on input group_size, got ",
group_size.shape().DebugString());
return errors::InvalidArgument(
"Unexpected dimensions on input group_size, got ",
group_size.shape().DebugString());
}
if (group_key.dims() > 0) {
return errors::Internal("Unexpected dimensions on input group_key, got ",
group_key.shape().DebugString());
return errors::InvalidArgument(
"Unexpected dimensions on input group_key, got ",
group_key.shape().DebugString());
}
if (instance_key.dims() > 0) {
return errors::Internal(
return errors::InvalidArgument(
"Unexpected dimensions on input instance_key, got ",
instance_key.shape().DebugString());
}
Expand Down Expand Up @@ -625,7 +627,7 @@ class CollectiveReduceV2OpKernel : public CollectiveOpV2Kernel {
/*group_size*/ c->input(1),
/*group_key*/ c->input(2),
/*instance_key*/ c->input(3)),
done);
done_with_cleanup);
col_params->instance.shape = c->input(0).shape();
col_params->merge_op = merge_op_.get();
col_params->final_op = final_op_.get();
Expand Down Expand Up @@ -855,14 +857,15 @@ class CollectiveInitializeCommunicatorOpKernel : public AsyncOpKernel {

Status CheckInputs(Tensor group_size_t, Tensor group_key_t) {
if (group_size_t.dims() > 0) {
return errors::Internal(
return errors::InvalidArgument(
"Unexpected dimensions on input group_size. "
"It shoulbe a scalar, got tensor with shape ",
group_size_t.shape().DebugString());
}
if (group_key_t.dims() > 0) {
return errors::Internal("Unexpected dimensions on input group_key, got ",
group_key_t.shape().DebugString());
return errors::InvalidArgument(
"Unexpected dimensions on input group_key, got ",
group_key_t.shape().DebugString());
}

auto group_size = group_size_t.unaligned_flat<int32>()(0);
Expand Down Expand Up @@ -1084,7 +1087,7 @@ class CollectiveReduceV3OpKernel : public CollectiveOpV3Kernel {
};
core::RefCountPtr<CollectiveGroupResource> resource;
OP_REQUIRES_OK_ASYNC(c, LookupResource(c, HandleFromInput(c, 1), &resource),
done);
done_with_cleanup);

Tensor group_assignment = c->input(2);

Expand Down Expand Up @@ -1134,7 +1137,7 @@ class CollectiveAllToAllV3OpKernel : public CollectiveOpV3Kernel {
};
core::RefCountPtr<CollectiveGroupResource> resource;
OP_REQUIRES_OK_ASYNC(c, LookupResource(c, HandleFromInput(c, 1), &resource),
done);
done_with_cleanup);

Tensor group_assignment = c->input(2);

Expand Down
63 changes: 63 additions & 0 deletions tensorflow/python/kernel_tests/collective_ops_test.py
Expand Up @@ -1182,6 +1182,69 @@ def f():
self.assertAllEqual(self.evaluate(f()), [[3.], [3.]])


@combinations.generate(
combinations.times(
combinations.combine(collective_op=[
combinations.NamedObject('all_reduce_v2',
CollectiveOpsV2.all_reduce),
combinations.NamedObject('all_gather_v2',
CollectiveOpsV2.all_gather)
]), device_combination))
class InvalidInputTest(test.TestCase, parameterized.TestCase):

def setUp(self):
_setup_context()
super().setUp()

def testInvalidGroupKey(self, collective_op, device, communication):
dev0 = '/device:%s:0' % device
group_size = 2
group_key = [100]
instance_key = 100
in_tensor = constant_op.constant([1.])

with self.assertRaises(errors.InvalidArgumentError):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)

def testInvalidGroupSize(self, collective_op, device, communication):
dev0 = '/device:%s:0' % device
group_size = -2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])

with self.assertRaises(errors.InvalidArgumentError):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)

def testInvalidInstanceKey(self, collective_op, device, communication):
dev0 = '/device:%s:0' % device
group_size = 2
group_key = 100
instance_key = [100]
in_tensor = constant_op.constant([1.])

with self.assertRaises(errors.InvalidArgumentError):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)


class CollectiveOpsV3Test(test.TestCase, parameterized.TestCase):

def setUp(self):
Expand Down

0 comments on commit ca38dab

Please sign in to comment.