Skip to content

Commit

Permalink
[Relay] Change some passes to mix mode (apache#6695)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiaoquan authored and Trevor Morris committed Oct 28, 2020
1 parent 7a6154d commit f9038b6
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 30 deletions.
8 changes: 6 additions & 2 deletions src/relay/analysis/util.cc
Expand Up @@ -71,7 +71,7 @@ class TypeVarTVisitor : public TypeVisitor {
InsertionSet<TypeVar>* bound_type_vars_;
};

class TypeVarEVisitor : private ExprVisitor {
class TypeVarEVisitor : private MixedModeVisitor {
public:
explicit TypeVarEVisitor(const IRModule& mod) : mod_(mod) {}

Expand Down Expand Up @@ -131,6 +131,8 @@ class TypeVarEVisitor : private ExprVisitor {
return CollectAll();
}

using MixedModeVisitor::VisitExpr_;

void VisitExpr_(const FunctionNode* f) final {
for (const auto& tp : f->type_params) {
type_vars_.Insert(tp);
Expand Down Expand Up @@ -159,7 +161,7 @@ class TypeVarEVisitor : private ExprVisitor {
const IRModule& mod_;
};

class VarVisitor : protected ExprVisitor, protected PatternVisitor {
class VarVisitor : protected MixedModeVisitor, protected PatternVisitor {
public:
Array<Var> Free(const Expr& expr) {
this->VisitExpr(expr);
Expand Down Expand Up @@ -204,6 +206,8 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor {
vars_.Insert(v);
}

using MixedModeVisitor::VisitExpr_;

void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }

void VisitExpr_(const FunctionNode* op) final {
Expand Down
16 changes: 7 additions & 9 deletions src/relay/analysis/well_formed.cc
Expand Up @@ -32,7 +32,7 @@ namespace tvm {
namespace relay {

//! brief make sure each Var is bound at most once in a scope.
class WellFormedChecker : private ExprVisitor, PatternVisitor {
class WellFormedChecker : private MixedModeVisitor, PatternVisitor {
public:
Optional<DiagnosticContext> diag_ctx;
Span occurs_in;
Expand Down Expand Up @@ -79,6 +79,8 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor {
total_bound.insert(v);
}

using MixedModeVisitor::VisitExpr_;

void VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
if (current_bound.count(v) == 0) {
Expand Down Expand Up @@ -126,7 +128,7 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor {

// CHECK(call->attrs.defined());
CHECK(call->type_args.defined());
ExprVisitor::VisitExpr_(call);
MixedModeVisitor::VisitExpr_(call);
}

void VisitClause(const Clause& c) final {
Expand All @@ -139,18 +141,14 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor {

void VisitVar(const Var& v) final { Bound(v); }

void VisitExpr(const Expr& e) final {
public:
bool CheckWellFormed(const Expr& e) {
if (auto v = e.as<VarNode>()) {
VisitExpr_(v);
} else {
// this->occurs_in = e->span;
ExprVisitor::VisitExpr(e);
VisitExpr(e);
}
}

public:
bool CheckWellFormed(const Expr& e) {
this->VisitExpr(e);
return well_formed;
}
};
Expand Down
4 changes: 3 additions & 1 deletion src/relay/ir/expr_functor.cc
Expand Up @@ -517,10 +517,12 @@ TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit").set_body_typed([](Expr ex
});

// Implement bind.
class ExprBinder : public ExprMutator, PatternMutator {
class ExprBinder : public MixedModeMutator, PatternMutator {
public:
explicit ExprBinder(const tvm::Map<Var, Expr>& args_map) : args_map_(args_map) {}

using MixedModeMutator::VisitExpr_;

Expr VisitExpr_(const LetNode* op) final {
CHECK(!args_map_.count(op->var)) << "Cannot bind an internel variable in let";
return ExprMutator::VisitExpr_(op);
Expand Down
6 changes: 4 additions & 2 deletions src/relay/transforms/de_duplicate.cc
Expand Up @@ -31,7 +31,7 @@ namespace tvm {
namespace relay {

Expr DeDup(const Expr& e) {
class DeDupMutator : public TypeMutator, public ExprMutator, public PatternMutator {
class DeDupMutator : public TypeMutator, public MixedModeMutator, public PatternMutator {
public:
TypeVar Fresh(const TypeVar& tv) {
TypeVar ret = TypeVar(tv->name_hint, tv->kind);
Expand All @@ -47,12 +47,14 @@ Expr DeDup(const Expr& e) {
return ret;
}

Expr VisitExpr(const Expr& e) final {
Expr DispatchVisitExpr(const Expr& e) final {
auto ret = ExprMutator::VisitExpr(e);
ret->checked_type_ = e->checked_type_;
return ret;
}

using MixedModeMutator::VisitExpr_;

Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return rename_.count(v) != 0 ? rename_.at(v) : v;
Expand Down
32 changes: 16 additions & 16 deletions src/relay/transforms/fold_constant.cc
Expand Up @@ -75,7 +75,7 @@ TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantChec

// TODO(tvm-team) consider combine dead-code with constant folder.
// or make a more powerful partial evaluator.
class ConstantFolder : public ExprMutator {
class ConstantFolder : public MixedModeMutator {
public:
explicit ConstantFolder(IRModule module)
: module_(module),
Expand All @@ -89,6 +89,8 @@ class ConstantFolder : public ExprMutator {
cast_op_(Op::Get("cast")),
ndarray_size_op_(Op::Get("ndarray_size")) {}

using MixedModeMutator::VisitExpr_;

Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
if (value.as<ConstantNode>()) {
Expand Down Expand Up @@ -118,7 +120,7 @@ class ConstantFolder : public ExprMutator {
}
}

Expr VisitExpr_(const CallNode* call) final {
Expr Rewrite_(const CallNode* call, const Expr& post) final {
if (inside_primitive) {
return GetRef<Expr>(call);
}
Expand All @@ -127,26 +129,25 @@ class ConstantFolder : public ExprMutator {
std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", "full_like", "full"};

auto origin_args = call->args;
Expr res = ExprMutator::VisitExpr_(call);
call = res.as<CallNode>();
call = post.as<CallNode>();
// We don't constant fold function with zero arguments.
// This is a heuristic that is useful.
// For example it is harmful to fold ones(shape=(4, 5)).
if (call->args.size() == 0) return res;
if (call->args.size() == 0) return post;
const OpNode* op = call->op.as<OpNode>();
if (op == nullptr) return res;
if (op == nullptr) return post;
if (skip_list.count(op->name)) {
return res;
return post;
}
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
if (op_stateful.get(GetRef<Op>(op), false)) return post;
// Try to evaluate shape_of op
if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) {
return EvaluateShapeOf(res, origin_args, call->attrs);
return EvaluateShapeOf(post, origin_args, call->attrs);
}

if (call->op == ndarray_size_op_) {
return EvaluateNdarraySize(res, origin_args, call->attrs);
return EvaluateNdarraySize(post, origin_args, call->attrs);
}

// We should think about potentially constant evaluation over these ops too.
Expand All @@ -162,19 +163,18 @@ class ConstantFolder : public ExprMutator {
}
}
if (all_const_args) {
return ConstEvaluate(res);
return ConstEvaluate(post);
} else {
return res;
return post;
}
}

Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr res = ExprMutator::VisitExpr_(op);
op = res.as<TupleGetItemNode>();
Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
op = post.as<TupleGetItemNode>();
if (const auto* tuple = op->tuple.as<TupleNode>()) {
return tuple->fields[op->index];
} else {
return res;
return post;
}
}

Expand Down

0 comments on commit f9038b6

Please sign in to comment.