In [None]:
from __future__ import annotations  # must import to defer parsing of annotations
import os
import numpy as np
import tvm
from tvm.relay import Call
from tvm import relax, tir, topi
from tvm.runtime import container
from tvm.relax.testing import nn

import tvm.script
from tvm.script import tir as T, relax as R

In [None]:
builder = relax.BlockBuilder()

input_size = 784
hidden_sizes = [128, 32]
output_size = 10

In [None]:
with builder.function(name="main"):
        model = nn.Sequential(
            nn.Linear(input_size, hidden_sizes[0]),
            nn.ReLU(),
            nn.Linear(hidden_sizes[0], hidden_sizes[1]),
            nn.ReLU(),
            nn.Linear(hidden_sizes[1], output_size),
            nn.LogSoftmax(),
        )
        # n is a symbolic variable to represent a dynamic batch size
        n = tir.Var("n", "int64")
        data = nn.Placeholder((n, input_size), name="data")
        output = model(data)
        params = [data] + model.parameters()
        builder.emit_func_output(output, params=params) 

In [None]:
mod = builder.get()
print(R.parser.astext(mod))

In [None]:
class Linear(nn.Module):
    """Applies a linear transformation to the input data: :math:`y = xA + b`."""

    def __init__(self, in_features, out_features, bias=True):
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter((in_features, out_features), name="linear_weight")
        if bias:
            self.bias = Parameter((out_features,), name="linear_bias")
        else:
            self.bias = None

    def forward(self, input: relax.Expr) -> relax.Var:
        y = emit_te(topi.matmul, input, self.weight)
        if self.bias is not None:
            y = emit_te(topi.add, y, self.bias)
        return y

In [None]:
def build_mlp(data, weight):
    bb = relax.BlockBuilder()

    with bb.function("mlp", [data, weight]):
        gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False)
        gv1 = bb.emit_te(topi.nn.relu, gv0)
        bb.emit_func_output(gv1)

    mod = bb.get()
    return mod


In [None]:
# symbolic dimensions
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")

# create data and weight variables
data = relax.Var("data", [n, m], relax.DynTensorType(2, "float32"))
weight = relax.Var("weight", [m, n], relax.DynTensorType(2, "float32"))

# construct a mlp model
mod = build_mlp(data, weight)
print(R.parser.astext(mod))

# build and create vm executor
target = tvm.target.Target("llvm", host="llvm")
ex = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())

In [None]:
# run the mlp model on relax vm
data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))
weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))
res = vm["mlp"](data, weight)
print(res)

In [None]:
@tvm.register_func("test.vm.tile")
def tile_packed(a, b):
    b[:] = tvm.nd.array(np.tile(a.asnumpy(), (1, 2)))

In [None]:
src = """@tvm.script.ir_module
class InputModule:
    @R.function
    def foo(x: Tensor[(n, m), "float32"]) -> Tensor:
        with relax.dataflow():
            y = R.call_tir("test.vm.tile", (x), (n, m * 2), dtype="float32")
            relax.output(y)
        return y
"""

# Original Relax Program
print("======================")
print("Original Relax Program\n")
mod = R.parser.from_source(src)
code = R.parser.astext(mod)
print(code)

In [None]:
# ToNonDataflow Pass
print("======================")
print("PASS0: To Non Dataflow\n")
mod = relax.transform.ToNonDataflow()(mod)
print(R.parser.astext(mod))

In [None]:
# CallDPS Rewrite
print("======================")
print("PASS1: CallDPS Rewrite\n")
mod = relax.transform.CallTIRRewrite()(mod)
print(R.parser.astext(mod))

In [None]:
# Memory Lower
print("======================")
print("PASS2: Memory Lower\n")
mod = relax.transform.VMMemoryLower()(mod)
print(R.parser.astext(mod))


In [None]:
# Shape Lower
print("======================")
print("PASS3: Shape Lower\n")
mod = relax.transform.VMShapeLower()(mod)
print(R.parser.astext(mod))

In [None]:
# Build & Execute
print("======================")
print("Build & Execute")

target = tvm.target.Target("llvm", host="llvm")
ex = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())

shape = (3, 4)
inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
out = vm["foo"](inp)
print("input: ", inp)
print("output: ", out)
np.testing.assert_allclose(np.tile(inp.asnumpy(), (1, 2)), out.asnumpy())