Skip to content
Permalink
Browse files Browse the repository at this point in the history
Update TPU AllToAll op to avoid divide by 0.
PiperOrigin-RevId: 400259638
Change-Id: Ic4cfe4fe7159da38caed8044ee005f898e42cd86
  • Loading branch information
bfontain authored and tensorflower-gardener committed Oct 1, 2021
1 parent c847822 commit a8ad3e5
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tensorflow/core/ops/tpu_cross_replica_ops.cc
Expand Up @@ -32,6 +32,7 @@ REGISTER_OP("AllToAll")
.Attr("split_count: int")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
ShapeHandle group_assignment = c->input(1);
if (!c->RankKnown(input)) {
c->set_output(0, c->UnknownShape());
return Status::OK();
Expand All @@ -42,6 +43,21 @@ REGISTER_OP("AllToAll")
int split_dimension;
int split_count;
TF_RETURN_IF_ERROR(c->GetAttr("split_count", &split_count));
if (split_count < 1) {
return errors::InvalidArgument("split_count ", split_count,
" must at least be one.");
}
if (c->RankKnown(group_assignment) && c->Rank(group_assignment) != 2) {
return errors::InvalidArgument("group_assignment must have rank 2.");
}
DimensionHandle num_replicas_per_group = c->Dim(group_assignment, 1);
if (c->ValueKnown(num_replicas_per_group) &&
(c->Value(num_replicas_per_group) != split_count)) {
return errors::InvalidArgument(
"split_count ", split_count,
" must equal the size of the second dimension of group_assignment ",
c->Value(num_replicas_per_group));
}

TF_RETURN_IF_ERROR(c->GetAttr("concat_dimension", &concat_dimension));

Expand All @@ -65,6 +81,12 @@ REGISTER_OP("AllToAll")
dims[i] = c->MakeDim(c->Value(dims[i]) * split_count);
}
if (i == split_dimension) {
if (c->ValueKnown(dims[i]) &&
(c->Value(dims[i]) % split_count != 0)) {
return errors::InvalidArgument(
"input dimension ", c->Value(dims[i]),
" not divisible by split_count ", split_count);
}
dims[i] = c->MakeDim(c->Value(dims[i]) / split_count);
}
}
Expand Down
46 changes: 46 additions & 0 deletions tensorflow/python/tpu/tpu_test.py
Expand Up @@ -32,6 +32,7 @@
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_feed
from tensorflow.python.tpu import training_loop
from tensorflow.python.tpu.ops import tpu_ops


class TPUContextTest(test.TestCase):
Expand Down Expand Up @@ -165,6 +166,51 @@ def test_prune_unconnected_ops(self):
graph.get_operation_by_name("import/y").get_attr(
tpu._TPU_REPLICATE_ATTR)


class TPUOpsTest(test.TestCase):

def test_all_to_all_zero_split_count(self):
with self.assertRaisesRegex(
ValueError, "split_count 0 must at least be one"):
tpu_ops.all_to_all(
x=[0.0, 0.1652, 0.6543],
group_assignment=[1, -1],
concat_dimension=0,
split_dimension=0,
split_count=0)

def test_all_to_all_group_assignment_wrong_shape(self):
with self.assertRaisesRegex(
ValueError, "group_assignment must have rank 2"):
tpu_ops.all_to_all(
x=[0.0, 0.1652, 0.6543],
group_assignment=[1, -1],
concat_dimension=0,
split_dimension=0,
split_count=2)

def test_all_to_all_split_count_not_equal_to_group_assignment_shape(self):
with self.assertRaisesRegex(
ValueError, "split_count 1 must equal the size of the second dimension "
"of group_assignment 2"):
tpu_ops.all_to_all(
x=[0.0, 0.1652, 0.6543],
group_assignment=[[0, 1], [2, 3]],
concat_dimension=0,
split_dimension=0,
split_count=1)

def test_all_to_all_split_count_not_divide_input_shape(self):
with self.assertRaisesRegex(
ValueError, "input dimension 3 not divisible by split_count 2"):
tpu_ops.all_to_all(
x=[[0.0], [0.1652], [0.6543]],
group_assignment=[[0, 1], [2, 3]],
concat_dimension=1,
split_dimension=0,
split_count=2)


def do_einsum():
a = array_ops.placeholder(dtype=dtypes.float32, name="a", shape=[2, 3, 4])
b = array_ops.placeholder(dtype=dtypes.float32, name="b", shape=[2, 4, 5])
Expand Down

0 comments on commit a8ad3e5

Please sign in to comment.