In [1]:
import tvm
import numpy as np

In [2]:
@tvm.register_func("tvm.tir.trace_change_int_first")
def trace_buffer(x):
    return x + 1

@tvm.register_func("tvm.tir.trace_change_int_second")
def trace_buffer(x):
    return x + 2

In [3]:
dtype = "int64"
n = 4
x = tvm.te.placeholder((n,), name="X", dtype=dtype)
y = tvm.te.compute(x.shape, lambda i: tvm.tir.trace([x[i]], "tvm.tir.trace_change_int_first"))
z = tvm.te.compute(x.shape, lambda i: tvm.tir.trace([y[i]], "tvm.tir.trace_change_int_second"))
s = tvm.te.create_schedule(z.op)
f = tvm.build(s, [x, y, z], "llvm")

xnd = tvm.nd.array(np.ones((n,), dtype=x.dtype))
ynd = tvm.nd.array(np.zeros((n,), dtype=y.dtype))
znd = tvm.nd.array(np.zeros((n,), dtype=z.dtype))
f(xnd, ynd, znd)

In [9]:
xnd, ynd, znd

(<tvm.nd.NDArray shape=(4,), cpu(0)>
 array([1, 1, 1, 1]),
 <tvm.nd.NDArray shape=(4,), cpu(0)>
 array([2, 2, 2, 2]),
 <tvm.nd.NDArray shape=(4,), cpu(0)>
 array([4, 4, 4, 4]))

## 测试

In [6]:
from tvm.ir import IRModule, structural_equal
from tvm import relay
from tvm.relay.transform import SimplifyInference, InferType


def test_simplify_batchnorm(dtype="float32"):
    rly = relay
    def simple_bn(x, gamma, beta, moving_mean, moving_var, axis=1, epsilon=1e-5, shape=None):
        # expect = (x - moving_mean) / sqrt(moving_var + eps) * gamma + beta
        scale = rly.multiply(
            rly.const(1, dtype) / rly.sqrt(moving_var + rly.const(epsilon, dtype)), gamma
        )
        shift = rly.add(rly.multiply(rly.negative(moving_mean), scale), beta)
        num_newaxis = len(shape) - (axis + 1)
        if num_newaxis:
            scale = rly.expand_dims(scale, axis=1, num_newaxis=num_newaxis)
            shift = rly.expand_dims(shift, axis=1, num_newaxis=num_newaxis)
        return x * scale + shift

    def check(dim, axis, nstep):
        eps = 0.01
        ttype1 = rly.TensorType(tuple(10 for i in range(dim)), dtype)
        ttype2 = rly.TensorType((10,), dtype)
        x = rly.var("x", ttype1)
        beta = rly.var("beta", ttype2)
        gamma = rly.var("gamma", ttype2)
        moving_var = rly.var("moving_var", ttype2)
        moving_mean = rly.var("moving_mean", ttype2)
        y1, y2 = x, x

        for _ in range(nstep):
            y1, _, _ = rly.nn.batch_norm(
                y1 + rly.const(1, dtype),
                gamma,
                beta,
                moving_mean,
                moving_var,
                epsilon=eps,
                axis=axis,
            )
            y1 = rly.nn.dropout(y1)
            y2 = simple_bn(
                y2 + rly.const(1, dtype),
                gamma,
                beta,
                moving_mean,
                moving_var,
                epsilon=eps,
                axis=axis,
                shape=ttype1.shape,
            )

        mod = IRModule.from_expr(y1)

        simplify = SimplifyInference()
        mod = InferType()(mod)
        mod = simplify(mod)
        y1 = mod["main"].body

        assert structural_equal(y1, y2, map_free_vars=True)

    check(2, 1, 1)
    check(4, 1, 1)
    check(4, 0, 3)


if __name__ == "__main__":
    test_simplify_batchnorm(dtype="float32")
    test_simplify_batchnorm(dtype="float16")

In [12]:
import numpy as np
import tvm
from tvm import relay

def run_opt_pass(expr: tvm.IRModule, opt_pass: tvm.transform.Pass):
    mod = tvm.IRModule.from_expr(expr)
    mod = relay.transform.InferType()(mod)
    mod = opt_pass(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body

In [16]:
import time
import ctypes

import tvm
from tvm import te
from tvm.contrib.utils import tempdir
from tvm.runtime.module import BenchmarkResult