In [1]:
from __future__ import absolute_import, print_function

import tvm
from tvm import te
import tvm.testing
import numpy as np

  from pandas import MultiIndex, Int64Index


In [2]:
N, M, L = 1024, 512, 64
A = te.placeholder((N, L), name="A")
B = te.placeholder((M, L), name="B")
k = te.reduce_axis((0, L), name="k")
C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[j, k], axis=k), name="C")
s = te.create_schedule(C.op)
print(tvm.lower(s, [A, B, C], simple_mode=True))

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 64), "float32"), B: T.Buffer((512, 64), "float32"), C: T.Buffer((1024, 512), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
        for i, j in T.grid(1024, 512):
            C_1 = T.Buffer((524288,), data=C.data)
            C_1[i * 512 + j] = T.float32(0)
            for k in range(64):
                cse_var_1: T.int32 = i * 512 + j
                A_1 = T.Buffer((65536,), data=A.data)
                B_1 = T.Buffer((32768,), data=B.data)
                C_1[cse_var_1] = C_1[cse_var_1] + A_1[i * 64 + k] * B_1[j * 64 + k]


In [3]:
factor = 16
x, y = C.op.axis
(z,) = C.op.reduce_axis
yo, yi = s[C].split(y, factor=factor)
s[C].reorder(x, yo, yi, z)
print(tvm.lower(s, [A, B, C], simple_mode=True))

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 64), "float32"), B: T.Buffer((512, 64), "float32"), C: T.Buffer((1024, 512), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
        for i, j_outer, j_inner in T.grid(1024, 32, 16):
            C_1 = T.Buffer((524288,), data=C.data)
            C_1[i * 512 + j_outer * 16 + j_inner] = T.float32(0)
            for k in range(64):
                cse_var_1: T.int32 = i * 512 + j_outer * 16 + j_inner
                A_1 = T.Buffer((65536,), data=A.data)
                B_1 = T.Buffer((32768,), data=B.data)
                C_1[cse_var_1] = C_1[cse_var_1] + A_1[i * 64 + k] * B_1[j_outer * 1024 + j_inner * 64 + k]


In [4]:
def intrin_gemv(m, l):
    a = te.placeholder((l,), name="a")
    b = te.placeholder((m, l), name="b")
    k = te.reduce_axis((0, l), name="k")
    c = te.compute((m,), lambda i: te.sum(a[k] * b[i, k], axis=k), name="c")
    Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[1])
    Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[te.var("s1"), 1])
    Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1])

    def intrin_func(ins, outs):
        ib = tvm.tir.ir_builder.create()
        aa, bb = ins
        cc = outs[0]
        ib.emit(
            tvm.tir.call_extern(
                "int32",
                "gemv_update",
                cc.access_ptr("w"),
                aa.access_ptr("r"),
                bb.access_ptr("r"),
                m,
                l,
                bb.strides[0],
            )
        )
        return ib.get()

    return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})

