# Adding a Compiler Pass to Relax

As with Relay, the primary means of implementing optimizations and analyses in Relax is via compiler passes. This tutorial gives an overview of different ways to traverse as well as transform the Relax AST. In addition to that, with the help of an example, the tutorial will introduce writing Relax passes that analyze and transform various components of Relax AST.

## AST Traversal

In a traditional compilation process, a source program is represented as text and is then parsed into a tree of grammar constructs, which is called an [abstract syntax tree (AST)](https://en.wikipedia.org/wiki/Abstract_syntax_tree), that can be processed more easily than the source program. (In Relax, we use a decorator to parse Python code and produce a Relax AST, but the principle is the same.) Compiler passes operate on the AST by traversing the tree of grammar constructs — hence we say that they make a *pass* over the AST — and either:

1. modify the AST (for example, inline functions, unroll loops, etc.), thus implementing a *program transformation*, or
2. collect information about AST **without** modifying it (for example, number of loops in the program), thus implementing a *program analysis.*

Analyses can be helpful for implementing program transformations by checking that assumptions about the input program have been met. This tutorial will outline the steps for implementing analyses and optimizations in Relax and, in the process, explain the mechanisms for doing so.

### Traversing the AST: Functors

Abstract syntax trees (ASTs) are, by their nature, recursive: Expressions in a program contain subexpressions, which may in turn contain further subexpressions. Hence, AST traversals must also be done in a recursive process and in a way that allows for customizing the order of the traversal and handling for different grammar constructs.

In compilers that are implemented in object-oriented programming languages, this is generally accomplished using the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern), in which the visitor (the parent class) examines the root of the AST using its `visit` method and passes the AST root to a method that operates on that particular grammar construct. For example, if the root of the AST is an `IfNode`, the visitor will dispatch the `visit_if` method. The individual visitor methods can perform whatever operations are necessary and, crucially, also call the parent class's `visit` method to process any subexpressions recursively without worrying about how to dispatch them, since the recursive call to `visit` handles that.

