# matrix multiplication

In [None]:
import tvm
import tvm.testing
from tvm import te
import numpy as np

# The size of the matrix
# (M, K) x (K, N)
M = 1024
K = 1024
N = 1024

dtype = "float32"

In [None]:
import timeit

target = tvm.target.Target(target="llvm", host="llvm")
dev = tvm.device(target.kind.name, 0)

a = tvm.runtime.tensor(np.random.rand(M, K).astype(dtype), dev)
b = tvm.runtime.tensor(np.random.rand(K, N).astype(dtype), dev)

np_repeat = 100
np_running_time = timeit.timeit(
    setup="import numpy as np\n"
    "M = " + str(M) + "\n"
    "K = " + str(K) + "\n"
    "N = " + str(N) + "\n"
    'dtype = "float32"\n'
    "a = np.random.rand(M, K).astype(dtype)\n"
    "b = np.random.rand(K, N).astype(dtype)\n",
    stmt="answer = np.dot(a, b)",
    number=np_repeat,
)
print("np running time: %f" % (np_running_time / np_repeat))

answer = np.dot(a.numpy(), b.numpy())

# TVM Matrix Multiplication using TE

In [None]:
from typing import Any

k = te.reduce_axis((0, K), "k")
A = te.placeholder((M, K), name="A")
B = te.placeholder((K, N), name="B")
C: Any = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C")

prin_func = te.create_prim_func([A, B, C]).with_attr("global_symbol", "mmult")
mod = tvm.IRModule({"mmult": prin_func})
lib = tvm.build(mod, target=target)
func = lib["mmult"]

c = tvm.runtime.tensor(np.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

In [None]:
def evaluate_operation(lib, optimization, log):
    func = lib["mmult"]
    assert func

    c = tvm.runtime.tensor(np.zeros((M, N), dtype=dtype), dev)
    func(a, b, c)
    tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

    evaluator = lib.time_evaluator(lib.entry_name, dev, number=10)
    mean_time = evaluator(a, b, c).mean
    print("%s: %f" % (optimization, mean_time))
    log.append((optimization, mean_time))

In [None]:
log = []

evaluate_operation(lib, "none", log)

In [None]:
print(mod)

# Optimization 1: Blocking

In [None]:
bn = 32

sch = tvm.tir.Schedule(mod)
block_c = sch.get_block("C", func_name="mmult")
x, y, k = sch.get_loops(block_c)

xo, xi = sch.split(x, factors=[None, bn])
yo, yi = sch.split(y, factors=[None, bn])
ko, ki = sch.split(k, factors = [None, 4])

sch.reorder(xo, yo, ko, ki, xi, yi)

blocked_mod = sch.mod
blocked_lib = tvm.build(blocked_mod, target=target)

func = blocked_lib["mmult"]

c = tvm.runtime.tensor(np.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

In [None]:
evaluate_operation(blocked_lib, "blocking", log)

In [None]:
print(blocked_mod)

# Optimization 2: Vectorization

In [None]:
sch.vectorize(yi)

vectorized_mod = sch.mod
vectorized_lib = tvm.build(vectorized_mod, target=target)

In [None]:
evaluate_operation(vectorized_lib, "vectorization", log)

In [None]:
print(vectorized_mod)

# Optimization 3: Loop Permutation

In [None]:
sch = tvm.tir.Schedule(mod)
block_c = sch.get_block("C", func_name="mmult")
x, y, k = sch.get_loops(block_c)

xo, xi = sch.split(x, factors=[None, bn])
yo, yi = sch.split(y, factors=[None, bn])
ko, ki = sch.split(k, factors = [None, 4])

sch.reorder(xo, yo, ko, xi, ki, yi)
sch.vectorize(yi)

permuted_mod = sch.mod
permuted_lib = tvm.build(permuted_mod, target=target)

func = permuted_lib["mmult"]

c = tvm.runtime.tensor(np.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

In [None]:
evaluate_operation(permuted_lib, "loop permutation", log)

In [None]:
print(permuted_mod)

# Optimization 4: Array Packing

In [None]:
k = te.reduce_axis((0, K), "k")

packedB = te.compute((N // bn, K, bn), lambda x, y, z: B[y, x * bn + z], name="packedB")
C: Any = te.compute(
    (M, N),
    lambda x, y: te.sum(A[x, k] * packedB[tvm.tir.floordiv(y, bn), k, tvm.tir.indexmod(y, bn)], axis=k),
    name="C",
)

prim_func = te.create_prim_func([A, B, C]).with_attr("global_symbol", "mmult")
mod = tvm.IRModule({"mmult": prim_func})

sch = tvm.tir.Schedule(mod)

block_c = sch.get_block("C", func_name="mmult")
x, y, k = sch.get_loops(block_c)
xo, xi = sch.split(x, factors=[None, bn])
yo, yi = sch.split(y, factors=[None, bn])
ko, ki = sch.split(k, factors=[None, 4])

sch.reorder(xo, yo, ko, xi, ki, yi)
sch.vectorize(yi)

block_pack = sch.get_block("packedB", func_name="mmult")
xp, yp, zp = sch.get_loops(block_pack)
sch.vectorize(zp)
sch.parallel(xp)

packing_mod = sch.mod
packing_lib = tvm.build(packing_mod, target=target)

func = packing_lib["mmult"]

c = tvm.runtime.tensor(np.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

In [None]:
evaluate_operation(packing_lib, "array packing", log)

In [None]:
print(packing_mod)

# Optimization 5: Optimizing Block Writing Through Caching

In [None]:
sch = tvm.tir.Schedule(mod)

block_c = sch.get_block("C", func_name="mmult")
x, y, k = sch.get_loops(block_c)

CC = sch.cache_write(block_c, 0, "global")

xo, xi = sch.split(x, factors=[None, bn])
yo, yi = sch.split(y, factors=[None, bn])
ko, ki = sch.split(k, factors=[None, 4])
sch.reorder(xo, yo, ko, xi, ki, yi)

xc, yc = sch.get_loops(CC)[-2:]
sch.unroll(ki)
sch.vectorize(yc)

sch.reverse_compute_at(CC, yo)

write_cache_mod = sch.mod
write_cache_lib = tvm.build(write_cache_mod, target=target)

func = write_cache_lib["mmult"]

c = tvm.runtime.tensor(np.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

In [None]:
evaluate_operation(write_cache_lib, "write cache", log)

In [None]:
print(write_cache_mod)

# Optimization 6: Parallelization

In [None]:
block_c = sch.get_block("C", "mmult")
xo = sch.get_loops(block_c)[0]
sch.parallel(xo)

block_pack = sch.get_block("packedB", "mmult")
xp, yp, zp = sch.get_loops(block_pack)
sch.parallel(xp)
sch.vectorize(zp)

In [None]:
parallelization_mod = sch.mod
parallelization_lib = tvm.build(parallelization_mod, target=target)

func = parallelization_lib["mmult"]

c = tvm.runtime.tensor(np.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

In [None]:
evaluate_operation(parallelization_lib, "write cache", log)

In [None]:
print(parallelization_mod)

# summary

In [None]:
baseline = log[0][1]
print("%s\t%s\t%s" % ("Operator".rjust(20), "Timing".rjust(20), "Performance".rjust(20)))
for result in log:
    print(
        "%s\t%s\t%s"
        % (result[0].rjust(20), str(result[1]).rjust(20), str(result[1] / baseline).rjust(20))
    )