In [5]:
gemv = intrin_gemv(factor, L)
s[C].tensorize(yi, gemv)
print(tvm.lower(s, [A, B, C], simple_mode=True))

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 64), "float32"), B: T.Buffer((512, 64), "float32"), C: T.Buffer((1024, 512), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
        for i, j_outer in T.grid(1024, 32):
            T.call_extern("int32", "gemv_update", T.tvm_access_ptr(T.type_annotation("float32"), C.data, i * 512 + j_outer * 16, 16, 2), T.tvm_access_ptr(T.type_annotation("float32"), A.data, i * 64, 64, 1), T.tvm_access_ptr(T.type_annotation("float32"), B.data, j_outer * 1024, 1024, 1), 16, 64, 64)


In [6]:
def gemv_impl():
    cc_code = """
      extern "C" int gemv_update(float *cc, float *aa, float *bb, int m, int l, int stride) {
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < l; ++j) {
                cc[i] += aa[j] * bb[i * stride + j];
            }
        }
        return 0;
      }
    """
    from tvm.contrib import utils, clang

    temp = utils.tempdir()
    ll_path = temp.relpath("temp.ll")
    # 从 C 源代码创建 LLVM ir
    ll_code = clang.create_llvm(cc_code, output=ll_path)
    return ll_code

In [7]:
s[C].pragma(x, "import_llvm", gemv_impl())
print(tvm.lower(s, [A, B, C], simple_mode=True))

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 64), "float32"), B: T.Buffer((512, 64), "float32"), C: T.Buffer((1024, 512), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
        i = T.int32()
        T.attr(T.iter_var(i, None, "DataPar", ""), "pragma_import_llvm", metadata["tir.StringImm"][0])
        for i, j_outer in T.grid(1024, 32):
            T.call_extern("int32", "gemv_update", T.tvm_access_ptr(T.type_annotation("float32"), C.data, i * 512 + j_outer * 16, 16, 2), T.tvm_access_ptr(T.type_annotation("float32"), A.data, i * 64, 64, 1), T.tvm_access_ptr(T.type_annotation("float32"), B.data, j_outer * 1024, 1024, 1), 16, 64, 64)

# Metadata omitted. Use show_meta=True in script() method to show it.


In [8]:
func = tvm.build(s, [A, B, C], target="llvm", name="gemv")

from tvm.topi.utils import get_const_tuple

dtype = A.dtype
dev = tvm.device("cpu", 0)
a = np.random.uniform(size=get_const_tuple(A.shape)).astype(dtype)
b = np.random.uniform(size=get_const_tuple(B.shape)).astype(dtype)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=dtype), dev)
func(tvm.nd.array(a, dev), tvm.nd.array(b, dev), c)
tvm.testing.assert_allclose(c.numpy(), np.dot(a, b.T), rtol=1e-3)

In [9]:
zo, zi = s[C].split(z, factor=factor)
s[C].reorder(x, yo, zo, yi, zi)

In [10]:
def gemv_impl():
    cc_code = """
      extern "C" int gemv_update(float *cc, float *aa, float *bb, int m, int l, int stride) {
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < l; ++j) {
                cc[i] += aa[j] * bb[i * stride + j];
            }
        }
        return 0;
      }
      extern "C" int gemv_reset(float *cc, int m) {
        for (int i = 0; i < m; ++i) {
            cc[i] = 0.0;
        }
        return 0;
      }
    """
    from tvm.contrib import utils, clang

    temp = utils.tempdir()
    ll_path = temp.relpath("temp.ll")
    # 从 C 源代码创建 LLVM ir
    ll_code = clang.create_llvm(cc_code, output=ll_path)
    return ll_code

def intrin_gemv(m, l):
    a = te.placeholder((l,), name="a")
    b = te.placeholder((m, l), name="b")
    k = te.reduce_axis((0, l), name="k")
    c = te.compute((m,), lambda i: te.sum(a[k] * b[i, k], axis=k), name="c")
    Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[1])
    Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[te.var("s1"), 1])
    Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1])

    def intrin_func(ins, outs):
        aa, bb = ins
        cc = outs[0]

        def _body():
            ib = tvm.tir.ir_builder.create()
            ib.emit(
                tvm.tir.call_extern(
                    "int32",
                    "gemv_update",
                    cc.access_ptr("w"),
                    aa.access_ptr("r"),
                    bb.access_ptr("r"),
                    m,
                    l,
                    bb.strides[0],
                )
            )
            return ib.get()

        def _reduce_reset():
            ib = tvm.tir.ir_builder.create()
            ib.emit(tvm.tir.call_extern("int32", "gemv_reset", cc.access_ptr("w"), m))
            return ib.get()

        def _reduce_update():
            return _body()

        return _body(), _reduce_reset(), _reduce_update()

    return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})

In [11]:
gemv = intrin_gemv(factor, factor)
s[C].tensorize(yi, gemv)
s[C].pragma(yo, "import_llvm", gemv_impl())

func = tvm.build(s, [A, B, C], target="llvm", name="gemv")
a = np.random.uniform(size=get_const_tuple(A.shape)).astype(dtype)
b = np.random.uniform(size=get_const_tuple(B.shape)).astype(dtype)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=dtype), dev)
func(tvm.nd.array(a, dev), tvm.nd.array(b, dev), c)
tvm.testing.assert_allclose(c.numpy(), np.dot(a, b.T), rtol=1e-3)