Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[TVMScript] Switch to the new parser (#276)
Browse files Browse the repository at this point in the history
* [TVMScript] Support cross-function call for relax function

This PR adds support for cross-function call for relax function, by declaring a function signature (i.e. an empty function that contains params and return type/shape but w/o body.)

However, the PR meets the issue of block_builder shape deduction, which does not use function `ret_shape` to infer the shape of GlobalVar Calls.
  • Loading branch information
Hzfengsy authored Nov 16, 2022
1 parent a443b8d commit 1c73696
Show file tree
Hide file tree
Showing 63 changed files with 1,742 additions and 1,386 deletions.
2 changes: 1 addition & 1 deletion apps/relax_examples/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

# get and print the IRmodule being built
mod = builder.get()
print(R.parser.astext(mod))
mod.show()

# build the IRModule and create relax vm
target = tvm.target.Target("llvm", host="llvm")
Expand Down
2 changes: 1 addition & 1 deletion apps/relax_examples/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
relax_mod = relay_translator.from_relay(relay_mod["main"], target)

# print the ResNet IRmodule got translated
print(R.parser.astext(relax_mod))
relax_mod.show()

# build the IRModule and create relax vm
ex = relax.vm.build(relax_mod, target)
Expand Down
5 changes: 4 additions & 1 deletion include/tvm/script/ir_builder/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ TVM_DLL IRModuleFrame IRModule();
* \brief Declare a Function without given the specific function implementation.
* \note It is usually used in cross-function call. And we can specify the function by `DefFunction`
* \param func_name The function unique name.
* \param func_signature A Function w/o body, which used to specify the function signature
* (i.e. func params and func return type/shape).
* \return The corresponding GlobalVar.
*/
TVM_DLL GlobalVar DeclFunction(const String& func_name);
TVM_DLL GlobalVar DeclFunction(const String& func_name,
const Optional<BaseFunc>& func_signature = NullOpt);

/*!
* \brief Define the function which is declared before.
Expand Down
20 changes: 16 additions & 4 deletions include/tvm/script/ir_builder/relax/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class SeqExprFrameNode : public RelaxFrameNode {
TVM_DECLARE_BASE_OBJECT_INFO(SeqExprFrameNode, RelaxFrameNode);

public:
void EnterWithScope() override;
void ExitWithScope() override;
};

Expand Down Expand Up @@ -94,6 +95,11 @@ class FunctionFrameNode : public SeqExprFrameNode {
* If the `ret_type` is not None, check the deduced type is a base type of the given one.
*/
Optional<Type> ret_type;
/*!
* \brief The function return shape.
* \sa ret_type
*/
Optional<tvm::relax::Expr> ret_shape;
/*! \brief The function attributes. */
Map<String, ObjectRef> attrs;
/*! \brief The block builder to create Relax function. */
Expand Down Expand Up @@ -130,17 +136,23 @@ class BlockFrameNode : public RelaxFrameNode {
/*! \brief The variables emitted in this block. */
Array<tvm::relax::Var> emitted_vars;
/*!
* \brief (Only used for a dataflow block.) A boolean indicating if the dataflow block is ended of
* construction. If it is true, any new binding trying to be emitted into this block will cause an
* error.
* \brief A boolean indicating if the dataflow block is ended of construction.
* If it is true, any new binding trying to be emitted into this block will cause an error.
* \note Only used for a dataflow block.
*/
bool block_ended;
/*!
* \brief The output vars of the dataflow block.
* \note Only used for a dataflow block.
*/
Array<tvm::relax::Var> output_vars;

void VisitAttrs(tvm::AttrVisitor* v) {
RelaxFrameNode::VisitAttrs(v);
v->Visit("is_dataflow", &is_dataflow);
v->Visit("emitted_vars", &emitted_vars);
v->Visit("block_ended", &block_ended);
v->Visit("output_vars", &output_vars);
// `block_ended` is not visited.
}

static constexpr const char* _type_key = "script.ir_builder.relax.BlockFrame";
Expand Down
20 changes: 10 additions & 10 deletions include/tvm/script/ir_builder/relax/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ TVM_DLL FunctionFrame Function();
* \param shape The shape of the parameter.
* \return The created function parameter var.
*/
TVM_DLL tvm::relax::Var Arg(const String& name, const Type& type,
const tvm::relax::ShapeExpr& shape);
TVM_DLL tvm::relax::Var Arg(const String& name, const Type& type, const tvm::relax::Expr& shape);

/*!
* \brief Specify the name of the last function frame.
Expand All @@ -99,6 +98,12 @@ TVM_DLL void FuncAttrs(Map<String, ObjectRef> attrs);
*/
TVM_DLL void FuncRetType(tvm::Type ret_type);

/*!
* \brief Specify the return shape of the last function frame.
* \param ret_shape The return shape.
*/
TVM_DLL void FuncRetShape(tvm::relax::Expr ret_shape);

/*!
* \brief Specify the return value of the last function frame.
* \param value The return value.
Expand Down Expand Up @@ -130,25 +135,20 @@ TVM_DLL void DataflowBlockOutput(const Array<tvm::relax::Var>& vars);
/*!
* \brief Emit a binding to the last binding block frame.
* \param value The right side value of the bindings to be emitted.
* \param is_dataflow_var A boolean indicating if the emitted binding variable is a dataflow
* variable.
* \return The left side var of the emitted binding.
*/
TVM_DLL tvm::relax::Var Emit(const tvm::relax::Expr& value, bool is_dataflow_var);
TVM_DLL tvm::relax::Var Emit(const tvm::relax::Expr& value);

/*!
* \brief Emit a match_shape binding to the last binding block frame.
* \param value The value of the MatchShape to be emitted.
* \param pattern The pattern of the MatchShape to be emitted.
* \param emit_var A boolean indicating if the MatchShape contains the emitted variable.
* \param is_dataflow_var A boolean indicating if the emitted variable is a dataflow variable when
* `emit_var` is true. When `emit_var` is false, the value of this flag will be ignored.
* \return The emitted var if `emit_var` is true. Otherwise, return `NullOpt`.
*/
TVM_DLL Optional<tvm::relax::Var> EmitMatchShape(const tvm::relax::Expr& value, //
const Array<PrimExpr>& pattern, //
bool emit_var, //
bool is_dataflow_var);
bool emit_var);

///////////////////////////// Type Deduce //////////////////////////////

Expand All @@ -161,7 +161,7 @@ TVM_DLL Optional<tvm::relax::Var> EmitMatchShape(const tvm::relax::Expr& value,
* And we annotate to the var with more detailed type.
*/
TVM_DLL void AnnotateTypeShape(const tvm::relax::Var& var, const Type& anno_type,
const Optional<tvm::relax::ShapeExpr>& anno_shape);
const Optional<tvm::relax::Expr>& anno_shape);

///////////////////////////// If Then Else /////////////////////////////

Expand Down
9 changes: 4 additions & 5 deletions python/tvm/ir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Function defintiions."""
from __future__ import annotations
"""Function definitions."""
from typing import Union, Dict
from enum import IntEnum
import tvm.runtime
Expand All @@ -42,7 +41,7 @@ def attrs(self):
"""Return the attrs member of the function."""
return _ffi_api.BaseFunc_Attrs(self)

def with_attr(self, attr_key_or_dict, attr_value=None) -> BaseFunc:
def with_attr(self, attr_key_or_dict, attr_value=None) -> "BaseFunc":
"""Create a new copy of the function and update the attribute.
Parameters
Expand Down Expand Up @@ -71,7 +70,7 @@ def with_attr(self, attr_key_or_dict, attr_value=None) -> BaseFunc:
res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value)
)

