Skip to content

Commit

Permalink
[Inference] rewrite identity_op_clean_pass (PaddlePaddle#55240)
Browse files Browse the repository at this point in the history
* rewrite identity_op_clean_pass

* fix

* adjust identity_op_clean_pass order in gpu passes

* fix ut
  • Loading branch information
yuanlehome authored and wz1qqx committed Jul 31, 2023
1 parent 9b92c6c commit fb4ec7f
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 84 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,10 @@ cc_test(
test_delete_assign_op_pass_cc
SRCS delete_assign_op_pass_test.cc
DEPS delete_assign_op_pass)
cc_test(
test_identity_op_clean_pass_cc
SRCS identity_op_clean_pass_test.cc
DEPS identity_op_clean_pass)
cc_test(
test_delete_dropout_pass_cc
SRCS delete_dropout_op_pass_test.cc
Expand Down
151 changes: 76 additions & 75 deletions paddle/fluid/framework/ir/identity_op_clean_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,104 +21,105 @@ namespace paddle {
namespace framework {
namespace ir {

class Graph;
namespace patterns {

void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("identity_scale_op_clean", graph);

// pre_op -> useless_op_in -> useless_op -> useless_op_out
// ->
// pre_op -> useless_op_out
GraphPatternDetector detector;
auto useless_op_in =
detector.mutable_pattern()
->NewNode("useless_op_in")
->assert_has_n_outputs(1)
->assert_var_not_persistable()
->assert_more([](Node* x) {
for (auto* op : x->inputs) {
auto op_type = op->Op()->Type();
if (op_type == "conditional_block" || op_type == "while") {
return false;
}
}
return true;
});
// pre_op -> useless_op_in -> useless_op -> useless_op_out
// ->
// pre_op -> useless_op_out
struct FindUselessOpPattern : public PatternBase {
FindUselessOpPattern(PDPattern* pattern, const std::string& name_scope);

// declare operator node's name
PATTERN_DECL_NODE(useless_op_in);
PATTERN_DECL_NODE(useless_op);
PATTERN_DECL_NODE(useless_op_out);
};

FindUselessOpPattern::FindUselessOpPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* useless_op_in = pattern->NewNode(useless_op_in_repr())
->assert_is_var()
->assert_var_not_persistable()
->assert_has_n_outputs(1)
->assert_more([](Node* x) {
for (auto* op : x->inputs) {
CHECK_EQ(op->IsOp(), true);
const auto& op_type = op->Op()->Type();
if (op_type == "conditional_block" ||
op_type == "while" || op_type == "feed") {
return false;
}
}
return true;
});

// This useless_op must have only one input and one output!
auto useless_op =
detector.mutable_pattern()
->NewNode("useless_op")
auto* useless_op =
pattern->NewNode(useless_op_repr())
->assert_is_op()
->assert_has_n_inputs(1)
->assert_has_n_outputs(1)
->assert_more([](Node* x) {
if (!x->IsOp()) {
return false;
}
if (x->Op()->Type() == "scale") {
const auto& op_type = x->Op()->Type();
if (op_type == "scale") {
auto scale = x->Op()->GetAttrIfExists<float>("scale");
auto bias = x->Op()->GetAttrIfExists<float>("bias");
if (bias == 0 && scale == 1) {
return true;
}
}
if (x->Op()->Type() == "cast") {
return bias == 0.f && scale == 1.f;
} else if (op_type == "cast") {
auto in_dtype = x->Op()->GetAttrIfExists<int>("in_dtype");
auto out_dtype = x->Op()->GetAttrIfExists<int>("out_dtype");
if (in_dtype == out_dtype) {
return true;
}
}
if (x->Op()->Type() == "c_identity") {
return in_dtype == out_dtype;
} else if (op_type == "c_identity") {
return true;
} else if (op_type == "assign") {
const auto& in_name = x->Op()->Input("X")[0];
const auto& out_name = x->Op()->Output("Out")[0];
return in_name == out_name;
} else if (op_type == "concat") {
return x->Op()->Input("X").size() == 1;
}
// you can add more cases here.
return false;
});
auto useless_op_out = detector.mutable_pattern()->NewNode("useless_op_out");

auto* useless_op_out =
pattern->NewNode(useless_op_out_repr())->assert_is_var();

useless_op->LinksFrom({useless_op_in}).LinksTo({useless_op_out});
}

} // namespace patterns

void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const {
Init(name_scope_, graph);

int found_subgraph_count = 0;
GraphPatternDetector gpd;
patterns::FindUselessOpPattern pattern(gpd.mutable_pattern(), name_scope_);

int found_count = 0;
GraphPatternDetector::handle_t handler =
[&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
Node* useless_op_var = subgraph.at(useless_op);
Node* useless_op_in_var = subgraph.at(useless_op_in);
Node* useless_op_out_var = subgraph.at(useless_op_out);
const std::string useless_op_in_name = useless_op_in_var->Name();
const std::string useless_op_out_name = useless_op_out_var->Name();
// Remove links in graph
GraphSafeRemoveNodes(graph, {useless_op_in_var, useless_op_var});
// Modify pre_op_desc
// Link pre_op directly to scale_out
for (auto& node : graph->Nodes()) {
if (node->IsOp()) {
auto* op_desc = node->Op();
auto out_vars_map = op_desc->Outputs();
for (auto out_var_map : out_vars_map) {
auto names = out_var_map.second;
bool reset = false;
for (size_t i = 0; i < names.size(); i++) {
if (names[i] == useless_op_in_name) {
reset = true;
names[i] = useless_op_out_name;
break;
}
}
if (reset) {
op_desc->SetOutput(out_var_map.first, names);
op_desc->Flush();
IR_NODE_LINK_TO(node, useless_op_out_var);
break;
}
}
}
GET_IR_NODE_FROM_SUBGRAPH(useless_op_in, useless_op_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(useless_op, useless_op, pattern);
GET_IR_NODE_FROM_SUBGRAPH(useless_op_out, useless_op_out, pattern);
CHECK_EQ(useless_op_in->IsVar(), true);
CHECK_EQ(useless_op_out->IsVar(), true);
CHECK_EQ(useless_op->IsOp(), true);

for (auto* prev_op : useless_op_in->inputs) {
CHECK_EQ(prev_op->IsOp(), true);
prev_op->Op()->RenameOutput(useless_op_in->Var()->Name(),
useless_op_out->Var()->Name());
IR_NODE_LINK_TO(prev_op, useless_op_out);
}
found_subgraph_count++;

GraphSafeRemoveNodes(graph, {useless_op_in, useless_op});
found_count++;
};

detector(graph, handler);
AddStatis(found_subgraph_count);
gpd(graph, handler);
AddStatis(found_count);
}

} // namespace ir
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/identity_op_clean_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class IdentityOpCleanPass : public FusePassBase {
void ApplyImpl(ir::Graph* graph) const override;

private:
virtual ~IdentityOpCleanPass() = default;
const std::string name_scope_{"identity_op_clean_pass"};
};

} // namespace ir
Expand Down
120 changes: 120 additions & 0 deletions paddle/fluid/framework/ir/identity_op_clean_pass_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"

namespace paddle {
namespace framework {
namespace ir {

TEST(identity_op_clean_pass, assign) {
ProgramDesc program;
auto* x_var = program.MutableBlock(0)->Var("assign_x");
auto* out_var = program.MutableBlock(0)->Var("assign_out");
out_var->SetName(x_var->Name());
OpDesc* assign_op = program.MutableBlock(0)->AppendOp();
assign_op->SetType("assign");
assign_op->SetInput("X", {x_var->Name()});
assign_op->SetOutput("Out", {out_var->Name()});

std::unique_ptr<Graph> graph(new Graph(program));
auto pass = PassRegistry::Instance().Get("identity_op_clean_pass");
graph.reset(pass->Apply(graph.release()));
int assign_num = GetNumOpNodes(graph, "assign");
PADDLE_ENFORCE_EQ(
assign_num,
0,
platform::errors::PreconditionNotMet(
"graph should have 0 assign after identity_op_clean_pass, "
"but actually has %d.",
assign_num));
}

TEST(identity_op_clean_pass, scale) {
ProgramDesc program;
auto* x_var = program.MutableBlock(0)->Var("scale_x");
auto* out_var = program.MutableBlock(0)->Var("scale_out");
OpDesc* scale_op = program.MutableBlock(0)->AppendOp();
scale_op->SetType("scale");
scale_op->SetInput("X", {x_var->Name()});
scale_op->SetOutput("Out", {out_var->Name()});
scale_op->SetAttr("scale", 1.f);
scale_op->SetAttr("bias", 0.f);

std::unique_ptr<Graph> graph(new Graph(program));
auto pass = PassRegistry::Instance().Get("identity_op_clean_pass");
graph.reset(pass->Apply(graph.release()));
int scale_num = GetNumOpNodes(graph, "scale");
PADDLE_ENFORCE_EQ(
scale_num,
0,
platform::errors::PreconditionNotMet(
"graph should have 0 scale op after identity_op_clean_pass, "
"but actually has %d.",
scale_num));
}

TEST(identity_op_clean_pass, cast) {
ProgramDesc program;
auto* x_var = program.MutableBlock(0)->Var("cast_x");
auto* out_var = program.MutableBlock(0)->Var("cast_out");
OpDesc* cast_op = program.MutableBlock(0)->AppendOp();
cast_op->SetType("cast");
cast_op->SetInput("X", {x_var->Name()});
cast_op->SetOutput("Out", {out_var->Name()});
cast_op->SetAttr("in_dtype", 5);
cast_op->SetAttr("out_dtype", 5);

std::unique_ptr<Graph> graph(new Graph(program));
auto pass = PassRegistry::Instance().Get("identity_op_clean_pass");
graph.reset(pass->Apply(graph.release()));
int cast_num = GetNumOpNodes(graph, "cast");
PADDLE_ENFORCE_EQ(
cast_num,
0,
platform::errors::PreconditionNotMet(
"graph should have 0 cast after identity_op_clean_pass, "
"but actually has %d.",
cast_num));
}

TEST(identity_op_clean_pass, concat) {
ProgramDesc program;
auto* x_var = program.MutableBlock(0)->Var("concat_x");
auto* out_var = program.MutableBlock(0)->Var("concat_out");
OpDesc* concat_op = program.MutableBlock(0)->AppendOp();
concat_op->SetType("concat");
concat_op->SetInput("X", {x_var->Name()});
concat_op->SetOutput("Out", {out_var->Name()});

std::unique_ptr<Graph> graph(new Graph(program));
auto pass = PassRegistry::Instance().Get("identity_op_clean_pass");
graph.reset(pass->Apply(graph.release()));
int concat_num = GetNumOpNodes(graph, "concat");
PADDLE_ENFORCE_EQ(
concat_num,
0,
platform::errors::PreconditionNotMet(
"graph should have 0 concat after identity_op_clean_pass, "
"but actually has %d.",
concat_num));
}

} // namespace ir
} // namespace framework
} // namespace paddle

USE_PASS(identity_op_clean_pass);
2 changes: 1 addition & 1 deletion paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ const std::vector<std::string> kCINNCompilerPasses{
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({
"map_op_to_another_pass", //
"identity_op_clean_pass", //
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"delete_quant_dequant_linear_op_pass", //
Expand Down Expand Up @@ -262,6 +261,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", //
#endif //
"transpose_flatten_concat_fuse_pass", //
"identity_op_clean_pass", //
"conv2d_fusion_layout_transfer_pass", //
"transfer_layout_elim_pass",
"auto_mixed_precision_pass", //
Expand Down
Loading

0 comments on commit fb4ec7f

Please sign in to comment.