Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#15 from zyfncg/drr_opt
Browse files Browse the repository at this point in the history
[DRR] support std::vector<int> attribute in DRR
  • Loading branch information
yuanlehome committed Aug 18, 2023
2 parents 2a7c68f + 63383f5 commit 3ddb601
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 63 deletions.
10 changes: 4 additions & 6 deletions paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {

for (size_t i = 0; i < drr_output_tensors.size(); ++i) {
if (!Matched) break;

// check child ops
auto drr_child_ops = drr_output_tensors[i]->consumers();
auto ir_output_value = ir_node->result(i);
Expand Down Expand Up @@ -266,9 +265,8 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
const MatchContextImpl& src_match_ctx,
ir::PatternRewriter& rewriter) const { // NOLINT
MatchContextImpl res_match_ctx;
// add input tensors info for res_match_ctx;
const auto& input_tensors = result_pattern_graph.input_tensors();
for (const auto& in_tensor : input_tensors) {
// add input tensors info for res_match_ctx
for (const auto& in_tensor : result_pattern_graph.input_tensors()) {
res_match_ctx.BindIrValue(
in_tensor,
std::make_shared<IrValue>(src_match_ctx.GetIrValue(in_tensor)));
Expand All @@ -277,8 +275,8 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
// topo order visit result_pattern_graph
GraphTopo graph_topo_visit(&result_pattern_graph);
graph_topo_visit.WalkGraphNodesTopoOrder(
[&rewriter, &res_match_ctx](const OpCall& op_call) {
CreateOperation(op_call, rewriter, &res_match_ctx);
[&src_match_ctx, &rewriter, &res_match_ctx](const OpCall& op_call) {
CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx);
});

return res_match_ctx;
Expand Down
90 changes: 90 additions & 0 deletions paddle/ir/pattern_rewrite/drr/ir_operation_creator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/ir/pattern_rewrite/drr/ir_operation_creator.h"
#include "paddle/fluid/ir/dialect/pd_op.h"

namespace ir {
namespace drr {

Value GetIrValueByDrrTensor(const Tensor& tensor,
const MatchContextImpl& res_match_ctx) {
return res_match_ctx.GetIrValue(tensor.name()).get();
}

std::vector<Value> GetIrValuesByDrrTensors(
const std::vector<const Tensor*>& tensors,
const MatchContextImpl& res_match_ctx) {
std::vector<Value> ir_values;
ir_values.reserve(tensors.size());
for (const auto* tensor : tensors) {
ir_values.push_back(GetIrValueByDrrTensor(*tensor, res_match_ctx));
}
return ir_values;
}

ir::AttributeMap CreateAttributeMap(const OpCall& op_call,
const MatchContextImpl& src_match_ctx) {
ir::AttributeMap attr_map;
for (const auto& kv : op_call.attributes()) {
attr_map[kv.first] = src_match_ctx.GetIrAttr(kv.second.name());
}
return attr_map;
}

template <typename T>
T GetAttr(const std::string& attr_name,
const OpCall& op_call,
const MatchContextImpl& src_match_ctx) {
return src_match_ctx.Attr<T>(op_call.attributes().at(attr_name).name());
}

Operation* CreateOperation(const OpCall& op_call,
const MatchContextImpl& src_match_ctx,
ir::PatternRewriter& rewriter, // NOLINT
MatchContextImpl* res_match_ctx) {
if (op_call.name() == "pd.reshape") {
const auto& inputs = op_call.inputs();
std::vector<Value> ir_values =
GetIrValuesByDrrTensors(inputs, *res_match_ctx);
// TODO(zyfncg): support attr in build op.
Operation* reshape_op = rewriter.Build<paddle::dialect::ReshapeOp>(
ir_values[0].dyn_cast<ir::OpResult>(),
ir_values[1].dyn_cast<ir::OpResult>());
res_match_ctx->BindIrValue(
op_call.outputs()[0]->name(),
std::make_shared<IrValue>(reshape_op->result(0)));
res_match_ctx->BindIrValue(
op_call.outputs()[1]->name(),
std::make_shared<IrValue>(reshape_op->result(1)));
return reshape_op;
} else if (op_call.name() == "pd.transpose") {
const auto& inputs = op_call.inputs();
std::vector<Value> ir_values =
GetIrValuesByDrrTensors(inputs, *res_match_ctx);
Operation* transpose_op = rewriter.Build<paddle::dialect::TransposeOp>(
ir_values[0].dyn_cast<ir::OpResult>(),
GetAttr<std::vector<int>>("perm", op_call, src_match_ctx));
res_match_ctx->BindIrValue(
op_call.outputs()[0]->name(),
std::make_shared<IrValue>(transpose_op->result(0)));
return transpose_op;
}

LOG(ERROR) << "Unknown op " << op_call.name();
return nullptr;
}

} // namespace drr
} // namespace ir
52 changes: 2 additions & 50 deletions paddle/ir/pattern_rewrite/drr/ir_operation_creator.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,61 +20,13 @@
#include "paddle/ir/pattern_rewrite/drr/match_context_impl.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"

#include "paddle/fluid/ir/dialect/pd_op.h"

namespace ir {
namespace drr {

Value GetIrValueByDrrTensor(const Tensor& tensor,
const MatchContextImpl& res_match_ctx) {
return res_match_ctx.GetIrValue(tensor.name()).get();
}

std::vector<Value> GetIrValuesByDrrTensors(
const std::vector<const Tensor*>& tensors,
const MatchContextImpl& res_match_ctx) {
std::vector<Value> ir_values;
ir_values.reserve(tensors.size());
for (const auto* tensor : tensors) {
ir_values.push_back(GetIrValueByDrrTensor(*tensor, res_match_ctx));
}
return ir_values;
}

Operation* CreateOperation(const OpCall& op_call,
const MatchContextImpl& src_match_ctx,
ir::PatternRewriter& rewriter, // NOLINT
MatchContextImpl* res_match_ctx) {
if (op_call.name() == "pd.reshape") {
const auto& inputs = op_call.inputs();
std::vector<Value> ir_values =
GetIrValuesByDrrTensors(inputs, *res_match_ctx);
// TODO(zyfncg): support attr in build op.
Operation* reshape_op = rewriter.Build<paddle::dialect::ReshapeOp>(
ir_values[0].dyn_cast<ir::OpResult>(),
std::vector<int64_t>{16, 3, 4, 16});
res_match_ctx->BindIrValue(
op_call.outputs()[0]->name(),
std::make_shared<IrValue>(reshape_op->result(0)));
res_match_ctx->BindIrValue(
op_call.outputs()[1]->name(),
std::make_shared<IrValue>(reshape_op->result(1)));
return reshape_op;
}
else if(op_call.name() == "pd.transpose") {
const auto& inputs = op_call.inputs();
std::vector<Value> ir_values = GetIrValuesByDrrTensors(inputs, *res_match_ctx);
Operation* transpose_op = rewriter.Build<paddle::dialect::TransposeOp>(
ir_values[0].dyn_cast<ir::OpResult>(),
std::vector<int>{0, 2, 1, 3});
res_match_ctx->BindIrValue(
op_call.outputs()[0]->name(),
std::make_shared<IrValue>(transpose_op->result(0)));
return transpose_op;
}

LOG(ERROR) << "Unknown op " << op_call.name();
return nullptr;
}
MatchContextImpl* res_match_ctx);

// template <typename Op>
// class CreateOperation {
Expand Down
31 changes: 25 additions & 6 deletions paddle/ir/pattern_rewrite/drr/match_context_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ PD_SPECIALIZE_CppTypeToIrAttribute(int32_t, Int32Attribute);
PD_SPECIALIZE_CppTypeToIrAttribute(int64_t, Int64Attribute);
PD_SPECIALIZE_CppTypeToIrAttribute(float, FloatAttribute);

template <typename T>
struct IrAttrTypeCast {
static T To(const ir::Attribute& attr) {
return attr.dyn_cast<typename CppTypeToIrAttribute<T>::type>().data();
}
};

template <>
struct IrAttrTypeCast<std::vector<int>> {
static std::vector<int> To(const ir::Attribute& attr) {
std::vector<int> result;
for (size_t i = 0; i < attr.dyn_cast<ir::ArrayAttribute>().size(); i++) {
result.push_back(attr.dyn_cast<ir::ArrayAttribute>()
.at(i)
.dyn_cast<ir::Int32Attribute>()
.data());
}
return result;
}
};

class MatchContextImpl final {
public:
MatchContextImpl() = default;
Expand All @@ -54,18 +75,16 @@ class MatchContextImpl final {
}

template <typename T>
T Attr(const std::string& attr_name) const {
return attr_map_.at(attr_name)
.dyn_cast<typename CppTypeToIrAttribute<T>::type>()
.data();
T Attr(const std::string& attr_id) const {
return IrAttrTypeCast<T>::To(attr_map_.at(attr_id));
}

const IrValue& GetIrValue(const std::string& tensor_name) const {
return *tensor_map_.at(tensor_name);
}

ir::Attribute GetIrAttr(const std::string& tensor_name) const {
return attr_map_.at(tensor_name);
ir::Attribute GetIrAttr(const std::string& attr_id) const {
return attr_map_.at(attr_id);
}

const std::unordered_map<const OpCall*, std::shared_ptr<IrOperation>>&
Expand Down
3 changes: 2 additions & 1 deletion test/cpp/ir/pattern_rewrite/drr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <memory>

#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h"
Expand Down Expand Up @@ -141,7 +142,7 @@ void BuildProgram(ir::Builder &builder) { // NOLINT

paddle::dialect::TransposeOp transpose_op2 =
builder.Build<paddle::dialect::TransposeOp>(transpose_op1.out(),
std::vector<int>{0, 1, 2, 3});
std::vector<int>{1, 0, 2, 3});

paddle::dialect::ReluOp relu_op_second =
builder.Build<paddle::dialect::ReluOp>(transpose_op2.out());
Expand Down

0 comments on commit 3ddb601

Please sign in to comment.