Skip to content

Commit

Permalink
Implement a pass detect fusion group of elementwise op (PaddlePaddle#…
Browse files Browse the repository at this point in the history
…19884)

* Add fusion_group_pass and elementwise pattern.

* Rewrite the detector of elementwise group.
test=develop

* Add a comment in codegen.

* Add more unittest cases.
test=develop

* Move code_generator related code to fusion_group directory.

* Correct the including path.

* Add the definition of SubGraph and finish the insert of fusion_group op in pass.

* Insert graph_vis_pass in tester to visualize the graph for debug.
  • Loading branch information
Xreki committed Oct 29, 2019
1 parent cd24b6e commit 29b5cff
Show file tree
Hide file tree
Showing 16 changed files with 695 additions and 38 deletions.
8 changes: 1 addition & 7 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
add_subdirectory(fuse_optimizer_ops_pass)
add_subdirectory(memory_optimize_pass)
add_subdirectory(multi_devices_graph_pass)
add_subdirectory(fusion_group)

# Usage: pass_library(target inference) will append to paddle_inference_pass.h
unset(INFER_IR_PASSES CACHE) # clear the global variable
Expand All @@ -30,8 +31,6 @@ function(pass_library TARGET DEST)
endif()
endfunction()

cc_library(codegen SRCS codegen.cc DEPS codegen_helper)
cc_library(codegen_helper SRCS codegen_helper.cc DEPS graph node graph_helper)
cc_library(node SRCS node.cc DEPS proto_desc)
cc_library(graph SRCS graph.cc DEPS node pretty_log)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
Expand Down Expand Up @@ -111,11 +110,6 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")

cc_library(pass_builder SRCS pass_builder.cc DEPS pass)

if(NOT APPLE AND NOT WIN32)
if(WITH_GPU)
cc_test(codegen_test SRCS codegen_test.cc DEPS codegen_helper codegen device_code lod_tensor)
endif()
endif()
cc_test(node_test SRCS node_test.cc DEPS node)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/framework/ir/fusion_group/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
cc_library(code_generator SRCS code_generator.cc code_generator_helper.cc DEPS graph)
if(NOT APPLE AND NOT WIN32)
if(WITH_GPU)
cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor)
endif()
endif()

cc_library(fusion_group_pass
SRCS fusion_group_pass.cc elementwise_group_detector.cc
DEPS graph_pattern_detector pass)
cc_test(test_fusion_group_pass SRCS fusion_group_pass_tester.cc DEPS fusion_group_pass graph_viz_pass)
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/codegen.h"

#include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include <set>
#include <sstream>
#include "paddle/fluid/framework/ir/codegen_helper.h"
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"

