Skip to content

Commit

Permalink
Improve flatbuffer verification.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 189668634
  • Loading branch information
tensorflower-gardener committed Mar 20, 2018
1 parent 2bd7f5e commit 41335ab
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 30 deletions.
2 changes: 2 additions & 0 deletions tensorflow/contrib/lite/toco/tflite/BUILD
Expand Up @@ -115,9 +115,11 @@ cc_library(
deps = [
":operator",
":types",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/schema:schema_fbs",
"//tensorflow/contrib/lite/toco:model",
"//tensorflow/contrib/lite/toco:tooling_util",
"//tensorflow/contrib/lite/tools:verifier",
"@flatbuffers",
],
)
Expand Down
7 changes: 5 additions & 2 deletions tensorflow/contrib/lite/toco/tflite/import.cc
Expand Up @@ -15,10 +15,12 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/tflite/import.h"

#include "flatbuffers/flexbuffers.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/toco/tflite/operator.h"
#include "tensorflow/contrib/lite/toco/tflite/types.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
#include "tensorflow/contrib/lite/tools/verifier.h"

namespace toco {

Expand Down Expand Up @@ -171,10 +173,11 @@ bool Verify(const void* buf, size_t len) {

std::unique_ptr<Model> Import(const ModelFlags& model_flags,
const string& input_file_contents) {
if (!Verify(input_file_contents.data(), input_file_contents.size())) {
::tflite::AlwaysTrueResolver r;
if (!::tflite::Verify(input_file_contents.data(), input_file_contents.size(),
r, ::tflite::DefaultErrorReporter())) {
LOG(FATAL) << "Invalid flatbuffer.";
}

const ::tflite::Model* input_model =
::tflite::GetModel(input_file_contents.data());

Expand Down
106 changes: 88 additions & 18 deletions tensorflow/contrib/lite/toco/tflite/import_test.cc
Expand Up @@ -36,12 +36,13 @@ class ImportTest : public ::testing::Test {
return builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.data()),
sizeof(T) * data.size());
}

Offset<Vector<Offset<::tflite::Buffer>>> BuildBuffers() {
auto buf0 = ::tflite::CreateBuffer(builder_, CreateDataVector<float>({}));
auto buf1 =
::tflite::CreateBuffer(builder_, CreateDataVector<float>({1.0f, 2.0f}));
auto buf1 = ::tflite::CreateBuffer(
builder_, CreateDataVector<float>({1.0f, 2.0f, 3.0f, 4.0f}));
auto buf2 =
::tflite::CreateBuffer(builder_, CreateDataVector<float>({3.0f}));
::tflite::CreateBuffer(builder_, CreateDataVector<float>({3.0f, 4.0f}));
return builder_.CreateVector(
std::vector<Offset<::tflite::Buffer>>({buf0, buf1, buf2}));
}
Expand All @@ -53,10 +54,10 @@ class ImportTest : public ::testing::Test {
/*max=*/builder_.CreateVector<float>({0.2f}),
/*scale=*/builder_.CreateVector<float>({0.3f}),
/*zero_point=*/builder_.CreateVector<int64_t>({100ll}));
auto t1 = ::tflite::CreateTensor(builder_,
builder_.CreateVector<int>({1, 2, 3, 4}),
::tflite::TensorType_FLOAT32, 1,
builder_.CreateString("tensor_one"), q);
auto t1 =
::tflite::CreateTensor(builder_, builder_.CreateVector<int>({1, 2, 2}),
::tflite::TensorType_FLOAT32, 1,
builder_.CreateString("tensor_one"), q);
auto t2 =
::tflite::CreateTensor(builder_, builder_.CreateVector<int>({2, 1}),
::tflite::TensorType_FLOAT32, 2,
Expand All @@ -65,18 +66,26 @@ class ImportTest : public ::testing::Test {
std::vector<Offset<::tflite::Tensor>>({t1, t2}));
}

Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes(
std::initializer_list<::tflite::BuiltinOperator> op_codes) {
std::vector<Offset<::tflite::OperatorCode>> op_codes_vector;
for (auto op : op_codes) {
op_codes_vector.push_back(::tflite::CreateOperatorCode(builder_, op, 0));
}
return builder_.CreateVector(op_codes_vector);
}

Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes() {
auto c1 = ::tflite::CreateOperatorCode(
builder_, ::tflite::BuiltinOperator_MAX_POOL_2D, 0);
auto c2 = ::tflite::CreateOperatorCode(
builder_, ::tflite::BuiltinOperator_CONV_2D, 0);
return builder_.CreateVector(
std::vector<Offset<::tflite::OperatorCode>>({c1, c2}));
return BuildOpCodes({::tflite::BuiltinOperator_MAX_POOL_2D,
::tflite::BuiltinOperator_CONV_2D});
}

Offset<Vector<Offset<::tflite::Operator>>> BuildOperators() {
auto is = builder_.CreateVector<int>({0});
auto os = builder_.CreateVector<int>({1});
Offset<Vector<Offset<::tflite::Operator>>> BuildOperators(
std::initializer_list<int> inputs, std::initializer_list<int> outputs) {
auto is = builder_.CreateVector<int>(inputs);
if (inputs.size() == 0) is = 0;
auto os = builder_.CreateVector<int>(outputs);
if (outputs.size() == 0) os = 0;
auto op = ::tflite::CreateOperator(
builder_, 0, is, os, ::tflite::BuiltinOptions_Conv2DOptions,
::tflite::CreateConv2DOptions(builder_, ::tflite::Padding_VALID, 1, 1,
Expand All @@ -87,6 +96,10 @@ class ImportTest : public ::testing::Test {
return builder_.CreateVector(std::vector<Offset<::tflite::Operator>>({op}));
}

Offset<Vector<Offset<::tflite::Operator>>> BuildOperators() {
return BuildOperators({0}, {1});
}

Offset<Vector<Offset<::tflite::SubGraph>>> BuildSubGraphs(
Offset<Vector<Offset<::tflite::Tensor>>> tensors,
Offset<Vector<Offset<::tflite::Operator>>> operators,
Expand Down Expand Up @@ -154,9 +167,9 @@ TEST_F(ImportTest, Tensors) {
Array& a1 = model->GetArray("tensor_one");
EXPECT_EQ(ArrayDataType::kFloat, a1.data_type);
EXPECT_THAT(a1.GetBuffer<ArrayDataType::kFloat>().data,
ElementsAre(1.0f, 2.0f));
ElementsAre(1.0f, 2.0f, 3.0f, 4.0f));
ASSERT_TRUE(a1.has_shape());
EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 3, 4));
EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 2));

const auto& mm = a1.minmax;
ASSERT_TRUE(mm.get());
Expand All @@ -169,6 +182,63 @@ TEST_F(ImportTest, Tensors) {
EXPECT_EQ(100, q->zero_point);
}

TEST_F(ImportTest, NoBuffers) {
auto buffers = 0;
auto tensors = BuildTensors();
auto opcodes = BuildOpCodes();
auto operators = BuildOperators();
auto subgraphs = BuildSubGraphs(tensors, operators);
auto comment = builder_.CreateString("");
::tflite::FinishModelBuffer(
builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
subgraphs, comment, buffers));
EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
"Missing 'buffers' section.");
}

TEST_F(ImportTest, NoInputs) {
auto buffers = BuildBuffers();
auto tensors = BuildTensors();
auto opcodes = BuildOpCodes();
auto operators = BuildOperators({}, {1});
auto subgraphs = BuildSubGraphs(tensors, operators);
auto comment = builder_.CreateString("");
::tflite::FinishModelBuffer(
builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
subgraphs, comment, buffers));
EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
"Missing 'inputs' for operator.");
}

