From 7c94871621f714a19443a36db2ff280b2150e382 Mon Sep 17 00:00:00 2001 From: Amit Srivastava Date: Mon, 8 Apr 2019 12:02:25 +0530 Subject: [PATCH] Cleanup & TC added l2_norm graph transformation. Clean-up: replace FindOperator with FindOp and add missing unit tests for the l2_norm graph transformation. --- .../identify_l2_normalization.cc | 16 +- .../toco/graph_transformations/tests/BUILD | 12 ++ .../tests/identify_l2_normalization_test.cc | 141 ++++++++++++++++++ 3 files changed, 154 insertions(+), 15 deletions(-) create mode 100644 tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc diff --git a/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc index 3b7c88ac62e48e..c8f453c254e0e0 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc @@ -25,20 +25,6 @@ limitations under the License. namespace toco { -namespace { - -std::vector>::iterator FindOperator( - Model* model, const Operator* op) { - auto it = model->operators.begin(); - for (; it != model->operators.end(); ++it) { - if (it->get() == op) { - break; - } - } - return it; -} -} // namespace - ::tensorflow::Status IdentifyL2Normalization::Run(Model* model, std::size_t op_index, bool* modified) { @@ -150,7 +136,7 @@ ::tensorflow::Status IdentifyL2Normalization::Run(Model* model, AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2norm_op)); // Erase the subgraph that is now replaced by L2Normalization - model->operators.erase(FindOperator(model, square_op)); + model->operators.erase(FindOp(*model, square_op)); DeleteOpAndArraysIfUnused(model, sum_op); if (add_op) { DeleteOpAndArraysIfUnused(model, add_op); diff --git a/tensorflow/lite/toco/graph_transformations/tests/BUILD b/tensorflow/lite/toco/graph_transformations/tests/BUILD index 03d331226d885e..d13f2b035098c9 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/BUILD +++ b/tensorflow/lite/toco/graph_transformations/tests/BUILD @@ -41,6 +41,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "identify_l2_normalization_test", + srcs = ["identify_l2_normalization_test.cc"], + deps = [ + "//tensorflow/lite/toco:graph_transformations", + "//tensorflow/lite/toco:model", + "//tensorflow/lite/toco:tooling_util", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + ], +) + tf_cc_test( name = "fuse_binary_into_following_affine_test", srcs = ["fuse_binary_into_following_affine_test.cc"], diff --git a/tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc new file mode 100644 index 00000000000000..4c55b7d6dcbb06 --- /dev/null +++ b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc @@ -0,0 +1,141 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include +#include "absl/memory/memory.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +void RunIdentifyL2Normalization(const std::vector& input, + const std::vector& input_shape, + const std::vector& output_shape, + const bool div_square = false) { + Model model; + Array& input0 = model.GetOrCreateArray("input0"); + Array& output = model.GetOrCreateArray("output"); + + *input0.mutable_shape()->mutable_dims() = input_shape; + input0.data_type = ArrayDataType::kFloat; + input0.GetMutableBuffer().data = input; + + *output.mutable_shape()->mutable_dims() = output_shape; + + auto sq_op = new TensorFlowSquareOperator; + sq_op->inputs = {"input0"}; + sq_op->outputs = {"output"}; + + Array& sumoutput = model.GetOrCreateArray("Sumoutput"); + *sumoutput.mutable_shape()->mutable_dims() = output_shape; + + auto sum_op = new TensorFlowSumOperator; + sum_op->inputs = {sq_op->outputs[0]}; + sum_op->outputs = {"Sumoutput"}; + + if (div_square) { + Array& sqrtoutput = model.GetOrCreateArray("squarertoutput"); + *sqrtoutput.mutable_shape()->mutable_dims() = output_shape; + + auto sqrt_op = new TensorFlowSqrtOperator; + sqrt_op->inputs = {sum_op->outputs[0]}; + sqrt_op->outputs = {"squarertoutput"}; + + Array& divoutput = model.GetOrCreateArray("Divoutput"); + *divoutput.mutable_shape()->mutable_dims() = output_shape; + + auto div_op = new DivOperator; + div_op->inputs = {"input0", sqrt_op->outputs[0]}; + div_op->outputs = {"Divoutput"}; + + /*Stack everything with the model*/ + model.operators.push_back(std::unique_ptr(div_op)); + model.operators.push_back(std::unique_ptr(sqrt_op)); + model.operators.push_back(std::unique_ptr(sum_op)); + model.operators.push_back(std::unique_ptr(sq_op)); + } else { + Array& rsqoutput = model.GetOrCreateArray("Rsquareoutput"); + *rsqoutput.mutable_shape()->mutable_dims() = output_shape; + + auto rsqrt_op = new TensorFlowRsqrtOperator; + rsqrt_op->inputs = {sum_op->outputs[0]}; + rsqrt_op->outputs = {"Rsquareoutput"}; + + Array& muloutput = model.GetOrCreateArray("Muloutput"); + *muloutput.mutable_shape()->mutable_dims() = output_shape; + + auto mul_op = new MulOperator; + mul_op->inputs = {"input0", rsqrt_op->outputs[0]}; + mul_op->outputs = {"Muloutput"}; + + /*Stack everything with the model*/ + model.operators.push_back(std::unique_ptr(mul_op)); + model.operators.push_back(std::unique_ptr(rsqrt_op)); + model.operators.push_back(std::unique_ptr(sum_op)); + model.operators.push_back(std::unique_ptr(sq_op)); + } + + bool modified; + ASSERT_TRUE(IdentifyL2Normalization().Run(&model, 0, &modified).ok()); + for (auto& op_it : model.operators) { + Operator* op = op_it.get(); + // Since the optimization has kicked in we should not find any + // Mul, Rsqrt, Add, Sqr operators + if (div_square) { + EXPECT_FALSE(op->type == OperatorType::kDiv); + EXPECT_FALSE(op->type == OperatorType::kSqrt); + } else { + EXPECT_FALSE(op->type == OperatorType::kMul); + EXPECT_FALSE(op->type == OperatorType::kRsqrt); + } + EXPECT_FALSE(op->type == OperatorType::kAdd); + EXPECT_FALSE(op->type == OperatorType::kSquare); + } +} + +// Test for reverse input in Min +TEST(IdentifyL2Normalization, MulRsqrtTest) { + RunIdentifyL2Normalization( + // Input data + {3, 1, 4, 1, -5, 9, -2, 6, 5, 3, 5, 8}, + + // Input shape + {3, 4}, + + {3, 4}, + + false); +} + +TEST(IdentifyL2Normalization, DivSqrtNormTest) { + RunIdentifyL2Normalization( + // Input data + {3, 1, 4, 1, -5, 9, -2, 6, 5, 3, 5, 8}, + + // Input shape + {3, 4}, + + {3, 4}, + + true); +} + +} // namespace +} // namespace toco