# 如何使用 TVM Pass Instrument

**原作者**: [Chi-Wei Wang](https://github.com/chiwwang)

随着越来越多的 pass 被实现，检测每个 pass 的执行、分析每个 pass 的效果和观测各种事件变得越来越有用。

可以通过向 {py:class}`tvm.transform.PassContext` 提供 {py:class}`tvm.ir.instrument.PassInstrument` 实例列表来检测（instrument）传递。提供了用于收集计时信息的 pass 工具（{py:class}`tvm.ir.instrument.PassTimingInstrument`），但可以通过 {py:func}`tvm.instrument.pass_instrument` 装饰器使用扩展机制。

本教程演示了开发人员如何使用 ``PassContext`` 检测（instrument） passes。请参阅 {ref}`pass-infra`。

In [1]:
import tvm
import tvm.relay as relay
from tvm.relay.testing import resnet
from tvm.contrib.download import download_testdata
from tvm.relay.build_module import bind_params_by_name
from tvm.ir.instrument import (
    PassTimingInstrument,
    pass_instrument,
)

## 创建 Relay 示例程序

在 Relay 中使用了预定义的 resnet-18 网络。

In [2]:
batch_size = 1
num_of_image_class = 1000
image_shape = (3, 224, 224)
output_shape = (batch_size, num_of_image_class)
relay_mod, relay_params = resnet.get_workload(num_layers=18, batch_size=1, image_shape=image_shape)

In [3]:
print("Printing the IR module...")
print(relay_mod["main"])

Printing the IR module...
fn (%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */, %bn_data_gamma: Tensor[(3), float32] /* ty=Tensor[(3), float32] */, %bn_data_beta: Tensor[(3), float32] /* ty=Tensor[(3), float32] */, %bn_data_moving_mean: Tensor[(3), float32] /* ty=Tensor[(3), float32] */, %bn_data_moving_var: Tensor[(3), float32] /* ty=Tensor[(3), float32] */, %conv0_weight: Tensor[(64, 3, 7, 7), float32] /* ty=Tensor[(64, 3, 7, 7), float32] */, %bn0_gamma: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %bn0_beta: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %bn0_moving_mean: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %bn0_moving_var: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %stage1_unit1_bn1_gamma: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %stage1_unit1_bn1_beta: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %stage1_unit1_bn1_moving_mean: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, 

## 使用 Instruments 创建 PassContext

要使用 instrument 运行所有的传递，请通过 ``instruments`` 参数将其传递给 ``PassContext`` 构造函数。内置的 ``PassTimingInstrument`` 用于分析每次传递的执行时间。

In [4]:
timing_inst = PassTimingInstrument()
with tvm.transform.PassContext(instruments=[timing_inst]):
    relay_mod = relay.transform.InferType()(relay_mod)
    relay_mod = relay.transform.FoldScaleAxis()(relay_mod)
    # 在退出上下文之前，获取 profile 结果。
    profiles = timing_inst.render()
print("Printing results of timing profile...")
print(profiles)

Printing results of timing profile...
InferType: 8012us [8012us] (51.67%; 51.67%)
FoldScaleAxis: 7495us [5us] (48.33%; 48.33%)
	FoldConstant: 7490us [1545us] (48.30%; 99.93%)
		InferType: 5944us [5944us] (38.33%; 79.37%)



## 使用当前带 Instruments 的 PassContext

你也可以使用当前的 ``PassContext`` 并通过 ``override_instruments`` 方法注册 ``PassInstrument`` 实例。注意，如果任何 instrument 已经存在，``override_instruments`` 会执行 ``exit_pass_ctx`` 方法。然后切换到新的 instrument，调用新 instrument 的 ``enter_pass_ctx`` 方法。有关这些方法，请参阅以下部分和 {py:func}`tvm.instrument.pass_instrument`。

In [5]:
cur_pass_ctx = tvm.transform.PassContext.current()
cur_pass_ctx.override_instruments([timing_inst])
relay_mod = relay.transform.InferType()(relay_mod)
relay_mod = relay.transform.FoldScaleAxis()(relay_mod)
profiles = timing_inst.render()
print("Printing results of timing profile...")
print(profiles)

Printing results of timing profile...
InferType: 8879us [8879us] (53.79%; 53.79%)
FoldScaleAxis: 7629us [6us] (46.21%; 46.21%)
	FoldConstant: 7623us [1488us] (46.18%; 99.92%)
		InferType: 6134us [6134us] (37.16%; 80.47%)



注册空 list 以清除现有 instruments。

注意，调用了 ``PassTimingInstrument`` 的 ``exit_pass_ctx``。Profiles 被清除，所以不会打印任何内容。

In [6]:
cur_pass_ctx.override_instruments([])
# profiles = timing_inst.render()
# Uncomment the call to .render() to see a warning like:
# Warning: no passes have been profiled, did you enable pass profiling?
# profiles = timing_inst.render()

## 创建定制的 Instrument 类

可以使用 {py:func}`tvm.instrument.pass_instrument` 装饰器创建定制的 instrument 类。

创建 instrument 类，它可以计算每次传递所导致的每个算子出现次数的变化。可以查看 ``op.name`` 来查找每个算子的名称。在传递前后这样做来计算差异。

In [7]:
@pass_instrument
class RelayCallNodeDiffer:
    def __init__(self):
        self._op_diff = []
        # Passes can be nested.
        # Use stack to make sure we get correct before/after pairs.
        self._op_cnt_before_stack = []

    def enter_pass_ctx(self):
        self._op_diff = []
        self._op_cnt_before_stack = []

    def exit_pass_ctx(self):
        assert len(self._op_cnt_before_stack) == 0, "The stack is not empty. Something wrong."

    def run_before_pass(self, mod, info):
        self._op_cnt_before_stack.append((info.name, self._count_nodes(mod)))

    def run_after_pass(self, mod, info):
        # Pop out the latest recorded pass.
        name_before, op_to_cnt_before = self._op_cnt_before_stack.pop()
        assert name_before == info.name, "name_before: {}, info.name: {} doesn't match".format(
            name_before, info.name
        )
        cur_depth = len(self._op_cnt_before_stack)
        op_to_cnt_after = self._count_nodes(mod)
        op_diff = self._diff(op_to_cnt_after, op_to_cnt_before)
        # only record passes causing differences.
        if op_diff:
            self._op_diff.append((cur_depth, info.name, op_diff))

    def get_pass_to_op_diff(self):
        """
        return [
          (depth, pass_name, {op_name: diff_num, ...}), ...
        ]
        """
        return self._op_diff

    @staticmethod
    def _count_nodes(mod):
        """Count the number of occurrences of each operator in the module"""
        ret = {}

        def visit(node):
            if isinstance(node, relay.expr.Call):
                if hasattr(node.op, "name"):
                    op_name = node.op.name
                else:
                    # Some CallNode may not have 'name' such as relay.Function
                    return
                ret[op_name] = ret.get(op_name, 0) + 1

        relay.analysis.post_order_visit(mod["main"], visit)
        return ret

    @staticmethod
    def _diff(d_after, d_before):
        """Calculate the difference of two dictionary along their keys.
        The result is values in d_after minus values in d_before.
        """
        ret = {}
        key_after, key_before = set(d_after), set(d_before)
        for k in key_before & key_after:
            tmp = d_after[k] - d_before[k]
            if tmp:
                ret[k] = d_after[k] - d_before[k]
        for k in key_after - key_before:
            ret[k] = d_after[k]
        for k in key_before - key_after:
            ret[k] = -d_before[k]
        return ret

## 应用 Passes 和多个 Instrument 类

可以在 ``PassContext`` 中使用多个 instrument 类。但是，应该注意到 instrument 方法是按顺序执行的，遵循 ``instruments`` 参数的顺序。因此，对于像 ``PassTimingInstrument`` 这样的 instrument 类，不可避免地要将其他 instrument 类的执行时间统计到最终的 profile 结果中。

In [8]:
call_node_inst = RelayCallNodeDiffer()
desired_layouts = {
    "nn.conv2d": ["NHWC", "HWIO"],
}
pass_seq = tvm.transform.Sequential(
    [
        relay.transform.FoldConstant(),
        relay.transform.ConvertLayout(desired_layouts),
        relay.transform.FoldConstant(),
    ]
)
relay_mod["main"] = bind_params_by_name(relay_mod["main"], relay_params)
# timing_inst 放在 call_node_inst 之后。
# 因此，``call_node.inst.run_after_pass()`` 的执行时间也被计算在内。
with tvm.transform.PassContext(opt_level=3, instruments=[call_node_inst, timing_inst]):
    relay_mod = pass_seq(relay_mod)
    profiles = timing_inst.render()
# 查看 timing-profile 结果
# print(profiles)



可以看到每个 op 类型增加/减少了多少 CallNode。

打印由每次 pass 引起的每个算子出现次数的变化：

In [9]:
from pprint import pprint

pprint(call_node_inst.get_pass_to_op_diff())

[(1, 'CanonicalizeOps', {'add': 1, 'nn.bias_add': -1}),
 (1, 'ConvertLayout', {'expand_dims': 1, 'layout_transform': 23}),
 (1, 'FoldConstant', {'expand_dims': -1, 'layout_transform': -21}),
 (0, 'sequential', {'add': 1, 'layout_transform': 2, 'nn.bias_add': -1})]


## 异常处理

如果 ``PassInstrument`` 的方法发生异常会发生什么？

定义在进入/退出 ``PassContext`` 时引发异常的 ``PassInstrument`` 类：

In [10]:
class PassExampleBase:
    def __init__(self, name):
        self._name = name

    def enter_pass_ctx(self):
        print(self._name, "enter_pass_ctx")

    def exit_pass_ctx(self):
        print(self._name, "exit_pass_ctx")

    def should_run(self, mod, info):
        print(self._name, "should_run")
        return True

    def run_before_pass(self, mod, pass_info):
        print(self._name, "run_before_pass")

    def run_after_pass(self, mod, pass_info):
        print(self._name, "run_after_pass")


@pass_instrument
class PassFine(PassExampleBase):
    pass


@pass_instrument
class PassBadEnterCtx(PassExampleBase):
    def enter_pass_ctx(self):
        print(self._name, "bad enter_pass_ctx!!!")
        raise ValueError("{} bad enter_pass_ctx".format(self._name))


@pass_instrument
class PassBadExitCtx(PassExampleBase):
    def exit_pass_ctx(self):
        print(self._name, "bad exit_pass_ctx!!!")
        raise ValueError("{} bad exit_pass_ctx".format(self._name))

如果 ``enter_pass_ctx`` 中发生异常，``PassContext`` 将禁用 pass instrumentation。运行每个成功完成 ``enter_pass_ctx`` 的 ``PassInstrument`` 的 ``exit_pass_ctx``。

在下面的例子中，可以看到 `PassFine_0` 的 ``exit_pass_ctx`` 在异常之后被执行。


In [11]:
demo_ctx = tvm.transform.PassContext(
    instruments=[
        PassFine("PassFine_0"),
        PassBadEnterCtx("PassBadEnterCtx"),
        PassFine("PassFine_1"),
    ]
)
try:
    with demo_ctx:
        relay_mod = relay.transform.InferType()(relay_mod)
except ValueError as ex:
    print("Catching", str(ex).split("\n")[-1])

PassFine_0 enter_pass_ctx
PassBadEnterCtx bad enter_pass_ctx!!!
PassFine_0 exit_pass_ctx
Catching ValueError: PassBadEnterCtx bad enter_pass_ctx


[11:07:25] /workspace/tvm/src/ir/transform.cc:196: Pass instrumentation entering pass context failed.
[11:07:25] /workspace/tvm/src/ir/transform.cc:197: Disable pass instrumentation.
[11:07:25] /workspace/tvm/src/ir/transform.cc:201: PassFine exiting PassContext ...
[11:07:25] /workspace/tvm/src/ir/transform.cc:203: PassFine exited PassContext.


``PassInstrument`` 实例中的异常会导致当前 ``PassContext`` 中的所有 instruments 被清除，因此当调用 ``override_instruments`` 时不会打印任何东西。

In [12]:
demo_ctx.override_instruments([])  # no PassFine_0 exit_pass_ctx printed....etc

如果 ``exit_pass_ctx`` 中发生异常，则禁用 pass instrument。然后传播异常。这意味着在抛出异常之后注册的 ``PassInstrument`` 实例不执行 ``exit_pass_ctx``。

In [13]:
demo_ctx = tvm.transform.PassContext(
    instruments=[
        PassFine("PassFine_0"),
        PassBadExitCtx("PassBadExitCtx"),
        PassFine("PassFine_1"),
    ]
)
try:
    # PassFine_1 execute enter_pass_ctx, but not exit_pass_ctx.
    with demo_ctx:
        relay_mod = relay.transform.InferType()(relay_mod)
except ValueError as ex:
    print("Catching", str(ex).split("\n")[-1])

PassFine_0 enter_pass_ctx
PassBadExitCtx enter_pass_ctx
PassFine_1 enter_pass_ctx
PassFine_0 should_run
PassBadExitCtx should_run
PassFine_1 should_run
PassFine_0 run_before_pass
PassBadExitCtx run_before_pass
PassFine_1 run_before_pass
PassFine_0 run_after_pass
PassBadExitCtx run_after_pass
PassFine_1 run_after_pass
PassFine_0 exit_pass_ctx
PassBadExitCtx bad exit_pass_ctx!!!
Catching ValueError: PassBadExitCtx bad exit_pass_ctx


[11:07:25] /workspace/tvm/src/ir/transform.cc:220: Pass instrumentation exiting pass context failed.


在 ``should_run``, ``run_before_pass``, ``run_after_pass`` 中发生的异常没有被显式处理——依赖于上下文管理器( ``with`` 语法)安全地退出 ``PassContext``。

以 ``run_before_pass`` 为例：

In [14]:
@pass_instrument
class PassBadRunBefore(PassExampleBase):
    def run_before_pass(self, mod, pass_info):
        print(self._name, "bad run_before_pass!!!")
        raise ValueError("{} bad run_before_pass".format(self._name))


demo_ctx = tvm.transform.PassContext(
    instruments=[
        PassFine("PassFine_0"),
        PassBadRunBefore("PassBadRunBefore"),
        PassFine("PassFine_1"),
    ]
)
try:
    # All exit_pass_ctx are called.
    with demo_ctx:
        relay_mod = relay.transform.InferType()(relay_mod)
except ValueError as ex:
    print("Catching", str(ex).split("\n")[-1])

PassFine_0 enter_pass_ctx
PassBadRunBefore enter_pass_ctx
PassFine_1 enter_pass_ctx
PassFine_0 should_run
PassBadRunBefore should_run
PassFine_1 should_run
PassFine_0 run_before_pass
PassBadRunBefore bad run_before_pass!!!
PassFine_0 exit_pass_ctx
PassBadRunBefore exit_pass_ctx
PassFine_1 exit_pass_ctx
Catching ValueError: PassBadRunBefore bad run_before_pass


还要注意，pass instrumentation 不是禁用的。因此，如果调用 ``override_instruments``，旧注册 ``PassInstrument`` 的 ``exit_pass_ctx`` 就会被调用。

In [15]:
demo_ctx.override_instruments([])

PassFine_0 exit_pass_ctx
PassBadRunBefore exit_pass_ctx
PassFine_1 exit_pass_ctx


如果不使用 ``with`` 语法封装 pass 执行，则不会调用 ``exit_pass_ctx``。让我们用当前的 ``PassContext`` 来试试：

In [16]:
cur_pass_ctx = tvm.transform.PassContext.current()
cur_pass_ctx.override_instruments(
    [
        PassFine("PassFine_0"),
        PassBadRunBefore("PassBadRunBefore"),
        PassFine("PassFine_1"),
    ]
)

PassFine_0 enter_pass_ctx
PassBadRunBefore enter_pass_ctx
PassFine_1 enter_pass_ctx


然后调用传递。``exit_pass_ctx`` 不像预期的那样在异常之后执行。

In [17]:
try:
    # No ``exit_pass_ctx`` got executed.
    relay_mod = relay.transform.InferType()(relay_mod)
except ValueError as ex:
    print("Catching", str(ex).split("\n")[-1])

PassFine_0 should_run
PassBadRunBefore should_run
PassFine_1 should_run
PassFine_0 run_before_pass
PassBadRunBefore bad run_before_pass!!!
Catching ValueError: PassBadRunBefore bad run_before_pass


清除 instruments。

In [18]:
cur_pass_ctx.override_instruments([])

PassFine_0 exit_pass_ctx
PassBadRunBefore exit_pass_ctx
PassFine_1 exit_pass_ctx