def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> BaseFunc:
def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> "BaseFunc":
"""Copy the IRModule and add the given attribute map to it.
Parameters
----------
Expand All @@ -87,7 +86,7 @@ def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> BaseFunc:

return _ffi_api.BaseFuncWithAttrs(self, attr_map)

def without_attr(self, attr_key: str) -> BaseFunc:
def without_attr(self, attr_key: str) -> "BaseFunc":
"""Create a new copy of the function with an attribute without provided key.
Parameters
Expand Down
7 changes: 3 additions & 4 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""IRModule that holds the functions and type definitions."""
from __future__ import annotations
from typing import Optional, Union, Dict
import ast
from tvm._ffi.base import string_types
Expand Down Expand Up @@ -333,7 +332,7 @@ def get_attrs(self):

return _ffi_api.Module_GetAttrs(self)

def with_attr(self, attr_key, attr_value) -> IRModule:
def with_attr(self, attr_key, attr_value) -> "IRModule":
"""Copy the IRModule and add an attribute to it.
Parameters
Expand All @@ -352,7 +351,7 @@ def with_attr(self, attr_key, attr_value) -> IRModule:

return _ffi_api.Module_WithAttr(self, attr_key, attr_value)

def without_attr(self, attr_key: str) -> IRModule:
def without_attr(self, attr_key: str) -> "IRModule":
"""Copy the IRModule and remove an attribute key and its associated value.
Parameters
----------
Expand All @@ -366,7 +365,7 @@ def without_attr(self, attr_key: str) -> IRModule:

return _ffi_api.Module_WithoutAttr(self, attr_key)

def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> IRModule:
def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> "IRModule":
"""Copy the IRModule and add the given attribute map to it.
Parameters
----------
Expand Down
58 changes: 46 additions & 12 deletions python/tvm/relax/dpl/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,13 +819,33 @@ def is_shape(shape: List[tvm.ir.PrimExpr]) -> "PrimArrPattern":
return PrimArrPattern(shape)


