<a href="https://colab.research.google.com/github/seongheechoi/education/blob/main/%EC%8B%A4%EC%8A%B5_3_5_auto_scheduling_sparse_matrix_multiplication_on_CPU_with_custom_sketch_rule.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **TVM 실습자료 3.5: Auto-scheduling Sparse Matrix Multiplication on CPU with Custom Sketch Rule**

In [None]:
!pip install numpy==1.26.4
import numpy as np
print(np.__version__)
!pip list | grep numpy

Collecting numpy==1.26.4
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/61.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.3/18.3 MB[0m [31m77.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
opencv-python-headless 4.12.0.88 requires numpy<

2.0.2
numpy                                 1.26.4


In [None]:
# Linux/MacOS CPU build only!
# See tlcpack.ai for other pre-built binaries including CUDA
!python -m pip install --upgrade pip
!pip install apache-tvm

Collecting pip
  Downloading pip-25.1.1-py3-none-any.whl.metadata (3.6 kB)
Downloading pip-25.1.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m17.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.1.1
Collecting apache-tvm
  Downloading apache_tvm-0.14.dev273-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.3 kB)
Downloading apache_tvm-0.14.dev273-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (69.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.1/69.1 MB[0m [31m56.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: apache-tvm
Successfully installed apache-tvm-0.14.dev273


In [None]:
import os

import numpy as np
import tvm
import tvm.testing
from tvm import te, auto_scheduler, runtime, topi
from tvm.auto_scheduler import _ffi_api
from tvm.topi.utils import get_const_tuple
from tvm.topi.sparse.utils import random_bsr_matrix

**Define the computation**

In [None]:
@auto_scheduler.register_workload
def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype):
    X = te.placeholder(shape=(M, K), dtype=dtype)
    W_data = te.placeholder(shape=w_data_shape, dtype=dtype)
    W_indices = te.placeholder(shape=w_indices_shape, dtype="int32")
    W_indptr = te.placeholder(shape=w_indptr_shape, dtype="int32")
    B = te.placeholder(shape=(M, N), dtype=dtype)

    out = topi.nn.sparse_dense(topi.nn.relu(X), W_data, W_indices, W_indptr)
    out = te.compute((M, N), lambda i, j: out[i, j] + B[i, j], name="BiasAdd")
    out = topi.nn.relu(out)

    return [X, W_data, W_indices, W_indptr, B, out]

**Special step for sparse workload**

In [None]:
# Define the basic shapes of this sparse computation
M = 128
K = 256
N = 512
BS_R = 16
BS_C = 1
density = 0.6

# Generate the test data with numpy
X_np = np.random.randn(M, K).astype("float32")
X_np = np.maximum(np.zeros((M, K), dtype="float32"), X_np)  # Relu
W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32")
W_np = W_sp_np.todense()
Y_np = X_np @ W_np.T  # Process the matrix multiplication
B_np = np.random.randn(M, N).astype("float32")
Y_np = Y_np + B_np  # Bias add
Y_np = np.maximum(np.zeros((M, N), dtype="float32"), Y_np)  # Relu

**Create the search task**

In [None]:
target = tvm.target.Target("llvm -mcpu=core-avx2")

# Register the sparse data to task inputs
prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%d_" % (
    N,
    K,
    BS_R,
    BS_C,
    W_sp_np.indices.shape[0],
    W_sp_np.indptr.shape[0],
)
task = tvm.auto_scheduler.SearchTask(
    func=sparse_dense,
    args=(M, N, K, W_sp_np.data.shape, W_sp_np.indices.shape, W_sp_np.indptr.shape, "float32"),
    target=target,
    task_inputs={
        prefix + "W_data": runtime.ndarray.array(W_sp_np.data),
        prefix + "W_indices": runtime.ndarray.array(W_sp_np.indices),
        prefix + "W_indptr": runtime.ndarray.array(W_sp_np.indptr),
    },
    task_inputs_save_to_file=True,
)

# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)

Computational DAG:
placeholder = PLACEHOLDER [33]
placeholder = PLACEHOLDER [4916, 16, 1]
placeholder = PLACEHOLDER [4916]
placeholder = PLACEHOLDER [128, 256]
compute(i0, i1) = max(placeholder[i0, i1], 0f)
compute(i, nb_j, j) += (placeholder[(placeholder[nb_j] + elem_idx), j, c]*compute[i, (placeholder[(placeholder[nb_j] + elem_idx)] + c)])
compute(m, n) = compute[m, floordiv(n, 16), floormod(n, 16)]
placeholder = PLACEHOLDER [128, 512]
BiasAdd(i, j) = (compute[i, j] + placeholder[i, j])
compute(i0, i1) = max(BiasAdd[i0, i1], 0f)



**Write the custom sketch for sparse dense op**

In [None]:
def meet_condition_func(search_policy, state, stage_id):
    state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag)
    if state.stages[stage_id].op.tag in [
        "sparse_dense_sp_rhs_bsrmm",
        "sparse_dense_sp_rhs_bsrmm_block",
    ]:
        return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST
    else:
        return auto_scheduler.PreloadCustomSketchRule.PASS


def apply_func(search_policy, state, stage_id):
    ret = []
    s0 = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag)
    if s0.stages[stage_id].op.tag == "sparse_dense_sp_rhs_bsrmm_block":
        return [s0.state_object, stage_id - 1]

    sparse_dense = s0.stages[stage_id].op
    sparse_dense_block = s0.stages[stage_id - 1].op
    assert sparse_dense.tag == "sparse_dense_sp_rhs_bsrmm"
    assert sparse_dense_block.tag == "sparse_dense_sp_rhs_bsrmm_block"

    # Set the default consumer of compute block
    consumer = sparse_dense

    # If sparse dense has a single elementwise consumer
    # We can compute inline the sparse_dense output stage
    consumers = _ffi_api.SearchPolicyUtilsGetConsumers(
        search_policy.search_task, s0.state_object, stage_id
    )
    if len(consumers) == 1:
        consumer_id = int(consumers.items()[0][0])
        if _ffi_api.SearchPolicyUtilsIsElementwiseMatch(
            search_policy.search_task, s0.state_object, stage_id, consumer_id
        ):
            consumer = s0.stages[consumer_id].op
            s0.compute_inline(sparse_dense)

    i, nb_j, j, row_offset, c = s0[sparse_dense_block].iters
    m, n = s0[consumer].iters
    i0, i1, i2 = s0.split(sparse_dense_block, i, [None, None])
    m0, m1 = s0.follow_split(consumer, m, len(s0.transform_steps) - 1, 1)
    j0, j1 = s0.split(sparse_dense_block, nb_j, [None])
    n0, n1 = s0.follow_split(consumer, n, len(s0.transform_steps) - 1, 1)
    s0.reorder(sparse_dense_block, [i0, j0, i1, j1, row_offset, i2, j, c])
    s0.reorder(consumer, [m0, n0, m1, n1])
    s0.compute_at(sparse_dense_block, consumer, n0)

    ret.append([s0.state_object, stage_id - 2])

    return ret

