# Adding a Compiler Pass to Relax

As with Relay, the primary means of implementing optimizations and analyses in Relax is via compiler passes. This tutorial gives an overview of different ways to traverse as well as transform the Relax AST. In addition to that, with the help of an example, the tutorial will introduce writing Relax passes that analyze and transform various components of Relax AST.

## AST Traversal

In a traditional compilation process, a source program is represented as text and is then parsed into a tree of grammar constructs, which is called an [abstract syntax tree (AST)](https://en.wikipedia.org/wiki/Abstract_syntax_tree), that can be processed more easily than the source program. (In Relax, we use a decorator to parse Python code and produce a Relax AST, but the principle is the same.) Compiler passes operate on the AST by traversing the tree of grammar constructs — hence we say that they make a *pass* over the AST — and either:

1. modify the AST (for example, inline functions, unroll loops, etc.), thus implementing a *program transformation*, or
2. collect information about AST **without** modifying it (for example, number of loops in the program), thus implementing a *program analysis.*

Analyses can be helpful for implementing program transformations by checking that assumptions about the input program have been met. This tutorial will outline the steps for implementing analyses and optimizations in Relax and, in the process, explain the mechanisms for doing so.

### Traversing the AST: Functors

Abstract syntax trees (ASTs) are, by their nature, recursive: Expressions in a program contain subexpressions, which may in turn contain further subexpressions. Hence, AST traversals must also be done in a recursive process and in a way that allows for customizing the order of the traversal and handling for different grammar constructs.

In compilers that are implemented in object-oriented programming languages, this is generally accomplished using the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern), in which the visitor (the parent class) examines the root of the AST using its `visit` method and passes the AST root to a method that operates on that particular grammar construct. For example, if the root of the AST is an `IfNode`, the visitor will dispatch the `visit_if` method. The individual visitor methods can perform whatever operations are necessary and, crucially, also call the parent class's `visit` method to process any subexpressions recursively without worrying about how to dispatch them, since the recursive call to `visit` handles that.

