From f9038b6ffb03c90fa272cbca487d384ec89eb9e3 Mon Sep 17 00:00:00 2001 From: lixiaoquan Date: Fri, 16 Oct 2020 23:47:27 +0800 Subject: [PATCH] [Relay] Change some passes to mix mode (#6695) --- src/relay/analysis/util.cc | 8 +++++-- src/relay/analysis/well_formed.cc | 16 ++++++-------- src/relay/ir/expr_functor.cc | 4 +++- src/relay/transforms/de_duplicate.cc | 6 +++-- src/relay/transforms/fold_constant.cc | 32 +++++++++++++-------------- 5 files changed, 36 insertions(+), 30 deletions(-) diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index 59ce01ce4227a..edf8fb644c576 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -71,7 +71,7 @@ class TypeVarTVisitor : public TypeVisitor { InsertionSet* bound_type_vars_; }; -class TypeVarEVisitor : private ExprVisitor { +class TypeVarEVisitor : private MixedModeVisitor { public: explicit TypeVarEVisitor(const IRModule& mod) : mod_(mod) {} @@ -131,6 +131,8 @@ class TypeVarEVisitor : private ExprVisitor { return CollectAll(); } + using MixedModeVisitor::VisitExpr_; + void VisitExpr_(const FunctionNode* f) final { for (const auto& tp : f->type_params) { type_vars_.Insert(tp); @@ -159,7 +161,7 @@ class TypeVarEVisitor : private ExprVisitor { const IRModule& mod_; }; -class VarVisitor : protected ExprVisitor, protected PatternVisitor { +class VarVisitor : protected MixedModeVisitor, protected PatternVisitor { public: Array Free(const Expr& expr) { this->VisitExpr(expr); @@ -204,6 +206,8 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor { vars_.Insert(v); } + using MixedModeVisitor::VisitExpr_; + void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef(var)); } void VisitExpr_(const FunctionNode* op) final { diff --git a/src/relay/analysis/well_formed.cc b/src/relay/analysis/well_formed.cc index 3e409d10b8855..5abbbc94fb364 100644 --- a/src/relay/analysis/well_formed.cc +++ b/src/relay/analysis/well_formed.cc @@ -32,7 +32,7 @@ namespace tvm { namespace relay { //! brief make sure each Var is bound at most once in a scope. -class WellFormedChecker : private ExprVisitor, PatternVisitor { +class WellFormedChecker : private MixedModeVisitor, PatternVisitor { public: Optional diag_ctx; Span occurs_in; @@ -79,6 +79,8 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { total_bound.insert(v); } + using MixedModeVisitor::VisitExpr_; + void VisitExpr_(const VarNode* op) final { Var v = GetRef(op); if (current_bound.count(v) == 0) { @@ -126,7 +128,7 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { // CHECK(call->attrs.defined()); CHECK(call->type_args.defined()); - ExprVisitor::VisitExpr_(call); + MixedModeVisitor::VisitExpr_(call); } void VisitClause(const Clause& c) final { @@ -139,18 +141,14 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { void VisitVar(const Var& v) final { Bound(v); } - void VisitExpr(const Expr& e) final { + public: + bool CheckWellFormed(const Expr& e) { if (auto v = e.as()) { VisitExpr_(v); } else { // this->occurs_in = e->span; - ExprVisitor::VisitExpr(e); + VisitExpr(e); } - } - - public: - bool CheckWellFormed(const Expr& e) { - this->VisitExpr(e); return well_formed; } }; diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index cbc41d225d4b5..a09179bcc5854 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -517,10 +517,12 @@ TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit").set_body_typed([](Expr ex }); // Implement bind. -class ExprBinder : public ExprMutator, PatternMutator { +class ExprBinder : public MixedModeMutator, PatternMutator { public: explicit ExprBinder(const tvm::Map& args_map) : args_map_(args_map) {} + using MixedModeMutator::VisitExpr_; + Expr VisitExpr_(const LetNode* op) final { CHECK(!args_map_.count(op->var)) << "Cannot bind an internel variable in let"; return ExprMutator::VisitExpr_(op); diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index d90e5c584df36..8c62fe6100c37 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -31,7 +31,7 @@ namespace tvm { namespace relay { Expr DeDup(const Expr& e) { - class DeDupMutator : public TypeMutator, public ExprMutator, public PatternMutator { + class DeDupMutator : public TypeMutator, public MixedModeMutator, public PatternMutator { public: TypeVar Fresh(const TypeVar& tv) { TypeVar ret = TypeVar(tv->name_hint, tv->kind); @@ -47,12 +47,14 @@ Expr DeDup(const Expr& e) { return ret; } - Expr VisitExpr(const Expr& e) final { + Expr DispatchVisitExpr(const Expr& e) final { auto ret = ExprMutator::VisitExpr(e); ret->checked_type_ = e->checked_type_; return ret; } + using MixedModeMutator::VisitExpr_; + Expr VisitExpr_(const VarNode* op) final { Var v = GetRef(op); return rename_.count(v) != 0 ? rename_.at(v) : v; diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 660aff2eed9a6..8d2cba05be492 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -75,7 +75,7 @@ TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantChec // TODO(tvm-team) consider combine dead-code with constant folder. // or make a more powerful partial evaluator. -class ConstantFolder : public ExprMutator { +class ConstantFolder : public MixedModeMutator { public: explicit ConstantFolder(IRModule module) : module_(module), @@ -89,6 +89,8 @@ class ConstantFolder : public ExprMutator { cast_op_(Op::Get("cast")), ndarray_size_op_(Op::Get("ndarray_size")) {} + using MixedModeMutator::VisitExpr_; + Expr VisitExpr_(const LetNode* op) final { Expr value = this->Mutate(op->value); if (value.as()) { @@ -118,7 +120,7 @@ class ConstantFolder : public ExprMutator { } } - Expr VisitExpr_(const CallNode* call) final { + Expr Rewrite_(const CallNode* call, const Expr& post) final { if (inside_primitive) { return GetRef(call); } @@ -127,26 +129,25 @@ class ConstantFolder : public ExprMutator { std::unordered_set skip_list{"zeros_like", "ones_like", "full_like", "full"}; auto origin_args = call->args; - Expr res = ExprMutator::VisitExpr_(call); - call = res.as(); + call = post.as(); // We don't constant fold function with zero arguments. // This is a heuristic that is useful. // For example it is harmful to fold ones(shape=(4, 5)). - if (call->args.size() == 0) return res; + if (call->args.size() == 0) return post; const OpNode* op = call->op.as(); - if (op == nullptr) return res; + if (op == nullptr) return post; if (skip_list.count(op->name)) { - return res; + return post; } // skip stateful ops. - if (op_stateful.get(GetRef(op), false)) return res; + if (op_stateful.get(GetRef(op), false)) return post; // Try to evaluate shape_of op if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) { - return EvaluateShapeOf(res, origin_args, call->attrs); + return EvaluateShapeOf(post, origin_args, call->attrs); } if (call->op == ndarray_size_op_) { - return EvaluateNdarraySize(res, origin_args, call->attrs); + return EvaluateNdarraySize(post, origin_args, call->attrs); } // We should think about potentially constant evaluation over these ops too. @@ -162,19 +163,18 @@ class ConstantFolder : public ExprMutator { } } if (all_const_args) { - return ConstEvaluate(res); + return ConstEvaluate(post); } else { - return res; + return post; } } - Expr VisitExpr_(const TupleGetItemNode* op) final { - Expr res = ExprMutator::VisitExpr_(op); - op = res.as(); + Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final { + op = post.as(); if (const auto* tuple = op->tuple.as()) { return tuple->fields[op->index]; } else { - return res; + return post; } }