Skip to content

Commit

Permalink
tf_saved_model: Disallow duplicate bound inputs.
Browse files Browse the repository at this point in the history
This is a useful invariant because (together with a local check that resource Value's are not passed twice in the same argument list to e.g. a called function) guarantees that resource variables don't alias in a module with tf_saved_model semantics.

PiperOrigin-RevId: 282003375
Change-Id: I7ba0dbda9a6ee3c734b4503fc7f68b09b505a758
  • Loading branch information
silvasean authored and tensorflower-gardener committed Nov 22, 2019
1 parent 2f889d7 commit f44a805
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 23 deletions.
66 changes: 43 additions & 23 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
Expand Down Expand Up @@ -186,6 +187,46 @@ static LogicalResult VerifySavedModelModule(
return success();
}

LogicalResult VerifyExportedFunc(FuncOp func) {
bool reached_bound_inputs = false;
for (int i = 0, e = func.getNumArguments(); i < e; i++) {
if (func.getArgAttr(i, "tf_saved_model.bound_input")) {
reached_bound_inputs = true;
continue;
}
if (func.getArgAttr(i, "tf_saved_model.index_path")) {
if (reached_bound_inputs) {
return func.emitError()
<< "all 'tf_saved_model.index_path' arg attributes should "
"precede all 'tf_saved_model.bound_input' arg attributes";
}
continue;
}
return func.emitError()
<< "all arguments should have 'tf_saved_model.index_path' or "
"'tf_saved_model.bound_input' attributes";
}
llvm::SmallDenseSet<StringRef, 8> unique_bound_inputs;
for (int i = 0, e = func.getNumArguments(); i < e; i++) {
if (auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
i, "tf_saved_model.bound_input")) {
if (!unique_bound_inputs.insert(attr.getValue()).second) {
return func.emitError()
<< "duplicate 'tf_saved_model.bound_input' binding";
}
}
}

for (int i = 0, e = func.getNumResults(); i < e; i++) {
if (!func.getResultAttr(i, "tf_saved_model.index_path")) {
return func.emitError() << "all results should have "
"'tf_saved_model.index_path' attributes";
}
}

return success();
}

LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute(
Operation *op, NamedAttribute named_attr) {
if (named_attr.first == "tf_saved_model.exported_names") {
Expand All @@ -204,29 +245,8 @@ LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute(
"'tf_saved_model.semantics'";
}
if (auto func = dyn_cast<FuncOp>(op)) {
bool reached_bound_inputs = false;
for (int i = 0, e = func.getNumArguments(); i < e; i++) {
if (func.getArgAttr(i, "tf_saved_model.bound_input")) {
reached_bound_inputs = true;
continue;
}
if (func.getArgAttr(i, "tf_saved_model.index_path")) {
if (reached_bound_inputs) {
return op->emitError()
<< "all 'tf_saved_model.index_path' arg attributes should "
"precede all 'tf_saved_model.bound_input' arg attributes";
}
continue;
}
return op->emitError()
<< "all arguments should have 'tf_saved_model.index_path' or "
"'tf_saved_model.bound_input' attributes";
}
for (int i = 0, e = func.getNumResults(); i < e; i++) {
if (!func.getResultAttr(i, "tf_saved_model.index_path")) {
return op->emitError() << "all results should have "
"'tf_saved_model.index_path' attributes";
}
if (failed(VerifyExportedFunc(func))) {
return failure();
}
}
return success();
Expand Down
Expand Up @@ -212,3 +212,16 @@ module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{'type' and 'value' attributes should have compatible tensor types}}
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v0", type = tensor<3xf32>, value = dense<42.0> : tensor<9xf32> } : () -> ()
}

// -----

module attributes {tf_saved_model.semantics} {
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.0> : tensor<f32> } : () -> ()
// expected-error@+1 {{duplicate 'tf_saved_model.bound_input' binding}}
func @f(
%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v},
%arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}
) attributes {tf_saved_model.exported_names = ["f"]} {
return
}
}

0 comments on commit f44a805

Please sign in to comment.