For Relax’s AST, the most general form of the visitor pattern is called an `ExprFunctor`, which is a base class for any compiler pass. Implementations are provided in both [C++](https://github.com/tlc-pack/relax/blob/relax/include/tvm/relax/expr_functor.h) and [Python](https://github.com/tlc-pack/relax/blob/relax/python/tvm/relax/expr_functor.py). In the C++ implementation, `VisitExpr` is the entry-point method (taking any `Expr` node and dispatching accordingly) and there are overrides of `VisitExpr_` for each grammar construct to handle the different cases. Since Python does not support overloading by argument types, the Python version uses `visit_expr` as the entry point and `visit_{name}_` methods for each grammar construct. (Note: One difference between the versions is that the C++ version also permits auxiliary arguments to be passed in addition to the input AST, though this can be accomplished in Python by wrapping the visitor methods.)

Note that the entry-point method can be overridden to include logic that fires before every expression node. This is a commonly used pattern in both the Relay and Relax codebases, with an example given below. (The same can be accomplished in Python using `super`.)

```cpp
void PrintVisitor::VisitExpr(const Expr& expr) {
  // fires for every expression type
  std::cout << "Here" << std::endl;
  // uses the base class's unmodified entry-point method to dispatch by node type
  ExprFunctor::VisitExpr(expr);
}
```

While `ExprFunctor`s are general enough to allow for implementing any compiler pass using them, they leave all of of the individual visitor cases undefined, requiring the user to fill them in. Passes typically only have a few node types of interest that affect them, so using a functor would require filling in implementations for many node types that likely will not require special processing. For common cases, we provide child classes of `ExprFunctor` that have sensible defaults and thus require overriding fewer cases. We discuss them below.



### Implementing Analyses: Visitors

An `ExprVisitor` is an `ExprFunctor` that traverses an AST and, by default, does not modify the AST. This it is used as a parent class for writing analysis passes. The base `ExprVisitor` includes a default implementation for each case, namely visiting each subexpression of the current AST node in order. There are, again, implementations in both [C++](https://github.com/tlc-pack/relax/blob/relax/include/tvm/relax/expr_functor.h#L186) and [Python](https://github.com/tlc-pack/relax/blob/relax/python/tvm/relax/expr_functor.py#L351) as `PyExprVisitor`. 

For example, here is the visitor case for `IfNode`:

```cpp
void ExprVisitor::VisitExpr_(const IfNode* op) {
  this->VisitSpan(op->span);
  this->VisitExpr(op->cond);
  this->VisitExpr(op->true_branch);
  this->VisitExpr(op->false_branch);
}
```

(`Span`s are used for storing metadata for error messages; they are not AST nodes *per se*.)

Even though the individual dispatch methods in `ExprVisitor` do not modify the AST, `ExprVisitor` is nevertheless very useful for implementing analyses. This can be accomplished by keeping some state in the `ExprVisitor` that tracks whatever value the analysis is meant to return, modifying that state in the individual cases, and finally returning the stored state. The basic workflow can be summarized as follows:

```python
# wrapper over the visitor logic; 
# it may also make sense to implement this as a method on the visitor,
# especially if there may be a need to call it recursively
def perform_analysis(program: Expr) -> ReturnType:
  visitor = MyProgramAnalysis()
  visitor.visit_expr(program)
  return visitor.get_state()

@relax.expr_functor.visitor
class MyProgramAnalysis(PyExprVisitor):
  # initialization logic, etc.

  def visit_tuple_(self, tuple_expr: Tuple) -> None:
    for expr in tuple_expr.fields:
      if self.some_condition_is_met(expr):
        self.update_state(expr)
        # can decide whether or not you need to visit the subexpressions
        self.visit_expr(expr)
      
  # etc., define cases as appropriate

  def visit_if_(self, if_expr: If) -> None:
    if self.another_condition_is_met(if_expr):
      self.update_state(if_expr)
    # can use the default visitor implementation to dispatch 
    # if that is the desired behavior
    super().visit_if_(if_expr)
```

For a real example, we may consider Relax’s [well-formedness analysis](https://github.com/tlc-pack/relax/blob/relax/src/relax/analysis/well_formed.cc), which checks that the input program is in [A-normal form](https://en.wikipedia.org/wiki/A-normal_form) (further discussed below), that all fields of a shape expression are integer-typed, and that all variables have exactly one definition (so there are no variables used that aren’t defined and no variable defined more than once).

The well-formedness checker keeps a Boolean field called `well_formed` that is true if the program meets the criteria for well-formedness and false otherwise. The visitor is invoked from a wrapper method called `WellFormed` that initializes the visitor, applies it to every function in the `IRModule`, and returns the final value of `well_formed`:

```cpp
bool WellFormed(const IRModule& m, Optional<DiagnosticContext> diag_ctx) {
  WellFormedChecker well_formed_checker = WellFormedChecker(diag_ctx);
  for (const auto& it : m->functions) {
    // register GlobalVar in the IRModule first
    well_formed_checker.RegisterGlobalVar(it.first);
  }

  for (const auto& it : m->functions) {
    // visit relax.Function
    if (auto* n = it.second.as<FunctionNode>()) {
      Function func = GetRef<Function>(n);
      well_formed_checker.VisitExpr(func);
    }
  }

  return well_formed_checker.well_formed;
}
```

Here is the well-formedness  checker’s logic for handling bound variables (e.g., parameters to functions or variables in assignments):

```cpp
void VisitVarDef_(const VarNode* var) {
    Var gv = GetRef<Var>(var);
    if (var_set_.count(gv) == 1) {
      Malformed(Diagnostic::Error(var->span)
                << "Var " << gv->name_hint() << " is defined more than once.");
    }
    // register Var
    var_set_.insert(gv);
}
```

First, it checks that the variable being visited does not already have a definition associated with it (giving an error if it does). Next, it adds the variable to its set of variables with definitions, which it can use to validate any future variable definitions encountered during the AST traversal.

Meanwhile, the case for `IfNode` does some recursive visits and has to manage some of the persistent state:

```cpp
void VisitExpr_(const IfNode* op) {
    this->VisitExpr(op->cond);
    std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> previous_var_set_ = var_set_;
    std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> previous_symbolic_var_set_ =
        prim_expr_visitor_.symbolic_var_set_;
    this->VisitBody(op->true_branch);
    var_set_ = previous_var_set_;
    prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_;
    this->VisitBody(op->false_branch);
    var_set_ = previous_var_set_;
    prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_;
}
```

Before visiting the true branch and false branches of the `IfNode`, the visitor stores the previous sets of definitions because the branches are [lexically scoped](https://en.wikipedia.org/wiki/Scope_(computer_science)#Lexical_scope_vs._dynamic_scope), so definitions in the branches are not visible outside the branches. This is an example of how the logic in a compiler pass may need to account for the semantics of Relax.

Many analyses in Relax are implemented using a function called  `PostOrderVisit`, which simply recurses down an AST and applies a user-provided function on each AST node *after* completing any recursive visits for that node (hence “post-order”). Its logic is very simple and allows these analyses to be implemented very compactly:

```cpp
class ExprApplyVisit : public ExprVisitor {
 public:
  explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {}

  void VisitExpr(const Expr& e) final {
    ExprVisitor::VisitExpr(e);
    f_(e);
  }

 private:
  std::function<void(const Expr&)> f_;
};

void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
  ExprApplyVisit(fvisit).VisitExpr(e);
}
```



### Implementing Transformations: (Base) Mutators

`ExprVisitor`s are useful when the goal is to examine an input program and draw some conclusion about it, but implementing compiler optimizations generally requires making changes to programs. `ExprMutator`s are used for the latter purpose, but they do not actually change an AST in-place; rather, they traverse the AST and produce a new AST that may incorporate changes. In fact, the default implementation of an `ExprMutator` simply traverses an AST and returns the very same AST—this allows users to override individual cases to have the mutator return a different program, with the helpful default that any cases left unchanged will simply preserve those parts of the program. This generally presents the abstraction of making changes to a program when what is really happening is that a new program is being constructed. See the implementations in [C++](https://github.com/tlc-pack/relax/blob/relax/include/tvm/relax/expr_functor.h) and [Python](https://github.com/tlc-pack/relax/blob/relax/python/tvm/relax/expr_functor.py) as `PyExprMutator`.

For example, here is the default implementation for the tuple node case:

```cpp
Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) {
  bool unchanged = true;
  tvm::Array<Expr> fields;
  for (Expr field : op->fields) {
    Expr new_field = this->VisitExpr(field);
    fields.push_back(new_field);
    unchanged &= new_field.same_as(field);
  }

  if (unchanged) {
    return GetRef<Expr>(op);
  } else {
    Expr new_tuple = Tuple(fields, op->span);
    return new_tuple;
  }
}
```

Notice that the method simply returns the original node if the result of visiting all the subexpressions is the same as in the original expression; this avoids allocating new AST nodes for no reason. (Also, since Relax AST nodes are immutable, this is safe.) Thus, if a visit to one of the fields of this tuple results in a change being made to the AST, the result will be a new tuple node that includes the changed field.

The basic workflow of implementing a program transformation using the base mutator might look like this:

```python
class MyProgramTransformation(ExprMutatorBase):
  def visit_if_(self, if_expr):
    # e.g., if it matches some pattern
    if some_condition_is_met(if_expr):
      # essentially replaces the encountered expression with some other expression
      # (in reality, a new program is built featuring the new subexpression)
      return my_new_node()
    return super().visit_expr(if_expr)

  # any cases not overridden will preserve the program as it was
```



### Using A-Normal Form in `ExprMutator`

The class shown above was `ExprMutatorBase` rather than `ExprMutator`. That is because the `ExprMutator` used to implement program transformations in Relax does not directly traverse arbitrary ASTs like in the example given, but instead expects ASTs to first be *normalized*, namely put into *[administrative normal form](https://en.wikipedia.org/wiki/A-normal_form)*, abbreviated “A-normal form” or “ANF.”

In ANF, expressions are allowed to have only *atomic* *expressions* (constants, function definitions, or variables) for subexpressions, meaning that complex expressions like call nodes cannot contain other complex expressions as subexpressions—the only way to pass their results to other complex expressions is to bind them to variables. For example, `f(g(x))` and `if f(x) then 1 else 0` are not permitted in ANF, while  `a = g(x); f(a)` and `a = f(x); if a then 1 else 0` are.

As described in [this page](https://matt.might.net/articles/a-normalization/) by Matt Might, one advantage of ANF for compilers is that it transforms a a deeply nested program into one where the order of operations is spelled out very clearly in a series of bindings. It is often useful in optimizations to reason directly about execution order and, as Might also observes, such bindings are easy to lower to assembly instructions. 

In Relax’s `ExprMutator`, the use of ANF has several advantages for writing passes:

1. It generally simplifies the handling of complex expressions by guaranteeing that subexpressions are variables or constants, thus reducing the amount of analysis needed to conclude that certain program transformations are safe. For example, consider the case of a call node: If the arguments to the call could be general expressions, it is possible, in principle, that one of the subexpressions is a call to a `PackedFunc` that has side effects. Thus, it would change the program’s semantics to either eliminate the call or reorder the arguments. In ANF, however, the arguments are guaranteed to be atomic expressions and thus cannot have any side effects.
2. In ANF, similarly to [single static assignment (SSA)](https://en.wikipedia.org/wiki/Static_single-assignment_form), complex nested expressions are instead transformed into a series of bindings. This way, the subexpressions can be addressed individually, which is often convenient for writing passes. (Indeed, the `ExprMutator` implementations include helper methods for tracking which variables are in scope and what their definitions are, allowing pass implementations to easily query for them.) Additionally, there is less need for passes to reason about the order of operations within a nested expression—all is determined by the order of the bindings, and it is easy to introduce new bindings and reorder them as needed.
3. It is common for Relax passes to add bindings to the current scope (a binding block). In a deeply nested expression, it can become difficult to reason about the order in which bindings will be added and whether there are any dependencies between these. With ANF, this is never an issue because the amount of nesting is limited.
4. The ExprMutator automatically normalizes expressions (via the [ExprNormalizer](https://github.com/tlc-pack/relax/blob/relax/src/relax/ir/block_builder.cc#L39) of its internal block builder `builder_`) and update the `checked_type_` and `shape_` fields of expressions, ensuring that changes made during the pass will be appropriately taken into account.

*We expect all Relax passes to have ANF input and ANF output and use ExprMutator; ExprMutatorBase is used in very rare cases, currently only Normalization uses ExprMutatorBase (when the user-written TVMScript is not in ANF and may have nested expressions).*

[The `Normalize` pass in Relax](https://github.com/tlc-pack/relax/blob/relax/src/relax/transform/normalize.cc) transforms an arbitrary Relax AST into ANF.

For a real example, `ExprMutator` was used to implement a [constant-folding pass in Relax](https://github.com/tlc-pack/relax/blob/relax/src/relax/transform/fold_constant.cc) rather compactly. For variables that are bound to constants, the mutator simply replaces the variables with those constants:

```python
Expr VisitExpr_(const VarNode* op) final {
    Optional<Expr> opt = LookupBinding(GetRef<Var>(op));
    // `as` check checks if opt is not null and is instance of constant
    if (opt.as<relax::ConstantNode>()) {
      return opt.value();
    }
    return ExprMutator::VisitExpr_(op);
}
```

The only complex case is handling calls to TIR functions: If there is a call to a TIR function and all the arguments are constants, then the TIR call itself can simply be evaluated at compile time.




## Pass Granularity

One of the design points when writing analysis or transformations is to the choose the granularity at which your pass would operate. Relax pass infrastructure enables developers to write passes at either Module, Function, DataflowBlock or TIR PrimFunc level. Choosing the right granularity can make it significantly easier to write your analyses and transformations. 

#### Module Pass
Module-level passes are designed to implement global analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes at this level have the full control of a given Relax IRModule including addition and deletion of functions. For example, [FuseOps pass](https://github.com/tlc-pack/relax/blob/relax/src/relax/transform/fuse_ops.cc#L776) which together with FuseTIR pass performs fusion in Relax is a module pass as it adds new functions to the IRModule.

#### Function Pass:
Function level passes are used to implement various intra-function optimizations for a given Relax IRModule. It fetches one function at a time from the function list in the IRModule for optimization. The scope of passes at this level is a Relax function. Therefore, we cannot add or delete a global function through these passes as they are not aware of the global information. For example, [FoldConstant pass](https://github.com/tlc-pack/relax/blob/relax/src/relax/transform/fold_constant.cc#L206) is a function pass as it performs constant folding within function scope.

#### TIR PrimFunc Pass
TIR PrimFunc level pass that applies analysis/transformations to all the TIR PrimFunc(s) within the IRModule. It fetches one TIR function at a time from the function list in the IRModule for optimization. The scope of passes at this level is a TIR PrimFunc. Therefore, we cannot add or delete a global function through these passes as they are not aware of the global information. [AnnotateTIROpPattern pass](https://github.com/tlc-pack/relax/blob/relax/src/relax/transform/annotate_tir_op_pattern.cc#L43) which annotates op pattern for TIR functions is an example of PrimFunc pass.

#### DataflowBlock Pass
Dataflow Block level passes are used to implement various dataflow block optimizations for a given Relax IRModule. It fetches one dataflow block at a time from the functions in an IRModule, and returns a rewritten DataflowBlock. The scope of passes at this level is a Relax Dataflow block. Therefore, we cannot modify the global scope Vars and symbolic shape Vars defined inside the dataflow block. All the operations under the dataflow block are side-effect-free and do not contain advanced control flows(such as if-then-else) or nested scopes. We expect most of the optimizations will happen at the dataflow block level and if their changes are local to the dataflow block then the pass developer can choose dataflow block granularity. [FMARewrite pass](https://github.com/tlc-pack/relax/blob/relax/src/relax/transform/fma_rewrite.cc#L150) is a toy dataflow block pass to perform fused multiply-add rewriting.

The following helper functions are provided to create each type of these aforementioned passes in C++. These helpers are also exposed to the Python frontend for users to favorably use Python APIs to create a specific pass object.

```c++
Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    tvm::Array<String> required, 
    bool traceable);

Pass CreateFunctionPass(
    const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
    int opt_level, 
    String name, 
    tvm::Array<String> 
    required, 
    bool traceable);

Pass CreatePrimFuncPass(
    const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
    int opt_level, 
    String name, 
    tvm::Array<String> required, 
    bool traceable);

Pass CreateDataflowBlockPass(
    const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)>& pass_func,
    int opt_level, 
    String name, 
    tvm::Array<String> required, 
    bool traceable);
```


## Example: Conv-Conv Fusion

In order to understand the process of writing a Relax pass, we will look at a `Conv -> Conv` fusion pass as a guide. It is a *toy* pass that highlights how Relax can be used to perform analysis of high-level graph, fuse the underlying TIR functions, and play around with the schedule.

We begin by initializing the TVM environment.

In [1]:
import numpy as np
import tvm
from tvm import tir, relax, topi
from tvm.ir.module import IRModule
from tvm.ir.transform import module_pass
from tvm.relax import PyExprMutator
from tvm.relax.analysis import remove_all_unused
from typing import Union


### Build IRModule with Unfused Convolutions

The code below constructs our input IRModule. The IRModule has two convolution operations represented by two successive `R.call_tir(conv2d, ...)` calls in the `main` function. The `conv2d` TIR function performs the convolution operation.


In [2]:
# We will use the BlockBuilder API to create the input IRModule. We could also
# potentially use TVMScript to generate it.
bb = relax.BlockBuilder()

def build_model(x, w1, w2) -> IRModule:
    
    # Main function
    with bb.function("main", [x, w1, w2]):
        with bb.dataflow():
            lv1 = bb.emit_te(topi.nn.conv2d, x, w1, strides=1, padding=1, dilation=1)
            gv = bb.emit_output(bb.emit_te(topi.nn.conv2d, lv1, w2, strides=1, padding=1, dilation=1))
        bb.emit_func_output(gv)
    
    # Return the IRModule being built
    return bb.get()

# Input Relax variables.
tensor_type = relax.DynTensorType(4, "float32")
x = relax.Var("x", (1, 16, 64, 64), tensor_type)
w1 = relax.Var("w1", (16, 16, 3, 3), tensor_type)
w2 = relax.Var("w2", (16, 16, 3, 3), tensor_type)

# Get the input IRModule
unfused_mod = build_model(x, w1, w2)

Next, we mark the TIR convolution function with the attribute `my_op_kind` = `convolution`. This would help us to identify the convolution operation in our pass later.

Note that this is a workaround until Relax has a full operator set. Currently, Relax has a sophisticated [pattern matching infrastructure](https://github.com/tlc-pack/relax/issues/160) to do pattern matching over Relax operators, types, and attributes. However, due to lack of a full operator set, we have to mark TIR functions with attributes in order to identify the underlying operation. For now, Relax pass developers can add their own markers similar to `my_op_kind` to get around this issue. Pattern matching would become much simpler with the addition of full Relax operator set in the future.

In [3]:
unfused_mod["conv2d"] = unfused_mod["conv2d"].with_attr("my_op_kind", "convolution")
unfused_mod.show()


### Fusing Convolutions

The goal is to fuse the two successive convolution operations i.e., replace the following in the input IRModule

```python
lv = R.call_tir(conv2d, (x, w1), (1, 16, 64, 64), dtype="float32")
lv1 = R.call_tir(conv2d, (lv, w2), (1, 16, 64, 64), dtype="float32")
```

with
```python
lv1 = R.call_tir(grouped_conv_conv, (x, w1, w2), (1, 16, 64, 64), dtype="float32")
```
where `grouped_conv_conv` is a TIR function that performs the two convolutions.

This can be achieved in two steps:

__Step 1:__ Group the two convolutions into a *primitive* Relax function. A *primitive* Relax function is one which has a single dataflow block with only `call_tir` bindings inside it. For easier identification, these functions are marked with `Primitive: 1` attribute.

__Step 2:__ Use the Relax [FuseTIR pass](https://github.com/tlc-pack/relax/blob/relax/src/relax/transform/fuse_tir.cc) to fuse the called TIR functions in the *primitive* Relax function `grouped_conv_conv` into a single TIR function and further replace all calls to the Relax function `grouped_conv_conv` with a `call_tir` to the fused TIR function.

The output of step 1 and step 2 are shown below:

__IRModule after Step 1__

```python
@R.function
def main(x: Tensor((1, 16, 64, 64), "float32"), w1: Tensor((16, 16, 3, 3), "float32"), w2: Tensor((16, 16, 3, 3), "float32")) -> Tensor(None, "float32", ndim = 4):
    # block 0
    with R.dataflow():
        lv1: Tensor((1, 16, 64, 64), "float32") = fused_conv_conv(x, w1, w2)
        gv: Tensor((1, 16, 64, 64), "float32") = lv1
        R.output(gv)
    return gv

# primitive Relax function    
@R.function
def grouped_conv_conv(x1: Tensor((1, 16, 64, 64), "float32"), w11: Tensor((16, 16, 3, 3), "float32"), w21: Tensor((16, 16, 3, 3), "float32")) -> Tensor(None, "float32", ndim = 4):
    # block 0
    with R.dataflow():
        lv0 = R.call_tir(conv2d, (x1, w11), (1, 16, 64, 64), dtype="float32")
        conv_2 = R.call_tir(conv2d, (lv0, w21), (1, 16, 64, 64), dtype="float32")
        R.output(conv_2)
    return conv_2
```

__IRModule after Step 2__

```python
@R.function
def main(x: Tensor((1, 16, 64, 64), "float32"), w1: Tensor((16, 16, 3, 3), "float32"), w2: Tensor((16, 16, 3, 3), "float32")) -> Tensor(None, "float32", ndim = 4):
    # block 0
    with R.dataflow():
        lv1 = R.call_tir(fused_conv_conv, (x, w1, w2), (1, 16, 64, 64), dtype="float32")
        gv: Tensor((1, 16, 64, 64), "float32") = lv1
        R.output(gv)
    return gv
```



### Toy Pass to Group Convolutions (Step 1)

Next we write the toy pass to group two successive convolution operations into a *primitive* Relax function. There are a few things to consider here:

* __Pass Granularity__: Since our pass would add a new Relax function to the module i.e., the changes are not limited to dataflow block or function scope, it cannot be a function or dataflow block pass. It must be a module pass. *Potentially we could have chosen this to be a dataflow block pass and added a function local to dataflow block scope which could be lifted to module scope later by lambda lifting pass.*
* __ExprVisitor or ExprMutator or ExprMutatorBase__: Since it is a transformation pass, we must use either `ExprMutator` or `ExprMutatorBase`. For simplicity, we would like to work with IR in ANF form and not handle all the complexities of non-ANF form. Hence we use `ExprMutator`.

In [4]:
@relax.expr_functor.mutator
class GroupTwoConvsMutator(PyExprMutator):

    def __init__(self, mod: IRModule) -> None:
        # ExprMutator has an internal BlockBuilder `builder_` which keeps an IRModule `context_mod_`
        # being built. We can optionally initialize this IRModule with the current module
        # with copy-on-write semantics which we can update/add functions to.
        super().__init__(mod)
        self.mod_ = mod
    
    # Matches the call_node against the pattern Conv -> Conv pattern. If pattern matches,
    # returns the two convolution call nodes, otherwise returns None.
    def pattern_match(self, call_node) -> Union[None, relax.Tuple]:
        # Helper function to check if the call node op is a `call_tir` operator and
        # the called TIR function is a convolution operation.
        def is_convolution(call_node: relax.Call) -> bool:
            if not isinstance(call_node, relax.Call):
                return False
            call_tir_op = tvm.ir.Op.get("relax.call_tir")
            if call_node.op != call_tir_op:
                return False
            global_var = call_node.args[0]
            tir_func = self.mod_[global_var]
            if tir_func.attrs["my_op_kind"] != "convolution":
                return False
            return True
        
        # Check if the current call_node is a convolution operation.
        if not is_convolution(call_node):
            return None

        # Check if the input tensor to this call_node is also a convolution operation.
        operands = call_node.args[1]
        input_tensor = operands[0]
        value = self.lookup_binding(input_tensor)
        if not is_convolution(value):
            return None
        return [value, call_node]

    def transform(self) -> IRModule:
        # Iterate over all the functions in the IRModule
        for global_var, func in self.mod_.functions.items():
            # Skip non-relax functions
            if not isinstance(func, relax.Function):
                continue
            # Skip primitive functions
            if "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0:
                continue
            # Update the non-primitive Relax function
            updated_func = self.visit_expr(func)
            # Remove any dead code in the updated function
            updated_func = remove_all_unused(updated_func)
            self.builder_.update_func(global_var, updated_func)
        
        # At the end of the transformation we return the updated IRModule from the BlockBuilder.
        return self.builder_.get()
    
    # We only need to override Call node mutator. If this call_node matches 
    # the Conv->Conv pattern, we can group these calls in a new primitive 
    # Relax function and replace the current call with call to the primitve 
    # function.
    def visit_call_(self, call_node: relax.Call) -> relax.Call:
        # Check if the call node matches our expected pattern
        conv_calls = self.pattern_match(call_node)
        if not conv_calls:
            return call_node
        
        # Get current vars from convolution calls
        x, w1, w2 = conv_calls[0].args[1][0], conv_calls[0].args[1][1], conv_calls[1].args[1][1]

        # Construct the parameters of the new function
        param_x = relax.Var("param_x", x.shape_, x._checked_type_)
        param_w1 = relax.Var("param_w1", w1.shape_, w1._checked_type_)
        param_w2 = relax.Var("param_w2", w2.shape_, w2._checked_type_)

        # Get the TIR convolution functions
        tir_convolution_0 = conv_calls[0].args[0]
        tir_convolution_1 = conv_calls[1].args[0]

        # Next we construct the primitive function with grouped convolutions
        
        # First convolution binding
        lv0 = relax.DataflowVar("lv0", conv_calls[0].shape_, conv_calls[0]._checked_type_)
        conv_1 = relax.call_tir(tir_convolution_0, [param_x, param_w1], conv_calls[0].shape_, dtype="float32")
        bindings = [relax.VarBinding(lv0, conv_1)]

        # Second convolution binding
        gv = relax.Var("gv", conv_calls[1].shape_, conv_calls[1]._checked_type_)
        conv_2 = relax.call_tir(tir_convolution_1, [lv0, param_w2], conv_calls[1].shape_, dtype="float32")
        bindings.append(relax.VarBinding(gv, conv_2))

        block = relax.DataflowBlock(bindings)
        seq_expr = relax.SeqExpr([block], gv)
        ret_type = conv_calls[1]._checked_type_
        func_name = "grouped_conv_conv"
        grouped_conv_conv = relax.Function([param_x, param_w1, param_w2], seq_expr, ret_type)

        # Add global_symbol and Primitive attribute. Later FuseTIR pass would use
        # the Primitive attribute to fuse the called TIR functions.
        grouped_conv_conv = grouped_conv_conv.with_attr("global_symbol", func_name).with_attr("Primitive", 1)

        # Normalize the newly created function and add it to the module
        normalized = self.builder_.normalize(grouped_conv_conv)
        global_var = self.builder_.add_func(normalized, func_name)

        # Construct a call to the primitive function
        return relax.Call(global_var, [x, w1, w2], None, None)
        

@module_pass(opt_level=2, name="group_two_conv")
class GroupTwoConvsPass:
    """The wrapper for the GroupTwoConv pass."""

    def transform_module(self, mod, ctx):
        return GroupTwoConvsMutator(mod).transform()


In [5]:
grouped_mod = GroupTwoConvsPass()(unfused_mod)
if not relax.analysis.well_formed(grouped_mod):
    print("IRModule is not well-formed")
grouped_mod.show()



### Fuse Underlying TIR Convolution Functions (Step 2)

Now we can simply use the FuseTIR pass to fuse the TIR Convolution functions. *Note that this separation between Relax FuseOps and FuseTIR passes allows us to implement custom fusion strategies (for example conv->conv fusion in this example).* 

In [6]:
fused_mod = relax.transform.FuseTIR()(grouped_mod)
if not relax.analysis.well_formed(fused_mod):
    print("IRModule is not well-formed")
fused_mod.show()

## Play with TIR Schedule

Next we can play around with schedule of the fused Conv->Conv TIR function. We can use the TIR Schedule API to play around with the schedule. The following code is for demo purpose only and does not necessarily represent an optimized schedule. 

This gives us the flexibility to change the TIR schedule manually or through a transformation pass and then use the Relax VM build API to compile and run on a target. This enables developers to quickly test out their ideas and check their impact in the end-to-end compilation of the model.

In [7]:
sch = tvm.tir.Schedule(fused_mod)
sch.work_on("grouped_conv_conv")

# Play around with loop reordering. This is for demo purpose only and
# does not represent a more optimized loop order.
conv1_block = sch.get_block("conv2d_nchw")
conv2_block = sch.get_block("conv2d_nchw_1")
c1_n, c1_co, c1_y, c1_x, c1_ci, c1_ky, c1_kx = sch.get_loops(conv1_block)
c2_n, c2_co, c2_y, c2_x, c2_ci, c2_ky, c2_kx = sch.get_loops(conv2_block)
sch.reorder(c1_n, c1_y, c1_x, c1_ky, c1_kx, c1_co, c1_ci)
sch.reorder(c2_n, c2_y, c2_x, c2_ky, c2_kx, c2_co, c2_ci)
pad_block_1 = sch.get_block("pad_temp")
pad_block_2 = sch.get_block("pad_temp_1")
sch.compute_at(pad_block_2, c2_ci)
sch.compute_at(conv1_block, c2_kx)

# Let's see what the modified module looks like
sch.mod.show()


In [8]:
# Compile and run
x = tvm.nd.array(np.random.normal(size=[1, 16, 64, 64]).astype("float32"))
w1 = tvm.nd.array(np.random.normal(size=[16, 16, 3, 3]).astype("float32"))
w2 = tvm.nd.array(np.random.normal(size=[16, 16, 3, 3]).astype("float32"))

# build the IRModule and create relax vm
target = tvm.target.Target("llvm")
ex = relax.vm.build(sch.mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())

res = vm["main"](x, w1, w2)