Skip to content

Commit

Permalink
[TIR] Tir constants integration into compilation pipeline (apache#8509)
Browse files Browse the repository at this point in the history
* [TIR] Introduce tir.allocate_const to TIR

This PR is adding non-scalar constant representation in TIR. This is used to
express constants (i.e., parameters) in the TIR instead of bypassing the
TIR as it's done until now.

Change-Id: Id3afc4d7197260cb43ecde60f05ccbce3fc42430

Co-authored-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Change-Id: Id4a09a637c9c1fd7d49989c6c10f474a78569e18

* [TIR] Integrate tir constant nodes in compilation pipeline

This PR integrates tir.allocate_const to the compilation pipeline to support --link-params.

Change-Id: Ic8d0cb75d596299fcae7078b304598afbf0c5494

Co-authored-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Change-Id: Id98cc682bbfacfe75c4d8b260fd41658f1f196b2

* [TIR] tir.const extraction

This commit tries to implement an amendment to tir.constant RFC
with centralized storage of constant data within the IRModule
Please note that data and irmod_storage_idx are not mutual exclisive
further more the irmod_storage_idx is valid only immediatly after
prim func addition to the mod or after update within the mod.
If prim func is out of the the module scope then the index become
meangless. irmod_storage_idx also is not used in calculation of hash
function of the tir.constant node.

Change-Id: I40742ed580468b0252ea3fec02184cba65e20871

* unit test fixed

Change-Id: Ied2186554d4cbad44b2346216c8be92449e55732

* cmsis-nn codegen fix

Now handled case when params of the functions came as constants

Change-Id: I5874e182e34ef94e23048eaf3c61b01a56d91131

* Fixes for unittests

Change-Id: I5b82ee3f80337155706b5470973f494a301b5d90

* Rebasing tests fixes

Change-Id: I94ac87907081bab53c1dd1ab2db106ae057b4b19

* Linter: added method param description

Change-Id: I2f8c4c8d244b74c794abaa6079c46cc593ffcbdb

* Printing removal fix

This patch removes forgotten print in fuse_ops

Change-Id: I4bb5934f3b4cd5fde19d36a8e3319aae136bce8a

* Bugfix

Fixed concurrent map update bug here

Change-Id: Ifec3bf5030086d9079b9e493096f17dfd82297ec

* Reworked logic for not to introduce empty constant list to modue attrs

Change-Id: I082c85b3b4b70c218f0d714f5613ef6e178bd020

* Added support for tir builtin::tvm_access_ptr

This fixed unit tests for tests/python/integration/test_arm_mprofile_dsp.py

Change-Id: I10919f301ef9ddc3fd87f0e1a8414e9a52fc7938

* Unit test fix

Fixes unit tests in torch frontend

Change-Id: I6c179834f93dd202605d1ce5a7f07d987b9dc469

* Addressed requested changes

Addressed changes requested upstream

Change-Id: I741e52b89eb285732c23b1ac7ff277e757a088c3

* Namespace usage changed to conform earlier C++ standard

Change-Id: I1b29238cfe2a6bedb525f4f823a3a540f631d836

* Bugfix

Change-Id: I57a44b714b307278a243817ec2864e53ad31366b

* updated IRModuleNode::ExtractPrimFuncConstants

Updated IRModuleNode::ExtractPrimFuncConstants as per
request upstream.

Change-Id: I35db0145fb5827efd0445ce665d0c99465274016

* Minor changes

typo fixd
renamed ExtractPrimFuncConstants to ExtractConstants
removed getters/setters from FuseMutator and added parametrized
constructor

Change-Id: Ib2326805781779b88c963a8642ff683c8755956e

* Moved LinkedParam/LinkedParamNode

Moved LinkedParam/LinkedParamNode from tvm::tir namespace to tvm
namespace

Change-Id: Ie3f0303bd4f7890c6d680268c91f2051977bc7f4

* Addressed upstream comments

Changed BindParams argument to Array<NDArray>
Removed 'name' argument from te.const
Switched to in-depth comparision of NDArrays in constant de-duplication
Removed extra final comma from NDArrayToTIR
Changed return type of ConstantAllocationSize to int64_t
Made link_param a tvm.testing.parameter for test_fuse_take and test_fuse_gather_nd

Change-Id: I4285099cc63756aa5ebe91a5bd207d4135499b41

* Removed unnecessary forward declaration

+linter

Change-Id: I2a6c0d1f97773aeb1ae3f458da252a22079ccdb1

* Constant extractor now is a separate pass

Change-Id: Ia4adca9d3315b26fbdc006ef7c115900c081e303

* Added forgotten file + unit test fix

Change-Id: Ice305f4fefd13fe95e97574e6d63ffeb664621df

* Changed to IRModule pass

Refactored ExtractPrimFuncConstants to IRModule pass.
deDup -> DeDup
Refactored logic of Applicator supplementary class

Change-Id: I6c120d175eb6790ba90f176c4f856bde8f0c7c94

* bugfix after rebasing

Change-Id: Ie3ee6ea2479476a30f486baef74f20070f117942

* -v -> -vv to have more debug information

Change-Id: I12c63731663b9c9ea574b9ed5cb17311ba3cf701

Co-authored-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
  • Loading branch information
2 people authored and pfk-beta committed Apr 11, 2022
1 parent 391bf17 commit 0dfa62d
Show file tree
Hide file tree
Showing 70 changed files with 1,221 additions and 338 deletions.
42 changes: 42 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,44 @@
#include <vector>

namespace tvm {
/*!
* \brief Describes one parameter that should be linked into the generated module.
*
* When parameters are to be linked in with generated code (i.e. on target_host-compatible
* backends), Relay attaches instances of this object to a global TIR function. Code-generators
* use the information contained in this node to include the parameter data in the generated
* module.
*/
class LinkedParamNode : public Object {
public:
/*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
int64_t id;

/*! \brief Parameter data which should get linked into the final module. */
::tvm::runtime::NDArray param;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("id", &id);
v->Visit("param", &param);
}

static constexpr const char* _type_key = "tir.LinkedParam";
TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
};

/*!
* \brief Managed reference to LinkedParamNode.
*/
class LinkedParam : public ObjectRef {
public:
TVM_DLL LinkedParam(int64_t id, tvm::runtime::NDArray param);

TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
};

class IRModule;

/*!
* \brief IRModule that holds functions and type definitions.
*
Expand Down Expand Up @@ -504,6 +541,11 @@ constexpr const char* kRuntime = "runtime";
*/
constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools";

/*
* \brief Module attribute for tir constants
*/
constexpr const char* kConstantsArray = "Constants";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_MODULE_H_
9 changes: 9 additions & 0 deletions include/tvm/node/structural_hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <tvm/node/functor.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/ndarray.h>

#include <functional>
#include <string>
Expand Down Expand Up @@ -199,5 +200,13 @@ class SHashReducer {
bool map_free_vars_;
};

class SEqualReducer;
struct NDArrayContainerTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce);
static bool SEqualReduce(const runtime::NDArray::Container* lhs,
const runtime::NDArray::Container* rhs, SEqualReducer equal);
};

} // namespace tvm
#endif // TVM_NODE_STRUCTURAL_HASH_H_
7 changes: 4 additions & 3 deletions include/tvm/relay/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ class ExecutorNode : public Object {
}

