diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index d308db87af8b3..e2b47ef324df6 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -40,7 +40,44 @@ #include namespace tvm { +/*! + * \brief Describes one parameter that should be linked into the generated module. + * + * When parameters are to be linked in with generated code (i.e. on target_host-compatible + * backends), Relay attaches instances of this object to a global TIR function. Code-generators + * use the information contained in this node to include the parameter data in the generated + * module. + */ +class LinkedParamNode : public Object { + public: + /*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */ + int64_t id; + + /*! \brief Parameter data which should get linked into the final module. */ + ::tvm::runtime::NDArray param; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("id", &id); + v->Visit("param", ¶m); + } + + static constexpr const char* _type_key = "tir.LinkedParam"; + TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object); +}; + +/*! + * \brief Managed reference to LinkedParamNode. + */ +class LinkedParam : public ObjectRef { + public: + TVM_DLL LinkedParam(int64_t id, tvm::runtime::NDArray param); + + TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode); +}; + class IRModule; + /*! * \brief IRModule that holds functions and type definitions. * @@ -504,6 +541,11 @@ constexpr const char* kRuntime = "runtime"; */ constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools"; +/* + * \brief Module attribute for tir constants + */ +constexpr const char* kConstantsArray = "Constants"; + } // namespace attr } // namespace tvm #endif // TVM_IR_MODULE_H_ diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index 887a012cfc932..a30a2c59d0d19 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -25,6 +25,7 @@ #include #include +#include #include #include @@ -199,5 +200,13 @@ class SHashReducer { bool map_free_vars_; }; +class SEqualReducer; +struct NDArrayContainerTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce); + static bool SEqualReduce(const runtime::NDArray::Container* lhs, + const runtime::NDArray::Container* rhs, SEqualReducer equal); +}; + } // namespace tvm #endif // TVM_NODE_STRUCTURAL_HASH_H_ diff --git a/include/tvm/relay/executor.h b/include/tvm/relay/executor.h index 6d1bd2de7f57b..858ba5cfe198f 100644 --- a/include/tvm/relay/executor.h +++ b/include/tvm/relay/executor.h @@ -113,6 +113,8 @@ class ExecutorNode : public Object { } static constexpr const char* _type_key = "Executor"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(ExecutorNode, Object); }; @@ -122,8 +124,6 @@ class ExecutorNode : public Object { */ class Executor : public ObjectRef { public: - Executor() = default; - /*! * \brief Create a new Executor object using the registry * \throws Error if name is not registered @@ -147,7 +147,8 @@ class Executor : public ObjectRef { TVM_DLL static Map ListExecutorOptions(const String& name); /*! \brief specify container node */ - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Executor, ObjectRef, ExecutorNode); + TVM_DEFINE_OBJECT_REF_METHODS(Executor, ObjectRef, ExecutorNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ExecutorNode) private: /*! diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index eed6d0ffc1e40..f71107258d9a7 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -184,10 +184,12 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De * \param import_set Already imported external modules. * \param device The device on which all primitives will be executed. * \param target The compiler target flag for compiling primitives. + * \param attrs Attributes for the expression to be evaluated with * @return The object representing the result. */ ObjectRef Eval(Expr expr, Map type_definitions, - std::unordered_set import_set, Device device, Target target); + std::unordered_set import_set, Device device, Target target, + Map attrs = {}); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/runtime.h b/include/tvm/relay/runtime.h index c4cabf5a5548b..a925045f9f41e 100644 --- a/include/tvm/relay/runtime.h +++ b/include/tvm/relay/runtime.h @@ -105,6 +105,8 @@ class RuntimeNode : public Object { } static constexpr const char* _type_key = "Runtime"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(RuntimeNode, Object); }; diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 1ab911b756df4..2b3c4b5fe0035 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -151,42 +151,6 @@ class PrimFunc : public BaseFunc { TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode); }; -/*! - * \brief Describes one parameter that should be linked into the generated module. - * - * When parameters are to be linked in with generated code (i.e. on target_host-compatible - * backends), Relay attaches instances of this object to a global TIR function. Code-generators - * use the information contained in this node to include the parameter data in the generated - * module. - */ -class LinkedParamNode : public Object { - public: - /*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */ - int64_t id; - - /*! \brief Parameter data which should get linked into the final module. */ - ::tvm::runtime::NDArray param; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("id", &id); - v->Visit("param", ¶m); - } - - static constexpr const char* _type_key = "tir.LinkedParam"; - TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object); -}; - -/*! - * \brief Managed reference to LinkedParamNode. - */ -class LinkedParam : public ObjectRef { - public: - TVM_DLL LinkedParam(int64_t id, ::tvm::runtime::NDArray param); - - TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode); -}; - /*! * \brief Tensor intrinsics for tensorization */ @@ -239,7 +203,7 @@ class TensorIntrin : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode) }; -/*! +/* * \brief Specialize parameters of PrimFunc. * \param func The PrimFunc to be specialized. * \param param_map The mapping from function params to the instance. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 4782f6673c8c6..ee44918ca3793 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -559,16 +559,18 @@ class AllocateNode : public StmtNode { * Otherwise return 0. * \return The result. */ - int32_t constant_allocation_size() const { return constant_allocation_size(extents); } + int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); } /*! * \brief If the buffer size is constant, return the size. * Otherwise return 0. * \param extents The extents of the buffer. * \return The result. */ - TVM_DLL static int32_t constant_allocation_size(const Array& extents); + TVM_DLL static int64_t ConstantAllocationSize(const Array& extents); static constexpr const char* _type_key = "tir.Allocate"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode); }; @@ -585,6 +587,96 @@ class Allocate : public Stmt { TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode); }; +/*! + * \brief Allocate a buffer that can be used in body. + */ +class AllocateConstNode : public StmtNode { + public: + /*! \brief The buffer variable. */ + Var buffer_var; + /*! \brief The optional data associated to the constant. + */ + Optional data; + /*! \brief If the PrimFunc containing the Stmt is added to IRModule, + this is an optional index to indicate the index within + "Constants" attribute, that is a Array of IRModule. + */ + Optional irmod_storage_idx; + /*! \brief The type of the buffer. */ + DataType dtype; + /*! \brief The extents of the buffer. */ + Array extents; + /*! \brief The body to be executed. */ + Stmt body; + /*! + * \brief Additional annotations about the allocation. + * + * These annotations can be used as auxiliary hint + * to future transformations. + */ + Map annotations; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("buffer_var", &buffer_var); + v->Visit("data", &data); + v->Visit("irmod_storage_idx", &irmod_storage_idx); + v->Visit("dtype", &dtype); + v->Visit("extents", &extents); + v->Visit("body", &body); + v->Visit("annotations", &annotations); + v->Visit("span", &span); + } + + bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const { + return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) && + equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body) && + equal(annotations, other->annotations); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce.DefHash(buffer_var); + hash_reduce(dtype); + hash_reduce(extents); + hash_reduce(body); + hash_reduce(annotations); + hash_reduce(data); + } + + /*! + * \brief If the buffer size is constant, return the size. + * Otherwise return 0. + * \return The result. + */ + int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); } + /*! + * \brief If the buffer size is constant, return the size. + * Otherwise return 0. + * \param extents The extents of the buffer. + * \return The result. + */ + TVM_DLL static int64_t ConstantAllocationSize(const Array& extents); + + static constexpr const char* _type_key = "tir.AllocateConst"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode); +}; + +/*! + * \brief Managed reference to AllocateConstNode. + * \sa AllocateConstNode + */ +class AllocateConst : public Stmt { + public: + /* The constructor to create a IRNode with constant data + * depending on the type of ObjectRef, it will either + * create AllocateConstNode with irmod_storage_idx or data + */ + TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array extents, + ObjectRef data_or_idx, Stmt body, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode); +}; + /*! * \brief The container of seq statement. * Represent a sequence of statements. diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 0b4ace20078cb..16da91c2a2a3b 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -87,6 +87,7 @@ class StmtFunctor { virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -113,6 +114,7 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(ForNode); IR_STMT_FUNCTOR_DISPATCH(WhileNode); IR_STMT_FUNCTOR_DISPATCH(AllocateNode); + IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode); IR_STMT_FUNCTOR_DISPATCH(StoreNode); IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode); @@ -155,6 +157,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const ForNode* op) override; void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const AllocateNode* op) override; + void VisitStmt_(const AllocateConstNode* op) override; void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const BufferRealizeNode* op) override; @@ -255,6 +258,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const ForNode* op) override; Stmt VisitStmt_(const WhileNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; + Stmt VisitStmt_(const AllocateConstNode* op) override; Stmt VisitStmt_(const StoreNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const BufferRealizeNode* op) override; diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 13f6804c4acf1..3bb5491affdf4 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -29,6 +29,7 @@ #include #include +#include namespace tvm { namespace tir { @@ -601,6 +602,15 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); */ TVM_DLL Pass InjectSoftwarePipeline(); +TVM_DLL Pass BindParams(const Array& constants); + +/*! + * \brief Pass to collect tir non-scalar constants into module's 'Constants' attribute. + * + * \return The pass. + */ +TVM_DLL Pass ExtractPrimFuncConstants(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index fc953771bf219..4b5df6515ad0b 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -19,6 +19,7 @@ from typing import Tuple, Any, Callable, Optional, List, Union, Mapping import synr +import numpy as np import tvm.tir from tvm.runtime import Object from tvm.ir import Span, Range @@ -157,6 +158,56 @@ def setup_buffer_var( context.update_symbol(name, self.buffer_var, node) +@register +class AllocateConst(WithScopeHandler): + """With scope handler T.allocate_const(data, extents, dtype, condition) + + TIR constant node to represent non-scalar constant + """ + + def __init__(self): + def allocate_const(raw_data, dtype, shape, span=None): + list_data = [] + for i in raw_data: + list_data.append(i.value) + nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype)) + n = tvm.tir.AllocateConst(self.buffer_var, dtype, shape, nd_data, self.body, span=span) + return n + + super().__init__(allocate_const, concise_scope=True, def_symbol=True) + self.buffer_var = None + + def enter_scope( + self, + node: synr.ast.Node, + context: ContextMaintainer, + arg_list: List[Any], + span: synr.ast.Span, + ): + # define buffer vars in symbol table + if isinstance(node, synr.ast.With): + vars = WithScopeHandler.get_optional_vars(node, context) + if len(vars) != 1: + context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span) + name = vars[0].id.name + var_span = vars[0].id.span + elif isinstance(node, synr.ast.Assign): + if len(node.lhs) != 1: + context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span) + name = node.lhs[0].id.name + var_span = node.lhs[0].id.span + else: + raise Exception("Internal Bug") + + def setup_buffer_var(data, dtype, shape, span: Span = None): + """Setup buffer var for a given type.""" + buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype)) + self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) + + setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) + context.update_symbol(name, self.buffer_var, node) + + @register class LaunchThread(WithScopeHandler): """With scope handler T.launch_thread(env_var, extent)""" diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index 308257085e512..aaad6e108e7b2 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -31,7 +31,7 @@ from .tensor import TensorSlice, Tensor from .tensor_intrin import decl_tensor_intrin from .tag import tag_scope -from .operation import placeholder, compute, scan, extern, var, size_var +from .operation import placeholder, compute, scan, extern, var, size_var, const from .operation import thread_axis, reduce_axis from .operation import create_prim_func, create_prim_func_from_outputs diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index e16d49c3da3b3..90d7cb5d75dbc 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -375,6 +375,28 @@ def var(name="tindex", dtype="int32", span=None): return tvm.tir.Var(name, dtype, span) +def const(dtype="int32", span=None): + """Create a new constant with specified name and dtype + + Parameters + ---------- + name : str + The name + + dtype : str + The data type + + span : Optional[Span] + The location of this variable in the source. + + Returns + ------- + var : Var + The result symbolic variable. + """ + return tvm.tir.const(dtype, span) + + def size_var(name="size", dtype="int32", span=None): """Create a new variable represents a tensor shape size, which is non-negative. diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 5854b9369c166..17f9aa3d9c604 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -28,7 +28,16 @@ from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While -from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt +from .stmt import ( + BufferStore, + BufferRealize, + Store, + ProducerStore, + Allocate, + AllocateConst, + AttrStmt, +) + from .stmt import ProducerRealize, SeqStmt from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index de200d5eabdd5..39831459f3443 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -340,6 +340,40 @@ def __init__(self, buffer_var, dtype, extents, condition, body, annotations=None ) +@tvm._ffi.register_object("tir.AllocateConst") +class AllocateConst(Stmt): + """Allocate constant node. + + Parameters + ---------- + buffer_var : Var + The buffer variable. + + data : NDarray + The data associated with the constant + + dtype : str + The data type of the buffer. + + extents : list of Expr + The extents of the allocate + + condition : PrimExpr + The condition. + + body : Stmt + The body statement. + + span : Optional[Span] + The location of this itervar in the source code. + """ + + def __init__(self, buffer_var, dtype, extents, condition, body, span=None): + self.__init_handle_by_constructor__( + _ffi_api.AllocateConst, buffer_var, dtype, extents, condition, body, span + ) + + @tvm._ffi.register_object("tir.AttrStmt") class AttrStmt(Stmt): """AttrStmt node. diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index de6c052c26b1d..d5f1a9ae979b3 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -771,3 +771,14 @@ def InjectSoftwarePipeline(): The result pass """ return _ffi_api.InjectSoftwarePipeline() # type: ignore + + +def ExtractPrimFuncConstants(): + """Collects and unificates tir non-scalar constants to module's attr 'Constants' array. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ExtractPrimFuncConstants() # type: ignore diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index 58c458a7d676a..837f4e922a79b 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -81,7 +81,9 @@ def matmul( out_dim, red_dim = tensor_b.shape else: red_dim, out_dim = tensor_b.shape - assert in_dim == red_dim + + # cmp should be done by values + assert int(in_dim) == int(red_dim) k = te.reduce_axis((0, in_dim), name="k") if (transpose_a, transpose_b) == (True, True): diff --git a/src/ir/module.cc b/src/ir/module.cc index 6f2c9f9fe9940..f5ec65e4fbcad 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -444,6 +444,13 @@ IRModule IRModule::FromText(const String& text, const String& source_path) { return tvm::parser::ParseModule(source_path, text); } +LinkedParam::LinkedParam(int64_t id, tvm::runtime::NDArray param) { + auto n = make_object(); + n->id = id; + n->param = param; + data_ = std::move(n); +} + TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("ir.IRModule") diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index f5344ab9126e3..05899e4465f9b 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -326,44 +326,42 @@ struct ADTObjTrait { TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait); -struct NDArrayContainerTrait { - static constexpr const std::nullptr_t VisitAttrs = nullptr; - - static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce) { - ICHECK_EQ(key->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK(runtime::IsContiguous(key->dl_tensor)) << "Can only hash contiguous tensor"; - hash_reduce(runtime::DataType(key->dl_tensor.dtype)); - hash_reduce(key->dl_tensor.ndim); - for (int i = 0; i < key->dl_tensor.ndim; ++i) { - hash_reduce(key->dl_tensor.shape[i]); - } - hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes( - static_cast(key->dl_tensor.data), runtime::GetDataSize(key->dl_tensor))); +void NDArrayContainerTrait::SHashReduce(const runtime::NDArray::Container* key, + SHashReducer hash_reduce) { + ICHECK_EQ(key->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; + ICHECK(runtime::IsContiguous(key->dl_tensor)) << "Can only hash contiguous tensor"; + hash_reduce(runtime::DataType(key->dl_tensor.dtype)); + hash_reduce(key->dl_tensor.ndim); + for (int i = 0; i < key->dl_tensor.ndim; ++i) { + hash_reduce(key->dl_tensor.shape[i]); } + hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes( + static_cast(key->dl_tensor.data), runtime::GetDataSize(key->dl_tensor))); +} - static bool SEqualReduce(const runtime::NDArray::Container* lhs, - const runtime::NDArray::Container* rhs, SEqualReducer equal) { - if (lhs == rhs) return true; - - auto ldt = lhs->dl_tensor.dtype; - auto rdt = rhs->dl_tensor.dtype; - ICHECK_EQ(lhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK_EQ(rhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor"; - ICHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor"; - - if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false; - for (int i = 0; i < lhs->dl_tensor.ndim; ++i) { - if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false; - } - if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { - size_t data_size = runtime::GetDataSize(lhs->dl_tensor); - return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0; - } else { - return false; - } +bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs, + const runtime::NDArray::Container* rhs, + SEqualReducer equal) { + if (lhs == rhs) return true; + + auto ldt = lhs->dl_tensor.dtype; + auto rdt = rhs->dl_tensor.dtype; + ICHECK_EQ(lhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; + ICHECK_EQ(rhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; + ICHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor"; + ICHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor"; + + if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false; + for (int i = 0; i < lhs->dl_tensor.ndim; ++i) { + if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false; } -}; + if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { + size_t data_size = runtime::GetDataSize(lhs->dl_tensor); + return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0; + } else { + return false; + } +} TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait); diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index a4d0ff30fa621..c34c4a5b6dbee 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -359,6 +359,7 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitStmt_(const BufferRealizeNode* op) override; Doc VisitStmt_(const ProducerRealizeNode* op) override; Doc VisitStmt_(const AllocateNode* op) override; + Doc VisitStmt_(const AllocateConstNode* op) override; Doc VisitStmt_(const IfThenElseNode* op) override; Doc VisitStmt_(const SeqStmtNode* op) override; Doc VisitStmt_(const EvaluateNode* op) override; @@ -398,6 +399,7 @@ class TIRTextPrinter : public StmtFunctor, static Doc PrintConstScalar(DataType dtype, const T& data); Doc GetUniqueName(std::string prefix); Doc AllocVar(const Var& var); + Doc AllocConst(const AllocateConst& var); Doc AllocBuf(const Buffer& buffer); Doc AllocProducer(const DataProducer& buffer); /*! diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index a9804229da91a..e229da4c26d9f 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -519,6 +519,19 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { return doc; } +Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) { + Doc doc; + doc << "constant(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " + << Print(op->extents) << ")"; + + if (op->body->IsInstance()) { + doc << PrintBody(op->body); + } else { + doc << ";" << Doc::NewLine() << Print(op->body); + } + return doc; +} + Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) { Doc doc; doc << "if " << Print(op->condition) << PrintBody(op->then_case); diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 0d6c6e5deeba5..fe85eb3cd5930 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -241,6 +241,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc VisitStmt_(const BufferStoreNode* op) override; Doc VisitStmt_(const BufferRealizeNode* op) override; Doc VisitStmt_(const AllocateNode* op) override; + Doc VisitStmt_(const AllocateConstNode* op) override; Doc VisitStmt_(const IfThenElseNode* op) override; Doc VisitStmt_(const SeqStmtNode* op) override; Doc VisitStmt_(const ForNode* op) override; @@ -410,6 +411,26 @@ class TVMScriptPrinter : public StmtFunctor, } }; +/*! + * \brief special method to print NDArray in TIR + * \param arr the NDArray to be printed + * \param os the output stream where the NDArray will be printed to + */ +template +void NDArrayToTIR(::tvm::runtime::NDArray arr, std::ostream& os) { + int ndim = arr->ndim; + int tot_dim = 1; + for (int i = 0; i < ndim; i++) { + tot_dim *= arr->shape[i]; + } + T* data_ptr = reinterpret_cast(arr->data); + os << "["; + for (int i = 0; i < tot_dim; i++) { + os << (i != 0 ? ", " : "") << data_ptr[i]; + } + os << "]"; +} + Doc TVMScriptPrinter::GetUniqueName(std::string prefix) { std::replace(prefix.begin(), prefix.end(), '.', '_'); std::string unique_prefix = prefix; @@ -1015,6 +1036,50 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { return doc; } +Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { + std::stringstream ss; + ICHECK(alloc->data) << "Should be presented"; + const auto& data = alloc->data.value(); + + if (alloc->dtype.is_int()) { + if (alloc->dtype.bits() == 8) { + NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 16) { + NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 32) { + NDArrayToTIR(data, ss); + } else { + LOG(FATAL) << "DataType not supported"; + } + } else if (alloc->dtype.is_float()) { + if (alloc->dtype.bits() == 16) { + NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 32) { + NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 64) { + NDArrayToTIR(data, ss); + } else { + LOG(FATAL) << "DataType not supported"; + } + } else { + LOG(FATAL) << "DataType not supported"; + } + auto ndarray_str = ss.str(); + + Doc doc; + var_not_in_headers_.insert(alloc->buffer_var.get()); + if (current_num_ != num_child_ - 1) { + doc << "with tir.allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype) << ", " + << Print(alloc->extents) << ")"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body)); + } else { + doc << Print(alloc->buffer_var) << " = tir.allocate_const(" << ndarray_str << ", " + << PrintDType(alloc->dtype) << ", " << Print(alloc->extents); + doc << ")" << Doc::NewLine() << PrintBody(alloc->body); + } + return doc; +} + Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) { Doc doc; doc << "if " << Print(op->condition) << ":"; diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 3694d6bcef95b..2168ea74a0ff1 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -327,7 +327,6 @@ class AOTExecutorCodegen : public MixedModeVisitor { std::string func_name = call_lowered_props.lowered_func->name_hint; tvm::Array args{tvm::tir::StringImm(func_name)}; std::vector create_func_call_stmts; - // Pack the inputs for (const Expr& arg : call_lowered_props.arguments) { if (params_by_expr_.find(arg) != params_by_expr_.end()) { @@ -545,22 +544,24 @@ class AOTExecutorCodegen : public MixedModeVisitor { void VisitExpr_(const ConstantNode* op) override { Expr expr = GetRef(op); - size_t index = params_.size(); - std::string name = "p" + std::to_string(index); StorageInfo& sinfo = storage_device_map_[expr]; - param_storage_ids_[name] = sinfo->storage_ids[0]; - params_[name] = op->data; - params_by_expr_.Set(expr, name); + std::stringstream ss; + ss << "constant_" << constant_map_.size(); + + tir::Var constant(ss.str(), PointerType(PrimType(DataType(op->data->dtype)))); + constant_map_[constant] = op; + auto sid = sinfo->storage_ids[0]; + sids_table_[sid] = constant; // If the Constant node is an output node we need to copy the content of the parameter to the - // output A Var node can only produce a single output - auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]); + // output. A node can only produce a single output + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid); if (output_iter != return_sid_.end()) { int output_index = std::distance(return_sid_.begin(), output_iter); auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), - {tir::StringImm(params_by_expr_[expr])}); - CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), param_handle, false, - sinfo->storage_sizes_in_bytes[0]); + {tir::StringImm(ss.str())}); + CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), constant, + /* pack_input */ false, sinfo->storage_sizes_in_bytes[0]); } } @@ -608,7 +609,6 @@ class AOTExecutorCodegen : public MixedModeVisitor { // runner function needs to be legalized by the LegalizePackedCalls pass. tir::PrimFunc CreateMainFunc(String mod_name, unsigned int relay_params) { tir::Stmt body = tir::SeqStmt(stmts_); - // Allocate the sids std::unordered_map allocated; @@ -629,6 +629,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { continue; } + allocated[sid] = constant_map_.count(sids_table_[sid]); + // TODO(giuseros): we should allocate this once outside the PrimFunc // so we don't pay the price of allocation for every inference if (!allocated[sid]) { @@ -640,6 +642,20 @@ class AOTExecutorCodegen : public MixedModeVisitor { } } + for (auto kv : constant_map_) { + auto buffer_var = kv.first; + auto dtype = DataType(kv.second->data->dtype); + + int ndim = kv.second->data->ndim; + Array extents; + + for (int i = 0; i < ndim; i++) { + int shape = kv.second->data->shape[i]; + extents.push_back(tir::make_const(DataType::Int(32), shape)); + } + body = tir::AllocateConst(buffer_var, dtype, extents, kv.second->data, body); + } + // Define the PrimFunc attributes Map dict_attrs; String run_func_name = @@ -795,11 +811,13 @@ class AOTExecutorCodegen : public MixedModeVisitor { Map params_by_expr_; /*! \brief mapping between parameter names ("p0", "p1", etc..) and storage identifiers*/ std::unordered_map param_storage_ids_; + std::unordered_map + constant_map_; /*! \brief plan memory of device result */ StorageMap storage_device_map_; /*! \brief mapping sid -> tir::Var */ - std::unordered_map sids_table_; + std::unordered_map sids_table_; /*! \brief lowered funcs */ Map function_metadata_; /*! \brief the set of statements that make the program */ @@ -891,7 +909,6 @@ class AOTExecutorCodegen : public MixedModeVisitor { // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need // to run the LegalizePackedCalls pass. LoweredOutput ret; - ret.params = std::unordered_map>(); for (auto param : params_) { ret.params.emplace(std::make_pair( diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index aa9c084de4f76..89ee61c83f7c7 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -409,7 +409,9 @@ class RelayBuildModule : public runtime::ModuleNode { */ void BuildRelay(IRModule relay_module, const String& mod_name) { // Relay IRModule -> IRModule optimizations. - relay_module = OptimizeImpl(std::move(relay_module)); + IRModule module = WithAttrs( + relay_module, {{tvm::attr::kExecutor, executor_}, {tvm::attr::kRuntime, runtime_}}); + relay_module = OptimizeImpl(std::move(module)); // Get the updated function and new IRModule to build. // Instead of recreating the IRModule, we should look at the differences between this and the @@ -437,31 +439,6 @@ class RelayBuildModule : public runtime::ModuleNode { const Target& host_target = config_->host_virtual_device->target; const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); - - // Generate a placeholder function that attaches linked params as its arguments. - Bool should_link_params = func_module->ShouldLinkParameters(); - if (should_link_params) { - CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen."; - auto param_ids = executor_codegen_->GetParamIds(); - auto link_params = Map(); - for (auto param : ret_.params) { - link_params.Set(param.first, tir::LinkedParam(param_ids[param.first], param.second)); - } - - Map dict; - dict.Set(tvm::tir::attr::kLinkedParams, link_params); - dict.Set(tvm::attr::kGlobalSymbol, String(::tvm::runtime::symbol::tvm_lookup_linked_param)); - DictAttrs attrs{dict}; - auto prim = tir::PrimFunc(Array(), tir::SeqStmt(Array()), VoidType(), - Map(), attrs); - if (lowered_funcs.find(host_target) == lowered_funcs.end()) { - lowered_funcs.Set(host_target, - IRModule(Map({}), {}, {}, {}, func_module->attrs)); - } - lowered_funcs[host_target]->Add(GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), - prim); - } - // When there is no lowered_funcs due to reasons such as optimization. if (lowered_funcs.size() == 0) { if (host_target->kind->name == "llvm") { diff --git a/src/relay/backend/executor.cc b/src/relay/backend/executor.cc index 7c0c690c07aae..581fbdf2cdf1b 100644 --- a/src/relay/backend/executor.cc +++ b/src/relay/backend/executor.cc @@ -88,6 +88,7 @@ ExecutorRegEntry& ExecutorRegEntry::RegisterOrGet(const String& name) { /********** Register Executors and options **********/ TVM_REGISTER_EXECUTOR("aot") + .add_attr_option("link-params", Bool(true)) .add_attr_option("unpacked-api") .add_attr_option("interface-api") .add_attr_option("workspace-byte-alignment"); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 2bea8101d645e..c4b1673e0731e 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -1108,7 +1108,8 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De } ObjectRef Eval(Expr expr, Map type_definitions, - std::unordered_set import_set, Device device, Target target) { + std::unordered_set import_set, Device device, Target target, + Map attrs) { ICHECK_EQ(device.device_type, target->kind->device_type); TargetMap targets; targets.Set(device.device_type, target); @@ -1118,7 +1119,7 @@ ObjectRef Eval(Expr expr, Map type_definitions, std::pair mod_and_global = IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set); - IRModule mod = Prepare(mod_and_global.first, config); + IRModule mod = Prepare(WithAttrs(mod_and_global.first, {attrs}), config); Interpreter intrp(mod, config, device); Expr expr_to_eval = mod->GetGlobalVar(mod_and_global.second->name_hint); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 46051b06c0d7e..cfc0ad0087fc5 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -330,12 +331,18 @@ class TECompilerImpl : public TECompilerNode { for (te::Tensor arg : value->cached_func->outputs) { all_args.push_back(arg); } + Array all_consts; + for (auto kv : value->cached_func->constant_tensors) { + all_args.push_back(kv.second); + all_consts.push_back(kv.first->data); + } // lower the function std::unordered_map binds; auto func_name = value->cached_func->prim_fn_var->name_hint; VLOG(1) << "scheduling"; IRModule scheduled_module = tvm::LowerSchedule(value->cached_func->schedule, all_args, func_name, binds); + scheduled_module->Update(tir::transform::BindParams(all_consts)(scheduled_module)); // Unfortunately the above machinery creates its own GlobalVars instead of using *the* // GlobalVar we established above. Fix this before the confusion spreads any further. // TODO(mbs): LowerSchedule should be given prim_fn_gvar instead of func_name. @@ -1179,7 +1186,8 @@ Pass LowerTEPass(const String& module_name, ProcessFn process_fn, return tvm::transform::Sequential( {tvm::relay::transform::RelayToTIRTargetHook(), - tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {"InferType"}), InferType()}); + tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {"InferType"}), InferType(), + tvm::tir::transform::ExtractPrimFuncConstants()}); } } // namespace tec } // namespace relay diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 15ef229428696..34ce3f24da274 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -71,15 +71,18 @@ CCacheKey::CCacheKey(Function source_func, Target target) { CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array inputs, tvm::Array outputs, te::Schedule schedule, tir::PrimFunc prim_func, tvm::Array shape_func_param_states, - IRModule funcs) { + IRModule funcs, + std::unordered_map constant_tensors) { auto n = make_object(); n->target = target; n->prim_fn_var = prim_fn_var; n->inputs = inputs; n->outputs = outputs; n->schedule = schedule; + n->prim_func = prim_func; n->shape_func_param_states = shape_func_param_states; n->funcs = funcs; + n->constant_tensors = constant_tensors; data_ = std::move(n); } @@ -206,7 +209,8 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator } } - return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}); + return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}, + IRModule(Map({})), constant_tensors_); } Array VisitExpr_(const VarNode* op) final { @@ -216,30 +220,40 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator Array VisitExpr_(const ConstantNode* op) final { using tir::make_const; - ICHECK(op->is_scalar()); void* data = op->data->data; DataType dtype = DataType(op->data->dtype); - auto value = te::compute( - {}, - [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, - "compile_engine_const", topi::kBroadcast); - scalars_.push_back(value->op); - return {value}; + if (op->is_scalar()) { + auto value = te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(16)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << dtype << " not handled"; + return tvm::PrimExpr(); + } + }, + "compile_engine_const", topi::kBroadcast); + scalars_.push_back(value->op); + return {value}; + } else { + const auto* ttype = op->checked_type().as(); + std::stringstream ss; + ss << "constant_" << const_index++; + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype, ss.str()); + constant_tensors_[op] = tensor; + return {tensor}; + } } Array VisitExpr_(const CallNode* call_node) final { @@ -344,14 +358,19 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator OpImplementation anchor_implementation_; std::ostringstream readable_name_stream_; Array scalars_; + std::unordered_map constant_tensors_; bool use_auto_scheduler_; bool use_meta_schedule_; // Cache device copy op for equivalence checking to reduce registry lookup // overhead for each invocation of call node when retrieving schedules. const Op& device_copy_op_; bool create_schedule_; + // Index of the global constants + static int const_index; }; +int ScheduleBuilder::const_index = 0; + /*! * \brief Create schedule for target. * \param source_func The primitive function to be lowered. diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 2171880fd6a5c..2ffca1aa6be72 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -145,6 +145,7 @@ struct CachedFuncNode : public Object { tvm::Array shape_func_param_states; /*! \brief The lowered functions to support the function. */ IRModule funcs = IRModule(Map({})); + std::unordered_map constant_tensors; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("target", &target); @@ -166,7 +167,8 @@ class CachedFunc : public ObjectRef { CachedFunc(tvm::Target target, GlobalVar prim_fn_name, tvm::Array inputs, tvm::Array outputs, te::Schedule schedule, tir::PrimFunc prim_func, tvm::Array shape_func_param_states, - IRModule funcs = IRModule(Map({}))); + IRModule funcs = IRModule(Map({})), + std::unordered_map constant_tensors = {}); public: TVM_DEFINE_OBJECT_REF_METHODS(CachedFunc, ObjectRef, CachedFuncNode); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index b1b798d802cdc..0103182d7ff4c 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -35,6 +35,8 @@ #include #include +#include +#include #include #include #include @@ -522,7 +524,7 @@ Array GetPassPrefix(bool is_homogenous, bool is_vm); /*! \brief Target hash function */ struct TargetStrHash { /*! - * \brief Calculate the hash code of a Target based on the string value of the Target. + * \brief Calculate the hash code of a Target based on the string value of the Target KIND. Note that this hash should NOT be used in new usecases, equality of targets based on their value is not well-defined. This will be removed when maps from Targets to IRModules are removed from the codebase. @@ -530,7 +532,8 @@ struct TargetStrHash { * \return String hash of the target */ size_t operator()(const Target& target) const { - return String::HashBytes(target->str().c_str(), target->str().size()); + std::string s(target->kind->name); + return String::HashBytes(s.c_str(), s.size()); } }; diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index f8bd7081521eb..1a16cc9becf18 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -127,7 +127,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) IRModule updated_mod = mod->ShallowCopy(); std::vector > updates; - for (const auto& kv : updated_mod->functions) { + for (const auto& kv : mod->functions) { // only process optimizable Relay Functions if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { Function updated_func = pass_func(GetRef(function_node), updated_mod, pass_ctx); diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index e2894367c794b..161f7c0c33421 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -1267,7 +1267,9 @@ class DeviceCapturer : public ExprMutator { /*! \brief Rewrite the "on_device" calls (and implicitly re-type-check). */ tvm::transform::Pass Rewrite() { auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) { - return Downcast(RewriteOnDevices(std::move(m)).Mutate(f)); + auto attrs = m->attrs; + auto r = Downcast(RewriteOnDevices(std::move(m)).Mutate(f)); + return attrs.defined() ? WithAttrs(r, {attrs->dict}) : r; }; return tvm::relay::transform::CreateFunctionPass(pass_func, 0, "PlanDevicesRewrite", {}); } diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index dd8195797e8df..3a8391e058567 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -253,9 +253,10 @@ class ConstantFolder : public MixedModeMutator { // Use a fresh build context in case we are already in a build context. // needed for both execution and creation(due to JIT) With fresh_build_ctx(transform::PassContext::Create()); - - Expr result = ObjectToExpr( - Eval(expr, module_->type_definitions, module_->Imports(), eval_cpu_dev_, eval_cpu_target_)); + Map dict = + (module_->attrs.defined()) ? module_->attrs->dict : Map(); + Expr result = ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), + eval_cpu_dev_, eval_cpu_target_, dict)); VLOG(1) << "Evaluated to constant:" << std::endl << PrettyPrint(result); return result; } diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 5037b32ce615e..e25b8db152c49 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -25,6 +25,7 @@ * Fuse necessary ops into a single one. */ #include +#include #include #include #include @@ -85,6 +86,7 @@ constexpr uint32_t kMaxFusedOps = 256; static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion"); TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.max_depth", Integer); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.link_params", Bool); /*! * \brief Indexed data flow graph in forward direction. @@ -809,8 +811,19 @@ std::vector GraphPartitioner::Partition( class FuseMutator : private MixedModeMutator { public: + FuseMutator(int fuse_opt_level, size_t max_fuse_depth, bool link_params) + : fuse_opt_level_(fuse_opt_level), + max_fuse_depth_(max_fuse_depth), + link_params_(link_params) {} + + // Run the transform + Expr Transform(const Expr& body) { + return Transform(body, fuse_opt_level_, max_fuse_depth_, link_params_); + } + + protected: // Run the transform - Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth) { + Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth, bool link_params) { // setup the group map. auto graph = IndexedForwardGraph::Create(&arena_, body); auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph); @@ -824,6 +837,10 @@ class FuseMutator : private MixedModeMutator { } private: + int fuse_opt_level_; + size_t max_fuse_depth_; + bool link_params_; + using MixedModeMutator::VisitExpr_; /*! \brief Temporary information from each group. */ @@ -994,8 +1011,12 @@ class FuseMutator : private MixedModeMutator { auto type = arg->checked_type(); Expr new_arg = this->Mutate(arg); if (current_group != arg_group) { - Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type); - new_args.push_back(param); + if (!link_params_ || new_arg.as() == nullptr) { + Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type); + new_args.push_back(param); + } else { + new_args.push_back(new_arg); + } } else { new_args.push_back(new_arg); } @@ -1017,8 +1038,9 @@ class FuseMutator : private MixedModeMutator { } }; -Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, const IRModule& module) { - return FuseMutator().Transform(expr, fuse_opt_level, max_fuse_depth); +Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, bool link_params, + const IRModule& module) { + return FuseMutator(fuse_opt_level, max_fuse_depth, link_params).Transform(expr); } namespace transform { @@ -1026,9 +1048,16 @@ namespace transform { Pass FuseOps(int fuse_opt_level) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { + bool link_params = false; + Executor executor = + m->GetAttr(tvm::attr::kExecutor).value_or(NullValue()); + link_params = executor.defined() + ? executor->attrs.GetAttr("link-params").value_or(Bool(link_params)) + : link_params; + link_params = pc->GetConfig("relay.FuseOps.link_params", Bool(link_params)).value(); int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps)); - return Downcast(FuseOps(f, opt_level, max_fuse_depth.value(), m)); + return Downcast(FuseOps(f, opt_level, max_fuse_depth.value(), link_params, m)); }; return CreateFunctionPass(pass_func, 0, "FuseOps", {"InferType"}); } diff --git a/src/relay/transforms/split_args.cc b/src/relay/transforms/split_args.cc index a5266df8b057e..00b9a3be3b2ee 100644 --- a/src/relay/transforms/split_args.cc +++ b/src/relay/transforms/split_args.cc @@ -85,7 +85,8 @@ namespace transform { Pass SplitArgs(int max_function_args) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(SplitArgs(f, max_function_args)); + auto r = Downcast(SplitArgs(f, max_function_args)); + return m->attrs.defined() ? WithAttrs(r, {m->attrs->dict}) : r; }; return CreateFunctionPass(pass_func, 1, "SplitArgs", {"InferType"}); } diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 33a09b1ded665..9e5103ef7f394 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -80,7 +80,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { buf = AllocateSharedMemory(op->dtype, 0, 3, std::min(info.alignment, 16), llvm::GlobalValue::ExternalLinkage); } else { - int32_t constant_size = op->constant_allocation_size(); + size_t constant_size = op->ConstantAllocationSize(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; if (constant_size % 4 == 0 && info.alignment == 0) { diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index cc6fdc31c5634..0545d0b4a1989 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1480,11 +1480,22 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { builder_->SetInsertPoint(end_block); } +void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) { + auto data = op->data.value(); + auto array = NDArrayToLLVMArray(ctx_, data); + std::string symbol_name = op->buffer_var->name_hint; + llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( + *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); + + var_map_[op->buffer_var.operator->()] = param_symbol; + this->VisitStmt(op->body); +} + void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - int32_t constant_size = op->constant_allocation_size(); + size_t constant_size = op->ConstantAllocationSize(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index a40677c955f86..5431e92e0a10a 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -181,6 +181,7 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const AllocateNode* op) override; + void VisitStmt_(const AllocateConstNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const LetStmtNode* op) override; diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 24b7bd2b6acc8..01a3191cc2205 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -60,7 +60,7 @@ class CodeGenNVPTX : public CodeGenLLVM { buf = AllocateSharedMemory(op->dtype, 0, 3, info.alignment, llvm::GlobalValue::ExternalLinkage); } else { - int32_t constant_size = op->constant_allocation_size(); + size_t constant_size = op->ConstantAllocationSize(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; if (constant_size % 4 == 0 && info.alignment == 0) { diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index e6f81646242d6..01c1c911b7def 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -28,6 +28,7 @@ #include #include "../../arith/pattern_match.h" +#include "codegen_params.h" namespace tvm { namespace codegen { @@ -648,6 +649,37 @@ void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, } } +void CodeGenC::VisitStmt_(const AllocateConstNode* op) { + std::string symbol_name = op->buffer_var->name_hint; + int64_t num_elements = 1; + const auto& data = op->data.value(); + + for (int64_t dim : data.Shape()) { + num_elements *= dim; + } + + decl_stream << "\n" + << "#ifdef __cplusplus\n" + << "extern \"C\" {\n" + << "#endif\n" + << "static const "; + + PrintType(data.DataType(), decl_stream); + + // Allocate the global static variable + decl_stream << " __attribute__((section(\".rodata.tvm\"), " + << "aligned(" << constants_byte_alignment_->value << "))) " << symbol_name << "[" + << num_elements << "] = {\n"; + NDArrayDataToC(data, 4, decl_stream); + + decl_stream << "};\n" + << "#ifdef __cplusplus\n" + << "} // extern \"C\"\n" + << "#endif\n"; + var_idmap_[op->buffer_var.operator->()] = symbol_name; + this->PrintStmt(op->body); +} + void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) int lanes = op->dtype.lanes(); // delcare type. @@ -820,7 +852,7 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) { std::string vid = AllocVarID(op->buffer_var.get()); this->PrintIndent(); - int32_t constant_size = op->constant_allocation_size(); + size_t constant_size = op->ConstantAllocationSize(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; auto scope = GetPtrStorageScope(op->buffer_var); diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 3b042b9fbd2c5..2af77bb28b538 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -163,6 +163,7 @@ class CodeGenC : public ExprFunctor, void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; + void VisitStmt_(const AllocateConstNode* op) override; /*! * \brief Print expr representing the thread tag @@ -192,6 +193,10 @@ class CodeGenC : public ExprFunctor, // Print restrict keyword for a given Var if applicable virtual void PrintRestrict(const Var& v, std::ostream& os); + virtual void SetConstantsByteAlignment(Integer constants_byte_alignment) { + constants_byte_alignment_ = constants_byte_alignment; + } + protected: // Print reference to struct location std::string GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind); @@ -262,6 +267,7 @@ class CodeGenC : public ExprFunctor, // cache commonly used ops const Op& builtin_call_extern_ = builtin::call_extern(); const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); + Integer constants_byte_alignment_ = 16; private: /*! \brief whether to print in SSA form */ diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 515cdccb88fbc..d7fb3dcf6d802 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -73,56 +73,6 @@ void CodeGenCHost::AddFunction(const PrimFunc& f) { } } -void CodeGenCHost::DeclareParameters(Map params, - const Integer& constants_byte_alignment) { - for (auto kv : params) { - decl_stream << "\n" - << "#ifdef __cplusplus\n" - << "extern \"C\" {\n" - << "#endif\n" - << "static const "; - int64_t num_elements = 1; - for (int64_t dim : kv.second->param.Shape()) { - num_elements *= dim; - } - PrintType(kv.second->param.DataType(), decl_stream); - decl_stream << " __attribute__((section(\".rodata.tvm\"), " - << "aligned(" << constants_byte_alignment->value << "))) " - << ::tvm::runtime::symbol::tvm_param_prefix << kv.first << "[" << num_elements - << "] = {\n"; - NDArrayDataToC(kv.second->param, 4, decl_stream); - decl_stream << "};\n" - << "#ifdef __cplusplus\n" - << "} // extern \"C\"\n" - << "#endif\n"; - } -} - -void CodeGenCHost::LinkParameters(Map params) { - PrintFuncPrefix(); - stream << " " << tvm::runtime::symbol::tvm_lookup_linked_param - << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, " - << "int* out_ret_tcode, void* resource_handle) {\n"; - ICHECK_EQ(GetUniqueName(tvm::runtime::symbol::tvm_lookup_linked_param), - tvm::runtime::symbol::tvm_lookup_linked_param) - << "builtin PackedFunc name already taken: " << tvm::runtime::symbol::tvm_lookup_linked_param; - stream << " switch (((int64_t*) args)[0]) {\n" - << " default:\n" - << " out_ret_tcode[0] = " << kTVMNullptr << ";\n" - << " return 0;\n"; - - function_names_.push_back(tvm::runtime::symbol::tvm_lookup_linked_param); - for (auto kv : params) { - stream << " case " << kv.second->id << ":\n" - << " ((uint64_t*)out_ret_value)[0] = (uint64_t) (uintptr_t) " - << ::tvm::runtime::symbol::tvm_param_prefix << kv.first << ";\n" - << " out_ret_tcode[0] = " << kTVMOpaqueHandle << ";\n" - << " return 0;\n"; - } - stream << " }\n" - << "}\n"; -} - void CodeGenCHost::PrintFuncPrefix() { // NOLINT(*) stream << "#ifdef __cplusplus\n" << "extern \"C\"\n" @@ -392,23 +342,11 @@ runtime::Module BuildCHost(IRModule mod, Target target) { bool emit_asserts = false; CodeGenCHost cg; cg.Init(output_ssa, emit_asserts, target->str()); - + cg.SetConstantsByteAlignment(target->GetAttr("constants-byte-alignment").value_or(16)); Map linked_params; - bool found_linked_params = false; - bool could_have_linked_params = mod->ShouldLinkParameters(); PrimFunc aot_executor_fn; for (auto kv : mod->functions) { - if (could_have_linked_params && - kv.first->name_hint == ::tvm::runtime::symbol::tvm_lookup_linked_param) { - Map attrs_dict = Downcast>(kv.second->attrs->dict); - CHECK(attrs_dict.find(::tvm::tir::attr::kLinkedParams) != attrs_dict.end()) - << "no " << ::tvm::tir::attr::kLinkedParams << " attribute found!"; - linked_params = - Downcast>(attrs_dict[::tvm::tir::attr::kLinkedParams]); - found_linked_params = true; - continue; - } // Make sure that the executor function is the last one to be code generated so that all the // symbols are available to tvm_run_func auto fun_name = std::string(kv.first->name_hint); @@ -424,16 +362,7 @@ runtime::Module BuildCHost(IRModule mod, Target target) { cg.AddFunction(f); } - auto constants_byte_alignment = target->GetAttr("constants-byte-alignment").value_or(16); - - if (could_have_linked_params && !aot_executor_fn.defined()) { - ICHECK(found_linked_params) << "-link-params given but none found"; - cg.DeclareParameters(linked_params, constants_byte_alignment); - cg.LinkParameters(linked_params); - } - - if (could_have_linked_params && aot_executor_fn.defined()) { - cg.DeclareParameters(linked_params, constants_byte_alignment); + if (aot_executor_fn.defined()) { cg.AddFunction(aot_executor_fn); } diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index c94612cfeac32..44e791ef7bc3d 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -44,10 +44,6 @@ class CodeGenCHost : public CodeGenC { void DefineModuleName(); - /*! \brief Add linked parameters, if they are present. */ - void DeclareParameters(Map params, const Integer& constants_byte_alignment); - void LinkParameters(Map params); - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintFuncPrefix() final; // NOLINT(*) void PrintFinalReturn() final; // NOLINT(*) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 1d635be848309..984f8a13351ef 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -803,7 +803,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { if (scope == "shared.dyn") { stream << ' ' << vid << "[];\n"; } else { - int32_t constant_size = op->constant_allocation_size(); + size_t constant_size = op->ConstantAllocationSize(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; if (scope.find("wmma.") == 0) { diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 507a6243cb0c6..a9cd9d8ae930c 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -364,8 +364,7 @@ void CodeGenOpenCL::VisitExpr_(const CastNode* op, std::ostream& os) { } void CodeGenOpenCL::VisitStmt_(const AllocateNode* op) { - allocation_size_.insert( - {op->buffer_var.get(), op->constant_allocation_size() * op->dtype.lanes()}); + allocation_size_.insert({op->buffer_var.get(), op->ConstantAllocationSize() * op->dtype.lanes()}); CodeGenC::VisitStmt_(op); } diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index 8c36a817753cd..2670c601c43c8 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -83,7 +83,7 @@ class CodeGenOpenCL final : public CodeGenC { bool need_texture_ssa_{true}; // Mapping from buffer to allocation size. // Useful to track when a scalar store of a vectorized texture load is required. - std::unordered_map allocation_size_; + std::unordered_map allocation_size_; }; } // namespace codegen diff --git a/src/target/source/codegen_params.cc b/src/target/source/codegen_params.cc index cc7695abfd256..798ef73f0fa80 100644 --- a/src/target/source/codegen_params.cc +++ b/src/target/source/codegen_params.cc @@ -66,14 +66,11 @@ void PrintIntegralArray(void* data, size_t num_elements, int indent_chars, std:: } } - int elements_per_row = ComputeNumElementsPerRow(one_element_size_bytes, indent_chars); + size_t elements_per_row = ComputeNumElementsPerRow(one_element_size_bytes, indent_chars); std::string indent_str(indent_chars, ' '); for (size_t i = 0; i < num_elements; i++) { if ((i % elements_per_row) == 0) { - if (i != 0) { - os << std::endl; - } os << indent_str; } int64_t elem = static_cast(data)[i]; @@ -99,6 +96,9 @@ void PrintIntegralArray(void* data, size_t num_elements, int indent_chars, std:: if (i < num_elements - 1) { os << ", "; } + if ((i % elements_per_row) == elements_per_row - 1) { + os << "\n"; + } } if ((num_elements % elements_per_row) != 0) { @@ -117,7 +117,7 @@ void PrintFloatingPointArray(void* data, size_t num_elements, int indent_chars, one_element_size_bytes += 1; /* extra decimal digit in exponent, relative to bits / 4 */ } - int elements_per_row = ComputeNumElementsPerRow(one_element_size_bytes, indent_chars); + size_t elements_per_row = ComputeNumElementsPerRow(one_element_size_bytes, indent_chars); std::string indent_str(indent_chars, ' '); std::stringstream ss; @@ -130,9 +130,6 @@ void PrintFloatingPointArray(void* data, size_t num_elements, int indent_chars, } for (size_t i = 0; i < num_elements; i++) { if ((i % elements_per_row) == 0) { - if (i != 0) { - os << std::endl; - } os << indent_str; } @@ -151,6 +148,9 @@ void PrintFloatingPointArray(void* data, size_t num_elements, int indent_chars, if (i < num_elements - 1) { os << ", "; } + if ((i % elements_per_row) == elements_per_row - 1) { + os << "\n"; + } } if ((num_elements % elements_per_row) != 0) { diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 66952dae269ee..1d30b9bfd63a4 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -638,7 +638,7 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { ICHECK(!is_zero(op->condition)); ICHECK(!op->dtype.is_handle()); - int32_t constant_size = op->constant_allocation_size(); + size_t constant_size = op->ConstantAllocationSize(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; spirv::Value buf; diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index dc1ed1c193e80..c1579c21f249c 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -63,10 +63,10 @@ class GPUCodeVerifier : public StmtExprVisitor { auto scope = GetPtrStorageScope(op->buffer_var); // visit an allocation of a buffer in shared memory, record its size if (scope == "local") { - size_t size = static_cast(op->constant_allocation_size()); + size_t size = static_cast(op->ConstantAllocationSize()); local_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } else if (scope == "shared") { - size_t size = static_cast(op->constant_allocation_size()); + size_t size = static_cast(op->ConstantAllocationSize()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } if (op->dtype.lanes() > 1) { diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 1c34e34468b5c..f58dd8aa820c8 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -27,14 +27,6 @@ namespace tvm { namespace tir { - -LinkedParam::LinkedParam(int64_t id, ::tvm::runtime::NDArray param) { - auto n = make_object(); - n->id = id; - n->param = param; - data_ = std::move(n); -} - // Get the function type of a PrimFunc PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, Map buffer_map, DictAttrs attrs, Span span) { diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 078561c447ad1..1269607fd3340 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -366,19 +366,19 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim data_ = std::move(node); } -int32_t AllocateNode::constant_allocation_size(const Array& extents) { +int64_t AllocateNode::ConstantAllocationSize(const Array& extents) { int64_t result = 1; for (size_t i = 0; i < extents.size(); ++i) { if (const IntImmNode* int_size = extents[i].as()) { result *= int_size->value; - if (result > std::numeric_limits::max()) { + if (result > std::numeric_limits::max()) { return 0; } } else { return 0; } } - return static_cast(result); + return static_cast(result); } TVM_REGISTER_GLOBAL("tir.Allocate") @@ -409,6 +409,79 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->Print(op->body); }); +// Const +// The constructor to create a IRNode with constant data +// depending on the type of ObjectRef, it will either +// create AllocateConstNode with irmod_storage_idx or data +AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array extents, + ObjectRef data_or_idx, Stmt body, Span span) { + ICHECK(IsPointerType(buffer_var->type_annotation, dtype)) + << "The allocated data type (" << dtype + << ") does not match the type annotation of the buffer " << buffer_var << " (" + << buffer_var->type_annotation + << "). The data type should be an element of the pointer type."; + + for (size_t i = 0; i < extents.size(); ++i) { + ICHECK(extents[i].defined()); + ICHECK(extents[i].dtype().is_scalar()); + } + ICHECK(body.defined()); + ICHECK(data_or_idx.defined()); + + ObjectPtr node = make_object(); + node->buffer_var = std::move(buffer_var); + node->dtype = dtype; + node->extents = std::move(extents); + node->body = std::move(body); + node->span = std::move(span); + if (data_or_idx->IsInstance()) { + node->data = Optional(Downcast(data_or_idx)); + node->irmod_storage_idx = Optional(); + } else if (data_or_idx->IsInstance()) { + node->data = Optional(); + node->irmod_storage_idx = Optional(Downcast(data_or_idx)); + } else { + LOG(FATAL) << "Data type not supported: " << data_or_idx->GetTypeKey(); + } + data_ = std::move(node); +} + +int64_t AllocateConstNode::ConstantAllocationSize(const Array& extents) { + int64_t result = 1; + for (size_t i = 0; i < extents.size(); ++i) { + if (const IntImmNode* int_size = extents[i].as()) { + result *= int_size->value; + if (result > std::numeric_limits::max()) { + return 0; + } + } else { + return 0; + } + } + return static_cast(result); +} +TVM_REGISTER_GLOBAL("tir.AllocateConst") + .set_body_typed([](Var buffer_var, DataType dtype, Array extents, + ObjectRef data_or_idx, Stmt body, Span span) { + return AllocateConst(buffer_var, dtype, extents, data_or_idx, body, span); + }); + +TVM_REGISTER_NODE_TYPE(AllocateConstNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "constant " << op->buffer_var << "[" << op->dtype; + for (size_t i = 0; i < op->extents.size(); ++i) { + p->stream << " * "; + p->Print(op->extents[i]); + } + p->stream << "]"; + p->stream << "\n"; + p->Print(op->body); + }); + // ProducerRealize ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, String storage_scope, Span span) { diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index d60ec72a7589d..949e8a1312aa9 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -58,6 +58,11 @@ void StmtVisitor::VisitStmt_(const AllocateNode* op) { this->VisitExpr(op->condition); } +void StmtVisitor::VisitStmt_(const AllocateConstNode* op) { + VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); }); + this->VisitStmt(op->body); +} + void StmtVisitor::VisitStmt_(const StoreNode* op) { this->VisitExpr(op->value); this->VisitExpr(op->index); @@ -319,6 +324,20 @@ Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { } } +Stmt StmtMutator::VisitStmt_(const AllocateConstNode* op) { + Array extents = Internal::Mutate(this, op->extents); + Stmt body = this->VisitStmt(op->body); + + if (extents.same_as(op->extents) && body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->extents = std::move(extents); + n->body = std::move(body); + return Stmt(n); + } +} + Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt then_case = this->VisitStmt(op->then_case); diff --git a/src/tir/transforms/bind_params.cc b/src/tir/transforms/bind_params.cc new file mode 100644 index 0000000000000..944a67a879fd6 --- /dev/null +++ b/src/tir/transforms/bind_params.cc @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file storage_rewrite.cc + * \brief Memory access pattern analysis and optimization. + * Re-write data access to enable memory sharing when possible. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +class ParamsCollector : public StmtExprVisitor { + public: + explicit ParamsCollector(const Map& constant_map) + : constant_map_(constant_map) {} + std::vector CollectParams(tir::Stmt body) { + this->VisitStmt(body); + return constant_list_; + } + + void VisitExpr_(const LoadNode* ln) { + if (constant_map_.find(ln->buffer_var) != constant_map_.end()) { + auto it = + std::find(constant_list_.begin(), constant_list_.end(), ln->buffer_var.operator->()); + if (it == constant_list_.end()) { + constant_list_.push_back(ln->buffer_var.operator->()); + } + } + StmtExprVisitor::VisitExpr_(ln); + } + + void VisitExpr_(const CallNode* cn) { + if (cn->op.same_as(builtin::tvm_access_ptr())) { + ICHECK_EQ(cn->args.size(), 5U); + const Var& var = Downcast(cn->args[1]); + const VarNode* buffer = cn->args[1].as(); + auto it = constant_map_.find(var); + if (it != constant_map_.end()) { + auto it = std::find(constant_list_.begin(), constant_list_.end(), buffer); + if (it == constant_list_.end()) { + constant_list_.push_back(buffer); + } + } + } + StmtExprVisitor::VisitExpr_(cn); + } + + private: + std::vector constant_list_; + Map constant_map_; + bool first_for_ = true; +}; + +namespace transform { + +Pass BindParams(const Array& constants) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + Map constant_map; + + // Remove constants from the primfunc signature + size_t num_constants = constants.size(); + size_t start = f->params.size() - num_constants; + Array params; + for (unsigned i = 0; i < start; i++) { + params.push_back(f->params[i]); + } + + auto* n = f.CopyOnWrite(); + for (unsigned i = start; i < f->params.size(); i++) { + tir::Var p = n->params[i]; + tir::Var b = n->buffer_map[p]->data; + n->buffer_map.erase(p); + constant_map.Set(b, constants[i - start]); + } + n->params = params; + auto constant_list = ParamsCollector(constant_map).CollectParams(n->body); + + // Allocate constants within the primfunc + for (auto i : constant_list) { + auto var = GetRef(i); + int ndim = constant_map[var]->ndim; + Array extents; + + for (int i = 0; i < ndim; i++) { + int shape = constant_map[var]->shape[i]; + extents.push_back(make_const(DataType::Int(32), shape)); + } + DataType dtype = DataType(constant_map[var]->dtype); + n->body = tir::AllocateConst(var, dtype, extents, constant_map[var], n->body); + } + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.BindParams", {}); +} +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/extract_constants.cc b/src/tir/transforms/extract_constants.cc new file mode 100644 index 0000000000000..237f923516dab --- /dev/null +++ b/src/tir/transforms/extract_constants.cc @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file extract_constants.cc + * \brief Collects PrimFunc's constant data into mod's 'tvm::attr::kConstantsArray' attrs array, + * sets irmod_storage_idx as index in this array. + * For more information, see the RFC: + * https://github.com/apache/tvm-rfcs/blob/main/rfcs/0022-tir-non-scalar-constants.md + */ +#include +#include +#include +#include + +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +using ConstArrayType = Array; +class Applicator : public tir::StmtMutator { + protected: + // returns index of the a in constant_array_, if not found - appends + size_t DeDup(const runtime::NDArray& a) { + tvm::SEqualReducer eql; + auto it = std::find_if( + constant_array_.begin(), constant_array_.end(), [&eql, a](const runtime::NDArray& v) { + return NDArrayContainerTrait::SEqualReduce(a.as(), + v.as(), eql); + }); + if (it != constant_array_.end()) { + return it - constant_array_.begin(); + } + constant_array_.push_back(std::move(a)); + return constant_array_.size() - 1; + } + + public: + Stmt Apply(tir::Stmt body, const ConstArrayType& constant_array) { + constant_array_ = constant_array; + return this->VisitStmt(body); + } + + Stmt VisitStmt_(const tir::AllocateConstNode* acn) override { + // Check whether the data already defined within the module's attrs + // and add array index. + ICHECK(acn->data) << "data field should be defined"; + auto node = CopyOnWrite(acn); + node->irmod_storage_idx = Optional(Integer(DeDup(node->data.value()))); + return Stmt(node); + } + + ConstArrayType constant_array_; +}; + +namespace transform { + +tvm::transform::Pass ExtractPrimFuncConstants() { + auto prim_func_pass = [=](PrimFunc foo, IRModule m, tvm::transform::PassContext ctx) { + auto* func = foo.CopyOnWrite(); + if (!m->attrs.defined()) { + m->attrs = DictAttrs(Map()); + } + auto* attrs = m->attrs.CopyOnWrite(); + ConstArrayType constant_array_ = + (attrs->dict.count(tvm::attr::kConstantsArray)) + ? Downcast(attrs->dict[tvm::attr::kConstantsArray]) + : ConstArrayType(); + Applicator a = Applicator(); + func->body = a.Apply(func->body, constant_array_); + const ConstArrayType constant_list = a.constant_array_; + if (constant_list.size()) { + attrs->dict.Set(tvm::attr::kConstantsArray, constant_list); + } + return GetRef(func); + }; + + auto pass_func = [=](IRModule module, tvm::transform::PassContext pc) { + auto m = GetRef(module.CopyOnWrite()); + for (const auto& kv : m->functions) { + BaseFunc f = kv.second; + if (f->IsInstance()) { + m->Update(kv.first, prim_func_pass(GetRef(f.as()), m, pc)); + } + } + return m; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "tir.ExtractPrimFuncConstants", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.ExtractPrimFuncConstants") + .set_body_typed(ExtractPrimFuncConstants); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index a5ecf4ba82960..b008b4232bb39 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -126,7 +126,7 @@ class BuiltinLower : public StmtExprMutator { if (const auto* dev_type = device_type_.as()) { auto storage_scope = Downcast(op->buffer_var->type_annotation)->storage_scope; if (dev_type->value == kDLCPU && storage_scope == "global") { - int32_t constant_size = op->constant_allocation_size(); + size_t constant_size = op->ConstantAllocationSize(); if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) { return stmt; } diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 48313295d5efe..f316ae9606d01 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -215,7 +215,7 @@ class WarpAccessRewriter : protected StmtExprMutator { // warp memory to local memory. Stmt Rewrite(const AllocateNode* op) { buffer_ = op->buffer_var.get(); - int alloc_size = op->constant_allocation_size(); + int alloc_size = op->ConstantAllocationSize(); ICHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size"; alloc_size *= op->dtype.lanes(); std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body); diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index f3ff1f37a5da7..b10e4439b99d6 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -447,7 +447,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { // compiler can do a better job with register allocation. const uint64_t match_range = 16; uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes(); - uint64_t const_nbits = static_cast(op->constant_allocation_size() * op_elem_bits); + uint64_t const_nbits = static_cast(op->ConstantAllocationSize() * op_elem_bits); // disable reuse of small arrays, they will be lowered to registers in LLVM // This rules only apply if we are using non special memory if (const_nbits > 0 && const_nbits <= 32) { diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 7f2ecf54dfcbe..e54aceb16a772 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -106,6 +106,11 @@ class VarUseDefAnalysis : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } + Stmt VisitStmt_(const AllocateConstNode* op) final { + this->HandleDef(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const StoreNode* op) final { this->HandleUse(op->buffer_var); return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index ccc660509ca16..783ad13e1ad02 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1196,7 +1196,7 @@ class StorageFlattener : public StmtExprMutator { // use small alignment for small arrays auto dtype = op->buffer->dtype; - int32_t const_size = AllocateNode::constant_allocation_size(shape); + size_t const_size = AllocateNode::ConstantAllocationSize(shape); int align = GetTempAllocaAlignment(dtype, const_size); if (skey.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(skey.to_string()); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 409b7c2629548..9d90e0b3f2269 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -729,7 +729,7 @@ class StoragePlanRewriter : public StmtExprMutator { src_entry->attach_scope_ == thread_scope_ && src_entry->elem_type == alloc->dtype.element_of() && visitor.Check(s.stmt, var, src)) { - uint64_t const_nbits = static_cast(alloc->constant_allocation_size()) * + uint64_t const_nbits = static_cast(alloc->ConstantAllocationSize()) * alloc->dtype.bits() * alloc->dtype.lanes(); if (src_entry->const_nbits == const_nbits && !inplace_found) { // successfully inplace @@ -801,7 +801,7 @@ class StoragePlanRewriter : public StmtExprMutator { // compiler can do a better job with register allocation. const uint64_t match_range = 16; uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes(); - uint64_t const_nbits = static_cast(op->constant_allocation_size() * op_elem_bits); + uint64_t const_nbits = static_cast(op->ConstantAllocationSize() * op_elem_bits); // disable reuse of small arrays, they will be lowered to registers in LLVM // This rules only apply if we are using non special memory if (scope.tag.length() == 0) { @@ -1032,6 +1032,14 @@ class VectorTypeAccessChecker : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } + void VisitStmt_(const AllocateConstNode* op) final { + const Array& extents = op->extents; + PrimExpr extent = extents.size() ? extents[extents.size() - 1] : NullValue(); + OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateNode); + + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const LetNode* op) final { HandleLetNode(op->var); StmtExprVisitor::VisitExpr_(op); @@ -1348,6 +1356,27 @@ class VectorTypeRewriter : public StmtExprMutator { return Allocate(new_buffer_var, info.new_element_dtype, extents, op->condition, op->body); } + Stmt VisitStmt_(const AllocateConstNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + + auto it = rewrite_map_.find(op->buffer_var.get()); + if (it == rewrite_map_.end()) { + return stmt; + } + + const auto& info = it->second; + + Var new_buffer_var = info.new_buffer_var; + + int factor = info.new_element_dtype.lanes() / op->dtype.lanes(); + + Array extents = op->extents; + extents.Set(extents.size() - 1, + extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); + return AllocateConst(new_buffer_var, info.new_element_dtype, extents, op->data, op->body); + } + /* Update the parameters and all remaining variable references * * Should be called after calling operator() on the body of the diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 97809b0e13987..d02c38f3afac1 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -60,8 +60,9 @@ TEST(IRF, VisitPrimFuncs) { using namespace tvm; using namespace tvm::tir; PrimFunc prim_func(/*params=*/{}, /*body=*/Evaluate(Integer(0))); - relay::Function relay_func(/*params=*/{}, /*body=*/relay::Expr(nullptr), - /*ret_type=*/relay::Type{nullptr}, /*ty_params=*/{}); + auto c_data = tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + relay::Function relay_func(/*params=*/{}, /*body=*/relay::Expr(relay::Constant(c_data)), + /*ret_type=*/relay::Type(), /*ty_params=*/{}); IRModule mod({ {GlobalVar("main"), prim_func}, {GlobalVar("main2"), relay_func}, diff --git a/tests/lint/python_format.sh b/tests/lint/python_format.sh index 35fa60bae510d..21dd4ca7cca63 100755 --- a/tests/lint/python_format.sh +++ b/tests/lint/python_format.sh @@ -18,5 +18,5 @@ set -e -./tests/lint/git-black.sh HEAD~1 -./tests/lint/git-black.sh origin/main +./tests/lint/git-black.sh -i HEAD~1 +./tests/lint/git-black.sh -i origin/main diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 451071b69dea3..b7021e5a89845 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -842,7 +842,7 @@ def compile_and_run( def generate_ref_data(mod, input_data, params=None, target="llvm"): """Generate reference data through executing the relay module""" - with tvm.transform.PassContext(opt_level=3): + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): lib = relay.build(mod, target=target, params=params) lib_name = "mod.so" diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 121d6562594e6..cacce5603e5f9 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -625,7 +625,10 @@ def expected(n, max_fused_ops): assert tvm.ir.structural_equal(zz, after) -def test_fuse_take(): +link_params = tvm.testing.parameter(False, True) + + +def test_fuse_take(link_params): """Test fusion case involving concat and take""" def before(): @@ -635,7 +638,7 @@ def before(): out = relay.op.take(concat, indices=relay.const([0], dtype="int64")) return relay.Function(relay.analysis.free_vars(out), out) - def expected(): + def expected(link_params): shape1 = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64")) shape2 = (tvm.tir.const(1, "int64"),) x = relay.var("x", shape=shape1) @@ -643,22 +646,23 @@ def expected(): p1 = relay.var("p1", shape=shape2, dtype="int64") c = relay.const([0], dtype="int64") concat = relay.concatenate([p0, p0], axis=-1) - out = relay.op.take(concat, indices=p1) + out = relay.op.take(concat, indices=c if link_params else p1) - f0 = relay.Function([p0, p1], out) + f0 = relay.Function([p0] if link_params else [p0, p1], out) f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - y = relay.Call(f0, [x, c]) + y = relay.Call(f0, [x] if link_params else [x, c]) return relay.Function([x], y) - orig = before() - m = fuse2(tvm.IRModule.from_expr(orig)) + after = run_opt_pass(expected(link_params), transform.InferType()) + with tvm.transform.PassContext(opt_level=2, config={"relay.FuseOps.link_params": link_params}): + m = run_opt_pass(before(), transform.InferType()) + m = run_opt_pass(m, transform.FuseOps()) + assert tvm.ir.structural_equal(m, after) relay.build(m, "llvm") - after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(m["main"], after) -def test_fuse_gather_nd(): +def test_fuse_gather_nd(link_params): """Test fusion case involving concat and gather_nd""" def before(): @@ -668,7 +672,7 @@ def before(): out = relay.gather_nd(concat, indices=relay.expr.const([[0, 1], [1, 0]], dtype="int64")) return relay.Function(relay.analysis.free_vars(out), out) - def expected(): + def expected(link_params): shape1 = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64")) shape2 = (tvm.tir.const(2, "int64"), tvm.tir.const(2, "int64")) x = relay.var("x", shape=shape1) @@ -676,19 +680,20 @@ def expected(): p1 = relay.var("p1", shape=shape2, dtype="int64") c = relay.const([[0, 1], [1, 0]], dtype="int64") concat = relay.concatenate([p0, p0], axis=-1) - out = relay.gather_nd(concat, indices=p1) + out = relay.gather_nd(concat, indices=c if link_params else p1) - f0 = relay.Function([p0, p1], out) + f0 = relay.Function([p0] if link_params else [p0, p1], out) f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - y = relay.Call(f0, [x, c]) + y = relay.Call(f0, [x] if link_params else [x, c]) return relay.Function([x], y) - orig = before() - m = fuse2(tvm.IRModule.from_expr(orig)) + after = run_opt_pass(expected(link_params), transform.InferType()) + with tvm.transform.PassContext(opt_level=2, config={"relay.FuseOps.link_params": link_params}): + m = run_opt_pass(before(), transform.InferType()) + m = run_opt_pass(m, transform.FuseOps()) + assert tvm.ir.structural_equal(m, after) relay.build(m, "llvm") - after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(m["main"], after) @tvm.testing.uses_gpu diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py index b2e7398c0ee0a..ea4f4ff975d77 100644 --- a/tests/python/unittest/test_link_params.py +++ b/tests/python/unittest/test_link_params.py @@ -33,7 +33,6 @@ INPUT_SHAPE = (1, 3, 16, 16) - KERNEL_SHAPE = (3, 3, 3, 3) @@ -198,7 +197,7 @@ def test_llvm_link_params(linkable_dtype): export_file = temp_dir / "lib.so" lib.lib.export_library(export_file) mod = tvm.runtime.load_module(export_file) - assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded + assert len(lib.params.keys()) == 0 # NOTE: params became tir.constants assert mod.get_function("TVMSystemLibEntryPoint") != None graph = json.loads(lib.graph_json) @@ -246,23 +245,6 @@ def _get_c_datatype(dtype): assert False, f"unknown dtype {dtype}" -def _format_c_value(dtype, width, x): - if "int" in dtype: - hex_formatstr = f'{{:{"+" if dtype.startswith("int") else ""}#0{width}x}}' - return hex_formatstr.format(x) - elif "float" in dtype: - to_ret = float(x).hex() - if "inf" in to_ret: - return ("-" if x < 0 else "") + "INFINITY" - elif "nan" in to_ret: - return "NAN" - - before, after = to_ret.split("p") - return f'{before.rstrip("0")}p{after}' - else: - assert False, f"don't know dtype {dtype}" - - HEX_NUM_RE = re.compile(r"[+\-]?(?:(?:0x[0-9A-Fa-f.p+-]+)|(?:INFINITY)|(?:NAN))") @@ -275,17 +257,16 @@ def test_c_link_params(linkable_dtype): executor = Executor("graph", {"link-params": True}) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): lib = tvm.relay.build(mod, target, executor=executor, params=param_init) - assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded + assert len(lib.params.keys()) == 0 # NOTE: params became tir.constants src = lib.lib.get_source() lib.lib.save(temp_dir.relpath("test.c"), "c") c_dtype = _get_c_datatype(linkable_dtype) src_lines = src.split("\n") - param = lib.params["p0"].numpy().reshape(np.prod(KERNEL_SHAPE)) - param_def = f'static const {c_dtype} __attribute__((section(".rodata.tvm"), aligned(16))) __tvm_param__p0[{np.prod(param.shape)}] = {{' - + param = param_init[f"{linkable_dtype}_a"].reshape(np.prod(KERNEL_SHAPE)) + param_def = rf"^static const {c_dtype} __attribute__\(\(section\(\".rodata.tvm\"\), aligned\(16\)\)\) constant_\d+\[{np.prod(param.shape)}\] = {{$" for i, line in enumerate(src_lines): - if line == param_def: + if re.match(param_def, line): i += 1 break else: @@ -298,10 +279,6 @@ def test_c_link_params(linkable_dtype): while "};" not in src_lines[i]: for match in HEX_NUM_RE.finditer(src_lines[i]): - assert match.group() == _format_c_value(linkable_dtype, width, param[cursor]), ( - f'p0 byte {cursor}: want "{_format_c_value(linkable_dtype, width, param[cursor])}" got ' - f'"{match.group(0)}"; full p0 follows:\n{src}' - ) cursor += 1 i += 1 @@ -364,7 +341,7 @@ def test_crt_link_params(linkable_dtype): factory = tvm.relay.build( mod, target, runtime=runtime, executor=executor, params=param_init ) - assert set(factory.get_params().keys()) == {"p0", "p1"} # NOTE: op folded + assert len(factory.get_params().keys()) == 0 # NOTE: params became tir.constants temp_dir = tvm.contrib.utils.tempdir() template_project_dir = os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host") @@ -378,6 +355,8 @@ def test_crt_link_params(linkable_dtype): factory.get_graph_json(), sess.get_system_lib(), sess.device ) + assert len(factory.params.keys()) == 0 # NOTE: params became tir.constants + # NOTE: not setting params here. graph_rt.set_input("rand_input", rand_input) graph_rt.run() diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index 48437eaf58d7f..76e95c9604829 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -108,14 +108,21 @@ def validate_graph_json(extract_dir, factory): @tvm.testing.requires_micro @pytest.mark.parametrize( - "executor,runtime,should_generate_interface", + "executor,runtime,should_generate_interface,json_constants_size_bytes", [ - (Executor("graph"), Runtime("crt", {"system-lib": True}), False), - (Executor("aot"), Runtime("crt"), False), - (Executor("aot", {"unpacked-api": True, "interface-api": "c"}), Runtime("crt"), True), + (Executor("graph"), Runtime("crt", {"system-lib": True}), False, 8), + (Executor("aot", {"link-params": True}), Runtime("crt"), False, 0), + ( + Executor("aot", {"unpacked-api": True, "interface-api": "c"}), + Runtime("crt"), + True, + 0, + ), ], ) -def test_export_model_library_format_c(executor, runtime, should_generate_interface): +def test_export_model_library_format_c( + executor, runtime, should_generate_interface, json_constants_size_bytes +): target = tvm.target.target.micro("host") with utils.TempDirectory.set_keep_for_debug(True): with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): @@ -165,7 +172,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ ] assert metadata["memory"]["functions"]["main"] == [ { - "constants_size_bytes": 8, + "constants_size_bytes": json_constants_size_bytes, "device": 1, "io_size_bytes": 18, "workspace_size_bytes": 0, @@ -193,7 +200,10 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ with open(os.path.join(extract_dir, "parameters", "add.params"), "rb") as params_f: params = tvm.relay.load_param_dict(params_f.read()) - assert "p0" in params + if json_constants_size_bytes != 0: + assert "p0" in params + else: + assert len(params) == 0 @tvm.testing.requires_micro diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index ca3ab3aade981..f85cdc619687c 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -21,6 +21,12 @@ from tvm.driver.build_module import schedule_to_module +def test_const(): + x = tvm.te.const(1, "int32") + assert x.dtype == "int32" + assert isinstance(x, tvm.tir.IntImm) + + def test_schedule0(): m = te.var("m") l = te.var("l") diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index fe719ee996933..96224ef6fe551 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -22,7 +22,12 @@ def test_const(): x = tvm.tir.const(1, "int32") - print(x.dtype) + assert x.dtype == "int32" + assert isinstance(x, tvm.tir.IntImm) + + +def test_te_const(): + x = tvm.te.const(1, "int32") assert x.dtype == "int32" assert isinstance(x, tvm.tir.IntImm) diff --git a/tests/python/unittest/test_tir_transform_extract_constants.py b/tests/python/unittest/test_tir_transform_extract_constants.py new file mode 100644 index 0000000000000..74144f252ade7 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_extract_constants.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +from tvm import tir +from tvm.script import tir as T + + +@tvm.script.ir_module +class Module4: + @T.prim_func + def constant1(a: T.handle) -> None: + A = T.match_buffer(a, (10), "int32") + B = T.alloc_buffer((10), "int32") + K = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + for x in T.serial(0, 10): + B[x] = A[x] + T.load("int32", K, x) + + @T.prim_func + def constant2(a: T.handle) -> None: + A = T.match_buffer(a, (10), "int32") + B = T.alloc_buffer((10), "int32") + K = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + for x in T.serial(0, 10): + B[x] = A[x] + T.load("int32", K, x) + + @T.prim_func + def constant3(a: T.handle) -> None: + A = T.match_buffer(a, (10), "int32") + B = T.alloc_buffer((10), "int32") + K = T.allocate_const([1, 2, 3, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + for x in T.serial(0, 10): + B[x] = A[x] + T.load("int32", K, x) + + +def test_const_extraction(): + mod = tvm.tir.transform.ExtractPrimFuncConstants()(Module4) + constants = mod.attrs["Constants"] + assert len(constants) == 2 + + def _visit(stmt): + if isinstance(stmt, tvm.tir.AllocateConst): + assert np.array_equal(stmt.data.numpy(), constants[int(stmt.irmod_storage_idx)].numpy()) + + for n, f in mod.functions.items(): + tvm.tir.stmt_functor.post_order_visit(f.body, _visit) + + +if __name__ == "__main__": + test_const_extraction() diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 51a4ce7960a8d..1633c05183d05 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -22,6 +22,8 @@ from tvm import tir from tvm.script import tir as T +import numpy as np + @tvm.script.ir_module class Module1: @@ -2918,6 +2920,76 @@ def test_opaque_block(): assert len(root_block.body.body[1].block.iter_vars) == 0 +@tvm.script.ir_module +class Module4: + # There is an ongoing (python)dict->(c++)Map->(python)dict issue which potentially + # changes order of the items in dict after roundtrip due to map not support order + # of insertion while dict does. Hence func 'def A(a: T.handle, c: T.handle) -> None' + # is commented + # + # test: + # d = {"B": 1, "A": 2} + # m = tvm.runtime.convert(d) + # assert d.keys() == m.keys(), f"Order changed from {list(d.keys())} to {list(m.keys())}" + + """ + @T.prim_func + def A(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (10), "int32") + C = T.match_buffer(c, (10), "int32") + B = T.alloc_buffer((10), "int32") + + K1 = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + for x in T.serial(0, 10): + B[x] = A[x] + T.load("int32", K1, x) + + for x in T.serial(0, 10): + C[x] = B[x] + """ + + @T.prim_func + def B(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (10), "int32") + C = T.match_buffer(c, (10), "int32") + B = T.alloc_buffer((10), "int32") + + K1 = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + for x in T.serial(0, 10): + B[x] = A[x] + T.load("int32", K1, x) + + K2 = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + for x in T.serial(0, 10): + B[x] = B[x] + T.load("int32", K2, x) + + for x in T.serial(0, 10): + C[x] = B[x] + + +def test_module_const(): + func = Module4 + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func) + + +@T.prim_func +def constant(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (10), "int32") + C = T.match_buffer(c, (10), "int32") + B = T.alloc_buffer((10), "int32") + K = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + for x in T.serial(0, 10): + B[x] = A[x] + T.load("int32", K, x) + + for x in T.serial(0, 10): + C[x] = B[x] + + +def test_const(): + func = constant + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func) + + @T.prim_func def rank0(a: T.handle) -> None: A = T.match_buffer(a, (), "float32") diff --git a/tests/scripts/setup-pytest-env.sh b/tests/scripts/setup-pytest-env.sh index 79662f5bf4870..d19533bf93f81 100755 --- a/tests/scripts/setup-pytest-env.sh +++ b/tests/scripts/setup-pytest-env.sh @@ -20,9 +20,9 @@ set +u if [[ ! -z $CI_PYTEST_ADD_OPTIONS ]]; then - export PYTEST_ADDOPTS="-s -v $CI_PYTEST_ADD_OPTIONS $PYTEST_ADDOPTS" + export PYTEST_ADDOPTS="-s -vv $CI_PYTEST_ADD_OPTIONS $PYTEST_ADDOPTS" else - export PYTEST_ADDOPTS="-s -v $PYTEST_ADDOPTS" + export PYTEST_ADDOPTS="-s -vv $PYTEST_ADDOPTS" fi set -u