# 拓展 TVM

```{topic} 导航
展示了其他库如何在 C++ 和 Python API 中扩展 TVM。

- 该库扩展了 TVM 的功能。
- Python 模块加载新的共享库，并可以使用 TVM 的 Python API 进行插值操作。
```

::::::{dropdown} 代码
:::::{tab-set-code}

::::{literalinclude} extension/src/tvm_ext.cc
:language: C++
::::

::::{literalinclude} extension/tvm_ext/__init__.py
:language: python
::::

:::::
::::::

编译：

```bash
make outputs/libs/libtvm_ext.so
```

In [1]:
import numpy as np
import torch
from torch import nn
import tvm
from tvm.ir.module import IRModule
from tvm import te, topi, relay
from extension import tvm_ext

## `tvm_ext.bind_add`

In [2]:
def add(a, b):
    return a + b

f = tvm_ext.bind_add(add, 1)
assert f(2) == 3

## `tvm.ext_dev`

In [14]:
n = 10
A = te.placeholder((n,), name="A")
B = te.compute((n,), lambda *i: A(*i) + 1.0, name="B")
s = te.create_schedule(B.op)

def check_llvm():
    f = tvm.build(s, [A, B], tvm.target.Target("ext_dev", "llvm"))
    dev = tvm.ext_dev(0)
    # launch the kernel.
    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
    b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)
    f(a, b)
    np.testing.assert_allclose(b.numpy(), a.numpy() + 1)

check_llvm()

## `tvm_ext.sym_add`

In [6]:
a = te.var("a")
b = te.var("b")
c = tvm_ext.sym_add(a, b)
assert c.a == a and c.b == b

## `tvm_ext.ivec_create`

In [7]:
ivec = tvm_ext.ivec_create(1, 2, 3)
assert isinstance(ivec, tvm_ext.IntVec)
assert ivec[0] == 1
assert ivec[1] == 2

def ivec_cb(v2):
    assert isinstance(v2, tvm_ext.IntVec)
    assert v2[2] == 3

tvm.runtime.convert(ivec_cb)(ivec)

## `extract_ext_funcs`

In [8]:
fdict = tvm._ffi.registry.extract_ext_funcs(tvm_ext._LIB.TVMExtDeclare)
assert fdict["mul"](3, 4) == 12

## `extern_call`

In [16]:
n = 10
A = te.placeholder((n,), name="A")
B = te.compute(
    (n,), lambda *i: tvm.tir.call_extern("float32", "TVMTestAddOne", A(*i)), name="B"
)
s = te.create_schedule(B.op)

def check_llvm():
    f = tvm.build(s, [A, B], "llvm")
    dev = tvm.cpu(0)
    # launch the kernel.
    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
    b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)
    f(a, b)
    np.testing.assert_allclose(b.numpy(), a.numpy() + 1)

check_llvm()

## `tvm_ext.NDSubClass`

In [17]:
a = tvm_ext.NDSubClass.create(additional_info=3)
b = tvm_ext.NDSubClass.create(additional_info=5)
assert isinstance(a, tvm_ext.NDSubClass)
c = a + b
d = a + a
e = b + b
assert a.additional_info == 3
assert b.additional_info == 5
assert c.additional_info == 8
assert d.additional_info == 6
assert e.additional_info == 10