TEST_F(ImportTest, NoOutputs) {
auto buffers = BuildBuffers();
auto tensors = BuildTensors();
auto opcodes = BuildOpCodes();
auto operators = BuildOperators({0}, {});
auto subgraphs = BuildSubGraphs(tensors, operators);
auto comment = builder_.CreateString("");
::tflite::FinishModelBuffer(
builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
subgraphs, comment, buffers));
EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
"Missing 'outputs' for operator.");
}

TEST_F(ImportTest, InvalidOpCode) {
auto buffers = BuildBuffers();
auto tensors = BuildTensors();
auto opcodes = BuildOpCodes({static_cast<::tflite::BuiltinOperator>(-1),
::tflite::BuiltinOperator_CONV_2D});
auto operators = BuildOperators();
auto subgraphs = BuildSubGraphs(tensors, operators);
auto comment = builder_.CreateString("");
::tflite::FinishModelBuffer(
builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
subgraphs, comment, buffers));
EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
"Operator id '-1' is out of range.");
}

TEST_F(ImportTest, MultipleSubGraphs) {
auto buffers = BuildBuffers();
auto tensors = BuildTensors();
Expand Down
71 changes: 63 additions & 8 deletions tensorflow/contrib/lite/tools/verifier.cc
Expand Up @@ -148,11 +148,52 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer,
// TODO(yichengfan): verify quantized tensors.
}

