Skip to content

Commit

Permalink
use ir::get_type_name
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome committed Aug 22, 2023
1 parent e92436e commit 16f066a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
10 changes: 5 additions & 5 deletions paddle/fluid/ir/drr/api/drr_pattern_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,27 @@

#pragma once

#include <typeinfo>

#include "paddle/fluid/ir/drr/api/drr_pattern_context.h"
#include "paddle/fluid/ir/drr/drr_rewrite_pattern.h"
#include "paddle/ir/core/type_name.h"

namespace ir {
namespace drr {

template <typename DrrPattern>
class DrrPatternBase {
public:
virtual ~DrrPatternBase() = default;

// Define the Drr Pattern.
virtual void operator()(ir::drr::DrrPatternContext* ctx) const = 0;

std::unique_ptr<DrrRewritePattern> Build(
std::unique_ptr<DrrRewritePattern<DrrPattern>> Build(
ir::IrContext* ir_context, ir::PatternBenefit benefit = 1) const {
DrrPatternContext drr_context;
this->operator()(&drr_context);
return std::make_unique<DrrRewritePattern>(
typeid(*this).name(), drr_context, ir_context, benefit);
return std::make_unique<DrrRewritePattern<DrrPattern>>(
drr_context, ir_context, benefit);
}
};

Expand Down
8 changes: 3 additions & 5 deletions paddle/fluid/ir/drr/drr_rewrite_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,17 @@
namespace ir {
namespace drr {

template <typename DrrPattern>
class DrrRewritePattern : public ir::RewritePattern {
public:
explicit DrrRewritePattern(const std::string& drr_pattern_name,
const DrrPatternContext& drr_context,
explicit DrrRewritePattern(const DrrPatternContext& drr_context,
ir::IrContext* context,
ir::PatternBenefit benefit = 1)
: ir::RewritePattern(
drr_context.source_pattern_graph()->AnchorNode()->name(),
benefit,
context,
{}),
drr_pattern_name_(drr_pattern_name),
source_pattern_graph_(drr_context.source_pattern_graph()),
constraints_(drr_context.constraints()),
result_pattern_graph_(drr_context.result_pattern_graph()) {
Expand All @@ -56,7 +55,7 @@ class DrrRewritePattern : public ir::RewritePattern {
std::shared_ptr<MatchContextImpl> src_match_ctx =
std::make_shared<MatchContextImpl>();
if (PatternGraphMatch(op, src_match_ctx)) {
VLOG(6) << "DRR pattern (" << drr_pattern_name_
VLOG(6) << "DRR pattern (" << ir::get_type_name<DrrPattern>()
<< ") is matched in program.";
PatternGraphRewrite(*src_match_ctx, rewriter);
return true;
Expand Down Expand Up @@ -331,7 +330,6 @@ class DrrRewritePattern : public ir::RewritePattern {
});
}

const std::string drr_pattern_name_;
const std::shared_ptr<SourcePatternGraph> source_pattern_graph_;
const std::vector<Constraint> constraints_;
const std::shared_ptr<ResultPatternGraph> result_pattern_graph_;
Expand Down
15 changes: 10 additions & 5 deletions test/cpp/ir/pattern_rewrite/drr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
#include "paddle/ir/transforms/dead_code_elimination_pass.h"

class RemoveRedundentReshapePattern : public ir::drr::DrrPatternBase {
class RemoveRedundentReshapePattern
: public ir::drr::DrrPatternBase<RemoveRedundentReshapePattern> {
public:
void operator()(ir::drr::DrrPatternContext *ctx) const override {
// Source patterns:待匹配的子图
Expand All @@ -44,7 +45,8 @@ class RemoveRedundentReshapePattern : public ir::drr::DrrPatternBase {
}
};

class FoldBroadcastToConstantPattern : public ir::drr::DrrPatternBase {
class FoldBroadcastToConstantPattern
: public ir::drr::DrrPatternBase<FoldBroadcastToConstantPattern> {
public:
void operator()(ir::drr::DrrPatternContext *ctx) const override {
ir::drr::SourcePattern pat = ctx->SourcePattern();
Expand All @@ -71,7 +73,8 @@ class FoldBroadcastToConstantPattern : public ir::drr::DrrPatternBase {
}
};

class RemoveRedundentTransposePattern : public ir::drr::DrrPatternBase {
class RemoveRedundentTransposePattern
: public ir::drr::DrrPatternBase<RemoveRedundentTransposePattern> {
public:
void operator()(ir::drr::DrrPatternContext *ctx) const override {
// Source pattern: 待匹配的子图
Expand All @@ -93,7 +96,8 @@ class RemoveRedundentTransposePattern : public ir::drr::DrrPatternBase {
}
};

class RemoveRedundentCastPattern : public ir::drr::DrrPatternBase {
class RemoveRedundentCastPattern
: public ir::drr::DrrPatternBase<RemoveRedundentCastPattern> {
void operator()(ir::drr::DrrPatternContext *ctx) const override {
auto pat = ctx->SourcePattern();
pat.Tensor("tmp") =
Expand All @@ -106,7 +110,8 @@ class RemoveRedundentCastPattern : public ir::drr::DrrPatternBase {
}
};

class RemoveUselessCastPattern : public ir::drr::DrrPatternBase {
class RemoveUselessCastPattern
: public ir::drr::DrrPatternBase<RemoveUselessCastPattern> {
public:
void operator()(ir::drr::DrrPatternContext *ctx) const override {
auto pat = ctx->SourcePattern();
Expand Down

0 comments on commit 16f066a

Please sign in to comment.