<a href="https://colab.research.google.com/github/seongheechoi/education/blob/main/%EC%8B%A4%EC%8A%B5_3_2_optimizing_operators_with_auto_scheduling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **TVM 실습자료 3.2: Optimizing Operators with Auto-scheduling**

In [None]:
!pip install numpy==1.26.4
import numpy as np
print(np.__version__)
!pip list | grep numpy

1.26.4
numpy                                 1.26.4


In [None]:
# Linux/MacOS CPU build only!
# See tlcpack.ai for other pre-built binaries including CUDA
!python -m pip install --upgrade pip
!pip install apache-tvm

Collecting pip
  Downloading pip-25.1.1-py3-none-any.whl.metadata (3.6 kB)
Downloading pip-25.1.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m29.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.1.1
Collecting apache-tvm
  Downloading apache_tvm-0.14.dev273-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.3 kB)
Downloading apache_tvm-0.14.dev273-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (69.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.1/69.1 MB[0m [31m39.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: apache-tvm
Successfully installed apache-tvm-0.14.dev273


In [None]:
import os

import numpy as np
import tvm
from tvm import te, auto_scheduler

**Defining the Matrix Multiplication**

In [None]:
@auto_scheduler.register_workload  # Note the auto_scheduler decorator
def matmul_add(N, L, M, dtype):
    A = te.placeholder((N, L), name="A", dtype=dtype)
    B = te.placeholder((L, M), name="B", dtype=dtype)
    C = te.placeholder((N, M), name="C", dtype=dtype)

    k = te.reduce_axis((0, L), name="k")
    matmul = te.compute(
        (N, M),
        lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
        name="matmul",
        attrs={"layout_free_placeholders": [B]},  # enable automatic layout transform for tensor B
    )
    out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")

    return [A, B, C, out]

**Create the search task**

In [None]:
target = tvm.target.Target("llvm")
N = L = M = 1024
task = tvm.auto_scheduler.SearchTask(func=matmul_add, args=(N, L, M, "float32"), target=target)

# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)

Computational DAG:
A = PLACEHOLDER [1024, 1024]
B = PLACEHOLDER [1024, 1024]
matmul(i, j) += (A[i, k]*B[k, j])
C = PLACEHOLDER [1024, 1024]
out(i, j) = (matmul[i, j] + C[i, j])



**Set Parameters for Auto-Scheduler**

In [None]:
log_file = "matmul.json"
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)

**Run the search**

In [None]:
# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)






**Inspecting the Optimized Schedule**

In [None]:
print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))

Lowered TIR:
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32"), out: T.Buffer((1024, 1024), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        auto_scheduler_layout_transform = T.allocate([1048576], "float32", "global")
        auto_scheduler_layout_transform_1 = T.Buffer((1048576,), data=auto_scheduler_layout_transform)
        for ax0_ax1_fused_ax2_fused in T.parallel(8):
            for ax4, ax5, ax6, ax7 in T.grid(32, 16, 32, 8):
                B_1 = T.Buffer((1048576,), data=B.data)
                auto_scheduler_layout_transform_1[ax0_ax1_fused_ax2_fused * 131072 + ax4 * 4096 + ax5 * 256 + ax6 * 8 + ax7] = B_1[ax4 * 32768 + ax6 * 1024 + ax0_ax1_fused_ax2_fused * 128 + ax5 * 8 + ax7]
        for i_outer_outer_j_outer_outer_fused_i_out

**Check correctness and evaluate performance**

In [None]:
func = tvm.build(sch, args, target)
a_np = np.random.uniform(size=(N, L)).astype(np.float32)
b_np = np.random.uniform(size=(L, M)).astype(np.float32)
c_np = np.random.uniform(size=(N, M)).astype(np.float32)
out_np = a_np.dot(b_np) + c_np

dev = tvm.cpu()
a_tvm = tvm.nd.array(a_np, device=dev)
b_tvm = tvm.nd.array(b_np, device=dev)
c_tvm = tvm.nd.array(c_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)
func(a_tvm, b_tvm, c_tvm, out_tvm)

# Check results
np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)

# Evaluate execution time.
evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
print(
    "Execution time of this operator: %.3f ms"
    % (np.median(evaluator(a_tvm, b_tvm, c_tvm, out_tvm).results) * 1000)
)

Execution time of this operator: 215.292 ms


**Using the record file**

In [None]:
print("Equivalent python schedule:")
print(task.print_best(log_file))

Equivalent python schedule:
matmul_i, matmul_j, matmul_k = tuple(matmul.op.axis) + tuple(matmul.op.reduce_axis)
out_i, out_j = tuple(out.op.axis) + tuple(out.op.reduce_axis)
matmul_i_o_i, matmul_i_i = s[matmul].split(matmul_i, factor=8)
matmul_i_o_o_i, matmul_i_o_i = s[matmul].split(matmul_i_o_i, factor=4)
matmul_i_o_o_o, matmul_i_o_o_i = s[matmul].split(matmul_i_o_o_i, factor=16)
matmul_j_o_i, matmul_j_i = s[matmul].split(matmul_j, factor=8)
matmul_j_o_o_i, matmul_j_o_i = s[matmul].split(matmul_j_o_i, factor=16)
matmul_j_o_o_o, matmul_j_o_o_i = s[matmul].split(matmul_j_o_o_i, factor=2)
matmul_k_o, matmul_k_i = s[matmul].split(matmul_k, factor=32)
s[matmul].reorder(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i, matmul_j_o_o_i, matmul_k_o, matmul_i_o_i, matmul_j_o_i, matmul_k_i, matmul_i_i, matmul_j_i)
out_i_o_i, out_i_i = s[out].split(out_i, factor=32)
out_i_o_o, out_i_o_i = s[out].split(out_i_o_i, factor=16)
out_j_o_i, out_j_i = s[out].split(out_j, factor=128)
out_j_o_o, out_j_o_i = 