# Relax 快速上手

In [1]:
import set_env

In [5]:
import tempfile
import numpy as np
import tvm
from tvm import relax
from tvm._ffi.base import TVMError
from tvm.script import relax as R, tir as T, ir as I

## `unique`

In [6]:
@I.ir_module
class InputModule:
    @R.function
    def foo(x: R.Tensor(("m", "n"), "int64")):
        y = R.unique(x, sorted=False)
        y_sorted = R.unique(x)
        return y, y_sorted

def run_cpu(mod, func_name, *input):
    target = tvm.target.Target("llvm")
    ex = relax.build(mod, target)
    vm = relax.VirtualMachine(ex, tvm.cpu())
    return vm[func_name](*input)

In [9]:
data_numpy = np.random.randint(0, 16, (16, 16))
data = tvm.nd.array(data_numpy)
result, result_sorted = run_cpu(InputModule, "foo", data)
expected_output_sorted, indices = np.unique(data_numpy, return_index=True)
expected_output = [data_numpy.flatten()[index] for index in sorted(indices, reverse=True)]
np.testing.assert_array_equal(expected_output_sorted, result_sorted.numpy())
np.testing.assert_array_equal(expected_output, result.numpy())

## 其他测试

In [None]:


@tvm.script.ir_module
class PrintTest:
    @R.function(pure=False)
    def foo(x: R.Tensor((), "int32")):
        # results have to be bound, but we don't use them
        # TODO: We should allow calls whose results are not bound for side effects;
        #       it would be easy syntactic sugar to add.
        p1 = R.print(x)
        p2 = R.print(x, format="Number: {}")
        t = (x, x)
        p3 = R.print(t, format="Tuple: {}")
        p4 = R.print(x, t)
        p5 = R.print(x, x, format="Custom print: {} {}")
        p6 = R.print(x, t, format="Another print: {} {}")
        return x


def test_print():
    try:
        stdout = sys.stdout
        with tempfile.TemporaryFile(mode="w+") as test_out:
            sys.stdout = test_out
            run_cpu(PrintTest, "foo", tvm.nd.array(np.array(1).astype("int32")))
            test_out.seek(0)
            printed_text = str(test_out.read())
            expected = "1\nNumber: 1\nTuple: (1, 1)\n1 (1, 1)\nCustom print: 1 1\nAnother print: 1 (1, 1)\n"
            assert printed_text in expected, ("printed_text is ", printed_text)
    finally:
        sys.stdout = stdout


@tvm.script.ir_module
class AssertOpTest:
    @R.function(pure=False)
    def passes(x: R.Tensor((), "int32")):
        p1 = R.assert_op(relax.const(True))
        return x

    @R.function(pure=False)
    def pass_with_args(x: R.Tensor((), "int32")):
        p1 = R.assert_op(relax.const(True), x, format="You won't see me")
        return x

    @R.function(pure=False)
    def simple_fail(x: R.Tensor((), "int32")):
        p1 = R.assert_op(relax.const(False))
        return x

    @R.function(pure=False)
    def fail_with_message(x: R.Tensor((), "int32")):
        p1 = R.assert_op(relax.const(False), format="I failed...")
        return x

    @R.function(pure=False)
    def fail_with_args(x: R.Tensor((), "int32")):
        # no format
        p1 = R.assert_op(relax.const(False), [x, x])
        return x

    @R.function(pure=False)
    def fail_with_formatted_message(x: R.Tensor((), "int32")):
        p1 = R.assert_op(relax.const(False), x, format="Number: {}")
        return x


def test_assert_op():
    def check_assertion_error(func_name, func_arg, expected_message):
        passed = False
        try:
            run_cpu(AssertOpTest, func_name, func_arg)
            passed = True
        except TVMError as e:
            # TVM will print out a TVMError that will contain the
            # generated error at the bottom of a stack trace
            assert "AssertionError" in e.args[0]
            assert expected_message in e.args[0]
        assert not passed

    run_cpu(AssertOpTest, "passes", tvm.nd.array(np.array(1).astype("int32")))
    run_cpu(AssertOpTest, "pass_with_args", tvm.nd.array(np.array(2).astype("int32")))
    check_assertion_error(
        "simple_fail", tvm.nd.array(np.array(3).astype("int32")), "Assertion Failed"
    )
    check_assertion_error(
        "fail_with_message", tvm.nd.array(np.array(4).astype("int32")), "I failed..."
    )
    check_assertion_error("fail_with_args", tvm.nd.array(np.array(5).astype("int32")), "5, 5")
    check_assertion_error(
        "fail_with_formatted_message", tvm.nd.array(np.array(6).astype("int32")), "Number: 6"
    )