static constexpr const char* _type_key = "Executor";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(ExecutorNode, Object);
};

Expand All @@ -122,8 +124,6 @@ class ExecutorNode : public Object {
*/
class Executor : public ObjectRef {
public:
Executor() = default;

/*!
* \brief Create a new Executor object using the registry
* \throws Error if name is not registered
Expand All @@ -147,7 +147,8 @@ class Executor : public ObjectRef {
TVM_DLL static Map<String, String> ListExecutorOptions(const String& name);

/*! \brief specify container node */
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Executor, ObjectRef, ExecutorNode);
TVM_DEFINE_OBJECT_REF_METHODS(Executor, ObjectRef, ExecutorNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExecutorNode)

private:
/*!
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,12 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
* \param import_set Already imported external modules.
* \param device The device on which all primitives will be executed.
* \param target The compiler target flag for compiling primitives.
* \param attrs Attributes for the expression to be evaluated with
* @return The object representing the result.
*/
ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, Device device, Target target);
std::unordered_set<String> import_set, Device device, Target target,
Map<String, ObjectRef> attrs = {});

} // namespace relay
} // namespace tvm
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/relay/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ class RuntimeNode : public Object {
}

static constexpr const char* _type_key = "Runtime";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(RuntimeNode, Object);
};

Expand Down
38 changes: 1 addition & 37 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,42 +151,6 @@ class PrimFunc : public BaseFunc {
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
};

