Skip to content
Permalink
Browse files Browse the repository at this point in the history
[tfg] Fix null type attribute
Type attributes can be null and the input should be `dyn_cast_or_null`.

PiperOrigin-RevId: 449889247
  • Loading branch information
tensorflower-gardener committed May 20, 2022
1 parent 4b04497 commit 3a75474
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tensorflow/core/ir/importexport/graphdef_import.cc
Expand Up @@ -699,7 +699,8 @@ StatusOr<unsigned> GraphDefImporter::ArgNumType(const NamedAttrList &attrs,
SmallVectorImpl<Type> &types) {
// Check whether a type list attribute is specified.
if (!arg_def.type_list_attr().empty()) {
if (auto v = attrs.get(arg_def.type_list_attr()).dyn_cast<ArrayAttr>()) {
if (auto v =
attrs.get(arg_def.type_list_attr()).dyn_cast_or_null<ArrayAttr>()) {
for (Attribute attr : v) {
if (auto dtype = attr.dyn_cast<TypeAttr>()) {
types.push_back(UnrankedTensorType::get(dtype.getValue()));
Expand All @@ -716,7 +717,8 @@ StatusOr<unsigned> GraphDefImporter::ArgNumType(const NamedAttrList &attrs,
unsigned num = 1;
// Check whether a number attribute is specified.
if (!arg_def.number_attr().empty()) {
if (auto v = attrs.get(arg_def.number_attr()).dyn_cast<IntegerAttr>()) {
if (auto v =
attrs.get(arg_def.number_attr()).dyn_cast_or_null<IntegerAttr>()) {
num = v.getValue().getZExtValue();
} else {
return NotFound("Type attr not found: ", arg_def.number_attr());
Expand All @@ -731,7 +733,7 @@ StatusOr<unsigned> GraphDefImporter::ArgNumType(const NamedAttrList &attrs,
return InvalidArgument("Arg '", arg_def.name(),
"' has invalid type and no type attribute");
} else {
if (auto v = attrs.get(arg_def.type_attr()).dyn_cast<TypeAttr>()) {
if (auto v = attrs.get(arg_def.type_attr()).dyn_cast_or_null<TypeAttr>()) {
dtype = v.getValue();
} else {
return NotFound("Type attr not found: ", arg_def.type_attr());
Expand Down
@@ -0,0 +1,21 @@
# RUN: not tfg-translate -graphdef-to-mlir %s 2>&1 | FileCheck %s

# CHECK: Type attr not found

library {
function {
signature {
name: "\344\264\264"
description: "value"
is_distributed_communication: true
}
node_def {
op: "Const"
input: "|"
}
control_ret {
key: ""
value: ""
}
}
}

0 comments on commit 3a75474

Please sign in to comment.