From 6104f0d4091c260ce9352f9155f7e9b725eab012 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Tue, 7 Jun 2022 07:56:19 -0700 Subject: [PATCH] Strengthen input verification for SpecializeType by replacing DCHECK with explicit test/status return. PiperOrigin-RevId: 453436708 --- tensorflow/core/framework/full_type_util.cc | 6 +++++- tensorflow/core/framework/full_type_util_test.cc | 13 +++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/framework/full_type_util.cc b/tensorflow/core/framework/full_type_util.cc index 97e9381df13da1..a8636420f11fa5 100644 --- a/tensorflow/core/framework/full_type_util.cc +++ b/tensorflow/core/framework/full_type_util.cc @@ -175,7 +175,11 @@ Status SubstituteVar(AttrMap& attrs, FullTypeDef& t) { } Status SubstituteForEach(AttrMap& attrs, FullTypeDef& t) { - DCHECK_EQ(t.args_size(), 3); + if (t.args_size() != 3) { + return Status(error::INVALID_ARGUMENT, + absl::StrCat("illegal FOR_EACH type, expected 3 args, got ", + t.args_size())); + } const auto& cont = t.args(0); const auto& tmpl = t.args(1); diff --git a/tensorflow/core/framework/full_type_util_test.cc b/tensorflow/core/framework/full_type_util_test.cc index 0324e64f96b0f9..6037879d069ff9 100644 --- a/tensorflow/core/framework/full_type_util_test.cc +++ b/tensorflow/core/framework/full_type_util_test.cc @@ -510,6 +510,19 @@ TEST(SpecializeType, ForEachOverridesTargetOfNestedForEach) { EXPECT_EQ(t_actual.args(1).args(0).args(0).args_size(), 0); } +TEST(SpecializeType, ForEachRejectsMalformedInput) { + OpDef op; + FullTypeDef* t = op.add_output_arg()->mutable_experimental_full_type(); + t->set_type_id(TFT_FOR_EACH); + t->add_args()->set_type_id(TFT_PRODUCT); + + NodeDef ndef; + AttrSlice attrs(ndef); + + FullTypeDef ft; + EXPECT_FALSE(SpecializeType(attrs, op, ft).ok()); +} + TEST(SpecializeType, RemovesLegacyVariant) { OpDef op; FullTypeDef* t = op.add_output_arg()->mutable_experimental_full_type();