namespace paddle {
namespace framework {
namespace ir {
Expand All @@ -23,9 +25,8 @@ CodeGenerator::CodeGenerator(CodeTemplate code_template) {
code_template_ = code_template;
}

// in order to get the right result of expression, we need to calculate, we
// store the expression as
// suffix Expressions using vector
// In order to get the right result of expression, we need to calculate and
// store the expression as suffix Expressions using vector.
std::string CodeGenerator::GenerateCode(TemplateVariable template_var) {
auto cuda_kernel = kernel_function + code_template_.Format(template_var);
return cuda_kernel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/codegen_helper.h"
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"

namespace paddle {
namespace framework {
Expand All @@ -23,8 +24,11 @@ namespace ir {
class CodeGenerator {
public:
explicit CodeGenerator(CodeTemplate code_template);

std::string GenerateCode(TemplateVariable template_var);
// TODO(wangchao66) std::string GenerateCode(const Graph& graph)

// TODO(wangchao): add a more general interface
// std::string Generate(const std::string name, const SubGraph& subgraph);

private:
CodeTemplate code_template_;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/ir/codegen_helper.h"
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
#include <algorithm>
#include <sstream>
#include <string>
#include <vector>

namespace paddle {
namespace framework {
namespace ir {
Expand Down Expand Up @@ -50,6 +52,7 @@ std::string OperationExpression::GetLHSTemplate() {
bool OperationExpression::SupportState() {
return (support_table.find(op_) == support_table.end());
}

// we Traverse the graph and get the group , all input id and output id is
// unique for the node which belong the group
std::string OperationExpression::GetExpression() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <iostream>
Expand Down Expand Up @@ -81,6 +82,7 @@ class TemplateVariable {
private:
std::unordered_map<std::string, std::string> strings_;
};

class CodeTemplate {
public:
CodeTemplate() = default;
Expand Down Expand Up @@ -110,6 +112,7 @@ class CodeTemplate {

return EmitIndents(ret);
}

std::string EmitIndents(std::string str) {
std::string ret = str;
int space_num = 0;
Expand Down Expand Up @@ -147,6 +150,7 @@ static std::string EmitUniqueName(std::vector<OperationExpression> expression) {
}
return ret.str();
}

// we get the parameter list code for the expression information
static std::string EmitDeclarationCode(
std::vector<OperationExpression> expression, std::string type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/codegen.h"

#include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include <gtest/gtest.h>
#include <cmath>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/codegen_helper.h"
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/device_code.h"
#include "paddle/fluid/platform/init.h"
#ifdef PADDLE_WITH_CUDA

TEST(codegen, cuda) {
#ifdef PADDLE_WITH_CUDA
TEST(code_generator, cuda) {
std::vector<int> mul_input{1, 2};
std::vector<int> add_input{3, 4};
std::vector<int> sub_input{5, 6};
Expand Down
161 changes: 161 additions & 0 deletions paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"

namespace paddle {
namespace framework {
namespace ir {
namespace fusion_group {

static std::unordered_set<std::string> binary_op_types = {
"elementwise_add", "elementwise_sub", "elementwise_mul",
"elementwise_div", "elementwise_min", "elementwise_max"};

static std::unordered_set<std::string> unary_op_types = {"relu", "sigmoid",
"tanh"};

static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
Node* n) {
if (n && n->IsOp() && n->Op() && n->outputs.size() > 0U) {
auto iter = op_types.find(n->Op()->Type());
if (iter != op_types.end()) {
return true;
}
}
return false;
}

static bool IsBinaryOp(Node* n) {
if (IsSpecifiedOp(binary_op_types, n) && n->inputs.size() == 2U) {
auto* x = n->inputs[0];
auto* y = n->inputs[1];

std::vector<int64_t> x_shape;
std::vector<int64_t> y_shape;
if (x && x->IsVar() && x->Var()) {
x_shape = x->Var()->GetShape();
}
if (y && y->IsVar() && y->Var()) {
y_shape = y->Var()->GetShape();
}
if (x_shape.size() == 0U || x_shape.size() != y_shape.size()) {
return false;
}
for (size_t i = 0; i < x_shape.size(); ++i) {
if (x_shape[i] != y_shape[i]) {
return false;
}
}
return true;
}
return false;
}

static bool IsUnaryOp(Node* n) { return IsSpecifiedOp(unary_op_types, n); }

bool ElementwiseGroupDetector::IsElementwiseOp(Node* n) {
return IsBinaryOp(n) || IsUnaryOp(n);
}

bool ElementwiseGroupDetector::IsInputOfElementwiseOp(Node* n,
std::string name) {
if (n && n->IsVar() && n->Var()) {
for (auto* op : n->outputs) {
if (IsElementwiseOp(op)) {
if (name.empty()) {
return true;
} else if (IsNthInput(n, op, name, 0)) {
return true;
}
}
}
}
return false;
}

bool ElementwiseGroupDetector::IsOutputOfElementwiseOp(Node* n) {
if (n && n->IsVar() && n->Var()) {
for (auto* op : n->inputs) {
if (IsElementwiseOp(op)) {
return true;
}
}
}
return false;
}

void ElementwiseGroupDetector::Insert(Node* n) {
if (subgraph_.nodes_set.find(n) == subgraph_.nodes_set.end()) {
VLOG(5) << "Insert " << n->Name() << " to subgraph " << name_;
subgraph_.nodes_set.insert(n);
}
}

int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) {
std::unordered_set<Node*> except_nodes_set;
for (size_t i = 0; i < except_nodes.size(); ++i) {
except_nodes_set.insert(except_nodes[i]);
}

int num_operations = 0;
if (IsElementwiseOp(n)) {
Insert(n);
num_operations += 1;
for (auto* var : n->inputs) {
Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n});
}
}
for (auto* var : n->outputs) {
Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n});
}
}
} else if (n && n->IsVar() && n->Var()) {
for (auto* op : n->inputs) {
if (IsElementwiseOp(op) &&
except_nodes_set.find(op) == except_nodes_set.end()) {
num_operations += Search(op, {n});
}
}
for (auto* op : n->outputs) {
if (IsElementwiseOp(op) &&
except_nodes_set.find(op) == except_nodes_set.end()) {
num_operations += Search(op, {n});
}
}
}
return num_operations;
}

int ElementwiseGroupDetector::operator()(Node* n) {
if (!IsOutputOfElementwiseOp(n) && IsInputOfElementwiseOp(n, "X")) {
name_ = n->Name();
Insert(n);
num_operations_ = Search(n, n->inputs);
VLOG(4) << "Detect elementwise subgraph begin with " << name_ << ", "
<< num_operations_ << " operations, " << GetSubgraph().GetNumNodes()
<< " nodes";
}
return num_operations_;
}

} // namespace fusion_group
} // namespace ir
} // namespace framework
} // namespace paddle
Loading

0 comments on commit 29b5cff

Please sign in to comment.