In [None]:
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy as np

In [None]:
# declare module
# TODO: Try work on the following workloads
dim = 16777216
# dim = 10000000

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(A: T.Buffer[dim, "float32"], 
             B: T.Buffer[dim, "float32"], 
             C: T.Buffer[dim, "float32"]):
        # extra annotations for the function
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i in range(dim):
            with T.block("C"):
                # declare a data parallel iterator on spatial domain
                vi = T.axis.spatial(dim, i)
                C[vi] = A[vi] + B[vi]
a = tvm.nd.array(np.arange(dim, dtype="float32"))
b = tvm.nd.array(np.ones(dim, dtype="float32")) 
c = tvm.nd.empty((dim,), dtype="float32")
target = "llvm"
dev = tvm.device(target, 0)

In [None]:
def split_sch(mod, factors=4):
    if isinstance(factors, int):
        factors = [None, factors]
    else:
        factors = [None] + factors

    assert(isinstance(factors, list))

    sch = tvm.tir.Schedule(mod)
    # Get block by its name
    block_c = sch.get_block("C")
    # Get loops surronding the block
    (i,) = sch.get_loops(block_c)
    return sch, sch.split(i, factors=factors)

In [None]:
def split_and_benchmark(mod, factors=4):
    sch, (i_0, i_1) = split_sch(mod, factors)
    # TODO: try different factors
    # TODO: try parallelize and not parallelize
    sch.parallel(i_0)
    # sch.vectorize(i_1)
    # mod = tvm.build(sch.mod, target="llvm -mcpu=skylake-avx512")
    mod = tvm.build(sch.mod, target="llvm")
    func = mod["main"]
    dev = tvm.device(target, 0)
    evaluator = mod.time_evaluator(mod.entry_name, dev, number=1000)
    print("Time is: %f ms" % (evaluator(a, b, c).mean * 1000))

In [None]:
def split_and_benchmark(mod, factors=[4, 4]):
    sch, (i_0, i_1, i_2) = split_sch(mod, factors)
    # TODO: try different reorder here
    # sch.reorder(i_0, i_2, i_1)
    sch.parallel(i_0)
    sch.vectorize(i_2)
    # print(sch.mod.script())
    # mod = tvm.build(sch.mod, target="llvm -mcpu=skylake-avx512")
    mod = tvm.build(sch.mod, target="llvm")
    func = mod["main"]
    dev = tvm.device(target, 0)
    evaluator = mod.time_evaluator(mod.entry_name, dev, number=1000)
    print("Time is: %f ms" % (evaluator(a, b, c).mean * 1000))