From 35e3eed195eaf61325ee421a0bd6c81b3d439417 Mon Sep 17 00:00:00 2001 From: Amit Srivastava Date: Mon, 8 Apr 2019 14:42:58 +0530 Subject: [PATCH] CleanUp & TC added for L2_pool Graph transformation. Clean-up: replace FindOperator with FindOp and add missing unit tests for the L2_pool graph transformation --- .../graph_transformations/identify_l2_pool.cc | 20 +---- .../toco/graph_transformations/tests/BUILD | 11 +++ .../tests/identify_l2_pool_test.cc | 88 +++++++++++++++++++ 3 files changed, 102 insertions(+), 17 deletions(-) create mode 100644 tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc diff --git a/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc index 6e0a7cdc31af2b..152330931ef0b6 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc @@ -24,20 +24,6 @@ limitations under the License. namespace toco { -namespace { - -std::vector>::iterator FindOperator( - Model* model, const Operator* op) { - auto it = model->operators.begin(); - for (; it != model->operators.end(); ++it) { - if (it->get() == op) { - break; - } - } - return it; -} -} // namespace - ::tensorflow::Status IdentifyL2Pool::Run(Model* model, std::size_t op_index, bool* modified) { *modified = false; @@ -105,9 +91,9 @@ ::tensorflow::Status IdentifyL2Pool::Run(Model* model, std::size_t op_index, model->EraseArray(sqrt_op->inputs[0]); // Erase three operators being replaced. - model->operators.erase(FindOperator(model, square_op)); - model->operators.erase(FindOperator(model, avpool_op)); - model->operators.erase(FindOperator(model, sqrt_op)); + model->operators.erase(FindOp(*model, square_op)); + model->operators.erase(FindOp(*model, avpool_op)); + model->operators.erase(FindOp(*model, sqrt_op)); *modified = true; return ::tensorflow::Status::OK(); diff --git a/tensorflow/lite/toco/graph_transformations/tests/BUILD b/tensorflow/lite/toco/graph_transformations/tests/BUILD index 03d331226d885e..07a7a473578c5e 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/BUILD +++ b/tensorflow/lite/toco/graph_transformations/tests/BUILD @@ -18,6 +18,17 @@ tf_cc_test( ], ) +tf_cc_test( + name = "identify_l2_pool_test", + srcs = ["identify_l2_pool_test.cc"], + deps = [ + "//tensorflow/lite/toco:graph_transformations", + "//tensorflow/lite/toco:model", + "//tensorflow/lite/toco:tooling_util", + "@com_google_googletest//:gtest_main", + ], +) + tf_cc_test( name = "resolve_constant_concatenation_test", srcs = ["resolve_constant_concatenation_test.cc"], diff --git a/tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc new file mode 100644 index 00000000000000..bfb2acf3aa6fe8 --- /dev/null +++ b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include +#include "absl/memory/memory.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +void RunIdentifyL2Pool(const std::vector& input, + const std::vector& input_shape, + const std::vector& output_shape) { + Model model; + Array& input0 = model.GetOrCreateArray("input0"); + Array& output = model.GetOrCreateArray("output"); + + *input0.mutable_shape()->mutable_dims() = input_shape; + input0.data_type = ArrayDataType::kFloat; + input0.GetMutableBuffer().data = input; + + *output.mutable_shape()->mutable_dims() = output_shape; + + auto sq_op = new TensorFlowSquareOperator; + sq_op->inputs = {"input0"}; + sq_op->outputs = {"output"}; + + Array& avgpooloutput = model.GetOrCreateArray("Avgpooloutput"); + *avgpooloutput.mutable_shape()->mutable_dims() = output_shape; + + auto avgpool_op = new AveragePoolOperator; + avgpool_op->inputs = {sq_op->outputs[0]}; + avgpool_op->outputs = {"Avgpooloutput"}; + + Array& sqrtoutput = model.GetOrCreateArray("Sqrtoutput"); + *sqrtoutput.mutable_shape()->mutable_dims() = output_shape; + + auto sqrt_op = new TensorFlowSqrtOperator; + sqrt_op->inputs = {avgpool_op->outputs[0]}; + sqrt_op->outputs = {"Sqrtoutput"}; + + /*Stack everything with the model*/ + model.operators.push_back(std::unique_ptr(sqrt_op)); + model.operators.push_back(std::unique_ptr(avgpool_op)); + model.operators.push_back(std::unique_ptr(sq_op)); + + bool modified; + ASSERT_TRUE(IdentifyL2Pool().Run(&model, 0, &modified).ok()); + for (auto& op_it : model.operators) { + Operator* op = op_it.get(); + // Since the optimization has kicked in we should not find any + // Square, avgpool & Sqrt operators + EXPECT_FALSE(op->type == OperatorType::kSqrt); + EXPECT_FALSE(op->type == OperatorType::kAveragePool); + EXPECT_FALSE(op->type == OperatorType::kSquare); + } +} +} // namespace + +TEST(IdentifyL2Pool, SimpleTest) { + RunIdentifyL2Pool( + // Input data + {3, 1, 4, 1, -5, 9, -2, 6, 5, 3, 5, 8}, + + // Input shape + {3, 4}, + + {3, 4}); +} + +} // namespace toco