From 29b5cffafded25f9bac74dea588143f4755e5912 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Tue, 29 Oct 2019 14:00:36 +0800 Subject: [PATCH] Implement a pass detect fusion group of elementwise op (#19884) * Add fusion_group_pass and elementwise pattern. * Rewrite the detector of elementwise group. test=develop * Add a comment in codegen. * Add more unittest cases. test=develop * Move code_generator related code to fusion_group directory. * Correct the including path. * Add the definition of SubGraph and finish the insert of fusion_group op in pass. * Insert graph_vis_pass in tester to visualize the graph for debug. --- paddle/fluid/framework/ir/CMakeLists.txt | 8 +- .../framework/ir/fusion_group/CMakeLists.txt | 11 ++ .../code_generator.cc} | 11 +- .../code_generator.h} | 8 +- .../code_generator_helper.cc} | 31 ++-- .../code_generator_helper.h} | 4 + .../code_generator_tester.cc} | 9 +- .../elementwise_group_detector.cc | 161 ++++++++++++++++++ .../fusion_group/elementwise_group_detector.h | 51 ++++++ .../ir/fusion_group/fusion_group_pass.cc | 115 +++++++++++++ .../ir/fusion_group/fusion_group_pass.h | 40 +++++ .../fusion_group/fusion_group_pass_tester.cc | 156 +++++++++++++++++ .../framework/ir/fusion_group/subgraph.h | 99 +++++++++++ .../ir/multihead_matmul_fuse_pass_tester.cc | 0 .../fluid/framework/ir/pass_tester_helper.h | 25 ++- .../ir/simplify_with_basic_ops_pass_tester.cc | 4 +- 16 files changed, 695 insertions(+), 38 deletions(-) create mode 100644 paddle/fluid/framework/ir/fusion_group/CMakeLists.txt rename paddle/fluid/framework/ir/{codegen.cc => fusion_group/code_generator.cc} (78%) rename paddle/fluid/framework/ir/{codegen.h => fusion_group/code_generator.h} (82%) rename paddle/fluid/framework/ir/{codegen_helper.cc => fusion_group/code_generator_helper.cc} (67%) rename paddle/fluid/framework/ir/{codegen_helper.h => fusion_group/code_generator_helper.h} (99%) rename paddle/fluid/framework/ir/{codegen_test.cc => fusion_group/code_generator_tester.cc} (96%) create mode 100644 paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc create mode 100644 paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h create mode 100644 paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc create mode 100644 paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h create mode 100644 paddle/fluid/framework/ir/fusion_group/fusion_group_pass_tester.cc create mode 100644 paddle/fluid/framework/ir/fusion_group/subgraph.h mode change 100755 => 100644 paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index cd1bd80b35246..3859eb3e7067d 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -6,6 +6,7 @@ file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n") add_subdirectory(fuse_optimizer_ops_pass) add_subdirectory(memory_optimize_pass) add_subdirectory(multi_devices_graph_pass) +add_subdirectory(fusion_group) # Usage: pass_library(target inference) will append to paddle_inference_pass.h unset(INFER_IR_PASSES CACHE) # clear the global variable @@ -30,8 +31,6 @@ function(pass_library TARGET DEST) endif() endfunction() -cc_library(codegen SRCS codegen.cc DEPS codegen_helper) -cc_library(codegen_helper SRCS codegen_helper.cc DEPS graph node graph_helper) cc_library(node SRCS node.cc DEPS proto_desc) cc_library(graph SRCS graph.cc DEPS node pretty_log) cc_library(graph_helper SRCS graph_helper.cc DEPS graph) @@ -111,11 +110,6 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") cc_library(pass_builder SRCS pass_builder.cc DEPS pass) -if(NOT APPLE AND NOT WIN32) - if(WITH_GPU) - cc_test(codegen_test SRCS codegen_test.cc DEPS codegen_helper codegen device_code lod_tensor) - endif() -endif() cc_test(node_test SRCS node_test.cc DEPS node) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) diff --git a/paddle/fluid/framework/ir/fusion_group/CMakeLists.txt b/paddle/fluid/framework/ir/fusion_group/CMakeLists.txt new file mode 100644 index 0000000000000..8d30c3efccb8b --- /dev/null +++ b/paddle/fluid/framework/ir/fusion_group/CMakeLists.txt @@ -0,0 +1,11 @@ +cc_library(code_generator SRCS code_generator.cc code_generator_helper.cc DEPS graph) +if(NOT APPLE AND NOT WIN32) + if(WITH_GPU) + cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor) + endif() +endif() + +cc_library(fusion_group_pass + SRCS fusion_group_pass.cc elementwise_group_detector.cc + DEPS graph_pattern_detector pass) +cc_test(test_fusion_group_pass SRCS fusion_group_pass_tester.cc DEPS fusion_group_pass graph_viz_pass) diff --git a/paddle/fluid/framework/ir/codegen.cc b/paddle/fluid/framework/ir/fusion_group/code_generator.cc similarity index 78% rename from paddle/fluid/framework/ir/codegen.cc rename to paddle/fluid/framework/ir/fusion_group/code_generator.cc index 60a5ff224a943..c477836607b2e 100644 --- a/paddle/fluid/framework/ir/codegen.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.cc @@ -11,10 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/ir/codegen.h" + +#include "paddle/fluid/framework/ir/fusion_group/code_generator.h" #include #include -#include "paddle/fluid/framework/ir/codegen_helper.h" +#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h" + namespace paddle { namespace framework { namespace ir { @@ -23,9 +25,8 @@ CodeGenerator::CodeGenerator(CodeTemplate code_template) { code_template_ = code_template; } -// in order to get the right result of expression, we need to calculate, we -// store the expression as -// suffix Expressions using vector +// In order to get the right result of expression, we need to calculate and +// store the expression as suffix Expressions using vector. std::string CodeGenerator::GenerateCode(TemplateVariable template_var) { auto cuda_kernel = kernel_function + code_template_.Format(template_var); return cuda_kernel; diff --git a/paddle/fluid/framework/ir/codegen.h b/paddle/fluid/framework/ir/fusion_group/code_generator.h similarity index 82% rename from paddle/fluid/framework/ir/codegen.h rename to paddle/fluid/framework/ir/fusion_group/code_generator.h index 2cf61ada48e72..0e208d445f3ee 100644 --- a/paddle/fluid/framework/ir/codegen.h +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.h @@ -11,10 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #pragma once #include #include -#include "paddle/fluid/framework/ir/codegen_helper.h" +#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h" namespace paddle { namespace framework { @@ -23,8 +24,11 @@ namespace ir { class CodeGenerator { public: explicit CodeGenerator(CodeTemplate code_template); + std::string GenerateCode(TemplateVariable template_var); - // TODO(wangchao66) std::string GenerateCode(const Graph& graph) + + // TODO(wangchao): add a more general interface + // std::string Generate(const std::string name, const SubGraph& subgraph); private: CodeTemplate code_template_; diff --git a/paddle/fluid/framework/ir/codegen_helper.cc b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc similarity index 67% rename from paddle/fluid/framework/ir/codegen_helper.cc rename to paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc index 5e0b14253c1d1..19dbde16b8994 100644 --- a/paddle/fluid/framework/ir/codegen_helper.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc @@ -1,21 +1,23 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -#include "paddle/fluid/framework/ir/codegen_helper.h" +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h" #include #include #include #include + namespace paddle { namespace framework { namespace ir { @@ -50,6 +52,7 @@ std::string OperationExpression::GetLHSTemplate() { bool OperationExpression::SupportState() { return (support_table.find(op_) == support_table.end()); } + // we Traverse the graph and get the group , all input id and output id is // unique for the node which belong the group std::string OperationExpression::GetExpression() { diff --git a/paddle/fluid/framework/ir/codegen_helper.h b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h similarity index 99% rename from paddle/fluid/framework/ir/codegen_helper.h rename to paddle/fluid/framework/ir/fusion_group/code_generator_helper.h index fbc59c4349042..5594d59fac824 100644 --- a/paddle/fluid/framework/ir/codegen_helper.h +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #pragma once #include @@ -81,6 +82,7 @@ class TemplateVariable { private: std::unordered_map strings_; }; + class CodeTemplate { public: CodeTemplate() = default; @@ -110,6 +112,7 @@ class CodeTemplate { return EmitIndents(ret); } + std::string EmitIndents(std::string str) { std::string ret = str; int space_num = 0; @@ -147,6 +150,7 @@ static std::string EmitUniqueName(std::vector expression) { } return ret.str(); } + // we get the parameter list code for the expression information static std::string EmitDeclarationCode( std::vector expression, std::string type) { diff --git a/paddle/fluid/framework/ir/codegen_test.cc b/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc similarity index 96% rename from paddle/fluid/framework/ir/codegen_test.cc rename to paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc index 7877b21848401..f68661675a0dd 100644 --- a/paddle/fluid/framework/ir/codegen_test.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc @@ -11,19 +11,20 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/ir/codegen.h" + +#include "paddle/fluid/framework/ir/fusion_group/code_generator.h" #include #include #include #include -#include "paddle/fluid/framework/ir/codegen_helper.h" +#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/operators/math.h" #include "paddle/fluid/platform/device_code.h" #include "paddle/fluid/platform/init.h" -#ifdef PADDLE_WITH_CUDA -TEST(codegen, cuda) { +#ifdef PADDLE_WITH_CUDA +TEST(code_generator, cuda) { std::vector mul_input{1, 2}; std::vector add_input{3, 4}; std::vector sub_input{5, 6}; diff --git a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc new file mode 100644 index 0000000000000..68063c34d1d9c --- /dev/null +++ b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc @@ -0,0 +1,161 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace fusion_group { + +static std::unordered_set binary_op_types = { + "elementwise_add", "elementwise_sub", "elementwise_mul", + "elementwise_div", "elementwise_min", "elementwise_max"}; + +static std::unordered_set unary_op_types = {"relu", "sigmoid", + "tanh"}; + +static bool IsSpecifiedOp(const std::unordered_set& op_types, + Node* n) { + if (n && n->IsOp() && n->Op() && n->outputs.size() > 0U) { + auto iter = op_types.find(n->Op()->Type()); + if (iter != op_types.end()) { + return true; + } + } + return false; +} + +static bool IsBinaryOp(Node* n) { + if (IsSpecifiedOp(binary_op_types, n) && n->inputs.size() == 2U) { + auto* x = n->inputs[0]; + auto* y = n->inputs[1]; + + std::vector x_shape; + std::vector y_shape; + if (x && x->IsVar() && x->Var()) { + x_shape = x->Var()->GetShape(); + } + if (y && y->IsVar() && y->Var()) { + y_shape = y->Var()->GetShape(); + } + if (x_shape.size() == 0U || x_shape.size() != y_shape.size()) { + return false; + } + for (size_t i = 0; i < x_shape.size(); ++i) { + if (x_shape[i] != y_shape[i]) { + return false; + } + } + return true; + } + return false; +} + +static bool IsUnaryOp(Node* n) { return IsSpecifiedOp(unary_op_types, n); } + +bool ElementwiseGroupDetector::IsElementwiseOp(Node* n) { + return IsBinaryOp(n) || IsUnaryOp(n); +} + +bool ElementwiseGroupDetector::IsInputOfElementwiseOp(Node* n, + std::string name) { + if (n && n->IsVar() && n->Var()) { + for (auto* op : n->outputs) { + if (IsElementwiseOp(op)) { + if (name.empty()) { + return true; + } else if (IsNthInput(n, op, name, 0)) { + return true; + } + } + } + } + return false; +} + +bool ElementwiseGroupDetector::IsOutputOfElementwiseOp(Node* n) { + if (n && n->IsVar() && n->Var()) { + for (auto* op : n->inputs) { + if (IsElementwiseOp(op)) { + return true; + } + } + } + return false; +} + +void ElementwiseGroupDetector::Insert(Node* n) { + if (subgraph_.nodes_set.find(n) == subgraph_.nodes_set.end()) { + VLOG(5) << "Insert " << n->Name() << " to subgraph " << name_; + subgraph_.nodes_set.insert(n); + } +} + +int ElementwiseGroupDetector::Search(Node* n, std::vector except_nodes) { + std::unordered_set except_nodes_set; + for (size_t i = 0; i < except_nodes.size(); ++i) { + except_nodes_set.insert(except_nodes[i]); + } + + int num_operations = 0; + if (IsElementwiseOp(n)) { + Insert(n); + num_operations += 1; + for (auto* var : n->inputs) { + Insert(var); + if (except_nodes_set.find(var) == except_nodes_set.end()) { + num_operations += Search(var, {n}); + } + } + for (auto* var : n->outputs) { + Insert(var); + if (except_nodes_set.find(var) == except_nodes_set.end()) { + num_operations += Search(var, {n}); + } + } + } else if (n && n->IsVar() && n->Var()) { + for (auto* op : n->inputs) { + if (IsElementwiseOp(op) && + except_nodes_set.find(op) == except_nodes_set.end()) { + num_operations += Search(op, {n}); + } + } + for (auto* op : n->outputs) { + if (IsElementwiseOp(op) && + except_nodes_set.find(op) == except_nodes_set.end()) { + num_operations += Search(op, {n}); + } + } + } + return num_operations; +} + +int ElementwiseGroupDetector::operator()(Node* n) { + if (!IsOutputOfElementwiseOp(n) && IsInputOfElementwiseOp(n, "X")) { + name_ = n->Name(); + Insert(n); + num_operations_ = Search(n, n->inputs); + VLOG(4) << "Detect elementwise subgraph begin with " << name_ << ", " + << num_operations_ << " operations, " << GetSubgraph().GetNumNodes() + << " nodes"; + } + return num_operations_; +} + +} // namespace fusion_group +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h new file mode 100644 index 0000000000000..ea1336819d451 --- /dev/null +++ b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/ir/fusion_group/subgraph.h" +#include "paddle/fluid/framework/ir/node.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace fusion_group { + +struct ElementwiseGroupDetector { + public: + int operator()(Node* n); + + SubGraph GetSubgraph() const { return subgraph_; } + + private: + bool IsElementwiseOp(Node* n); + bool IsInputOfElementwiseOp(Node* n, std::string name = ""); + bool IsOutputOfElementwiseOp(Node* n); + + void Insert(Node* n); + int Search(Node* n, std::vector except_nodes = {}); + + private: + std::string name_; + int num_operations_{0}; + SubGraph subgraph_; +}; + +} // namespace fusion_group +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc new file mode 100644 index 0000000000000..9f7dd15f62d81 --- /dev/null +++ b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h" +#include +#include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +void FusionGroupPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL(graph); + + int num_elementwise_groups = DetectFusionGroup(graph, 0); + LOG(INFO) << "Detect " << num_elementwise_groups + << " elementwise fusion groups."; +} + +int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const { + std::vector subgraphs; + std::unordered_set all_nodes = graph->Nodes(); + for (Node* n : all_nodes) { + bool is_found = false; + for (auto& subgraph : subgraphs) { + if (subgraph.nodes_set.find(n) != subgraph.nodes_set.end()) { + is_found = true; + break; + } + } + if (is_found) { + continue; + } + + fusion_group::SubGraph subgraph; + if (type == 0) { + fusion_group::ElementwiseGroupDetector detector; + int num_operations = detector(n); + if (num_operations >= 2) { + subgraph = detector.GetSubgraph(); + } + } + + if (!subgraph.IsEmpty()) { + subgraphs.push_back(subgraph); + } + } + + // TODO(liuyiqun): check whether there are intersection between subgraphs + for (size_t i = 0; i < subgraphs.size(); ++i) { + InsertFusionGroupOp(graph, subgraphs[i]); + } + return subgraphs.size(); +} + +void FusionGroupPass::InsertFusionGroupOp( + Graph* graph, const fusion_group::SubGraph& subgraph) const { + std::vector input_vars_of_subgraph = subgraph.GetInputVarNodes(); + std::vector output_vars_of_subgraph = subgraph.GetOutputVarNodes(); + std::unordered_set external_nodes; + + OpDesc op_desc; + op_desc.SetType("fusion_group"); + + std::vector input_names; + for (auto* n : input_vars_of_subgraph) { + input_names.push_back(n->Name()); + external_nodes.insert(n); + } + op_desc.SetInput("Xs", input_names); + + std::vector output_names; + for (auto* n : output_vars_of_subgraph) { + output_names.push_back(n->Name()); + external_nodes.insert(n); + } + op_desc.SetOutput("Outs", output_names); + op_desc.SetAttr("type", subgraph.type); + op_desc.SetAttr("func_name", subgraph.func_name); + + auto fusion_group_node = graph->CreateOpNode(&op_desc); + for (auto* in : input_vars_of_subgraph) { + IR_NODE_LINK_TO(in, fusion_group_node); + } + for (auto* out : output_vars_of_subgraph) { + IR_NODE_LINK_TO(fusion_group_node, out); + } + + std::unordered_set internal_nodes; + for (auto* n : subgraph.nodes_set) { + if (external_nodes.find(n) == external_nodes.end()) { + internal_nodes.insert(n); + } + } + GraphSafeRemoveNodes(graph, internal_nodes); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fusion_group_pass, paddle::framework::ir::FusionGroupPass); diff --git a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h new file mode 100644 index 0000000000000..c61db8f9ea0a0 --- /dev/null +++ b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/ir/fusion_group/subgraph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class FusionGroupPass : public Pass { + protected: + void ApplyImpl(Graph* graph) const override; + + private: + int DetectFusionGroup(Graph* graph, int type = 0) const; + void InsertFusionGroupOp(Graph* graph, + const fusion_group::SubGraph& subgraph) const; + + const std::string name_scope_{"fusion_group"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass_tester.cc b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass_tester.cc new file mode 100644 index 0000000000000..ac951fe1080e8 --- /dev/null +++ b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass_tester.cc @@ -0,0 +1,156 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h" + +#include +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(FusionGroupPass, elementwise_list) { + // inputs operator output + // -------------------------------------------------------- + // (x, y) mul -> tmp_0 + // (tmp_0, z) elementwise_add -> tmp_1 + // tmp_1 relu -> tmp_2 + // (tmp_2, w) elementwise_add -> tmp_3 + // + // Expression: tmp_3 = relu(mul(x, y) + z) + w + Layers layers; + auto* x = layers.data("x", {16, 16}); + auto* y = layers.data("y", {16, 32}); + auto* tmp_0 = layers.mul(x, y); + tmp_0->SetShape({16, 32}); + auto* z = layers.data("z", {16, 32}); + auto* tmp_1 = layers.elementwise_add(tmp_0, z); + auto* tmp_2 = layers.relu(tmp_1); + tmp_2->SetShape({16, 32}); + auto* w = layers.data("w", {16, 32}); + layers.elementwise_add(tmp_2, w); + + std::unique_ptr graph(new Graph(layers.main_program())); + + // The following codes is to insert a graph_viz_pass to transform the graph to + // a .dot file. It is used for debug. + // auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass"); + // graph_viz_pass->Set("graph_viz_path", new + // std::string("00_elementwise_list.dot")); + // graph.reset(graph_viz_pass->Apply(graph.release())); + + auto fusion_group_pass = PassRegistry::Instance().Get("fusion_group_pass"); + VLOG(3) << DebugString(graph); + + graph.reset(fusion_group_pass->Apply(graph.release())); + int num_fusion_group_ops = GetNumOpNodes(graph, "fusion_group"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_fusion_group_ops, 1); + + // The following codes is to insert a graph_viz_pass to transform the graph to + // a .dot file. It is used for debug. + // auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass"); + // graph_viz_pass->Set("graph_viz_path", new + // std::string("01_elementwise_list.fusion_group.dot")); + // graph.reset(graph_viz_pass->Apply(graph.release())); +} + +TEST(FusionGroupPass, elementwise_tree) { + // inputs operator output + // -------------------------------------------------------- + // (x0, y0) mul -> tmp_0 + // x1 sigmoid -> tmp_1 + // (tmp_0, tmp_1) elementwise_mul -> tmp_2 + // x2 sigmoid -> tmp_3 + // x3 tanh -> tmp_4 + // (tmp_3, tmp_4) elementwise_mul -> tmp_5 + // (tmp_2, tmp_5) elementwise_add -> tmp_6 + // x4 tanh -> tmp_7 + // x5 sigmoid -> tmp_8 + // (tmp_7, tmp_8) elementwise_mul -> tmp_9 + // (tmp_6, tmp_9) mul -> tmp_10 + // + // Expression: tmp_6 = mul(x0, y0) * sigmoid(x1) + sigmoid(x2) * tanh(x3) + // tmp_9 = tanh(x4) * sigmoid(x5) + // tmp_10 = mul(tmp_6, tmp_9) + Layers layers; + auto* x0 = layers.data("x0", {16, 16}); + auto* y0 = layers.data("y0", {16, 32}); + auto* tmp_0 = layers.mul(x0, y0); + tmp_0->SetShape({16, 32}); + + auto* x1 = layers.data("x1", {16, 32}); + auto* tmp_1 = layers.sigmoid(x1); + tmp_1->SetShape({16, 32}); + + auto* tmp_2 = layers.elementwise_mul(tmp_0, tmp_1); + tmp_2->SetShape({16, 32}); + + auto* x2 = layers.data("x2", {16, 32}); + auto* tmp_3 = layers.sigmoid(x2); + tmp_3->SetShape({16, 32}); + auto* x3 = layers.data("x3", {16, 32}); + auto* tmp_4 = layers.tanh(x3); + tmp_4->SetShape({16, 32}); + auto* tmp_5 = layers.elementwise_mul(tmp_3, tmp_4); + tmp_5->SetShape({16, 32}); + + auto* tmp_6 = layers.elementwise_add(tmp_2, tmp_5); + tmp_6->SetShape({16, 32}); + + auto* x4 = layers.data("x4", {16, 32}); + auto* tmp_7 = layers.tanh(x4); + tmp_7->SetShape({16, 32}); + auto* x5 = layers.data("x5", {16, 32}); + auto* tmp_8 = layers.sigmoid(x5); + tmp_8->SetShape({16, 32}); + + auto* tmp_9 = layers.elementwise_mul(tmp_7, tmp_8); + tmp_9->SetShape({16, 32}); + layers.mul(tmp_6, tmp_9); + + std::unique_ptr graph(new Graph(layers.main_program())); + + // The following codes is to insert a graph_viz_pass to transform the graph to + // a .dot file. It is used for debug. + // auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass"); + // graph_viz_pass->Set("graph_viz_path", new + // std::string("00_elementwise_tree.dot")); + // graph.reset(graph_viz_pass->Apply(graph.release())); + + auto fusion_group_pass = PassRegistry::Instance().Get("fusion_group_pass"); + LOG(INFO) << DebugString(graph); + + graph.reset(fusion_group_pass->Apply(graph.release())); + int num_fusion_group_ops = GetNumOpNodes(graph, "fusion_group"); + LOG(INFO) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_fusion_group_ops, 2); + + // The following codes is to insert a graph_viz_pass to transform the graph to + // a .dot file. It is used for debug. + // auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass"); + // graph_viz_pass->Set("graph_viz_path", new + // std::string("01_elementwise_tree.fusion_group.dot")); + // graph.reset(graph_viz_pass->Apply(graph.release())); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(fusion_group_pass); +USE_PASS(graph_viz_pass); diff --git a/paddle/fluid/framework/ir/fusion_group/subgraph.h b/paddle/fluid/framework/ir/fusion_group/subgraph.h new file mode 100644 index 0000000000000..a1f06d55410fb --- /dev/null +++ b/paddle/fluid/framework/ir/fusion_group/subgraph.h @@ -0,0 +1,99 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/ir/node.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace fusion_group { + +struct SubGraph { + int type{-1}; + std::string func_name; + std::unordered_set nodes_set; + + bool IsEmpty() { return nodes_set.empty(); } + + size_t GetNumNodes() { return nodes_set.size(); } + + int GetNumOperations() { + int num_operations = 0; + for (auto* n : nodes_set) { + if (n && n->IsOp() && n->Op()) { + num_operations++; + } + } + return num_operations; + } + + std::vector GetInputVarNodes() const { + // The order of input nodes should be consistent with that of the generated + // code. + std::vector input_vars; + for (auto* n : nodes_set) { + if (n && n->IsVar() && n->Var()) { + bool is_found = true; + // When the inputs size is 0, it is also considered the input var of + // subgraph. + if (n->inputs.size() == 0U) { + is_found = false; + } + // Normally a var node has only one input op node. + for (auto* in : n->inputs) { + if (nodes_set.find(in) == nodes_set.end()) { + is_found = false; + } + } + if (!is_found) { + input_vars.push_back(n); + } + } + } + return input_vars; + } + + std::vector GetOutputVarNodes() const { + // The order of output nodes should be consistant with that of the generated + // code. + std::vector output_vars; + for (auto* n : nodes_set) { + if (n && n->IsVar() && n->Var()) { + bool is_found = true; + if (n->outputs.size() == 0U) { + is_found = false; + } + for (auto* out : n->outputs) { + if (nodes_set.find(out) == nodes_set.end()) { + is_found = false; + } + } + if (!is_found) { + output_vars.push_back(n); + } + } + } + return output_vars; + } +}; + +} // namespace fusion_group +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc old mode 100755 new mode 100644 diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 0601b8801af96..2709b6e2facc9 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include #include "paddle/fluid/framework/op_proto_maker.h" @@ -92,6 +93,14 @@ struct Layers { return unary_op("relu", x, out); } + VarDesc* sigmoid(VarDesc* x, VarDesc* out = nullptr) { + return unary_op("sigmoid", x, out); + } + + VarDesc* tanh(VarDesc* x, VarDesc* out = nullptr) { + return unary_op("tanh", x, out); + } + VarDesc* fc(VarDesc* input, VarDesc* w, VarDesc* bias, int in_num_col_dims = 1, std::string activation_type = "") { VarDesc* out = lod_tensor(unique_name()); @@ -119,6 +128,10 @@ struct Layers { return binary_op("elementwise_add", x, y, out); } + VarDesc* elementwise_mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) { + return binary_op("elementwise_mul", x, y, out); + } + VarDesc* dropout(VarDesc* x, float dropout_prob, std::string dropout_implementation) { VarDesc* out = lod_tensor(unique_name()); @@ -399,10 +412,9 @@ static std::string DebugString(Node* node) { return os.str(); } -static std::string DebugString(const std::unique_ptr& graph) { +static std::string DebugString(const std::unordered_set& nodes) { std::ostringstream os; - os << "Graph: {\n"; - for (auto* node : graph->Nodes()) { + for (auto* node : nodes) { if (node->IsOp() && node->Op()) { os << " "; } else if (node->IsVar() && node->Var()) { @@ -410,7 +422,12 @@ static std::string DebugString(const std::unique_ptr& graph) { } os << DebugString(node) << "\n"; } - os << "}\n"; + return os.str(); +} + +static std::string DebugString(const std::unique_ptr& graph) { + std::ostringstream os; + os << "Graph: {\n" << DebugString(graph->Nodes()) << "}\n"; return os.str(); } diff --git a/paddle/fluid/framework/ir/simplify_with_basic_ops_pass_tester.cc b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass_tester.cc index 7fb67df495f1d..324b9c0b7da24 100644 --- a/paddle/fluid/framework/ir/simplify_with_basic_ops_pass_tester.cc +++ b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass_tester.cc @@ -59,12 +59,12 @@ TEST(SimplifyWithBasicOpsPass, dropout) { int num_scale_nodes_after = GetNumOpNodes(graph, "scale"); VLOG(3) << DebugString(graph); - PADDLE_ENFORCE_EQ(num_dropout_nodes_after, 0UL); + PADDLE_ENFORCE_EQ(num_dropout_nodes_after, 0); if (dropout_implementation == "downgrade_in_infer") { PADDLE_ENFORCE_EQ(num_dropout_nodes_before, num_scale_nodes_after - num_scale_nodes_before); } else { - PADDLE_ENFORCE_EQ(num_scale_nodes_after - num_scale_nodes_before, 0UL); + PADDLE_ENFORCE_EQ(num_scale_nodes_after - num_scale_nodes_before, 0); } } }