diff --git a/paddle/fluid/ir/drr/api/drr_pattern_context.h b/paddle/fluid/ir/drr/api/drr_pattern_context.h index 45be4b3963a7b..85be176fcf478 100644 --- a/paddle/fluid/ir/drr/api/drr_pattern_context.h +++ b/paddle/fluid/ir/drr/api/drr_pattern_context.h @@ -14,11 +14,13 @@ #pragma once +#include #include #include #include #include #include +#include #include "paddle/fluid/ir/drr/api/match_context.h" @@ -34,9 +36,9 @@ class PatternGraph; class SourcePatternGraph; class ResultPatternGraph; -class Attribute { +class NormalAttribute { public: - explicit Attribute(const std::string& name) : attr_name_(name) {} + explicit NormalAttribute(const std::string& name) : attr_name_(name) {} const std::string& name() const { return attr_name_; } @@ -44,6 +46,23 @@ class Attribute { std::string attr_name_; }; +using AttrComputeFunc = std::function; + +class ComputeAttribute { + public: + explicit ComputeAttribute(const AttrComputeFunc& attr_compute_func) + : attr_compute_func_(attr_compute_func) {} + + const AttrComputeFunc& attr_compute_func() const { + return attr_compute_func_; + } + + private: + AttrComputeFunc attr_compute_func_; +}; + +using Attribute = std::variant; + class TensorShape { public: explicit TensorShape(const std::string& tensor_name) @@ -245,7 +264,12 @@ class ResultPattern { return ctx_->ResultTensorPattern(name); } - Attribute Attr(const std::string& attr_name) { return Attribute(attr_name); } + Attribute Attr(const std::string& attr_name) const { + return NormalAttribute(attr_name); + } + Attribute Attr(const AttrComputeFunc& attr_compute_func) const { + return ComputeAttribute(attr_compute_func); + } private: friend class SourcePattern; @@ -269,7 +293,9 @@ class SourcePattern { return ctx_->SourceTensorPattern(name); } - Attribute Attr(const std::string& attr_name) { return Attribute(attr_name); } + Attribute Attr(const std::string& attr_name) const { + return NormalAttribute(attr_name); + } void RequireEqual(const TensorShape& first, const TensorShape& second) { ctx_->RequireEqual(first, second); diff --git a/paddle/fluid/ir/drr/ir_operation_creator.cc b/paddle/fluid/ir/drr/ir_operation_creator.cc index 729a47cc3a691..15a9bb2df5083 100644 --- a/paddle/fluid/ir/drr/ir_operation_creator.cc +++ b/paddle/fluid/ir/drr/ir_operation_creator.cc @@ -39,7 +39,14 @@ ir::AttributeMap CreateAttributeMap(const OpCall& op_call, const MatchContextImpl& src_match_ctx) { ir::AttributeMap attr_map; for (const auto& kv : op_call.attributes()) { - attr_map[kv.first] = src_match_ctx.GetIrAttr(kv.second.name()); + std::visit( + [&](auto&& arg) { + if constexpr (std::is_same_v, + NormalAttribute>) { + attr_map[kv.first] = src_match_ctx.GetIrAttr(arg.name()); + } + }, + kv.second); } return attr_map; } @@ -48,7 +55,16 @@ template T GetAttr(const std::string& attr_name, const OpCall& op_call, const MatchContextImpl& src_match_ctx) { - return src_match_ctx.Attr(op_call.attributes().at(attr_name).name()); + const auto& attr = op_call.attributes().at(attr_name); + if (std::holds_alternative(attr)) { + return src_match_ctx.Attr(std::get(attr).name()); + } else if (std::holds_alternative(attr)) { + MatchContext ctx(std::make_shared(src_match_ctx)); + return std::any_cast( + std::get(attr).attr_compute_func()(ctx)); + } else { + IR_THROW("Unknown attrbute type for : %s.", attr_name); + } } Operation* CreateOperation(const OpCall& op_call, diff --git a/paddle/fluid/ir/drr/match_context_impl.h b/paddle/fluid/ir/drr/match_context_impl.h index 6a184e6d45527..bc2ccae99e9f2 100644 --- a/paddle/fluid/ir/drr/match_context_impl.h +++ b/paddle/fluid/ir/drr/match_context_impl.h @@ -107,15 +107,22 @@ class MatchContextImpl final { operation_map_.emplace(op_call, op); const auto& attrs = op_call->attributes(); for (const auto& kv : attrs) { - BindIrAttr(kv.second.name(), op->get()->attribute(kv.first)); + std::visit( + [&](auto&& arg) { + if constexpr (std::is_same_v, + NormalAttribute>) { + BindIrAttr(arg.name(), op->get()->attribute(kv.first)); + } + }, + kv.second); } } + private: void BindIrAttr(const std::string& attr_name, ir::Attribute attr) { attr_map_.emplace(attr_name, attr); } - private: std::unordered_map> tensor_map_; std::unordered_map> operation_map_; diff --git a/test/cpp/ir/pattern_rewrite/drr_test.cc b/test/cpp/ir/pattern_rewrite/drr_test.cc index b3803f93df0a2..19e05b9378d92 100644 --- a/test/cpp/ir/pattern_rewrite/drr_test.cc +++ b/test/cpp/ir/pattern_rewrite/drr_test.cc @@ -88,9 +88,18 @@ class RemoveRedundentTransposePattern // Result patterns: 要替换的子图 ir::drr::ResultPattern res = pat.ResultPattern(); - // todo 先简单用perm2替换 + const auto &new_perm_attr = + res.Attr([](const ir::drr::MatchContext &match_ctx) -> std::any { + const auto &perm1 = match_ctx.Attr>("perm_1"); + const auto &perm2 = match_ctx.Attr>("perm_2"); + std::vector new_perm; + for (int v : perm2) { + new_perm.emplace_back(perm1[v]); + } + return new_perm; + }); const auto &tranpose_continuous = - res.Op("pd.transpose", {{"perm", pat.Attr("perm_2")}}); + res.Op("pd.transpose", {{"perm", new_perm_attr}}); res.Tensor("ret") = tranpose_continuous(res.Tensor("arg_transpose")); }