# Relax ONNX 前端

In [1]:
import numpy as np
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx

import onnx
from onnx import helper, TensorProto, ModelProto, mapping
import onnxruntime

bg = np.random.MT19937(0)
rg = np.random.Generator(bg)
# from tvm.relax.ir.instrument import WellFormedInstrument
# tvm.transform.PassContext.current().override_instruments([WellFormedInstrument()])

模拟输入：

In [2]:
def generate_random_inputs(
    model: ModelProto, inputs: dict[str, np.ndarray]|None = None
) -> dict[str, np.ndarray]:
    input_values = {}
    # 遍历模型输入并提取它们的形状。
    for i in model.graph.input:
        if inputs is not None and i.name in inputs and inputs[i.name] is not None:
            input_values[i.name] = inputs[i.name]
            continue
        shape = []
        for dim in i.type.tensor_type.shape.dim:
            shape.append(dim.dim_value)

        # 从输入提取数据类型
        if i.type.tensor_type.elem_type:
            dtype = helper.tensor_dtype_to_np_dtype(i.type.tensor_type.elem_type)
        else:
            dtype = "float32"

        # 为每个输入生成随机输入。
        if dtype == "bool":
            # random_value = np.random.choice(a=[False, True], size=shape)
            random_value = rg.choice(a=[False, True], size=shape)
        else:
            # random_value = np.random.normal(size=shape).astype(dtype)
            random_value = rg.standard_normal(size=shape).astype(dtype)
        input_values[i.name] = random_value
    return input_values

检查一致性：

In [3]:
def check_correctness(
    model: ModelProto, inputs: dict[str, np.ndarray]|None = None, opset: int = None
) -> None:
    """通过导入器在 onnxruntime 和 TVM 上运行 ONNX 模型，并确认结果匹配。否则，将引发异常。

    Parameters
    ----------
    model: 应该测试的输入 ONNX 模型。
    inputs: 可选的字典，包含 ONNX 模型中每个输入的值。
        
    opset: 用于 onnx 导入器的 opset 版本。
    """
    if opset is not None:
        model.opset_import[0].version = opset

    # 如果没有提供输入，则从 ONNX 图中提取它们并生成随机值，我们将使用这些随机值进行测试。
    inputs = generate_random_inputs(model, inputs)

    # 通过 onnx 运行模型以获得预期结果。
    ort_session = onnxruntime.InferenceSession(
        model.SerializeToString(), providers=["CPUExecutionProvider"]
    )
    ort_output = ort_session.run([], inputs)

    # 通过 onnx 导入器将 onnx 模型转换为 relax 格式。
    tvm_model = from_onnx(model, opset=opset, keep_params_in_input=True)
    # 转换为推理模式的算子。
    tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
    # 将任何 relax 算子合法化为 tensorir。
    tvm_model = relax.transform.LegalizeOps()(tvm_model)

    # 将模型与参数分离。
    tvm_model, params = relax.frontend.detach_params(tvm_model)
    # 将 relax graph 编译成 VM，然后运行。
    with tvm.transform.PassContext(opt_level=3):
        ex = relax.build(tvm_model, target="llvm")
        vm = relax.VirtualMachine(ex, tvm.cpu())
    # 准备输入
    input_list = [
        inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs
    ]
    if params:
        input_list += params["main"]

    # 运行模型并检查输出。
    vm.set_input("main", *input_list)
    vm.invoke_stateful("main")
    tvm_output = vm.get_outputs("main")
    # 如果只有一个输出，则将其包装为列表。
    if isinstance(tvm_output, tvm.nd.NDArray):
        tvm_output = [tvm_output]
    # 如果输出是 shape tuple，则将其转换为 ndarray 以进行比较。
    if isinstance(tvm_output, tvm.runtime.ShapeTuple):
        tvm_output = [tvm.nd.array([int(i) for i in tvm_output])]

    tvm_num_outputs = len(tvm_output)
    # 形状元组需要特殊处理。
    if isinstance(tvm_output, tvm.runtime.ShapeTuple):
        tvm_num_outputs = 1

    # 检查输出数量是否匹配。
    assert tvm_num_outputs == len(ort_output), "Unequal number of outputs"

    for (tvm_out, ort_out) in zip(tvm_output, ort_output):
        # TODO 允许可配置的容差值。
        # 有时 None 被用于表示未使用的输出。
        if ort_out is not None:
            np.testing.assert_allclose(tvm_out.numpy(), ort_out, atol=1e-5)