In [None]:
log_file = "sparse_dense.json"
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)

search_policy = auto_scheduler.SketchPolicy(
    task,
    program_cost_model=auto_scheduler.XGBModel(),
    init_search_callbacks=[
        auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func, "SparseDense")
    ],
)

**Run the search**

In [None]:
def tune_and_evaluate(tune_option, search_policy):
    # Run auto-tuning (search)
    task.tune(tune_option, search_policy)

    # Apply the best schedule
    sch, args = task.apply_best(log_file)

    # We can lower the schedule to see the IR after auto-scheduling.
    # The auto-scheduler correctly performs optimizations including multi-level tiling,
    # layout transformation, parallelization, vectorization, unrolling, and operator fusion.
    print("Lowered TIR:")
    print(tvm.lower(sch, args, simple_mode=True))

    # Check correctness and evaluate performance
    # We build the binary and check its correctness and performance.
    func = tvm.build(sch, args, target)

    dev = tvm.cpu()

    X_tvm = tvm.nd.array(X_np, device=dev)
    W_data_tvm = tvm.nd.array(W_sp_np.data, device=dev)
    W_indices_tvm = tvm.nd.array(W_sp_np.indices, device=dev)
    W_indptr_tvm = tvm.nd.array(W_sp_np.indptr, device=dev)
    B_tvm = tvm.nd.array(B_np, device=dev)
    Y_tvm = tvm.nd.empty(Y_np.shape, device=dev)

    func(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, B_tvm, Y_tvm)

    # Check results
    tvm.testing.assert_allclose(Y_np, Y_tvm.numpy(), atol=1e-4, rtol=1e-4)

    # Evaluate execution time.
    evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
    print(
        "Execution time of this operator: %.3f ms"
        % (
            np.median(
                evaluator(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, B_tvm, Y_tvm).results
            )
            * 1000
        )
    )


# Notice: We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run it by yourself.
tune_and_evaluate(tune_option, search_policy)




Lowered TIR:
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(placeholder: T.Buffer((128, 256), "float32"), placeholder_1: T.Buffer((4916, 16, 1), "float32"), placeholder_2: T.Buffer((4916,), "int32"), placeholder_3: T.Buffer((33,), "int32"), placeholder_4: T.Buffer((128, 512), "float32"), compute: T.Buffer((128, 512), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        for i0_outer in T.parallel(8):
            compute_1 = T.allocate([128], "float32", "global")
            for i1_outer in range(64):
                compute_2 = T.Buffer((128,), data=compute_1)
                for i_outer_inner in range(2):
                    cse_var_2: T.int32 = i1_outer // 2
                    cse_var_1: T.int32 = i_outer_inner * 64
                    compute_2[cse_var_1] = T.float32(0)
                    compute_2[cse_var_1 + 1] = T.float32(0)
               