@tvm.script.ir_module
class ShapeOfTest:
    @R.function
    def get_shape(t: R.Tensor(ndim=-1, dtype="int32")) -> R.Shape(ndim=-1):
        return R.shape_of(t)

    @R.function
    def get_constrained_shape(t: R.Tensor(ndim=1, dtype="int32")) -> R.Shape(ndim=1):
        # require the input tensor to have rank 1
        return R.shape_of(t)

    @R.function
    def get_scalar_shape() -> R.Shape(()):
        x: R.Tensor((), "int32") = R.const(1, dtype="int32")
        return R.shape_of(x)

    @R.function
    def get_constant_shape() -> R.Shape((2, 2)):
        x: R.Tensor((2, 2), "int32") = R.const(
            np.array([[1, 2], [3, 4]], dtype="int32"), dtype="int32"
        )
        return R.shape_of(x)


def test_op_shape_of():
    unit_shape = run_cpu(ShapeOfTest, "get_scalar_shape")
    assert unit_shape == tvm.runtime.ShapeTuple([])

    const_shape = run_cpu(ShapeOfTest, "get_constant_shape")
    assert const_shape == tvm.runtime.ShapeTuple([2, 2])

    scalar_shape = run_cpu(ShapeOfTest, "get_shape", tvm.nd.array(np.array(1, dtype="int32")))
    assert scalar_shape == tvm.runtime.ShapeTuple([])

    tensor_shape = run_cpu(
        ShapeOfTest, "get_shape", tvm.nd.array(np.zeros((1, 2, 3)).astype("int32"))
    )
    assert tensor_shape == tvm.runtime.ShapeTuple([1, 2, 3])

    constrained_shape = run_cpu(
        ShapeOfTest, "get_constrained_shape", tvm.nd.array(np.zeros((1,)).astype("int32"))
    )
    assert constrained_shape == tvm.runtime.ShapeTuple([1])


@tvm.script.ir_module
class ShapeToTensorTest:
    @R.function
    def const_shape(shape: R.Shape(ndim=-1)) -> R.Tensor(ndim=-1):
        return R.shape_to_tensor(shape)

    @R.function
    def symbolic_shape(shape: R.Shape(("m", "n"))) -> R.Tensor(ndim=-1):
        m = T.int64()
        n = T.int64()
        return R.shape_to_tensor(shape)


def test_op_shape_to_tensor():
    # Check struct info
    isinstance(ShapeToTensorTest["const_shape"].body.struct_info, tvm.relax.TensorStructInfo)
    assert ShapeToTensorTest["const_shape"].body.struct_info.ndim == 1
    isinstance(ShapeToTensorTest["symbolic_shape"].body.struct_info, tvm.relax.TensorStructInfo)
    assert ShapeToTensorTest["symbolic_shape"].body.struct_info.ndim == 1

    # Check its functionality
    out2d = run_cpu(ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 2]))
    assert isinstance(out2d, tvm.runtime.ndarray.NDArray)
    assert np.array_equal(out2d.numpy(), np.array([3, 2]))

    out3d = run_cpu(ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2]))
    assert isinstance(out3d, tvm.runtime.ndarray.NDArray)
    assert np.array_equal(out3d.numpy(), np.array([3, 3, 2]))

    out4d = run_cpu(ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2, 2]))
    assert isinstance(out4d, tvm.runtime.ndarray.NDArray)
    assert np.array_equal(out4d.numpy(), np.array([3, 3, 2, 2]))

    outs = run_cpu(ShapeToTensorTest, "symbolic_shape", tvm.runtime.ShapeTuple([3, 2]))
    assert isinstance(outs, tvm.runtime.ndarray.NDArray)
    assert np.array_equal(outs.numpy(), np.array([3, 2]))


def test_op_call_pure_packed():
    @tvm.script.ir_module
    class CallPureTest:
        @R.function
        def pure_copy(x: R.Tensor((3, 4), "float32")):
            z = R.call_pure_packed(
                "vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32"))
            )
            return z

    np.random.seed(0)  # to avoid flakiness
    arr = np.random.rand(3, 4).astype("float32")
    copy_found = run_cpu(CallPureTest, "pure_copy", tvm.nd.array(arr))
    assert (copy_found.numpy() == arr).all()
