This is the kernel we get directly from auto-tuning. See profile.ipynb

```
import tvm

from tvm import relax
from tvm.relax.frontend import nn

from typing import Optional
from tvm import te
from tvm import dlight as dl
from tvm.target import Target
import numpy as np
import tempfile

import timeit

from mlc_dac.layers import CachedWNConv1d

conv1d = CachedWNConv1d(512, 512, 7, stride=1, dilation=9, padding=0)
mod, params = conv1d.export_tvm(
    {"forward": {"x": nn.spec.Tensor((5, 512, 62), "float32")}},
    debug=True
)

trials = 2000
target = Target.from_device("metal")

with target, tempfile.TemporaryDirectory() as tmp_dir:
    seq = tvm.transform.Sequential(
        [
            relax.get_pipeline("zero"),
            relax.transform.MetaScheduleTuneTIR(work_dir=tmp_dir, max_trials_global=trials),
            relax.transform.MetaScheduleApplyDatabase(work_dir=tmp_dir),
        ]
    )

    mod = seq(mod)

mod.show()
```

In [2]:
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def cached_padding_1d_crop(var_x: T.handle, var_res: T.handle):
        T.func_attr({"op_pattern": 0})
        b, c, out = T.int32(), T.int32(), T.int32()
        x = T.match_buffer(var_x, (b, c, out))
        n = T.int32()
        res = T.match_buffer(var_res, (b, c, n))
        # with T.block("root"):
        for bb, cc, nn in T.grid(b, c, n):
            with T.block("res_crop"):
                vb, vc, vn = T.axis.remap("SSS", [bb, cc, nn])
                T.reads(x[vb, vc, vn])
                T.writes(res[vb, vc, vn])
                res[vb, vc, vn] = x[vb, vc, vn]

    @T.prim_func
    def cached_padding_1d_init(var_cache: T.handle):
        T.func_attr({"op_pattern": 0})
        b, c, p = T.int32(), T.int32(), T.int32()
        cache = T.match_buffer(var_cache, (b, c, p))
        # with T.block("root"):
        for bb, cc, pp in T.grid(b, c, p):
            with T.block("cache_init"):
                vb, vc, vp = T.axis.remap("SSS", [bb, cc, pp])
                T.reads()
                T.writes(cache[vb, vc, vp])
                cache[vb, vc, vp] = T.float32(0.0)

    @T.prim_func
    def cached_padding_1d_update(var_cache: T.handle, var_data: T.handle, var_res: T.handle):
        T.func_attr({"op_pattern": 8})
        B, c, p = T.int32(), T.int32(), T.int32()
        cache = T.match_buffer(var_cache, (B, c, p))
        b, n = T.int32(), T.int32()
        data = T.match_buffer(var_data, (b, c, n))
        out = T.int32()
        res = T.match_buffer(var_res, (b, c, out))
        # with T.block("root"):
        for bb, cc, oo in T.grid(b, c, out):
            with T.block("res_update"):
                vb, vc, vo = T.axis.remap("SSS", [bb, cc, oo])
                T.reads(cache[vb, vc, vo], data[vb, vc, vo - p])
                T.writes(res[vb, vc, vo])
                res[vb, vc, vo] = T.if_then_else(vo < p, cache[vb, vc, vo], data[vb, vc, vo - p])
        for bb, cc, pp in T.grid(b, c, p):
            with T.block("cache_update"):
                vb, vc, vp = T.axis.remap("SSS", [bb, cc, pp])
                T.reads(res[vb, vc, out - p + vp])
                T.writes(cache[vb, vc, vp])
                cache[vb, vc, vp] = res[vb, vc, out - p + vp]

    @T.prim_func(private=True)
    def fused_conv1d_add(x: T.Buffer((1, 512, 62), "float32"), weight: T.Buffer((512, 512, 7), "float32"), bias: T.Buffer((1, 512, 1), "float32"), out: T.Buffer((1, 512, 8), "float32")):
        T.func_attr({"tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        out_local = T.alloc_buffer((1, 512, 8), scope="local")
        pad_shared = T.alloc_buffer((1, 512, 62), scope="shared")
        weight_shared = T.alloc_buffer((512, 512, 7), scope="shared")
        for lbx in T.thread_binding(64, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
            for ltx in T.thread_binding(64, thread="threadIdx.x"):
                for yy in range(1):
                    with T.block("conv1d_init"):
                        v_ff = T.axis.spatial(512, lbx * 8 + ltx // 8)
                        v_yy = T.axis.spatial(8, ltx % 8 + yy)
                        T.reads()
                        T.writes(out_local[0, v_ff, v_yy])
                        out_local[0, v_ff, v_yy] = T.float32(0.0)
                for ic_outer in range(8):
                    for shared_fused_outer in range(16):
                        for shared_tid in T.thread_binding(64, thread="threadIdx.x"):
                            for vec_idx in T.vectorized(4):
                                with T.block("pad_shared"):
                                    v1 = T.axis.spatial(512, ic_outer * 64 + (shared_fused_outer * 256 + shared_tid * 4 + vec_idx) // 62)
                                    v2 = T.axis.spatial(62, (shared_fused_outer * 256 + shared_tid * 4 + vec_idx) % 62)
                                    T.where((shared_fused_outer * 64 + shared_tid) * 4 + vec_idx < 3968)
                                    T.reads(x[0, v1, v2])
                                    T.writes(pad_shared[0, v1, v2])
                                    pad_shared[0, v1, v2] = x[0, v1, v2]
                    for weight_outer in range(56):
                        for shared_tid in T.thread_binding(64, thread="threadIdx.x"):
                            with T.block("wnconv1d_shared"):
                                v0 = T.axis.spatial(512, lbx * 8 + (weight_outer * 64 + shared_tid) // 448)
                                v1 = T.axis.spatial(512, ic_outer * 64 + (weight_outer * 64 + shared_tid) % 448 // 7)
                                v2 = T.axis.spatial(7, (weight_outer * 64 + shared_tid) % 7)
                                T.reads(weight[v0, v1, v2])
                                T.writes(weight_shared[v0, v1, v2])
                                weight_shared[v0, v1, v2] = weight[v0, v1, v2]
                    for rc_1, yy, rc_2, ry_2 in T.grid(16, 1, 4, 7):
                        with T.block("conv1d_ncw_update"):
                            v_ff = T.axis.spatial(512, lbx * 8 + ltx // 8)
                            v_yy = T.axis.spatial(8, ltx % 8 + yy)
                            v_rc = T.axis.reduce(512, ic_outer * 64 + rc_1 * 4 + rc_2)
                            v_ry = T.axis.reduce(7, ry_2)
                            T.reads(out_local[0, v_ff, v_yy], pad_shared[0, v_rc, v_ry * 9 + v_yy], weight_shared[v_ff, v_rc, v_ry])
                            T.writes(out_local[0, v_ff, v_yy])
                            out_local[0, v_ff, v_yy] = out_local[0, v_ff, v_yy] + pad_shared[0, v_rc, v_ry * 9 + v_yy] * weight_shared[v_ff, v_rc, v_ry]
                for yy in range(1):
                    with T.block("conv1d_ncw_intermediate_local"):
                        v1 = T.axis.spatial(512, lbx * 8 + ltx // 8)
                        v2 = T.axis.spatial(8, ltx % 8 + yy)
                        T.reads(out_local[0, v1, v2], bias[0, v1, 0])
                        T.writes(out[0, v1, v2])
                        out[0, v1, v2] = out_local[0, v1, v2] + bias[0, v1, 0]

    @T.prim_func(private=True)
    def fused_tir_sqrt_divide_multiply(lv4: T.Buffer((T.int64(512), T.int64(1), T.int64(1)), "float32"), weight_v: T.Buffer((T.int64(512), T.int64(512), T.int64(7)), "float32"), weight_g: T.Buffer((T.int64(512), T.int64(1), T.int64(1)), "float32"), T_multiply_intermediate: T.Buffer((T.int64(512), T.int64(512), T.int64(7)), "float32")):
        T.func_attr({"tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                for ax0_ax1_ax2_fused_0 in range(T.int64(7)):
                    with T.block("T_divide"):
                        v_ax0 = T.axis.spatial(T.int64(512), (ax0_ax1_ax2_fused_0 * T.int64(262144) + ax0_ax1_ax2_fused_1 * T.int64(1024) + ax0_ax1_ax2_fused_2) // T.int64(3584))
                        v_ax1 = T.axis.spatial(T.int64(512), (ax0_ax1_ax2_fused_0 * T.int64(262144) + ax0_ax1_ax2_fused_1 * T.int64(1024) + ax0_ax1_ax2_fused_2) % T.int64(3584) // T.int64(7))
                        v_ax2 = T.axis.spatial(T.int64(7), (ax0_ax1_ax2_fused_0 * T.int64(262144) + ax0_ax1_ax2_fused_1 * T.int64(1024) + ax0_ax1_ax2_fused_2) % T.int64(7))
                        T.reads(weight_g[v_ax0, T.int64(0), T.int64(0)], weight_v[v_ax0, v_ax1, v_ax2], lv4[v_ax0, T.int64(0), T.int64(0)])
                        T.writes(T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
                        T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = weight_g[v_ax0, T.int64(0), T.int64(0)] * (weight_v[v_ax0, v_ax1, v_ax2] / T.sqrt(lv4[v_ax0, T.int64(0), T.int64(0)]))

    @T.prim_func(private=True)
    def fused_tir_square_sum(weight_v: T.Buffer((T.int64(512), T.int64(512), T.int64(7)), "float32"), lv3_red_intermediate: T.Buffer((T.int64(512), T.int64(1), T.int64(1)), "float32")):
        T.func_attr({"tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_ax1_ax2_fused in T.thread_binding(T.int64(512), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}):
            for k1_k2_fused_0 in range(T.int64(112)):
                for k1_k2_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                    with T.block("lv3_red"):
                        v_ax0 = T.axis.spatial(T.int64(512), ax0_ax1_ax2_fused)
                        v_ax1 = T.axis.spatial(T.int64(1), T.int64(0))
                        v_ax2 = T.axis.spatial(T.int64(1), T.int64(0))
                        v_k1 = T.axis.reduce(T.int64(512), (k1_k2_fused_0 * T.int64(32) + k1_k2_fused_1) // T.int64(7))
                        v_k2 = T.axis.reduce(T.int64(7), (k1_k2_fused_0 * T.int64(32) + k1_k2_fused_1) % T.int64(7))
                        T.reads(weight_v[v_ax0, v_k1, v_k2])
                        T.writes(lv3_red_intermediate[v_ax0, v_ax1, v_ax2])
                        with T.init():
                            lv3_red_intermediate[v_ax0, v_ax1, v_ax2] = T.float32(0.0)
                        lv3_red_intermediate[v_ax0, v_ax1, v_ax2] = lv3_red_intermediate[v_ax0, v_ax1, v_ax2] + weight_v[v_ax0, v_k1, v_k2] * weight_v[v_ax0, v_k1, v_k2]

    @T.prim_func(private=True)
    def reshape(bias: T.Buffer((T.int64(512),), "float32"), T_reshape: T.Buffer((T.int64(1), T.int64(512), T.int64(1)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v_ax0 = T.axis.spatial(T.int64(1), T.int64(0))
                    v_ax1 = T.axis.spatial(T.int64(512), ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1)
                    v_ax2 = T.axis.spatial(T.int64(1), T.int64(0))
                    T.reads(bias[(v_ax1 + v_ax2) % T.int64(512)])
                    T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                    T_reshape[v_ax0, v_ax1, v_ax2] = bias[(v_ax1 + v_ax2) % T.int64(512)]

    @R.function
    def _initialize_effect() -> R.Tuple(R.Object, R.Object, R.Object):
        cls = Module
        with R.dataflow():
            _io: R.Object = R.null_value()
            cache_cache: R.Object = R.call_pure_packed("vm.builtin.cached_padding_1d_create", R.prim_value(0), R.prim_value(0), R.prim_value(32), R.prim_value(0), cls.cached_padding_1d_init, cls.cached_padding_1d_update, cls.cached_padding_1d_crop, sinfo_args=(R.Object,))
            downsampling_delay_cache: R.Object = R.call_pure_packed("vm.builtin.cached_padding_1d_create", R.prim_value(0), R.prim_value(1), R.prim_value(32), R.prim_value(0), cls.cached_padding_1d_init, cls.cached_padding_1d_update, cls.cached_padding_1d_crop, sinfo_args=(R.Object,))
            gv: R.Tuple(R.Object, R.Object, R.Object) = _io, cache_cache, downsampling_delay_cache
            R.output(gv)
        return gv

    @R.function
    def forward(x: R.Tensor((1, 512, 62), dtype="float32"), _io: R.Object, cache_cache: R.Object, downsampling_delay_cache: R.Object, weight_g: R.Tensor((512, 1, 1), dtype="float32"), weight_v: R.Tensor((512, 512, 7), dtype="float32"), bias: R.Tensor((512,), dtype="float32")) -> R.Tuple(R.Tensor((1, 512, 8), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object)):
        R.func_attr({"num_input": 4})
        cls = Module
        with R.dataflow():
            lv1: R.Tensor((1, 512, 62), dtype="float32") = R.call_pure_packed("vm.builtin.cached_padding_1d_update", downsampling_delay_cache, x, sinfo_args=(R.Tensor((1, 512, 62), dtype="float32"),))
            lv2: R.Tensor((1, 512, 62), dtype="float32") = R.call_pure_packed("vm.builtin.cached_padding_1d_update", cache_cache, lv1, sinfo_args=(R.Tensor((1, 512, 62), dtype="float32"),))
            lv = R.call_tir(cls.fused_tir_square_sum, (weight_v,), out_sinfo=R.Tensor((512, 1, 1), dtype="float32"))
            lv1_1 = R.call_tir(cls.fused_tir_sqrt_divide_multiply, (lv, weight_v, weight_g), out_sinfo=R.Tensor((512, 512, 7), dtype="float32"))
            lv8 = R.call_tir(cls.reshape, (bias,), out_sinfo=R.Tensor((1, 512, 1), dtype="float32"))
            lv2_1 = R.call_tir(cls.fused_conv1d_add, (lv2, lv1_1, lv8), out_sinfo=R.Tensor((1, 512, 8), dtype="float32"))
            gv1: R.Tuple(R.Tensor((1, 512, 8), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object)) = lv2_1, (_io, cache_cache, downsampling_delay_cache)
            R.output(gv1)
        return gv1

This is how we build and run the auto-tuned TIR kernel

In [7]:
import tvm

from tvm import relax
from tvm.relax.frontend import nn

from typing import Optional
from tvm import te
from tvm import dlight as dl
from tvm.target import Target
import numpy as np

mod = Module

device = tvm.metal()
target = Target.from_device("metal")
with target:
    seq = dl.ApplyDefaultSchedule(
        dl.gpu.Fallback(),
    )
    vm_mod = seq(mod)

ex = relax.build(vm_mod, target)
vm = relax.VirtualMachine(ex, device, profile=True)
effects = vm.module["_initialize_effect"]()

np.random.seed(0)

weight_g = np.random.randn(512, 1, 1).astype("float32")
weight_v = np.random.randn(512, 512, 7).astype("float32")
bias = np.random.randn(512).astype("float32")

tvm_params = [weight_g, weight_v, bias]
tvm_params = [tvm.nd.array(param, device=device) for param in tvm_params]

audio_data = np.random.randn(1, 512, 62).astype("float32")
audio_data = tvm.nd.array(audio_data, device=device)

time_eval = vm.time_evaluator("forward", device, 10, 5)(audio_data, *effects, *tvm_params)
print(time_eval)

out = vm.module["forward"](audio_data, *effects, *tvm_params)
print(out[0].asnumpy())

time_eval = vm.time_evaluator("forward", device, 10, 5)(audio_data, *effects, *tvm_params)
print(time_eval)

Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   0.1427       0.1405       0.1633       0.1308       0.0111                  
[[[ 2.824526   -1.3132825   3.8338013  ... -1.8539417  -0.52261025
   -0.92020506]
  [ 2.191527    1.5215199   1.6404716  ...  1.4687572   0.7257114
    2.0899653 ]
  [ 1.780653    1.9595389   1.6723082  ... -0.44428122  1.0687006
    0.28043497]
  ...
  [-2.4418955  -2.6515665  -1.9757025  ... -1.0010937  -1.3631712
   -0.07432413]
  [ 2.2440364   1.9657502   1.8604834  ... -1.0799818   1.519831
    0.00400889]
  [-0.43892244  0.04964514 -1.30905    ... -1.3583254  -1.4844707
   -0.57526046]]]
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   0.1312       0.1333       0.1388       0.1177       0.0072                  


Now, we try to write a kernel with the optimized schedule (Notice the tiling size is the same) for generic input shapes. The `conv1d` kernel takes shapes as inputs and generates a TIR kernel with the schedule

In [1]:
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

def get_conv1d_fn(x, weight, stride, dilation):
    NUM_BLKS = 64
    bdx = 64

    (_, C_in, N), model_dtype = x.shape, x.dtype
    (C_out, C_in, K), storage_dtype = weight.shape, weight.dtype
    O = (N - dilation * (K - 1) - 1) // stride + 1

    out_ch_per_block = (C_out + NUM_BLKS - 1) // NUM_BLKS
    threads_per_out_ch = bdx // out_ch_per_block
    spatial_per_thread = (O + threads_per_out_ch - 1) // threads_per_out_ch

    assert C_in % 16 == 0, "C_in must be a multiple of 16"
    in_ch_per_block = C_in // 8
    weight_per_block = K * in_ch_per_block

    rc1 = 16
    rc2 = in_ch_per_block // rc1

    vec_blk = 4 * bdx

    @T.prim_func()
    def fused_conv1d_add(var_x: T.handle, weight: T.Buffer((C_out, C_in, K), storage_dtype), bias: T.Buffer((C_out), storage_dtype), var_out: T.handle):
        # T.func_attr({"tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        x = T.match_buffer(var_x, (1, C_in, N), model_dtype)
        out = T.match_buffer(var_out, (1, C_out, O), model_dtype)
        
        out_local = T.alloc_buffer((1, C_out, O), scope="local")
        pad_shared = T.alloc_buffer((1, C_in, N), scope="shared")
        weight_shared = T.alloc_buffer((C_out, C_in, K), scope="shared")

        for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
            for vtx in T.thread_binding(T.int64(1), thread="vthread.x"):
                for ltx in T.thread_binding(bdx, thread="threadIdx.x"):
                    for yy in range(spatial_per_thread):
                        with T.block("conv1d_init"):
                            v_ff = T.axis.spatial(C_out, lbx * out_ch_per_block + ltx // threads_per_out_ch)
                            v_yy = T.axis.spatial(O, ltx % threads_per_out_ch * spatial_per_thread + yy)
                            T.reads()
                            T.writes(out_local[0, v_ff, v_yy])
                            out_local[0, v_ff, v_yy] = T.float32(0.0)
                    for ic_outer in range(T.int32(8)):
                        for shared_fused_outer in range(T.ceildiv(in_ch_per_block * N, bdx * 4)):
                            for shared_tid in T.thread_binding(bdx, thread="threadIdx.x"):
                                for vec_idx in T.vectorized(T.int32(4)):
                                    with T.block("pad_shared"):
                                        v1 = T.axis.spatial(C_in, ic_outer * in_ch_per_block + (shared_fused_outer * vec_blk + shared_tid * T.int32(4) + vec_idx) // N)
                                        v2 = T.axis.spatial(N, (shared_fused_outer * vec_blk + shared_tid * T.int32(4) + vec_idx) % N)
                                        T.where((shared_fused_outer * bdx + shared_tid) * T.int32(4) + vec_idx < in_ch_per_block * N)
                                        T.reads(x[0, v1, v2])
                                        T.writes(pad_shared[0, v1, v2])
                                        pad_shared[0, v1, v2] = x[0, v1, v2]
                        for weight_outer in range(T.ceildiv(weight_per_block * out_ch_per_block, bdx)):
                            for shared_tid in T.thread_binding(bdx, thread="threadIdx.x"):
                                with T.block("wnconv1d_shared"):
                                    v0 = T.axis.spatial(C_out, lbx * out_ch_per_block + (weight_outer * bdx + shared_tid) // weight_per_block)
                                    v1 = T.axis.spatial(C_in, ic_outer * in_ch_per_block + (weight_outer * bdx + shared_tid) % weight_per_block // K)
                                    v2 = T.axis.spatial(K, (weight_outer * bdx + shared_tid) % K)
                                    T.reads(weight[v0, v1, v2])
                                    T.writes(weight_shared[v0, v1, v2])
                                    weight_shared[v0, v1, v2] = weight[v0, v1, v2]
                        for rc_1, yy, rc_2, ry_2 in T.grid(rc1, spatial_per_thread, rc2, K):
                            with T.block("conv1d_ncw_update"):
                                v_ff = T.axis.spatial(C_out, lbx * out_ch_per_block + ltx // threads_per_out_ch)
                                v_yy = T.axis.spatial(O, ltx % threads_per_out_ch * spatial_per_thread + yy)
                                v_rc = T.axis.reduce(C_in, ic_outer * in_ch_per_block + rc_1 * rc2 + rc_2)
                                v_ry = T.axis.reduce(K, ry_2)
                                T.reads(out_local[0, v_ff, v_yy], pad_shared[0, v_rc, v_ry * dilation + v_yy * stride], weight_shared[v_ff, v_rc, v_ry])
                                T.writes(out_local[0, v_ff, v_yy])
                                out_local[0, v_ff, v_yy] = out_local[0, v_ff, v_yy] + pad_shared[0, v_rc, v_ry * dilation + v_yy * stride] * weight_shared[v_ff, v_rc, v_ry]
                    for yy in range(spatial_per_thread):
                        with T.block("conv1d_ncw_intermediate_local"):
                            v1 = T.axis.spatial(C_out, lbx * out_ch_per_block + ltx // threads_per_out_ch)
                            v2 = T.axis.spatial(O, ltx % threads_per_out_ch * spatial_per_thread + yy)
                            T.reads(out_local[0, v1, v2], bias[v1])
                            T.writes(out[0, v1, v2])
                            out[0, v1, v2] = out_local[0, v1, v2] + bias[v1]

    return fused_conv1d_add

    # return nn.op.tensor_ir_op(
    #     fused_conv1d_add,
    #     "fused_conv1d_add",
    #     args=[x, weight, bias],
    #     out=nn.Tensor.placeholder((1, C_out, O), dtype=model_dtype),
    # )

[01:37:15] /Users/cfruan/Documents/tvm-unity/src/target/llvm/llvm_instance.cc:226: Error: Using LLVM 19.1.1 with `-mcpu=apple-latest` is not valid in `-mtriple=arm64-apple-macos`, using default `-mcpu=generic`
[01:37:15] /Users/cfruan/Documents/tvm-unity/src/target/llvm/llvm_instance.cc:226: Error: Using LLVM 19.1.1 with `-mcpu=apple-latest` is not valid in `-mtriple=arm64-apple-macos`, using default `-mcpu=generic`
[01:37:15] /Users/cfruan/Documents/tvm-unity/src/target/llvm/llvm_instance.cc:226: Error: Using LLVM 19.1.1 with `-mcpu=apple-latest` is not valid in `-mtriple=arm64-apple-macos`, using default `-mcpu=generic`


We can then evaluate the performance and correctness of the generated kernel

In [3]:
import tvm

from tvm import relax

from tvm import dlight as dl
from tvm.target import Target
import numpy as np
import torch

mod = tvm.IRModule()
mod["fused_conv1d_add"] = get_conv1d_fn(
    tvm.te.placeholder((1, 512, 62), dtype="float32"),
    tvm.te.placeholder((512, 512, 7), dtype="float32"),
    1,
    9,
)
mod.show()

device = tvm.metal()
target = Target.from_device("metal")
with target:
    seq = dl.ApplyDefaultSchedule(
        dl.gpu.Fallback(),
    )
    vm_mod = seq(mod)

rt_mod = tvm.build(vm_mod, target=target)

np.random.seed(0)

weight = np.random.randn(512, 512, 7).astype("float32")
bias = np.random.randn(512).astype("float32")

tvm_params = [weight, bias]
tvm_params = [tvm.nd.array(param, device=device) for param in tvm_params]

out = tvm.nd.empty((1, 512, 8), device=device)

audio_data = np.random.randn(1, 512, 62).astype("float32")
audio_data = tvm.nd.array(audio_data, device=device)

time_eval = rt_mod.time_evaluator("fused_conv1d_add", device, 10, 5)(audio_data, *tvm_params, out)
print(time_eval)

rt_mod["fused_conv1d_add"](audio_data, *tvm_params, out)
out_relax = out.asnumpy()
out_torch = torch.nn.functional.conv1d(torch.tensor(audio_data.asnumpy()), torch.tensor(weight), torch.tensor(bias), stride=1, dilation=9)
print(np.allclose(out_relax, out_torch.numpy(), atol=1e-5))

Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   0.0754       0.0722       0.0927       0.0692       0.0088                  
True


Let's now measure the speedup before and after optimization

In [19]:
np.random.seed(0)

weight = np.random.randn(512, 512, 7).astype("float32")
bias = np.random.randn(512).astype("float32")
data = np.random.randn(1, 512, 62).astype("float32")

# Pytorch MPS

import torch
from tvm.relax.frontend import nn

weight_torch = torch.tensor(weight).to("mps")
bias_torch = torch.tensor(bias).to("mps")
data_torch = torch.tensor(data).to("mps")

conv1d_torch = torch.nn.Conv1d(512, 512, 7, stride=1, padding=0, dilation=9).to("mps")
conv1d_torch.weight.data = weight_torch
conv1d_torch.bias.data = bias_torch

# Pytorch CPU

weight_torch_cpu = torch.tensor(weight).to("cpu")
bias_torch_cpu = torch.tensor(bias).to("cpu")
data_torch_cpu = torch.tensor(data).to("cpu")

conv1d_torch_cpu = torch.nn.Conv1d(512, 512, 7, stride=1, padding=0, dilation=9).to("cpu")
conv1d_torch_cpu.weight.data = weight_torch_cpu
conv1d_torch_cpu.bias.data = bias_torch_cpu

import time

def time_it(func, warmup=10, runs=100):
    with torch.no_grad():
        # Warmup
        for _ in range(warmup):
            func()
        
        # Timing
        start = time.perf_counter()
        for _ in range(runs):
            func()
        avg_time = (time.perf_counter() - start) / runs
    return f"{avg_time*1000:.4f}ms"

print("Torch with MPS")
result = time_it(lambda: conv1d_torch(data_torch))
print(result)

print("Torch with CPU")
result = time_it(lambda: conv1d_torch_cpu(data_torch_cpu))
print(result)

# Before

relax_mod = nn.Conv1D(512, 512, 7, stride=1, padding=0, dilation=9)
mod, _ = relax_mod.export_tvm({
    "forward": {
        "x": nn.spec.Tensor([1, 512, 62], "float32"),
    }
})

device = tvm.metal()
target = Target.from_device("metal")
with target:
    seq = tvm.transform.Sequential(
        [
            tvm.relax.transform.LegalizeOps(),
            tvm.relax.transform.AnnotateTIROpPattern(),
            tvm.relax.transform.FoldConstant(),
            tvm.relax.transform.FuseOps(),
            tvm.relax.transform.FuseTIR(),
            dl.ApplyDefaultSchedule(
                dl.gpu.Matmul(),
                dl.gpu.GEMV(),
                dl.gpu.Reduction(),
                dl.gpu.GeneralReduction(),
                dl.gpu.Fallback(),
            ),
        ]
    )
    vm_mod = seq(mod)

ex = relax.build(vm_mod, target)
vm = relax.VirtualMachine(ex, device, profile=True)

tvm_params = [weight, bias]
tvm_params = [tvm.nd.array(param, device=device) for param in tvm_params]

data = tvm.nd.array(data, device=device)

time_eval = vm.time_evaluator("forward", device, 10, 5)(data, *tvm_params)
print(time_eval)

def time_tvm(fn, data, tvm_params, warmup=10, runs=100):
    # Warmup
    for _ in range(warmup):
        fn(data, *tvm_params)
    
    # Timing
    start = time.perf_counter()
    for _ in range(runs):
        fn(data, *tvm_params)
    avg_time = (time.perf_counter() - start) / runs
    return f"{avg_time*1000:.4f}ms"

print("TVM without tuning")
avg_time = time_tvm(vm["forward"], data, tvm_params)
print(avg_time)

# After


bb = relax.BlockBuilder()
conv1d_fn = get_conv1d_fn(
    tvm.te.placeholder((1, 512, 62), dtype="float32"),
    tvm.te.placeholder((512, 512, 7), dtype="float32"),
    1,
    9,
)

x = relax.Var("x", R.Tensor([1, 512, 62], "float32"))
weight = relax.Var("weight", R.Tensor([512, 512, 7], "float32"))
bias = relax.Var("bias", R.Tensor([512], "float32"))

with bb.function("forward", [x, weight, bias]):
    with bb.dataflow():
        tir_conv1d = bb.add_func(conv1d_fn, "fused_conv1d_add")
        gv = bb.emit(
            relax.call_tir(tir_conv1d, [x, weight, bias], out_sinfo=relax.TensorStructInfo([1, 512, 8], "float32"))
        )
        bb.emit_output(gv)
    bb.emit_func_output(gv)

mod = bb.get()

with target:
    seq = tvm.transform.Sequential(
        [
            tvm.relax.transform.LegalizeOps(),
            tvm.relax.transform.AnnotateTIROpPattern(),
            tvm.relax.transform.FoldConstant(),
            tvm.relax.transform.FuseOps(),
            tvm.relax.transform.FuseTIR(),
            dl.ApplyDefaultSchedule(
                dl.gpu.Matmul(),
                dl.gpu.GEMV(),
                dl.gpu.Reduction(),
                dl.gpu.GeneralReduction(),
                dl.gpu.Fallback(),
            ),
        ]
    )
    vm_mod = seq(mod)

np.random.seed(0)

weight = np.random.randn(512, 512, 7).astype("float32")
bias = np.random.randn(512).astype("float32")

tvm_params = [weight, bias]
tvm_params = [tvm.nd.array(param, device=device) for param in tvm_params]

ex = relax.build(vm_mod, target=target)
vm = relax.VirtualMachine(ex, device, profile=True)
time_eval_after = vm.time_evaluator("forward", device, 10, 5)(data, *tvm_params)
print(time_eval_after)


print("TVM after tuning")
avg_time = avg_time = time_tvm(vm["forward"], data, tvm_params)
print(avg_time)

# report = vm.profile("forward", data, *tvm_params)
# print(report)

Torch with MPS
0.0425ms
Torch with CPU
0.6068ms
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   0.4154       0.3763       0.5601       0.3617       0.0747                  
TVM without tuning
0.1573ms
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   0.0820       0.0784       0.0890       0.0770       0.0052                  
TVM after tuning
0.0235ms
