# `contains_impure_call`

In [1]:
import tvm
from tvm import relax as rx
from tvm.relax.analysis import contains_impure_call
from tvm.script import relax as R

In [2]:
@tvm.script.ir_module
class PureTest:
    @R.function
    def pure_func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
        y = R.add(x, x)
        z = R.multiply(x, y)
        return R.add(z, R.const(1, "int32"))

PureTest.show()
assert not contains_impure_call(PureTest["pure_func"])

In [6]:
@tvm.script.ir_module
class ImpureTest:
    @R.function(pure=False)
    def impure_func() -> R.Object:
        y = R.print(format="I am a message")
        return y

assert contains_impure_call(ImpureTest["impure_func"])

In [7]:
@tvm.script.ir_module
class NestedTest:
    @R.function
    def pure_with_impure_nested() -> R.Tensor((), "int32"):
        # unused
        @R.function(pure=False)
        def impure_inner() -> R.Object:
            y = R.print(format="Another, worse, message")
            return y

        x = R.const(0, dtype="int32")
        return R.add(x, x)

assert not contains_impure_call(NestedTest["pure_with_impure_nested"])
assert contains_impure_call(
    NestedTest["pure_with_impure_nested"].body.blocks[0].bindings[0].value
)

In [6]:
# Ignoring a recursive call. This can be useful if some transformation
# removes an impure operation and the compiler needs to check if the impure
# function has become pure
# 忽略递归调用。
# 如果某些转换删除了 impure 的运算，并且编译器需要检查 impure 的函数是否已变为纯净，则此方法很有用。
@tvm.script.ir_module
class RecursiveTest:
    @R.function(pure=False)
    def recursive_impure() -> R.Object:
        x = R.const(1, "int32")
        y = R.add(x, x)
        z = R.print(x, y, format="{} {}")
        w = RecursiveTest.recursive_impure()
        return w

assert contains_impure_call(RecursiveTest["recursive_impure"])
# but if we remove the impure call...
body = RecursiveTest["recursive_impure"].body
own_name = body.blocks[0].bindings[-1].value.op
# skipping the call to print...
new_bindings = [
    body.blocks[0].bindings[0],
    body.blocks[0].bindings[1],
    body.blocks[0].bindings[-1],
]
# Note: we construct the function in this way so that we keep the old vars
# with their current StructInfo. That would get fixed during normalization.
# However, this situation is meant to correspond to an intermediate state
# that might arise within a pass.
new_body = rx.SeqExpr([rx.BindingBlock(new_bindings)], body.body)

# if we didn't ignore the recursive call, the fact the var's StructInfo
# calls it impure would throw it off
assert not contains_impure_call(new_body, own_name=own_name)
assert contains_impure_call(new_body)