For Relax’s AST, the most general form of the visitor pattern is called an `ExprFunctor`, which is a base class for any compiler pass. Implementations are provided in both [C++](https://github.com/tlc-pack/relax/blob/6495823859cdcbe87899b31de1b82a824e08c4f3/include/tvm/relax/expr_functor.h#L68-L139) and [Python](https://github.com/tlc-pack/relax/blob/6495823859cdcbe87899b31de1b82a824e08c4f3/python/tvm/relax/expr_functor.py#L36-L119). In the C++ implementation, `VisitExpr` is the entry-point method (taking any `Expr` node and dispatching accordingly) and there are overrides of `VisitExpr_` for each grammar construct to handle the different cases. Since Python does not support overloading by argument types, the Python version uses `visit_expr` as the entry point and `visit_{name}_` methods for each grammar construct. (Note: One difference between the versions is that the C++ version also permits auxiliary arguments to be passed in addition to the input AST, though this can be accomplished in Python by wrapping the visitor methods.)

Note that the entry-point method can be overridden to include logic that fires before every expression node. This is a commonly used pattern in both the Relay and Relax codebases, with an example given below. (The same can be accomplished in Python using `super`.)

```cpp
void PrintVisitor::VisitExpr(const Expr& expr) {
  // fires for every expression type
  std::cout << "Here" << std::endl;
  // uses the base class's unmodified entry-point method to dispatch by node type
  ExprFunctor::VisitExpr(expr);
}
```

While `ExprFunctor`s are general enough to allow for implementing any compiler pass using them, they leave all of of the individual visitor cases undefined, requiring the user to fill them in. Passes typically only have a few node types of interest that affect them, so using a functor would require filling in implementations for many node types that likely will not require special processing. For common cases, we provide child classes of `ExprFunctor` that have sensible defaults and thus require overriding fewer cases. We discuss them below.



### Implementing Analyses: Visitors

An `ExprVisitor` is an `ExprFunctor` that traverses an AST and, by default, does not modify the AST. This it is used as a parent class for writing analysis passes. The base `ExprVisitor` includes a default implementation for each case, namely visiting each subexpression of the current AST node in order. There are, again, implementations in both [C++](https://github.com/tlc-pack/relax/blob/6495823859cdcbe87899b31de1b82a824e08c4f3/src/relax/ir/expr_functor.cc#L36-L197) and [Python](https://github.com/tlc-pack/relax/blob/6495823859cdcbe87899b31de1b82a824e08c4f3/python/tvm/relax/expr_functor.py#L122-L261). 

For example, here is the visitor case for `IfNode`:

```cpp
void ExprVisitor::VisitExpr_(const IfNode* op) {
  this->VisitSpan(op->span);
  this->VisitExpr(op->cond);
  this->VisitExpr(op->true_branch);
  this->VisitExpr(op->false_branch);
}
```

(`Span`s are used for storing metadata for error messages; they are not AST nodes *per se*.)

Even though the individual dispatch methods in `ExprVisitor` do not modify the AST, `ExprVisitor` is nevertheless very useful for implementing analyses. This can be accomplished by keeping some state in the `ExprVisitor` that tracks whatever value the analysis is meant to return, modifying that state in the individual cases, and finally returning the stored state. The basic workflow can be summarized as follows:

```python
# wrapper over the visitor logic; 
# it may also make sense to implement this as a method on the visitor,
# especially if there may be a need to call it recursively
def perform_analysis(program: Expr) -> ReturnType:
  visitor = MyProgramAnalysis()
  visitor.visit_expr(program)
  return visitor.get_state()

class MyProgramAnalysis(ExprVisitor):
  # initialization logic, etc.

  def visit_tuple(self, tuple_expr: Tuple) -> None:
    for expr in tuple_expr.fields:
      if self.some_condition_is_met(expr):
        self.update_state(expr)
        # can decide whether or not you need to visit the subexpressions
        self.visit_expr(expr)
      
  # etc., define cases as appropriate

  def visit_if(self, if_expr: If) -> None:
    if self.another_condition_is_met(if_expr):
      self.update_state(if_expr)
    # can use the default visitor implementation to dispatch 
    # if that is the desired behavior
    super().visit_if(if_expr)
```

For a real example, we may consider Relax’s [well-formedness analysis](https://github.com/tlc-pack/relax/blob/relax/src/relax/analysis/well_formed.cc), which checks that the input program is in [A-normal form](https://en.wikipedia.org/wiki/A-normal_form) (further discussed below), that all fields of a shape expression are integer-typed, and that all variables have exactly one definition (so there are no variables used that aren’t defined and no variable defined more than once).

The well-formedness checker keeps a Boolean field called `well_formed` that is true if the program meets the criteria for well-formedness and false otherwise. The visitor is invoked from a wrapper method called `WellFormed` that initializes the visitor, applies it to every function in the `IRModule`, and returns the final value of `well_formed`:

```cpp
bool WellFormed(const IRModule& m, Optional<DiagnosticContext> diag_ctx) {
  WellFormedChecker well_formed_checker = WellFormedChecker(diag_ctx);
  for (const auto& it : m->functions) {
    // register GlobalVar in the IRModule first
    well_formed_checker.RegisterGlobalVar(it.first);
  }

  for (const auto& it : m->functions) {
    // visit relax.Function
    if (auto* n = it.second.as<FunctionNode>()) {
      Function func = GetRef<Function>(n);
      well_formed_checker.VisitExpr(func);
    }
  }

  return well_formed_checker.well_formed;
}
```

Here is the well-formedness  checker’s logic for handling bound variables (e.g., parameters to functions or variables in assignments):

```cpp
void VisitVarDef_(const VarNode* var) {
    Var gv = GetRef<Var>(var);
    if (var_set_.count(gv) == 1) {
      Malformed(Diagnostic::Error(var->span)
                << "Var " << gv->name_hint() << " is defined more than once.");
    }
    // register Var
    var_set_.insert(gv);
}
```

First, it checks that the variable being visited does not already have a definition associated with it (giving an error if it does). Next, it adds the variable to its set of variables with definitions, which it can use to validate any future variable definitions encountered during the AST traversal.

Meanwhile, the case for `IfNode` does some recursive visits and has to manage some of the persistent state:

```cpp
void VisitExpr_(const IfNode* op) {
    this->VisitExpr(op->cond);
    std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> previous_var_set_ = var_set_;
    std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> previous_symbolic_var_set_ =
        prim_expr_visitor_.symbolic_var_set_;
    this->VisitBody(op->true_branch);
    var_set_ = previous_var_set_;
    prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_;
    this->VisitBody(op->false_branch);
    var_set_ = previous_var_set_;
    prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_;
}
```

Before visiting the true branch and false branches of the `IfNode`, the visitor stores the previous sets of definitions because the branches are [lexically scoped](https://en.wikipedia.org/wiki/Scope_(computer_science)#Lexical_scope_vs._dynamic_scope), so definitions in the branches are not visible outside the branches. This is an example of how the logic in a compiler pass may need to account for the semantics of Relax.

Many analyses in Relax are implemented using a function called  `PostOrderVisit`, which simply recurses down an AST and applies a user-provided function on each AST node *after* completing any recursive visits for that node (hence “post-order”). Its logic is very simple and allows these analyses to be implemented very compactly:

```cpp
class ExprApplyVisit : public ExprVisitor {
 public:
  explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {}

  void VisitExpr(const Expr& e) final {
    ExprVisitor::VisitExpr(e);
    f_(e);
  }

 private:
  std::function<void(const Expr&)> f_;
};

void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
  ExprApplyVisit(fvisit).VisitExpr(e);
}
```

### Implementing Transformations: (Base) Mutators

`ExprVisitor`s are useful when the goal is to examine an input program and draw some conclusion about it, but implementing compiler optimizations generally requires making changes to programs. `ExprMutator`s are used for the latter purpose, but they do not actually change an AST in-place; rather, they traverse the AST and produce a new AST that may incorporate changes. In fact, the default implementation of an `ExprMutator` simply traverses an AST and returns the very same AST—this allows users to override individual cases to have the mutator return a different program, with the helpful default that any cases left unchanged will simply preserve those parts of the program. This generally presents the abstraction of making changes to a program when what is really happening is that a new program is being constructed. See the implementations in [C++](https://github.com/tlc-pack/relax/blob/6495823859cdcbe87899b31de1b82a824e08c4f3/src/relax/ir/expr_functor.cc#L223-L665) and [Python](https://github.com/tlc-pack/relax/blob/6495823859cdcbe87899b31de1b82a824e08c4f3/python/tvm/relax/expr_functor.py#L264-L714).

For example, here is the default implementation for the tuple node case:

```cpp
Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) {
  bool unchanged = true;
  tvm::Array<Expr> fields;
  for (Expr field : op->fields) {
    Expr new_field = this->VisitExpr(field);
    fields.push_back(new_field);
    unchanged &= new_field.same_as(field);
  }

  if (unchanged) {
    return GetRef<Expr>(op);
  } else {
    Expr new_tuple = Tuple(fields, op->span);
    return new_tuple;
  }
}
```

Notice that the method simply returns the original node if the result of visiting all the subexpressions is the same as in the original expression; this avoids allocating new AST nodes for no reason. (Also, since Relax AST nodes are immutable, this is safe.) Thus, if a visit to one of the fields of this tuple results in a change being made to the AST, the result will be a new tuple node that includes the changed field.

The basic workflow of implementing a program transformation using the base mutator might look like this:

```python
class MyProgramTransformation(ExprMutatorBase):
  def visit_if_(self, if_expr):
    # e.g., if it matches some pattern
    if some_condition_is_met(if_expr):
      # essentially replaces the encountered expression with some other expression
      # (in reality, a new program is built featuring the new subexpression)
      return my_new_node()
    super().visit_expr(if_expr)

  # any cases not overridden will preserve the program as it was
```

### Using A-Normal Form in `ExprMutator`

The class shown above was `ExprMutatorBase` rather than `ExprMutator`. That is because the `ExprMutator` used to implement program transformations in Relax does not directly traverse arbitrary ASTs like in the example given, but instead expects ASTs to first be *normalized*, namely put into *[administrative normal form](https://en.wikipedia.org/wiki/A-normal_form)*, abbreviated “A-normal form” or “ANF.”

In ANF, expressions are allowed to have only *atomic* *expressions* (constants, function definitions, or variables) for subexpressions, meaning that complex expressions like call nodes cannot contain other complex expressions as subexpressions—the only way to pass their results to other complex expressions is to bind them to variables. For example, `f(g(x))` and `if f(x) then 1 else 0` are not permitted in ANF, while  `a = g(x); f(a)` and `a = f(x); if a then 1 else 0` are.

As described in [this page](https://matt.might.net/articles/a-normalization/) by Matt Might, one advantage of ANF for compilers is that it transforms a a deeply nested program into one where the order of operations is spelled out very clearly in a series of bindings. It is often useful in optimizations to reason directly about execution order and, as Might also observes, such bindings are easy to lower to assembly instructions. 

In Relax’s `ExprMutator`, the use of ANF has several advantages for writing passes:

1. It generally simplifies the handling of complex expressions by guaranteeing that subexpressions are variables or constants, thus reducing the amount of analysis needed to conclude that certain program transformations are safe. For example, consider the case of a call node: If the arguments to the call could be general expressions, it is possible, in principle, that one of the subexpressions is a call to a `PackedFunc` that has side effects. Thus, it would change the program’s semantics to either eliminate the call or reorder the arguments. In ANF, however, the arguments are guaranteed to be atomic expressions and thus cannot have any side effects.
2. In ANF, similarly to [single static assignment (SSA)](https://en.wikipedia.org/wiki/Static_single-assignment_form), complex nested expressions are instead transformed into a series of bindings. This way, the subexpressions can be addressed individually, which is often convenient for writing passes. (Indeed, the `ExprMutator` implementations include helper methods for tracking which variables are in scope and what their definitions are, allowing pass implementations to easily query for them.) Additionally, there is less need for passes to reason about the order of operations within a nested expression—all is determined by the order of the bindings, and it is easy to introduce new bindings and reorder them as needed.
3. It is common for Relax passes to add bindings to the current scope (a binding block). In a deeply nested expression, it can become difficult to reason about the order in which bindings will be added and whether there are any dependencies between these. With ANF, this is never an issue because the amount of nesting is limited.

[The `Normalize` pass in Relax](https://github.com/tlc-pack/relax/blob/relax/src/relax/transform/normalize.cc) transforms an arbitrary Relax AST into ANF.

For a real example, `ExprMutator` was used to implement a [constant-folding pass in Relax](https://github.com/tlc-pack/relax/blob/relax/src/relax/transform/fold_constant.cc) rather compactly. For variables that are bound to constants, the mutator simply replaces the variables with those constants:

```python
Expr VisitExpr_(const VarNode* op) final {
    Optional<Expr> opt = LookupBinding(GetRef<Var>(op));
    // `as` check checks if opt is not null and is instance of constant
    if (opt.as<relax::ConstantNode>()) {
      return opt.value();
    }
    return ExprMutator::VisitExpr_(op);
}
```

The only complex case is handling calls to TIR functions: If there is a call to a TIR function and all the arguments are constants, then the TIR call itself can simply be evaluated at compile time.




## Granularity
### Function
### Dataflow Block
### Module

## Example: Conv-Conv Fusion

In [1]:
from __future__ import annotations
import tvm
from tvm import relax, topi
from tvm.script import relax as R, tir as T
from tvm.ir.transform import module_pass
from tvm.relax import ExprMutator
from tvm.ir.module import IRModule
from tvm import tir, relax

## Unfused Convolution Module

In [2]:
bb = relax.BlockBuilder()

# Main function
def build_model(x, w1, w2):
    
    with bb.function("main", [x, w1, w2]):
        with bb.dataflow():
            lv1 = bb.emit_te(topi.nn.conv2d, x, w1, strides=1, padding=1, dilation=1)
            gv = bb.emit_output(bb.emit_te(topi.nn.conv2d, lv1, w2, strides=1, padding=1, dilation=1))
        bb.emit_func_output(gv)
    
    return bb.get()

tensor_type = relax.DynTensorType(4, "float32")
x = relax.Var("x", (1, 16, 64, 64), tensor_type)
w1 = relax.Var("w1", (16, 16, 3, 3), tensor_type)
w2 = relax.Var("w2", (16, 16, 3, 3), tensor_type)

mod = build_model(x, w1, w2)
mod["conv2d"] = mod["conv2d"].with_attr("my_op_kind", "convolution")
mod.show()


In [3]:
class FuseTwoConvMutator(ExprMutator):

    def __init__(self, mod: IRModule) -> None:
        super().__init__(mod)
        self.mod_ = mod

    def is_convolution(self, call_node):
        call_tir_op = tvm.ir.Op.get("relax.call_tir")
        if call_node.op != call_tir_op:
            return False
        global_var = call_node.args[0]
        tir_func = self.mod_[global_var]
        if tir_func.attrs["my_op_kind"] != "convolution":
            return False
        return tir_func
    
    def pattern_match(self, call_node):
        first_conv_tir = self.is_convolution(call_node)
        if not first_conv_tir:
            return None
        operands = call_node.args[1]
        input_tensor = operands[0]
        value = self.lookup_binding(input_tensor)
        if not value:
            return None
        if not isinstance(value, relax.Call):
            return None
        second_conv_tir = self.is_convolution(value)
        if not second_conv_tir:
            return None
        return [value, call_node]

    def transform(self) -> IRModule:
        for global_var, func in self.mod_.functions.items():
            if isinstance(func, relax.Function):
                func = self.visit_expr(func)
                self.builder_.update_func(global_var, func)
        return self.builder_.get()
    
    def visit_call_(self, call_node: relax.Call) -> relax.Call:
        conv_calls = self.pattern_match(call_node)
        if not conv_calls:
            return call_node
        # Build fused op
        # construct a subgraph
        # R.parser.pretty_print(conv_call[0])
        args = [conv_calls[0].args[1][0], conv_calls[0].args[1][1], conv_calls[1].args[1][1]]
        x = relax.Var("x", args[0].shape_, args[0]._checked_type_)
        w1 = relax.Var("w1", args[1].shape_, args[1]._checked_type_)
        w2 = relax.Var("w2", args[2].shape_, args[2]._checked_type_)
        conv2d = call_node.args[0]

        lv0 = relax.DataflowVar("lv0", conv_calls[0].shape_, conv_calls[0]._checked_type_)
        conv_1 = relax.call_tir(conv2d, [x, w1], conv_calls[0].shape_, dtype="float32")
        b0 = relax.VarBinding(lv0, conv_1)
        gv = relax.Var("conv_2", conv_calls[1].shape_, conv_calls[1]._checked_type_)

        b1 = relax.VarBinding(gv, relax.call_tir(conv2d, [lv0, w2], conv_calls[1].shape_, dtype="float32"))
        bindings = [b0, b1]
        blocks = [relax.DataflowBlock(bindings)]
        seq_expr = relax.SeqExpr(blocks, blocks[-1].bindings[-1].var)
        ret_type = conv_calls[1]._checked_type_


        func_name = "fused_conv_conv"
        fused_conv_conv = relax.Function([x, w1, w2], seq_expr, ret_type).with_attr("global_symbol", func_name).with_attr("Primitive", 1)
        normalized = self.builder_.normalize(fused_conv_conv)
        global_var = self.builder_.add_func(normalized, "fused_conv_conv")
        # self.builder_.add_func(conv2d, "conv2d")

        # construct a call to the subgraph
        fused_conv_call = relax.Call(global_var, [args[0], args[1], args[2]], None, None)
        return fused_conv_call
        

@module_pass(opt_level=2, name="fuse_two_conv")
class FuseTwoConv:
    """The wrapper for the FuseTwoConv pass."""

    def transform_module(self, mod, ctx):
        return FuseTwoConvMutator(mod).transform()


In [4]:
after = FuseTwoConv()(mod)
if not relax.analysis.well_formed(after):
    print("NOT WELL FORMED")
after.show("dark")



In [5]:
after_tir_fuse = relax.transform.FuseTIR()(after)
if not relax.analysis.well_formed(after_tir_fuse):
    print("NOT WELL FORMED")
after_tir_fuse.show()

## Play With Fused TIR Schedule