## 算子测试

### `sanitize`

In [4]:
workloads = [
    ([".", "123"], ["_", "input_123"]),
    ([".", "_"], ["_", "__1"]),
    (["123", "input_123"], ["input_123", "input_123_1"]),
]
for input_names, expected_names in workloads:
    node = helper.make_node("Add", inputs=input_names, outputs=["output"])
    graph = helper.make_graph(
        [node],
        "test",
        inputs=[
            helper.make_tensor_value_info(str(var), TensorProto.FLOAT, [32, 32])
            for var in input_names
        ],
        outputs=[
            helper.make_tensor_value_info("output", TensorProto.FLOAT, [32, 32]),
        ],
    )
    model = helper.make_model(graph, producer_name="test_sanitizer")

    tvm_model = from_onnx(model)

    for i, param in enumerate(tvm_model["main"].params):
        assert param.name_hint == expected_names[i]



### 辅助函数

In [5]:
def verify_unary(op_name, shape, attrs={}, domain=None):
    test_node = helper.make_node(op_name, ["x"], ["y"], **attrs, domain=domain)
    graph = helper.make_graph(
        [test_node],
        "elemwise_test",
        inputs=[
            helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
        ],
        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)],
    )

    model = helper.make_model(graph, producer_name="elemwise_test")
    check_correctness(model)


def verify_binary(op_name, shape_a, shape_b, shape_c, attrs={}, domain=None):
    test_node = helper.make_node(op_name, ["a", "b"], ["c"], **attrs, domain=domain)
    graph = helper.make_graph(
        [test_node],
        "binary_test",
        inputs=[
            helper.make_tensor_value_info("a", TensorProto.FLOAT, shape_a),
            helper.make_tensor_value_info("b", TensorProto.FLOAT, shape_b),
        ],
        outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, shape_c)],
    )

    model = helper.make_model(graph, producer_name="binary_test")
    check_correctness(model)


def verify_compare(op_name, shape, attrs={}, domain=None):
    test_node = helper.make_node(op_name, ["a", "b"], ["c"], **attrs, domain=domain)
    graph = helper.make_graph(
        [test_node],
        "compare_test",
        inputs=[
            helper.make_tensor_value_info("a", TensorProto.FLOAT, shape),
            helper.make_tensor_value_info("b", TensorProto.FLOAT, shape),
        ],
        outputs=[helper.make_tensor_value_info("c", TensorProto.BOOL, shape)],
    )

    model = helper.make_model(graph, producer_name="compare_test")
    check_correctness(model)


def verify_ternary(op_name, shape_a, shape_b, shape_c, shape_d, attrs={}, domain=None):
    test_node = helper.make_node(op_name, ["a", "b", "c"], ["d"], **attrs, domain=domain)
    graph = helper.make_graph(
        [test_node],
        "ternary_test",
        inputs=[
            helper.make_tensor_value_info("a", TensorProto.FLOAT, shape_a),
            helper.make_tensor_value_info("b", TensorProto.FLOAT, shape_b),
            helper.make_tensor_value_info("c", TensorProto.FLOAT, shape_c),
        ],
        outputs=[helper.make_tensor_value_info("d", TensorProto.FLOAT, shape_d)],
    )

    model = helper.make_model(graph, producer_name="ternary_test")
    check_correctness(model)