using flatbuffers::Offset;
using flatbuffers::Vector;

bool VerifyOperators(const Vector<Offset<Operator>>& operators,
ErrorReporter* error_reporter) {
for (const auto& op : operators) {
if (!op->inputs()) {
ReportError(error_reporter, "Missing 'inputs' for operator.");
return false;
}
if (!op->outputs()) {
ReportError(error_reporter, "Missing 'outputs' for operator.");
return false;
}
}
return true;
}

bool VerifySubGraphs(const Model& model, ErrorReporter* error_reporter) {
if (!model.subgraphs()) {
ReportError(error_reporter, "Missing 'subgraphs' section.");
return false;
}
for (const auto& subgraph : *model.subgraphs()) {
if (!subgraph->operators()) {
ReportError(error_reporter, "Missing 'operators' section in subgraph.");
return false;
}

if (!VerifyOperators(*subgraph->operators(), error_reporter)) {
return false;
}
}
return true;
}

// Verifies tensors have valid properties and legit buffer if set.
bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) {
if (!model.subgraphs()) {
return true;
}
if (!model.buffers()) {
ReportError(error_reporter, "Missing 'buffers' section.");
return false;
}

for (const auto& subgraph : *model.subgraphs()) {
if (!subgraph->tensors()) {
continue;
Expand All @@ -167,19 +208,23 @@ bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) {
return false;
}
auto* buffer = model.buffers()->Get(tensor->buffer());
if (!buffer || !buffer->data()) {
if (!buffer) {
ReportError(error_reporter, "Tensor buffer %d not set",
tensor->buffer());
return false;
}

if (tensor->type() == TensorType_STRING) {
if (!VerifyStringTensorBuffer(*buffer, error_reporter)) {
return false;
}
} else {
if (!VerifyNumericTensorBuffer(*tensor, *buffer, error_reporter)) {
return false;
// Many transient tensors don't have data in the flatbuffer. Their
// buffers will be allocated by the interpreter at run-time.
if (buffer->data()) {
if (tensor->type() == TensorType_STRING) {
if (!VerifyStringTensorBuffer(*buffer, error_reporter)) {
return false;
}
} else {
if (!VerifyNumericTensorBuffer(*tensor, *buffer, error_reporter)) {
return false;
}
}
}
}
Expand All @@ -193,6 +238,13 @@ bool VerifyOps(const Model& model, const OpResolver& resolver,
return true;
}
for (const auto& opcode : *model.operator_codes()) {
if (opcode->builtin_code() < BuiltinOperator_MIN ||
opcode->builtin_code() > BuiltinOperator_MAX) {
ReportError(error_reporter, "Operator id '%d' is out of range.",
opcode->builtin_code());
return false;
}

if (opcode->builtin_code() == BuiltinOperator_CUSTOM) {
if (!resolver.FindOp(opcode->custom_code()->c_str())) {
ReportError(error_reporter, "Unsupported custom op: %s",
Expand Down Expand Up @@ -223,6 +275,9 @@ bool Verify(const void* buf, size_t len, const OpResolver& resolver,
ReportError(error_reporter, "Invalid model version %d", model->version());
return false;
}
if (!VerifySubGraphs(*model, error_reporter)) {
return false;
}
if (!VerifyTensors(*model, error_reporter)) {
return false;
}
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/contrib/lite/tools/verifier.h
Expand Up @@ -23,6 +23,21 @@ limitations under the License.

namespace tflite {

class AlwaysTrueResolver : public OpResolver {
public:
AlwaysTrueResolver() {}
TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override {
static TfLiteRegistration null_registration = {nullptr, nullptr, nullptr,
nullptr};
return &null_registration;
}
TfLiteRegistration* FindOp(const char* op) const override {
static TfLiteRegistration null_registration = {nullptr, nullptr, nullptr,
nullptr};
return &null_registration;
}
};

// Verifies the integrity of a Tensorflow Lite flatbuffer model file.
// Currently, it verifies:
// * The file is following a legit flatbuffer schema.
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/contrib/lite/tools/verifier_test.cc
Expand Up @@ -113,8 +113,8 @@ TEST(VerifyModel, TestEmptyModel) {
/*description=*/0, /*buffers=*/0);
::tflite::FinishModelBuffer(builder, model);

ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize(),
MutableOpResolver{}, DefaultErrorReporter()));
ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(),
MutableOpResolver{}, DefaultErrorReporter()));
}

TEST(VerifyModel, TestSimpleModel) {
Expand Down

0 comments on commit 41335ab

Please sign in to comment.