# 追踪变量

In [1]:
import numpy as np
import tvm
from tvm import te, relay

In [2]:
n = tvm.tir.const(128, "int32")
a = te.placeholder((n,), name="a")
b = te.placeholder((n,), name="b")
c = te.compute((n,), lambda i: a[i] + b[i], name="c")

sch = te.create_schedule(c.op)
ir = tvm.lower(sch, [a, b, c])
ir.show()

In [3]:
loops = []

def find_width8(op):
    """找出所有范围能被 8 除的 'tir.For' 节点。"""
    if isinstance(op, tvm.tir.For):
        if isinstance(op.extent, tvm.tir.IntImm):
            if op.extent.value % 8 == 0:
                loops.append(op)

In [4]:
def vectorize8(op):
    """Split can vectorize the loops found in `find_width8`."""
    if op in loops:
        extent = op.extent.value
        name = op.loop_var.name
        lo, li = te.var(name + ".outer"), te.var(name + ".inner")
        body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li})
        body = tvm.tir.For(li, 0, 8, tvm.tir.ForKind.VECTORIZED, body)
        body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.ForKind.SERIAL, body)
        return body
    return None


@tvm.tir.transform.prim_func_pass(opt_level=0)
def vectorize(f, mod, ctx):
    global loops
    tvm.tir.stmt_functor.post_order_visit(f.body, find_width8)
    if not loops:
        return f
    # 最后一个 list 参数表示要转换的节点类型。
    # 因此，在这种情况下，只有 `For` 节点会调用 `vectorize8`
    return f.with_body(tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ["tir.For"]))

In [5]:
with tvm.transform.PassContext(config={"tir.add_lower_pass":
                                       [(1, vectorize)]}):
    tvm.lower(sch, [a, b, c]).show()

In [6]:
def example():
    shape = (1, 64, 54, 54)
    c_data = np.empty(shape).astype("float32")
    c = relay.const(c_data)
    weight = relay.var("weight", shape=(64, 64, 3, 3))
    x = relay.var("x", relay.TensorType((1, 64, 56, 56), "float32"))
    conv = relay.nn.conv2d(x, weight, kernel_size=(3, 3))
    y = relay.add(c, c)
    y = relay.multiply(y, relay.const(2, "float32"))
    y = relay.add(conv, y)
    z = relay.add(y, c)
    z1 = relay.add(y, c)
    z2 = relay.add(z, z1)
    return relay.Function([x, weight], z2)

In [7]:
f = example()
mod = tvm.IRModule.from_expr(f)
print(mod)

def @main(%x: Tensor[(1, 64, 56, 56), float32], %weight: Tensor[(64, 64, 3, 3), float32]) {
  %0 = add(meta[relay.Constant][0], meta[relay.Constant][0]);
  %1 = nn.conv2d(%x, %weight, padding=[0, 0, 0, 0], kernel_size=[3, 3]);
  %2 = multiply(%0, 2f);
  %3 = add(%1, %2);
  %4 = add(%3, meta[relay.Constant][0]);
  %5 = add(%3, meta[relay.Constant][0]);
  add(%4, %5)
}




In [8]:
transform = relay.transform.FoldConstant()
mod = transform(mod)
print(mod)

def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {
  %0 = nn.conv2d(%x, %weight, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  %1 = add(%0, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  %2 = add(%1, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  %3 = add(%1, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  add(%2, %3) /* ty=Tensor[(1, 64, 54, 54), float32] */
}




In [11]:
f = example()
mod = tvm.IRModule.from_expr(f)
# Glob 感兴趣的 passes.
seq = tvm.transform.Sequential(
    [
        relay.transform.FoldConstant(),
        relay.transform.EliminateCommonSubexpr(),
        relay.transform.FuseOps(fuse_opt_level=2),
    ]
)
mod1 = seq(mod)
print(mod1)

def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {
  %4 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {
    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;
    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;
    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;
    %3 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;
    add(%2, %3) /* ty=Tensor[(1, 64, 54, 54), float32] */
  } /* ty=fn (Tensor[(1, 64, 56, 56), 

In [12]:
@relay.transform.function_pass(opt_level=1)
class CustomPipeline:
    """Simple test function to replace one argument to another."""

    def __init__(self, multiplier):
        self.multiplier = multiplier

    # This function can define a pass.
    def transform_function(self, func, mod, ctx):
        obj = self

        class ReplaceConstant(tvm.relay.ExprMutator):
            def visit_constant(self, c):
                return relay.multiply(obj.multiplier, c)

        return ReplaceConstant().visit(func)


f = example()
mod = tvm.IRModule.from_expr(f)
custom_pass = CustomPipeline(multiplier=relay.const(3, "float32"))
assert custom_pass.info.name == "CustomPipeline"
mod3 = custom_pass(mod)
print(mod3)

def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {
  %0 = multiply(3f /* ty=float32 */, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  %1 = add(%0, %0) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  %2 = multiply(3f /* ty=float32 */, 2f /* ty=float32 */) /* ty=float32 */;
  %3 = nn.conv2d(%x, %weight, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  %4 = multiply(%1, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  %5 = add(%3, %4) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  %6 = add(%5, %0) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  %7 = add(%5, %0) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  add(%6, %7) /* ty=Tensor[(1, 64, 54, 54), float32] */
}