def _is_call_tir(
func_pattern: DFPattern,
args: Union[List, Tuple, TuplePattern] = None,
shape: Union[Tuple, List[tvm.ir.PrimExpr], DFPattern] = None,
) -> CallPattern:
if args is None:
args = wildcard()
elif isinstance(args, (list, tuple)):
args = TuplePattern(args)

if shape is None:
shape = wildcard()
elif isinstance(shape, (list, Array)):
shape = PrimArrPattern(shape)
elif isinstance(shape, (tuple)):
shape = is_tuple(shape) # multiple shape patterns

return is_op("relax.call_tir")(func_pattern, args, shape)


def is_call_tir(
func_name: str,
args: Union[List, Tuple, TuplePattern] = None,
shape: Union[Tuple, List[tvm.ir.PrimExpr], DFPattern] = None,
) -> CallPattern:
"""
Syntax sugar for creating a CallPattern for call_tir
Syntax sugar for creating a CallPattern for call_tir that calls an function through global var.
Parameters
----------
Expand All @@ -841,19 +861,33 @@ def is_call_tir(
CallPattern
The resulting CallPattern
"""
if args is None:
args = wildcard()
elif isinstance(args, (list, tuple)):
args = TuplePattern(args)
func_pattern = GlobalVarPattern(func_name)
return _is_call_tir(func_pattern, args, shape)

if shape is None:
shape = wildcard()
elif isinstance(shape, (list, Array)):
shape = PrimArrPattern(shape)
elif isinstance(shape, (tuple)):
shape = is_tuple(shape) # multiple shape patterns

return is_op("relax.call_tir")(GlobalVarPattern(func_name), args, shape)
def is_call_tir_extern(
func_name: str,
args: Union[List, Tuple, TuplePattern] = None,
shape: Union[Tuple, List[tvm.ir.PrimExpr], DFPattern] = None,
) -> CallPattern:
"""Syntax sugar for creating a CallPattern for call_tir that calls an extern function
Parameters
----------
func_name : str
Name of the CPS function to call.
args : Union[List[DFPattern], Tuple[DFPattern]], optional
Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments
shape : Union[Tuple, List[tvm.ir.PrimExpr], DFPattern], optional
Shape (or shapes in a tuple) of the output, by default None meaning arbitrary shape(s)
Returns
-------
CallPattern
The resulting CallPattern
"""
func_pattern = ExternFuncPattern(func_name)
return _is_call_tir(func_pattern, args, shape)


def is_call_packed(
Expand Down
17 changes: 15 additions & 2 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,23 @@ def name_hint(self):
return name

def __call__(self, *args: Any, attrs=None) -> Call:
if self.checked_type and isinstance(self.checked_type, ty.FuncType):
if self._checked_type_ and isinstance(self._checked_type_, ty.FuncType):
return Call(self, args, attrs=attrs)
else:
raise TypeError("Only vars with function type can be called")
raise TypeError(
f"Only vars with function type can be called, but got type: {self._checked_type_}"
)

def __getitem__(self, key):
if not isinstance(key, int):
raise TypeError("TupleGetItem only supports integer index")
var_type = self._checked_type_
if var_type and isinstance(var_type, ty.TupleType):
return TupleGetItem(self, key)
else:
raise TypeError(
f"Only vars with TupleType is subscriptable, but got type: {self._checked_type_}"
)


@tvm._ffi.register_object("relax.expr.DataflowVar")
Expand Down
Loading

0 comments on commit 1c73696

Please sign in to comment.