Skip to content

Commit

Permalink
Strengthen input verification for SpecializeType by replacing DCHECK …
Browse files Browse the repository at this point in the history
…with explicit test/status return.

PiperOrigin-RevId: 453436708
  • Loading branch information
Dan Moldovan authored and tensorflower-gardener committed Jun 7, 2022
1 parent c6c1755 commit 6104f0d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tensorflow/core/framework/full_type_util.cc
Expand Up @@ -175,7 +175,11 @@ Status SubstituteVar(AttrMap& attrs, FullTypeDef& t) {
}

Status SubstituteForEach(AttrMap& attrs, FullTypeDef& t) {
DCHECK_EQ(t.args_size(), 3);
if (t.args_size() != 3) {
return Status(error::INVALID_ARGUMENT,
absl::StrCat("illegal FOR_EACH type, expected 3 args, got ",
t.args_size()));
}

const auto& cont = t.args(0);
const auto& tmpl = t.args(1);
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/core/framework/full_type_util_test.cc
Expand Up @@ -510,6 +510,19 @@ TEST(SpecializeType, ForEachOverridesTargetOfNestedForEach) {
EXPECT_EQ(t_actual.args(1).args(0).args(0).args_size(), 0);
}

TEST(SpecializeType, ForEachRejectsMalformedInput) {
OpDef op;
FullTypeDef* t = op.add_output_arg()->mutable_experimental_full_type();
t->set_type_id(TFT_FOR_EACH);
t->add_args()->set_type_id(TFT_PRODUCT);

NodeDef ndef;
AttrSlice attrs(ndef);

FullTypeDef ft;
EXPECT_FALSE(SpecializeType(attrs, op, ft).ok());
}

TEST(SpecializeType, RemovesLegacyVariant) {
OpDef op;
FullTypeDef* t = op.add_output_arg()->mutable_experimental_full_type();
Expand Down

0 comments on commit 6104f0d

Please sign in to comment.