Skip to content
Permalink
Browse files Browse the repository at this point in the history
TFLite: Error out when the graph has a recurion.
Recursion is currently unsupported.

PiperOrigin-RevId: 371708957
Change-Id: I8dfad0d85cbfe08e39ae8ea7bad21254ddee5003
  • Loading branch information
miaout17 authored and tensorflower-gardener committed May 3, 2021
1 parent 46b80bd commit c6173f5
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 2 deletions.
1 change: 1 addition & 0 deletions tensorflow/lite/BUILD
Expand Up @@ -643,6 +643,7 @@ cc_test(
"testdata/test_min_runtime.bin",
"testdata/test_model.bin",
"testdata/test_model_broken.bin",
"testdata/unsupported_recursion.bin",
"testdata/while_op_with_forwarding_input.bin",
],
tags = [
Expand Down
46 changes: 46 additions & 0 deletions tensorflow/lite/core/subgraph.cc
Expand Up @@ -156,6 +156,42 @@ const char* GetTFLiteOpName(const TfLiteRegistration& op_reg) {
return tflite::EnumNamesBuiltinOperator()[op_reg.builtin_code];
}

// An utility test to detect if the subgraph is abused:
// 1. Detects if recursion exists in the graph (recursion is not currently
// supported.
// 2. Detects if the interpreter / subgraph is used in multiple subgraphs.
// Note: It's clearly documented that the interpreter / subgraph are not
// thread-safe. This serves as a check with possible false negatives
// unless we switch to atomic boolean flags.
class SubgraphGuard {
public:
SubgraphGuard(TfLiteContext* context, bool* is_subgraph_in_use)
: is_subgraph_in_use_(is_subgraph_in_use) {
if (*is_subgraph_in_use_) {
TF_LITE_KERNEL_LOG(
context,
"Subgraph is already in use. Using an interpreter or a subgraph in "
"multiple threads is not supported. Recursion in the graph is not "
"supported.");
status_ = kTfLiteError;
} else {
*is_subgraph_in_use_ = true;
}
}
~SubgraphGuard() {
// If tht original status was OK, recover the boolean flag.
if (status_ == kTfLiteOk) {
*is_subgraph_in_use_ = false;
}
}

TfLiteStatus status() const { return status_; }

private:
TfLiteStatus status_ = kTfLiteOk;
bool* is_subgraph_in_use_;
};

} // namespace

// A trivial implementation of GraphInfo around the Interpreter.
Expand Down Expand Up @@ -655,6 +691,7 @@ TfLiteStatus Subgraph::BytesRequired(TfLiteType type, const int* dims,

TfLiteStatus Subgraph::AllocateTensors() {
TFLITE_SCOPED_TAGGED_DEFAULT_PROFILE(profiler_.get(), "AllocateTensors");

if (!consistent_) {
ReportError("AllocateTensors() called on inconsistent model.");
return kTfLiteError;
Expand All @@ -678,6 +715,12 @@ TfLiteStatus Subgraph::AllocateTensors() {
return kTfLiteOk;
}

// Note `AllocateTensors` sometimes calls itself recursively above
// for delegates. Therefore only the logic below need to be guarded
// by `SubgraphGuard`.
SubgraphGuard guard(&context_, &is_subgraph_in_use_);
TF_LITE_ENSURE_OK(&context_, guard.status());

next_execution_plan_index_to_prepare_ = 0;
next_execution_plan_index_to_plan_allocation_ = 0;
next_original_execution_plan_index_to_prepare_ = 0;
Expand Down Expand Up @@ -1014,6 +1057,9 @@ TfLiteStatus Subgraph::PrepareOpsAndTensors() {
}

TfLiteStatus Subgraph::Invoke() {
SubgraphGuard guard(&context_, &is_subgraph_in_use_);
TF_LITE_ENSURE_OK(&context_, guard.status());

if (!consistent_) {
ReportError("Invoke called on model that is not consistent.");
return kTfLiteError;
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/lite/core/subgraph.h
Expand Up @@ -759,6 +759,10 @@ class Subgraph {
// Whether memory planner should be instantiated to retain intermediates for
// debugging.
bool preserve_all_tensors_ = false;

// Whether the subgraph is currently in use (e.g. running the `Invoke`
// or `AllocateTensors` functions).
bool is_subgraph_in_use_ = false;
};

} // namespace tflite
Expand Down
2 changes: 0 additions & 2 deletions tensorflow/lite/kernels/while.cc
Expand Up @@ -138,8 +138,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* subgraphs = this_subgraph->GetSubgraphs();
TF_LITE_ENSURE(context, op_data->cond_subgraph_index < subgraphs->size());
TF_LITE_ENSURE(context, op_data->body_subgraph_index < subgraphs->size());
TF_LITE_ENSURE(context,
op_data->cond_subgraph_index != op_data->body_subgraph_index);

Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
Expand Down
19 changes: 19 additions & 0 deletions tensorflow/lite/model_test.cc
Expand Up @@ -600,6 +600,25 @@ TEST(BasicFlatBufferModel, TestHandleMalformedModelReuseTensor) {
ASSERT_NE(interpreter->AllocateTensors(), kTfLiteOk);
}

// Recursion & reentrant are not supported in TFLite.
// The test ensures it fails gracefullly instead of crashing with
// a stack overflow.
TEST(BasicFlatBufferModel, TestUnsupportedRecursion) {
const auto model_path =
"tensorflow/lite/testdata/unsupported_recursion.bin";

std::unique_ptr<tflite::FlatBufferModel> model =
FlatBufferModel::BuildFromFile(model_path);
ASSERT_NE(model, nullptr);

tflite::ops::builtin::BuiltinOpResolver resolver;
InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> interpreter;
ASSERT_EQ(builder(&interpreter), kTfLiteOk);
ASSERT_NE(interpreter, nullptr);
ASSERT_NE(interpreter->AllocateTensors(), kTfLiteOk);
}

// The models here have a buffer index for a tensor pointing to a null buffer.
// This results in the tensor being interpreted as read-write, but the model
// assumes the tensor is read-only. As such, `interpreter->Invoke()` would
Expand Down
Binary file added tensorflow/lite/testdata/unsupported_recursion.bin
Binary file not shown.

0 comments on commit c6173f5

Please sign in to comment.