Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
Relay->Relax translator (ResNet example) (#75)
Browse files Browse the repository at this point in the history
* Relay translator; build static bert forward/backward.

* rebase

* Add ops.

* resnet demo

* cleanup code.

* Rebase.

* Address comments.

* leverage FTVMCompute for most op translation; reuse relay.Constant.

* lint.
  • Loading branch information
YuchenJin authored and junrushao committed Feb 9, 2023
1 parent b241723 commit 686095b
Show file tree
Hide file tree
Showing 12 changed files with 410 additions and 21 deletions.
53 changes: 53 additions & 0 deletions apps/relax_examples/resnet.py
@@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Example ResNet workload by translating the Relay program to Relax"""

import tvm
import tvm.testing
from tvm.relay import testing
from tvm import relax, relay
from tvm.relax.testing import relay_translator, nn
from tvm.runtime import vm as vm_rt
from tvm.script import relax as R
import numpy as np

if __name__ == "__main__":
relay_mod, _ = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype="float32")

# translate the ResNet model from Relay to Relax
relax_mod = relay_translator.from_relay(relay_mod["main"])

# print the ResNet IRmodule got translated
print(R.parser.astext(relax_mod))

# build the IRModule and create relax vm
target = tvm.target.Target("llvm", host="llvm")
ex, lib = relax.vm.build(relax_mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)

# init weights and run the model on relax vm
shape = (1, 3, 224, 224)
data = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
params = nn.init_params(relax_mod)
res = vm["main"](data, *params)

# check correctness by comparing with relay result
exe = relay.vm.compile(relay_mod, target)
relay_vm = vm_rt.VirtualMachine(exe, tvm.cpu())
inputs = [data] + params
expected_output = relay_vm.run(*inputs)
tvm.testing.assert_allclose(res.numpy(), expected_output.numpy(), rtol=1e-4, atol=1e-4)
1 change: 1 addition & 0 deletions include/tvm/relax/expr.h
Expand Up @@ -22,6 +22,7 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/span.h>
#include <tvm/node/node.h>
#include <tvm/relax/type.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/map.h>
Expand Down
32 changes: 26 additions & 6 deletions python/tvm/relax/block_builder.py
Expand Up @@ -199,7 +199,10 @@ def _convert_te_arg_helper(arg):
key, str
), "emit_te only supports dict with string as the key currently"
return {k: _convert_te_arg_helper(arg[k]) for k in arg}
elif isinstance(arg, (int, float, str)):
elif (
isinstance(arg, (int, float, str, tir.IntImm, tvm.ir.Type, tvm.ir.Attrs))
or arg is None
):
return arg
raise TypeError("not supported type in emit_te: {}".format(type(arg)))

Expand Down Expand Up @@ -291,12 +294,20 @@ def emit(self, expr: Expr) -> Var:
def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var:
"""Emit a call node according to the te function.
This function converts arguments from relax expression to te tensor,
The callback func should return a te tensor.
The callback func should return a te tensor or a list of te tensors.
Parameters
----------
func : Callable
A function that return a te tensor.
A function that returns a te tensor or a list of te tensors.
args : Any, optional
arguments passed to the function.
kwargs : Any, optional
The keyword arguments passed to the function.
Note that the key "primfunc_name_hint" is reserved for passing name hint
to the PrimFunc that gets generated.
Returns
-------
Expand Down Expand Up @@ -403,22 +414,31 @@ def rx_func(x: Tensor[(n,), "float32"], y: Tensor[((n + 1),), "float32"])
= relax.call_tir(((n + 1),), te_func, (y,), (n,))
return gv
"""
primfunc_name_hint = kwargs.pop("primfunc_name_hint", None)
new_args, te_arg_list = self._convert_te_arg(args)
new_kwargs, te_kwarg_list = self._convert_te_arg(kwargs)

te_args = te_arg_list + te_kwarg_list

te_out = func(*new_args, **new_kwargs)
assert isinstance(te_out, tvm.te.tensor.Tensor) or (
isinstance(te_out, (tuple, list))
isinstance(te_out, (tuple, list, tvm.ir.Array))
and all(isinstance(t, tvm.te.tensor.Tensor) for t in te_out)
), "only support te.tensor or tuple/list of te.tensor as function output"
), "only support te.tensor or tuple/list/Array of te.tensor as function output"

if isinstance(te_out, (tuple, list, tvm.ir.Array)) and len(te_out) == 1:
te_out = te_out[0]

outs = [te_out] if isinstance(te_out, tvm.te.tensor.Tensor) else list(te_out)
unbound_tir_vars = self._get_unbound_tir_vars(te_args + outs)

inputs = [*te_args] + outs
tir_func = tvm.te.create_prim_func(inputs, unbound_tir_vars)
gvar = self.add_func(tir_func, func.__name__)

if primfunc_name_hint:
gvar = self.add_func(tir_func, primfunc_name_hint)
else:
gvar = self.add_func(tir_func, func.__name__)

call_args = [x.op.value for x in te_args]
output_shape = (
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/relax/testing/__init__.py
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, redefined-builtin
"""The Relax testing namespace containing nn and translator."""

from .nn import *
from .relay_translator import *
19 changes: 19 additions & 0 deletions python/tvm/relax/testing/_ffi_api.py
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
"""FFI API for for Relax."""
import tvm._ffi

tvm._ffi._init_api("relax", __name__)
29 changes: 16 additions & 13 deletions python/tvm/relax/testing/nn.py
Expand Up @@ -14,9 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin
"""PyTorch-like nn.Module API for constructing workloads."""


from typing import List, Any, Callable
import tvm
from typing import List, Optional, Union, Dict, Any, Callable
from tvm import relax, topi, tir
import numpy as np

Expand Down Expand Up @@ -51,11 +54,11 @@ class Module:
"""Base class for all model modules.
A neural network or a layer can subclass this class.
Example
-------
.. code-block:: python
# Define a linear layer
class Linear(Module)
def __init__(self, in_features, out_features, bias=True):
Expand All @@ -67,7 +70,8 @@ def __init__(self, in_features, out_features, bias=True):
else:
self.bias = None
# All submodules should implement forward. Defines the forward computation performed at every call.
# All submodules should implement forward.
# Defines the forward computation performed at every call.
def forward(self, input: relax.Expr) -> relax.Var:
y = emit_te(topi.matmul, input, self.weight)
if self.bias is not None:
Expand All @@ -79,7 +83,7 @@ def parameters(self) -> List[Parameter]:
"""Return the list of parameters in the module."""
return _unpack_params(self.__dict__)

def forward(self):
def forward(self, input: relax.Expr):
"""Define the computation performed at every call."""
raise NotImplementedError()

Expand All @@ -90,22 +94,21 @@ def __call__(self, *args, **kwargs):
def _unpack_params(value: object) -> List[relax.Var]:
if isinstance(value, Parameter):
return [value]
elif isinstance(value, Module):
if isinstance(value, Module):
return value.parameters()
elif isinstance(value, dict):
if isinstance(value, dict):
params = []
for k, v in value.items():
for v in value.values():
params += _unpack_params(v)
return params
elif isinstance(value, (list, tuple)):
if isinstance(value, (list, tuple)):
params = []
for v in value:
params += _unpack_params(v)
return params
elif isinstance(value, (int, float, str)):
if isinstance(value, (int, float, str)):
return []
else:
raise TypeError("not supported type when unpacking parameters: {}".format(type(value)))
raise TypeError("not supported type when unpacking parameters: {}".format(type(value)))


def init_params(mod: tvm.IRModule) -> List[tvm.nd.array]:
Expand All @@ -122,7 +125,7 @@ def init_params(mod: tvm.IRModule) -> List[tvm.nd.array]:
shape.append(int(i))
else:
raise TypeError("cannot initialize for unknown-shape parameters.")
params.append(tvm.nd.array(np.random.rand(*shape).astype(np.float32)))
params.append(tvm.nd.array(np.zeros(shape).astype(np.float32)))
else:
raise TypeError("cannot initialize for unknown-shape parameters.")
return params
Expand Down

0 comments on commit 686095b

Please sign in to comment.