In [1]:
from typing import List
import numpy as np
import tvm
import tvm.testing
from tvm.script import tir as T

fcreate = tvm.get_global_func("vm.builtin.cached_padding_1d_create")
fforward = tvm.get_global_func("vm.builtin.cached_padding_1d_update")
fview = tvm.get_global_func("vm.builtin.cached_padding_1d_view")


@T.prim_func
def cached_padding_1d_init(
    var_cache: T.handle,
):
    b = T.int32()
    c = T.int32()
    p = T.int32()

    cache = T.match_buffer(var_cache, (b, c, p), "float32")
    for bb, cc, pp in T.grid(b, c, p):
        with T.block("cache_init"):
            vb, vc, vp = T.axis.remap("SSS", [bb, cc, pp])
            cache[vb, vc, vp] = 0.0


@T.prim_func
def cached_padding_1d_update(
    var_cache: T.handle,
    var_data: T.handle,
    var_res: T.handle,
):
    B = T.int32()
    c = T.int32()
    p = T.int32()

    b = T.int32()
    n = T.int32()
    out = T.int32()

    cache = T.match_buffer(var_cache, (B, c, p), "float32")
    data = T.match_buffer(var_data, (b, c, n), "float32")
    res = T.match_buffer(var_res, (b, c, out), "float32")

    for bb, cc, oo in T.grid(b, c, out):
        with T.block("res_update"):
            vb, vc, vo = T.axis.remap("SSS", [bb, cc, oo])
            res[vb, vc, vo] = T.if_then_else(
                vo < p, cache[vb, vc, vo], data[vb, vc, vo - p]
            )

    for bb, cc, pp in T.grid(b, c, p):
        with T.block("cache_update"):
            vb, vc, vp = T.axis.remap("SSS", [bb, cc, pp])
            cache[vb, vc, vp] = res[vb, vc, out - p + vp]


@T.prim_func
def cached_padding_1d_crop(
    var_x: T.handle,
    var_res: T.handle,
):
    b = T.int32()
    c = T.int32()
    out = T.int32()
    n = T.int32()

    x = T.match_buffer(var_x, (b, c, out), "float32")
    res = T.match_buffer(var_res, (b, c, n), "float32")

    for bb, cc, nn in T.grid(b, c, n):
        with T.block("res_crop"):
            vb, vc, vn = T.axis.remap("SSS", [bb, cc, nn])
            res[vb, vc, vn] = x[vb, vc, vn]


def build_tir_func(tir_funcs: List[tvm.tir.PrimFunc], target="llvm"):
    return [tvm.build(tir_func, target=target).entry_func for tir_func in tir_funcs]


finit, fupdate, fcrop = build_tir_func(
    [cached_padding_1d_init, cached_padding_1d_update, cached_padding_1d_crop]
)

cache = fcreate(
    3,
    True,
    1,
    finit,
    fupdate,
    fcrop,
)

x = tvm.nd.array(np.random.rand(1, 3, 10).astype("float32"))
print(x.asnumpy())
fforward(cache, x)
fview(cache).asnumpy()

[20:55:10] /Users/cfruan/Documents/tvm-unity/src/target/llvm/llvm_instance.cc:226: Error: Using LLVM 19.1.1 with `-mcpu=apple-latest` is not valid in `-mtriple=arm64-apple-macos`, using default `-mcpu=generic`
[20:55:10] /Users/cfruan/Documents/tvm-unity/src/target/llvm/llvm_instance.cc:226: Error: Using LLVM 19.1.1 with `-mcpu=apple-latest` is not valid in `-mtriple=arm64-apple-macos`, using default `-mcpu=generic`
[20:55:10] /Users/cfruan/Documents/tvm-unity/src/target/llvm/llvm_instance.cc:226: Error: Using LLVM 19.1.1 with `-mcpu=apple-latest` is not valid in `-mtriple=arm64-apple-macos`, using default `-mcpu=generic`


[[[0.68613315 0.46777123 0.4602704  0.7681316  0.4725257  0.9414942
   0.16952671 0.6756237  0.53193885 0.5842968 ]
  [0.00798621 0.3718034  0.6281481  0.55726165 0.13244139 0.5540062
   0.97739136 0.96519315 0.6632463  0.21178755]
  [0.57561857 0.9785503  0.49822545 0.00793451 0.8648276  0.6589372
   0.16547087 0.34171608 0.55715764 0.6491728 ]]]


array([[[0.6756237 , 0.53193885, 0.5842968 ],
        [0.96519315, 0.6632463 , 0.21178755],
        [0.34171608, 0.55715764, 0.6491728 ]]], dtype=float32)