/*!
* \brief Describes one parameter that should be linked into the generated module.
*
* When parameters are to be linked in with generated code (i.e. on target_host-compatible
* backends), Relay attaches instances of this object to a global TIR function. Code-generators
* use the information contained in this node to include the parameter data in the generated
* module.
*/
class LinkedParamNode : public Object {
public:
/*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
int64_t id;

/*! \brief Parameter data which should get linked into the final module. */
::tvm::runtime::NDArray param;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("id", &id);
v->Visit("param", &param);
}

static constexpr const char* _type_key = "tir.LinkedParam";
TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
};

/*!
* \brief Managed reference to LinkedParamNode.
*/
class LinkedParam : public ObjectRef {
public:
TVM_DLL LinkedParam(int64_t id, ::tvm::runtime::NDArray param);

TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
};

/*!
* \brief Tensor intrinsics for tensorization
*/
Expand Down Expand Up @@ -239,7 +203,7 @@ class TensorIntrin : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode)
};

/*!
/*
* \brief Specialize parameters of PrimFunc.
* \param func The PrimFunc to be specialized.
* \param param_map The mapping from function params to the instance.
Expand Down
96 changes: 94 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -559,16 +559,18 @@ class AllocateNode : public StmtNode {
* Otherwise return 0.
* \return The result.
*/
int32_t constant_allocation_size() const { return constant_allocation_size(extents); }
int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); }
/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
* \param extents The extents of the buffer.
* \return The result.
*/
TVM_DLL static int32_t constant_allocation_size(const Array<PrimExpr>& extents);
TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);

static constexpr const char* _type_key = "tir.Allocate";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
};

Expand All @@ -585,6 +587,96 @@ class Allocate : public Stmt {
TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
};

/*!
* \brief Allocate a buffer that can be used in body.
*/
class AllocateConstNode : public StmtNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;
/*! \brief The optional data associated to the constant.
*/
Optional<runtime::NDArray> data;
/*! \brief If the PrimFunc containing the Stmt is added to IRModule,
this is an optional index to indicate the index within
"Constants" attribute, that is a Array<NDArray> of IRModule.
*/
Optional<Integer> irmod_storage_idx;
/*! \brief The type of the buffer. */
DataType dtype;
/*! \brief The extents of the buffer. */
Array<PrimExpr> extents;
/*! \brief The body to be executed. */
Stmt body;
/*!
* \brief Additional annotations about the allocation.
*
* These annotations can be used as auxiliary hint
* to future transformations.
*/
Map<String, ObjectRef> annotations;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
v->Visit("data", &data);
v->Visit("irmod_storage_idx", &irmod_storage_idx);
v->Visit("dtype", &dtype);
v->Visit("extents", &extents);
v->Visit("body", &body);
v->Visit("annotations", &annotations);
v->Visit("span", &span);
}

bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const {
return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body) &&
equal(annotations, other->annotations);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(buffer_var);
hash_reduce(dtype);
hash_reduce(extents);
hash_reduce(body);
hash_reduce(annotations);
hash_reduce(data);
}

/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
* \return The result.
*/
int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); }
/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
* \param extents The extents of the buffer.
* \return The result.
*/
TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);

static constexpr const char* _type_key = "tir.AllocateConst";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode);
};

/*!
* \brief Managed reference to AllocateConstNode.
* \sa AllocateConstNode
*/
class AllocateConst : public Stmt {
public:
/* The constructor to create a IRNode with constant data
* depending on the type of ObjectRef, it will either
* create AllocateConstNode with irmod_storage_idx or data
*/
TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
ObjectRef data_or_idx, Stmt body, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
};

/*!
* \brief The container of seq statement.
* Represent a sequence of statements.
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
Expand All @@ -113,6 +114,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(ForNode);
IR_STMT_FUNCTOR_DISPATCH(WhileNode);
IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode);
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
Expand Down Expand Up @@ -155,6 +157,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const WhileNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AllocateConstNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const BufferRealizeNode* op) override;
Expand Down Expand Up @@ -255,6 +258,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const ForNode* op) override;
Stmt VisitStmt_(const WhileNode* op) override;
Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const AllocateConstNode* op) override;
Stmt VisitStmt_(const StoreNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const BufferRealizeNode* op) override;
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/tir/function.h>

#include <string>
#include <vector>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -601,6 +602,15 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner();
*/
TVM_DLL Pass InjectSoftwarePipeline();

TVM_DLL Pass BindParams(const Array<runtime::NDArray>& constants);

/*!
* \brief Pass to collect tir non-scalar constants into module's 'Constants' attribute.
*
* \return The pass.
*/
TVM_DLL Pass ExtractPrimFuncConstants();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
Loading

0 comments on commit 0dfa62d

Please sign in to comment.