From f4e8cf179ab678a6be7864d1c6efdd8bb3f99feb Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 9 Jun 2020 22:28:52 -0700 Subject: [PATCH] [REFACTOR][TIR] Provide->ProducerStore, Realize->ProducerRealize. (#5750) This PR finishes up the final step for DSL/TIR de-coupling to refactor Provide/Realize to use the DataProducer. As in the case of ProducerLoad, ProducerStore/Realize are not supposed to appear in a vaid TIR function ans are only used by high-level DSLs as intermediate structures. --- include/tvm/tir/stmt.h | 170 +++++++----------- include/tvm/tir/stmt_functor.h | 16 +- python/tvm/te/hybrid/parser.py | 8 +- python/tvm/te/hybrid/util.py | 8 +- python/tvm/tir/__init__.py | 4 +- python/tvm/tir/stmt.py | 40 ++--- src/contrib/hybrid/codegen_hybrid.cc | 42 ++--- src/contrib/hybrid/codegen_hybrid.h | 12 +- src/te/operation/compute_op.cc | 9 +- src/te/operation/cross_thread_reduction.cc | 4 +- src/te/operation/extern_op.cc | 3 +- src/te/operation/hybrid_op.cc | 10 +- src/te/operation/scan_op.cc | 2 +- src/te/schedule/schedule_ops.cc | 26 ++- ...hedule_postproc_rewrite_for_tensor_core.cc | 137 ++++++-------- .../schedule/schedule_postproc_to_primfunc.cc | 14 +- src/tir/ir/stmt.cc | 95 ++++------ src/tir/ir/stmt_functor.cc | 16 +- src/tir/transforms/inject_virtual_thread.cc | 2 +- src/tir/transforms/remove_no_op.cc | 4 +- src/tir/transforms/storage_flatten.cc | 4 +- src/tir/transforms/vectorize_loop.cc | 14 +- tests/lint/git-clang-format.sh | 14 +- .../python/unittest/test_te_hybrid_script.py | 40 ++--- tests/python/unittest/test_te_tensor.py | 4 +- tests/python/unittest/test_tir_constructor.py | 10 -- .../test_tir_transform_lower_warp_memory.py | 2 +- 27 files changed, 301 insertions(+), 409 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index dda964203e2a..118ec0fc809a 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -334,44 +334,92 @@ class BufferRealize : public Stmt { }; /*! - * \brief Store value into mult-dimensional array defined by func. + * \brief Store value into mult-dimensional array that will be read by the consumer + * of the producer. * - * \note Deprecated, move to BufferStore in the future. + * \note This node only appears in high-level DSLs that are built on top of the TIR. + * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower + * this node before TIR transformations. + * + * \sa DataProducer */ -class ProvideNode : public StmtNode { +class ProducerStoreNode : public StmtNode { public: - /*! \brief The function to be updated. */ - FunctionRef func; - /*! \brief The output value index if func's value is a tuple. */ - int value_index{0}; + /*! \brief The producer to store the results into. */ + DataProducer producer; /*! \brief The value to be stored. */ PrimExpr value; /*! \brief The index arguments of the function. */ - Array args; + Array indices; void VisitAttrs(AttrVisitor* v) { - v->Visit("func", &func); - v->Visit("value_index", &value_index); + v->Visit("producer", &producer); v->Visit("value", &value); - v->Visit("args", &args); + v->Visit("indices", &indices); } - bool SEqualReduce(const ProvideNode* other, SEqualReducer equal) const { - return equal(func, other->func) && equal(value_index, other->value_index) && - equal(value, other->value) && equal(args, other->args); + bool SEqualReduce(const ProducerStoreNode* other, SEqualReducer equal) const { + return equal(producer, other->producer) && equal(value, other->value) && + equal(indices, other->indices); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(func); - hash_reduce(value_index); + hash_reduce(producer); hash_reduce(value); - hash_reduce(args); + hash_reduce(indices); + } + + TVM_DLL static Stmt make(DataProducer producer, PrimExpr value, Array indices); + + static constexpr const char* _type_key = "ProducerStore"; + TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode); +}; + +/*! + * \brief Annotate the bounds where the data produced by the producer + * need to be written and read in body. + * We will need to allocate space for the corresponding regions. + * + * \note This node only appears in high-level DSLs that are built on top of the TIR. + * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower + * this node before TIR transformations. + * + * \sa DataProducer + */ +class ProducerRealizeNode : public StmtNode { + public: + /*! \brief The producer that produces the data. */ + DataProducer producer; + /*! \brief Bounds to be realized. */ + Region bounds; + /*! \brief Only realize if condition holds. */ + PrimExpr condition; + /*! \brief The body of realization. */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("producer", &producer); + v->Visit("bounds", &bounds); + v->Visit("condition", &condition); + v->Visit("body", &body); + } + + TVM_DLL static Stmt make(DataProducer producer, Region bounds, PrimExpr condition, Stmt body); + + bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const { + return equal(producer, other->producer) && equal(bounds, other->bounds) && + equal(condition, other->condition) && equal(body, other->body); } - TVM_DLL static Stmt make(FunctionRef func, int value_index, PrimExpr value, Array args); + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(producer); + hash_reduce(bounds); + hash_reduce(condition); + hash_reduce(body); + } - static constexpr const char* _type_key = "Provide"; - TVM_DECLARE_FINAL_OBJECT_INFO(ProvideNode, StmtNode); + static constexpr const char* _type_key = "ProducerRealize"; + TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode); }; /*! @@ -453,58 +501,6 @@ class FreeNode : public StmtNode { TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode); }; -/*! - * \brief Annotate the bounds where func need to be written and read in body. - * We will need to allocate space for the corresponding regions. - * - * \note Deprecated, move to BufferRealize in the future. - */ -class RealizeNode : public StmtNode { - public: - /*! \brief The function to be realized. */ - FunctionRef func; - /*! \brief The output value index if func's value is a tuple. */ - int value_index; - /*! \brief The data type of the array. */ - DataType dtype; - /*! \brief Bounds to be realized. */ - Region bounds; - /*! \brief Only realize if condition holds. */ - PrimExpr condition; - /*! \brief The body of realization. */ - Stmt body; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("func", &func); - v->Visit("value_index", &value_index); - v->Visit("dtype", &dtype); - v->Visit("bounds", &bounds); - v->Visit("condition", &condition); - v->Visit("body", &body); - } - - TVM_DLL static Stmt make(FunctionRef func, int value_index, DataType dtype, Region bounds, - PrimExpr condition, Stmt body); - - bool SEqualReduce(const RealizeNode* other, SEqualReducer equal) const { - return equal(func, other->func) && equal(value_index, other->value_index) && - equal(dtype, other->dtype) && equal(bounds, other->bounds) && - equal(condition, other->condition) && equal(body, other->body); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(func); - hash_reduce(value_index); - hash_reduce(dtype); - hash_reduce(bounds); - hash_reduce(condition); - hash_reduce(body); - } - - static constexpr const char* _type_key = "Realize"; - TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode); -}; - /*! * \brief The container of seq statement. * Represent a sequence of statements. @@ -777,23 +773,6 @@ class Prefetch : public Stmt { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode); }; -/*! - * \brief Auxiliary data structure used in IR Pass to indicate a tensor. - */ -struct TensorKey { - FunctionRef f; - int value_index; - - inline bool operator==(const TensorKey& other) const { - return f == other.f && value_index == other.value_index; - } - inline std::string GetName() const { - if (f->num_outputs() == 1) return f->func_name(); - std::ostringstream os; - os << f->func_name() << ".v" << value_index; - return os.str(); - } -}; /*! \brief namespace of possible attribute sin AttrStmt.attr_key */ namespace attr { @@ -933,17 +912,4 @@ TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type); } // namespace tir } // namespace tvm - -namespace std { -template <> -struct hash<::tvm::tir::TensorKey> { - std::size_t operator()(const ::tvm::tir::TensorKey& k) const { - size_t lhs = ::tvm::ObjectPtrHash()(k.f); - size_t rhs = static_cast(k.value_index); - lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); - return lhs; - } -}; -} // namespace std - #endif // TVM_TIR_STMT_H_ diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 052ea92ce41e..9a85b3852254 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -92,8 +92,8 @@ class StmtFunctor { virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const RealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const ProducerRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -114,8 +114,8 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(StoreNode); IR_STMT_FUNCTOR_DISPATCH(FreeNode); IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); - IR_STMT_FUNCTOR_DISPATCH(ProvideNode); - IR_STMT_FUNCTOR_DISPATCH(RealizeNode); + IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode); + IR_STMT_FUNCTOR_DISPATCH(ProducerRealizeNode); IR_STMT_FUNCTOR_DISPATCH(PrefetchNode); IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); IR_STMT_FUNCTOR_DISPATCH(EvaluateNode); @@ -156,8 +156,8 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const BufferRealizeNode* op) override; void VisitStmt_(const FreeNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; - void VisitStmt_(const ProvideNode* op) override; - void VisitStmt_(const RealizeNode* op) override; + void VisitStmt_(const ProducerStoreNode* op) override; + void VisitStmt_(const ProducerRealizeNode* op) override; void VisitStmt_(const PrefetchNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; @@ -248,8 +248,8 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const BufferRealizeNode* op) override; Stmt VisitStmt_(const FreeNode* op) override; Stmt VisitStmt_(const AssertStmtNode* op) override; - Stmt VisitStmt_(const ProvideNode* op) override; - Stmt VisitStmt_(const RealizeNode* op) override; + Stmt VisitStmt_(const ProducerStoreNode* op) override; + Stmt VisitStmt_(const ProducerRealizeNode* op) override; Stmt VisitStmt_(const PrefetchNode* op) override; Stmt VisitStmt_(const SeqStmtNode* op) override; Stmt VisitStmt_(const EvaluateNode* op) override; diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 75300ab405e9..913b4534eea6 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -212,7 +212,7 @@ def wrap_up_realize(self, node, body): _domain = [Range.make_by_min_extent(0, i) for i in _buf.shape] _dtype = _buf.dtype _true = tvm.runtime.convert(True) - body = tvm.tir.Realize(_buf.op, 0, _dtype, _domain, _true, body) + body = tvm.tir.ProducerRealize(_buf, _domain, _true, body) body = tvm.tir.AttrStmt(_buf.op, 'realize_scope', tvm.runtime.convert(_scope), body) for elem in to_pop: @@ -307,7 +307,7 @@ def visit_AugAssign(self, node): read = tvm.tir.ProducerLoad(buf, args) value = HybridParser._binop_maker[type(node.op)](read, rhs) - return tvm.tir.Provide(buf.op, 0, value, args) + return tvm.tir.ProducerStore(buf, value, args) def visit_Assign(self, node): @@ -358,13 +358,13 @@ def visit_Assign(self, node): lhs = self.visit(lhs_) if lhs is not None: buf, args = lhs - return tvm.tir.Provide(buf.op, 0, rhs, args) + return tvm.tir.ProducerStore(buf, rhs, args) return util.make_nop() lhs, args = self.visit(lhs) _internal_assert(isinstance(lhs, Tensor), \ "An array access's LHS is expected to be a expr.Call!") - res = tvm.tir.Provide(lhs.op, lhs.value_index, rhs, args) + res = tvm.tir.ProducerStore(lhs, rhs, args) return res diff --git a/python/tvm/te/hybrid/util.py b/python/tvm/te/hybrid/util.py index 35c59f11be70..810509b6e9cd 100644 --- a/python/tvm/te/hybrid/util.py +++ b/python/tvm/te/hybrid/util.py @@ -75,15 +75,15 @@ def replace_io(body, rmap): from tvm.tir import stmt_functor def replace(op): - if isinstance(op, _stmt.Provide) and op.func in rmap.keys(): - buf = rmap[op.func] - return _stmt.Provide(buf.op, op.value_index, op.value, op.args) + if isinstance(op, _stmt.ProducerStore) and op.producer.op in rmap.keys(): + buf = rmap[op.producer.op] + return _stmt.ProducerStore(buf, op.value, op.indices) if isinstance(op, _expr.ProducerLoad) and op.producer.op in rmap.keys(): buf = rmap[op.producer.op] return _expr.ProducerLoad(buf, op.indices) return None - return stmt_functor.ir_transform(body, None, replace, ['Provide', 'Call']) + return stmt_functor.ir_transform(body, None, replace, ['ProducerStore', 'ProducerLoad']) def _is_tvm_arg_types(args): diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 9aec24a77f6f..982b31cc2f54 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -28,8 +28,8 @@ from .expr import IterVar, Any from .stmt import Stmt, LetStmt, AssertStmt, For -from .stmt import BufferStore, BufferRealize, Store, Provide, Allocate, AttrStmt -from .stmt import Free, Realize, SeqStmt +from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt +from .stmt import Free, ProducerRealize, SeqStmt from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .function import PrimFunc diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index eee5b0b002e0..f4d84716a47d 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -184,26 +184,23 @@ def __init__(self, buffer, bounds, condition, body): @tvm._ffi.register_object -class Provide(Stmt): - """Provide node. +class ProducerStore(Stmt): + """ProducerStore node. Parameters ---------- - func : Operation - The operation to create the function. - - value_index : int - The output value index + producer : DataProducer + The data producer. value : PrimExpr The value to be stored. - args : list of Expr - The index arguments of the Provide. + indices : list of Expr + The index arguments of the store. """ - def __init__(self, func, value_index, value, args): + def __init__(self, producer, value, indices): self.__init_handle_by_constructor__( - _ffi_api.Provide, func, value_index, value, args) + _ffi_api.ProducerStore, producer, value, indices) @tvm._ffi.register_object @@ -276,19 +273,13 @@ def __init__(self, buffer_var): @tvm._ffi.register_object -class Realize(Stmt): - """Realize node. +class ProducerRealize(Stmt): + """ProducerRealize node. Parameters ---------- - func : Operation - The operation to create the function. - - value_index : int - The output value index - - dtype : str - The data type of the operation. + producer : DataProducer + The data producer. bounds : list of range The bound of realize @@ -300,15 +291,12 @@ class Realize(Stmt): The realize body """ def __init__(self, - func, - value_index, - dtype, + producer, bounds, condition, body): self.__init_handle_by_constructor__( - _ffi_api.Realize, func, value_index, dtype, - bounds, condition, body) + _ffi_api.ProducerRealize, producer, bounds, condition, body) @tvm._ffi.register_object diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 706252057c98..e9ec585de164 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -205,7 +205,7 @@ void CodeGenHybrid::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT void CodeGenHybrid::VisitExpr_(const ProducerLoadNode* op, std::ostream& os) { // NOLINT(*) auto tensor = Downcast(op->producer); - os << GetTensorID(tensor->op, tensor->value_index); + os << GetTensorID(tensor); os << "["; for (size_t i = 0; i < op->indices.size(); ++i) { if (i) os << ", "; @@ -300,7 +300,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { PrintStmt(op->body); indent_ -= tab_; } else if (op->attr_key == tir::attr::realize_scope) { - auto v = Downcast(op->node); + auto v = Downcast(op->node); alloc_storage_scope_[v] = op->value.as()->value; PrintStmt(op->body); } else { @@ -309,20 +309,21 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { } } -void CodeGenHybrid::VisitStmt_(const RealizeNode* op) { - CHECK(alloc_storage_scope_.count(op->func)); - if (!alloc_storage_scope_[op->func].empty()) { +void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) { + auto tensor = Downcast(op->producer); + CHECK(alloc_storage_scope_.count(tensor->op)); + if (!alloc_storage_scope_[tensor->op].empty()) { PrintIndent(); - stream << GetTensorID(op->func, op->value_index) << " = allocate(("; + stream << GetTensorID(tensor) << " = allocate(("; for (size_t i = 0; i < op->bounds.size(); ++i) { if (i) stream << ", "; stream << PrintExpr(op->bounds[i]->extent); } if (op->bounds.size() == 1) stream << ", "; stream << "), '"; - PrintType(op->dtype, stream); + PrintType(tensor->dtype, stream); stream << "', '"; - stream << alloc_storage_scope_[op->func] << "')\n"; + stream << alloc_storage_scope_[tensor->op] << "')\n"; } PrintStmt(op->body); } @@ -337,13 +338,14 @@ void CodeGenHybrid::VisitStmt_(const AssertStmtNode* op) { PrintStmt(op->body); } -void CodeGenHybrid::VisitStmt_(const ProvideNode* op) { +void CodeGenHybrid::VisitStmt_(const ProducerStoreNode* op) { + auto tensor = Downcast(op->producer); PrintIndent(); - stream << GetTensorID(op->func, op->value_index); + stream << GetTensorID(tensor); stream << "["; - for (size_t i = 0; i < op->args.size(); ++i) { + for (size_t i = 0; i < op->indices.size(); ++i) { if (i) stream << ", "; - PrintExpr(op->args[i], stream); + PrintExpr(op->indices[i], stream); } stream << "] = "; PrintExpr(op->value, stream); @@ -407,14 +409,14 @@ std::string CodeGenHybrid::GetVarID(const VarNode* v) { return id_map_[key] = GetUniqueName(v->name_hint); } -std::string CodeGenHybrid::GetTensorID(const FunctionRef& func, int value_index) { - auto key = std::make_pair(func.get(), value_index); +std::string CodeGenHybrid::GetTensorID(const Tensor& tensor) { + auto key = std::make_pair(tensor->op.get(), tensor->value_index); if (id_map_.count(key)) { return id_map_[key]; } - std::string name_hint = func->func_name(); - if (func->num_outputs() > 1) { - name_hint += "_v" + std::to_string(value_index); + std::string name_hint = tensor->op->func_name(); + if (tensor->op->num_outputs() > 1) { + name_hint += "_v" + std::to_string(tensor->value_index); } return id_map_[key] = GetUniqueName(name_hint); } @@ -472,7 +474,7 @@ void CodeGenHybrid::DumpStmt(const Stmt& stmt, const Array& inputs, for (size_t i = 0; i < inputs.size(); ++i) { if (i) stream << ", "; if (auto tensor = inputs[i].as()) { - stream << GetTensorID(tensor->op, tensor->value_index); + stream << GetTensorID(GetRef(tensor)); } else { auto var = inputs[i].as(); CHECK(var) << "Input should either be a tensor or a variable!"; @@ -483,7 +485,7 @@ void CodeGenHybrid::DumpStmt(const Stmt& stmt, const Array& inputs, indent_ += tab_; for (size_t i = 0; i < outputs.size(); ++i) { PrintIndent(); - stream << GetTensorID(outputs[i]->op, outputs[i]->value_index) << " = output_tensor(("; + stream << GetTensorID(outputs[i]) << " = output_tensor(("; for (size_t j = 0; j < outputs[i]->shape.size(); ++j) { if (j) stream << ", "; PrintExpr(outputs[i]->shape[j], stream); @@ -496,7 +498,7 @@ void CodeGenHybrid::DumpStmt(const Stmt& stmt, const Array& inputs, stream << "return "; for (size_t i = 0; i < outputs.size(); ++i) { if (i) stream << ", "; - stream << GetTensorID(outputs[i]->op, outputs[i]->value_index); + stream << GetTensorID(outputs[i]); } stream << "\n"; } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 5dd91c8fb65a..b01ca2763e28 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -25,6 +25,7 @@ #define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ #include +#include #include #include #include @@ -119,11 +120,11 @@ class CodeGenHybrid : public ExprFunctor, // statment void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const StoreNode* op) override; - void VisitStmt_(const ProvideNode* op) override; + void VisitStmt_(const ProducerStoreNode* op) override; void VisitStmt_(const ForNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const AllocateNode* op) override; - void VisitStmt_(const RealizeNode* op) override; + void VisitStmt_(const ProducerRealizeNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; @@ -164,12 +165,11 @@ class CodeGenHybrid : public ExprFunctor, std::string GetVarID(const VarNode* v); /*! * \brief Get or allocate the ID for the given tensor. - * \param func The tensor to allocate a name. - * \param value_index The value index of the given tensor. + * \param tensor The tensor to allocate a name. */ - std::string GetTensorID(const FunctionRef& func, int value_index); + std::string GetTensorID(const Tensor& tensor); /*! \brief the storage scope of allocation */ - std::map alloc_storage_scope_; + std::map alloc_storage_scope_; }; } // namespace contrib diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 41bf49f79aac..25715f439322 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -266,8 +266,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, Stmt realize = body; for (int i = this->num_outputs(); i > 0; --i) { Tensor t = stage->op.output(i - 1); - realize = - tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize); + realize = tir::ProducerRealizeNode::make(t, bounds, const_true(), realize); // alignment requirement, only useful for compute for (size_t i = 0; i < num_schedulable_dims(); ++i) { auto it = stage->iter_var_attrs.find(this->axis[i]); @@ -312,8 +311,8 @@ void MakeReduction(const ComputeOpNode* op, const Array& tensors, Stmt* Array update_value = (*combiner)(lhs, reduce->source); for (size_t i = 0; i < size; ++i) { Tensor t = tensors[i]; - inits.emplace_back(ProvideNode::make(t->op, t->value_index, init_value[i], args)); - provides.emplace_back(ProvideNode::make(t->op, t->value_index, update_value[i], args)); + inits.emplace_back(ProducerStoreNode::make(t, init_value[i], args)); + provides.emplace_back(ProducerStoreNode::make(t, update_value[i], args)); } *init = SeqStmt::Flatten(inits); *provide = SeqStmt::Flatten(provides); @@ -328,7 +327,7 @@ Stmt MakeProvide(const ComputeOpNode* op, const Tensor& t) { for (IterVar iv : op->axis) { args.push_back(iv->var); } - return ProvideNode::make(t->op, t->value_index, op->body[t->value_index], args); + return ProducerStoreNode::make(t, op->body[t->value_index], args); } Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage, diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index cdcb124d338b..e1ef617cea0d 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -210,8 +210,8 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, std::vector assigns(size); for (size_t idx = 0; idx < size; ++idx) { DataType t = reduces[idx]->dtype; - assigns[idx] = ProvideNode::make( - stage->op, idx, LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args); + assigns[idx] = ProducerStoreNode::make( + stage->op.output(idx), LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args); } Stmt assign_body = SeqStmt::Flatten(assigns); assign_body = MergeNest(MakeIfNest(output_preds), assign_body); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 90f3f2e0a4c6..25a596f60140 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -128,8 +128,7 @@ Stmt ExternOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = - tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealizeNode::make(t, bounds, const_true(), realize_body); } return realize_body; } diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 55996a5afe77..d0ffcfcce2e3 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -152,8 +152,7 @@ Stmt HybridOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = - tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealizeNode::make(t, bounds, const_true(), realize_body); } return realize_body; } @@ -460,12 +459,11 @@ class ProviderReplacer : public tir::StmtMutator { public: explicit ProviderReplacer(const std::unordered_map& vmap) : vmap_(vmap) {} - Stmt VisitStmt_(const tir::ProvideNode* op) final { - Tensor t = Downcast(op->func).output(op->value_index); + Stmt VisitStmt_(const tir::ProducerStoreNode* op) final { + Tensor t = Downcast(op->producer); auto it = vmap_.find(t); if (it != vmap_.end()) { - Stmt ret = - tir::ProvideNode::make(it->second->op, it->second->value_index, op->value, op->args); + Stmt ret = tir::ProducerStoreNode::make(it->second, op->value, op->indices); found = true; return this->VisitStmt(ret); } diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 01ef0cd82ce4..4e6c8247e263 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -246,7 +246,7 @@ Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_mapspatial_axis_[sp_idx]; bounds.push_back(dom_map.at(sp_ax)); } - ret = tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), ret); + ret = tir::ProducerRealizeNode::make(t, bounds, const_true(), ret); } return ret; } diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 10f1ed3326ab..228ce45a7831 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -217,13 +217,12 @@ class SchedulePostProc : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } - Stmt VisitStmt_(const RealizeNode* op) final { - TensorKey key{op->func, op->value_index}; + Stmt VisitStmt_(const ProducerRealizeNode* op) final { + auto key = Downcast(op->producer); auto it = replace_realize_.find(key); if (it != replace_realize_.end()) { if (it->second.defined()) { - Stmt ret = RealizeNode::make(it->second->op, it->second->value_index, op->dtype, op->bounds, - op->condition, op->body); + Stmt ret = ProducerRealizeNode::make(it->second, op->bounds, op->condition, op->body); return this->VisitStmt(ret); } else { return this->VisitStmt(op->body); @@ -233,12 +232,12 @@ class SchedulePostProc : public StmtExprMutator { } } - Stmt VisitStmt_(const ProvideNode* op) final { - TensorKey key{op->func, op->value_index}; + Stmt VisitStmt_(const ProducerStoreNode* op) final { + auto key = Downcast(op->producer); auto it = replace_buffer_.find(key); if (it != replace_buffer_.end()) { const Tensor& dst = it->second; - Stmt ret = ProvideNode::make(dst->op, dst->value_index, op->value, op->args); + Stmt ret = ProducerStoreNode::make(dst, op->value, op->indices); return this->VisitStmt(ret); } else { return StmtExprMutator::VisitStmt_(op); @@ -250,9 +249,7 @@ class SchedulePostProc : public StmtExprMutator { op = expr.as(); CHECK(op != nullptr); - auto tensor = Downcast(op->producer); - TensorKey key{tensor->op, tensor->value_index}; - + auto key = Downcast(op->producer); auto it = replace_buffer_.find(key); if (it != replace_buffer_.end()) { const Tensor& dst = it->second; @@ -304,9 +301,8 @@ class SchedulePostProc : public StmtExprMutator { private: void AddReplace(Tensor src, Tensor dst, Tensor repl_realize = Tensor(), Operation repl_op = Operation()) { - TensorKey key{src->op, src->value_index}; - replace_buffer_[key] = dst; - replace_realize_[key] = repl_realize; + replace_buffer_[src] = dst; + replace_realize_[src] = repl_realize; replace_op_[src->op.get()] = repl_op; } // The thread extent scope. @@ -314,9 +310,9 @@ class SchedulePostProc : public StmtExprMutator { // The scan value std::unordered_map var_value_; // buffer replacement - std::unordered_map replace_buffer_; + std::unordered_map replace_buffer_; // buffere realization to be replaced - std::unordered_map replace_realize_; + std::unordered_map replace_realize_; // replace producer consumer. std::unordered_map replace_op_; // integer analyzer diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index e0d58827e1d8..da45e8ae3dfe 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -53,11 +53,6 @@ struct Tile { int k{-1}; }; -TensorKey TensorKeyFromProducer(DataProducer producer) { - auto tensor = Downcast(producer); - return TensorKey{tensor->op, tensor->value_index}; -} - std::string simplify_name(std::string input) { auto pos = input.find("."); if (pos != std::string::npos) { @@ -88,7 +83,7 @@ class MMAMatcher : public StmtVisitor { bi.name = kv.second->name; bi.dtype = kv.second->dtype; bi.external = true; - buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi; + buf_map_[kv.first] = bi; } } @@ -104,9 +99,9 @@ class MMAMatcher : public StmtVisitor { } } - void VisitStmt_(const ProvideNode* op) final { + void VisitStmt_(const ProducerStoreNode* op) final { StmtVisitor::VisitStmt_(op); - auto it = buf_map_.find(TensorKey{op->func, op->value_index}); + auto it = buf_map_.find(Downcast(op->producer)); if (it == buf_map_.end()) { return; } @@ -119,8 +114,8 @@ class MMAMatcher : public StmtVisitor { } } - void VisitStmt_(const RealizeNode* op) final { - TensorKey key{op->func, op->value_index}; + void VisitStmt_(const ProducerRealizeNode* op) final { + auto key = Downcast(op->producer); if (buf_map_.count(key)) { if (!buf_map_.at(key).external) { return; @@ -128,8 +123,8 @@ class MMAMatcher : public StmtVisitor { this->VisitStmt(op->body); } else { BufferInfo bi; - bi.name = key.GetName(); - bi.dtype = op->dtype; + bi.name = key->GetNameHint(); + bi.dtype = key->dtype; buf_map_[key] = bi; this->VisitStmt(op->body); buf_map_[key].released = true; @@ -167,7 +162,7 @@ class MMAMatcher : public StmtVisitor { if (strkey != "local") { return false; } - auto it1 = buf_map_.find(TensorKey{tensor->op, tensor->value_index}); + auto it1 = buf_map_.find(tensor); if (it1 == buf_map_.end()) { return false; } @@ -179,7 +174,7 @@ class MMAMatcher : public StmtVisitor { } // Do the pattern matching - bool mma_sync_match_(const ProvideNode* op, BufferInfo store_buffer) { + bool mma_sync_match_(const ProducerStoreNode* op, BufferInfo store_buffer) { auto* add = op->value.as(); if (add == nullptr) { return false; @@ -227,9 +222,9 @@ class MMAMatcher : public StmtVisitor { return true; } - std::unordered_map buf_map_; + std::unordered_map buf_map_; std::unordered_map storage_scope_; - std::unordered_map> mma_sync_; + std::unordered_map> mma_sync_; std::unordered_map buf_name_; std::unordered_set frag_reg_; bool matched_{false}; @@ -365,7 +360,7 @@ class ScheduleAnalyser { private: std::unordered_map matrix_abc_; std::unordered_map matrix_major_; - std::unordered_map> mma_sync_; + std::unordered_map> mma_sync_; std::unordered_map buf_name_; }; @@ -403,7 +398,7 @@ class BufferAnalyser : public StmtExprVisitor { bi.strides = kv.second->strides; bi.shape = kv.second->shape; bi.external = true; - buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi; + buf_map_[kv.first] = bi; } } @@ -421,7 +416,7 @@ class BufferAnalyser : public StmtExprVisitor { te::Tensor tensor = Downcast(op->node); const CallNode* tuple = op->value.as(); CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); - auto& vinfo = dim_align_[TensorKey{tensor->op, tensor->value_index}]; + auto& vinfo = dim_align_[tensor]; size_t dim = tuple->args[0].as()->value; if (dim >= vinfo.size()) { vinfo.resize(dim + 1); @@ -434,15 +429,15 @@ class BufferAnalyser : public StmtExprVisitor { } } - void VisitStmt_(const ProvideNode* op) final { + void VisitStmt_(const ProducerStoreNode* op) final { StmtExprVisitor::VisitStmt_(op); - TensorKey key{op->func, op->value_index}; + auto key = Downcast(op->producer); auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key.f; + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key->GetNameHint(); const BufferInfo& bi = it->second; CHECK(!bi.released) << "Read a buffer that is already out of scope"; - if (matrix_abc_.count(key.GetName())) { + if (matrix_abc_.count(key->GetNameHint())) { if (bi.shape.size() < 2) { invalid_ = true; return; @@ -469,19 +464,19 @@ class BufferAnalyser : public StmtExprVisitor { } strides.push_back(make_const(DataType::Int(32), 1)); } - strides_.insert(std::make_pair(key.GetName(), strides)); + strides_.insert(std::make_pair(key->GetNameHint(), strides)); if (frag_reg_.count(bi.name)) { - PrimExpr dst = ProducerLoad(Downcast(op->func).output(0), op->args); + PrimExpr dst = ProducerLoad(op->producer, op->indices); frag_load_.insert(std::make_pair(op, dst)); - auto rel_index = bi.RelIndex(op->args); - if (op->args.size() < 2) { + auto rel_index = bi.RelIndex(op->indices); + if (op->indices.size() < 2) { invalid_ = true; return; } std::vector tile_size; - for (auto i = op->args.size() - 1; i + 2 >= op->args.size(); --i) { + for (auto i = op->indices.size() - 1; i + 2 >= op->indices.size(); --i) { index_visitor.scaling_factor_ = 16; if (const IntImmNode* shape = bi.shape[i].as()) { tile_size.push_back(shape->value); @@ -530,7 +525,7 @@ class BufferAnalyser : public StmtExprVisitor { const ProducerLoadNode* value = op->value.as(); // TODO(tvm-team): string matching is dangerous, consider other means. if (value != nullptr && frag_reg_.count(value->producer->GetNameHint())) { - PrimExpr dst = ProducerLoad(Downcast(op->func).output(0), op->args); + PrimExpr dst = ProducerLoad(op->producer, op->indices); frag_store_.insert(std::make_pair(op, dst)); } } @@ -539,9 +534,8 @@ class BufferAnalyser : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); auto tensor = Downcast(op->producer); - TensorKey key{tensor->op, tensor->value_index}; - auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key.f; + auto it = buf_map_.find(tensor); + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << tensor->GetNameHint(); const BufferInfo& bi = it->second; CHECK(!bi.released) << "Read a buffer that is already out of scope"; @@ -572,7 +566,7 @@ class BufferAnalyser : public StmtExprVisitor { } strides.push_back(make_const(DataType::Int(32), 1)); } - strides_.insert(std::make_pair(key.GetName(), strides)); + strides_.insert(std::make_pair(tensor->GetNameHint(), strides)); if (!frag_reg_.count(bi.name)) { return; @@ -594,8 +588,8 @@ class BufferAnalyser : public StmtExprVisitor { } } - void VisitStmt_(const RealizeNode* op) final { - TensorKey key{op->func, op->value_index}; + void VisitStmt_(const ProducerRealizeNode* op) final { + auto key = Downcast(op->producer); if (buf_map_.count(key)) { CHECK(buf_map_.at(key).external); this->VisitStmt(op->body); @@ -629,8 +623,8 @@ class BufferAnalyser : public StmtExprVisitor { strides = Array(rstrides.rbegin(), rstrides.rend()); } - bi.name = key.GetName(); - bi.dtype = op->dtype; + bi.name = key->GetNameHint(); + bi.dtype = key->dtype; bi.strides = strides; bi.shape = shape; @@ -726,15 +720,15 @@ class BufferAnalyser : public StmtExprVisitor { return false; } - std::unordered_map buf_map_; - std::unordered_map> dim_align_; + std::unordered_map buf_map_; + std::unordered_map> dim_align_; std::unordered_map storage_scope_; std::unordered_map matrix_abc_; std::unordered_map matrix_major_; std::unordered_set frag_reg_; std::unordered_map> strides_; - std::unordered_map frag_load_; - std::unordered_map frag_store_; + std::unordered_map frag_load_; + std::unordered_map frag_store_; std::unordered_map thread_extent_; IndexVisitor index_visitor; Tile warp_tile_; @@ -787,30 +781,29 @@ class TensorCoreIRMutator : public StmtExprMutator { warp_tile_(buffer_analyser.warp_tile_), warp_threads_y_(buffer_analyser.warp_threads_y_) {} - Stmt VisitStmt_(const RealizeNode* op) final { - TensorKey key{op->func, op->value_index}; + Stmt VisitStmt_(const ProducerRealizeNode* op) final { + auto key = Downcast(op->producer); bounds_[key] = op->bounds; Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); if (op != nullptr) { - if (!frag_reg_.count(key.GetName())) { + if (!frag_reg_.count(key->GetNameHint())) { return stmt; } - auto new_extents = get_tile_size_(simplify_name(key.GetName())); + auto new_extents = get_tile_size_(simplify_name(key->GetNameHint())); Region new_bounds; for (size_t i = 0; i < op->bounds.size() - 2; ++i) { new_bounds.push_back(op->bounds[i]); } - CHECK_GE(op->bounds.size(), 2) << "Less than 2 dimensions for matrix " << key.GetName(); + CHECK_GE(op->bounds.size(), 2) << "Less than 2 dimensions for matrix " << key->GetNameHint(); new_bounds.push_back( Range::make_by_min_extent(op->bounds[op->bounds.size() - 2]->min, new_extents[0])); new_bounds.push_back( Range::make_by_min_extent(op->bounds[op->bounds.size() - 1]->min, new_extents[1])); - return RealizeNode::make(op->func, op->value_index, op->dtype, new_bounds, op->condition, - op->body); + return ProducerRealizeNode::make(op->producer, new_bounds, op->condition, op->body); } return stmt; } @@ -834,7 +827,7 @@ class TensorCoreIRMutator : public StmtExprMutator { return stmt; } - Stmt VisitStmt_(const ProvideNode* op) final { + Stmt VisitStmt_(const ProducerStoreNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); auto it = mma_sync_.find(op); if (it != mma_sync_.end()) { @@ -869,17 +862,14 @@ class TensorCoreIRMutator : public StmtExprMutator { }; auto call_add_c = [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer& buffer) { - return add_buffer_bind_scope_(cc, buffer_node_c, TensorKeyFromProducer(cc->producer), - mma_sync_call); + return add_buffer_bind_scope_(cc, buffer_node_c, mma_sync_call); }; auto call_add_b = [this, &cb, &buffer_node_b, &call_add_c](const Buffer& buffer) { - return add_buffer_bind_scope_(cb, buffer_node_b, TensorKeyFromProducer(cb->producer), - call_add_c); + return add_buffer_bind_scope_(cb, buffer_node_b, call_add_c); }; - return add_buffer_bind_scope_(ca, buffer_node_a, TensorKeyFromProducer(ca->producer), - call_add_b); + return add_buffer_bind_scope_(ca, buffer_node_a, call_add_b); } auto it2 = frag_load_.find(op); @@ -896,8 +886,7 @@ class TensorCoreIRMutator : public StmtExprMutator { }; ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(pload, buffer_node, TensorKeyFromProducer(pload->producer), - fill_fragment_call); + return add_buffer_bind_scope_(pload, buffer_node, fill_fragment_call); } const CallNode* value = op->value.as(); @@ -937,15 +926,13 @@ class TensorCoreIRMutator : public StmtExprMutator { }; ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(pload, buffer_node, TensorKeyFromProducer(pload->producer), - load_matrix_call); + return add_buffer_bind_scope_(pload, buffer_node, load_matrix_call); } auto it3 = frag_store_.find(op); if (it3 != frag_store_.end()) { - TensorKey key{op->func, op->value_index}; - auto it = strides_.find(key.GetName()); - CHECK(it != strides_.end()) << "Cannot find stride for " << key.GetName(); + auto it = strides_.find(op->producer->GetNameHint()); + CHECK(it != strides_.end()) << "Cannot find stride for " << op->producer->GetNameHint(); auto strides = it->second; CHECK_GE(strides.size(), 2); PrimExpr stride = strides[strides.size() - 2]; @@ -968,8 +955,7 @@ class TensorCoreIRMutator : public StmtExprMutator { }; ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(pload, buffer_node, TensorKeyFromProducer(pload->producer), - store_matrix_call); + return add_buffer_bind_scope_(pload, buffer_node, store_matrix_call); } return stmt; @@ -1028,10 +1014,10 @@ class TensorCoreIRMutator : public StmtExprMutator { } Stmt add_buffer_bind_scope_(const ProducerLoadNode* pload, - const ObjectPtr& buffer_node, const TensorKey& key, + const ObjectPtr& buffer_node, const std::function& call_back) { auto tensor = Downcast(pload->producer); - auto it = bounds_.find(key); + auto it = bounds_.find(tensor); CHECK(it != bounds_.end()); Array min_bound; for (auto i : it->second) { @@ -1077,13 +1063,6 @@ class TensorCoreIRMutator : public StmtExprMutator { buffer_node->offset_factor = 1; Buffer buffer(buffer_node); - ObjectPtr tensor_node = make_object(); - tensor_node->value_index = key.value_index; - tensor_node->op = Downcast(key.f); - tensor_node->shape = shape; - tensor_node->dtype = tensor->dtype; - Tensor tensor_bind(tensor_node); - Array args; for (size_t i = 0; i < pload->indices.size(); ++i) { args.push_back(pload->indices[i]); @@ -1091,19 +1070,19 @@ class TensorCoreIRMutator : public StmtExprMutator { } auto tuple = CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic); - Array node = {buffer, tensor_bind}; + Array node = {buffer, tensor}; return AttrStmtNode::make(node, "buffer_bind_scope", tuple, call_back(buffer)); } std::unordered_map matrix_abc_; std::unordered_map matrix_major_; - std::unordered_map> mma_sync_; + std::unordered_map> mma_sync_; std::unordered_map> strides_; std::unordered_set frag_reg_; std::unordered_map loop_scaling_; - std::unordered_map frag_load_; - std::unordered_map frag_store_; - std::unordered_map bounds_; + std::unordered_map frag_load_; + std::unordered_map frag_store_; + std::unordered_map bounds_; arith::Analyzer analyzer_; Tile warp_tile_; int warp_threads_y_{-1}; diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 96df24dc6c7a..74f4a2cf36e1 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -94,24 +94,24 @@ class TensorToBufferMapper : public StmtExprMutator { } } - Stmt VisitStmt_(const RealizeNode* op) final { - Tensor tensor = Downcast(op->func).output(op->value_index); + Stmt VisitStmt_(const ProducerRealizeNode* op) final { + Tensor tensor = Downcast(op->producer); Buffer buffer = GetOrAllocBuffer(tensor); auto ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); + op = ret.as(); return BufferRealize(buffer, op->bounds, op->condition, op->body); } - Stmt VisitStmt_(const ProvideNode* op) final { - Tensor tensor = Downcast(op->func).output(op->value_index); + Stmt VisitStmt_(const ProducerStoreNode* op) final { + Tensor tensor = Downcast(op->producer); Buffer buffer = GetBuffer(tensor); auto ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); + op = ret.as(); - return BufferStore(buffer, op->value, op->args); + return BufferStore(buffer, op->value, op->indices); } PrimExpr VisitExpr_(const ProducerLoadNode* op) final { diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 4c58fd63c69f..094abc331bde 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -123,24 +123,15 @@ TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) { } }); -Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array args) { - CHECK(value_index >= 0 && value_index < func->num_outputs()) - << "value index output function return value bound"; - CHECK(value.defined()) << "Provide of undefined value\n"; - - for (size_t i = 0; i < args.size(); ++i) { - CHECK(args[i].defined()) << "Provide to undefined location\n"; - } - - ObjectPtr node = make_object(); - node->func = std::move(func); - node->value_index = value_index; +Stmt ProducerStoreNode::make(DataProducer producer, PrimExpr value, Array indices) { + ObjectPtr node = make_object(); + node->producer = std::move(producer); node->value = std::move(value); - node->args = std::move(args); + node->indices = std::move(indices); return Stmt(node); } -TVM_REGISTER_GLOBAL("tir.Provide").set_body_typed(ProvideNode::make); +TVM_REGISTER_GLOBAL("tir.ProducerStore").set_body_typed(ProducerStoreNode::make); Stmt AllocateNode::make(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, Stmt body) { @@ -161,6 +152,28 @@ Stmt AllocateNode::make(Var buffer_var, DataType dtype, Array extents, return Stmt(node); } +Stmt ProducerRealizeNode::make(DataProducer producer, Region bounds, PrimExpr condition, + Stmt body) { + for (size_t i = 0; i < bounds.size(); ++i) { + CHECK(bounds[i]->min.defined()); + CHECK(bounds[i]->extent.defined()); + CHECK(bounds[i]->min.dtype().is_scalar()); + CHECK(bounds[i]->extent.dtype().is_scalar()); + } + CHECK(body.defined()); + CHECK(condition.defined()); + CHECK(condition.dtype().is_bool()); + + ObjectPtr node = make_object(); + node->producer = std::move(producer); + node->bounds = std::move(bounds); + node->condition = std::move(condition); + node->body = std::move(body); + return Stmt(node); +} + +TVM_REGISTER_GLOBAL("tir.ProducerRealize").set_body_typed(ProducerRealizeNode::make); + // overloaded, needs special handling // has default args TVM_REGISTER_GLOBAL("tir.Allocate") @@ -192,30 +205,6 @@ Stmt FreeNode::make(Var buffer_var) { TVM_REGISTER_GLOBAL("tir.Free").set_body_typed(FreeNode::make); -Stmt RealizeNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds, - PrimExpr condition, Stmt body) { - for (size_t i = 0; i < bounds.size(); ++i) { - CHECK(bounds[i]->min.defined()); - CHECK(bounds[i]->extent.defined()); - CHECK(bounds[i]->min.dtype().is_scalar()); - CHECK(bounds[i]->extent.dtype().is_scalar()); - } - CHECK(body.defined()); - CHECK(condition.defined()); - CHECK(condition.dtype().is_bool()); - - ObjectPtr node = make_object(); - node->func = std::move(func); - node->value_index = value_index; - node->dtype = dtype; - node->bounds = std::move(bounds); - node->condition = std::move(condition); - node->body = std::move(body); - return Stmt(node); -} - -TVM_REGISTER_GLOBAL("tir.Realize").set_body_typed(RealizeNode::make); - Prefetch::Prefetch(Buffer buffer, Array bounds) { data_ = make_object(buffer, bounds); } @@ -372,18 +361,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); p->PrintIndent(); - p->stream << op->func->func_name() << "("; - for (size_t i = 0; i < op->args.size(); ++i) { - p->Print(op->args[i]); - if (i < op->args.size() - 1) p->stream << ", "; - } - p->stream << ")"; - if (op->func->num_outputs() != 1) { - p->stream << ".value[" << op->value_index << "]"; + p->stream << op->producer->GetNameHint() << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) p->stream << ", "; } + p->stream << "]"; p->stream << " ="; p->Print(op->value); p->stream << '\n'; @@ -459,10 +445,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); p->PrintIndent(); - p->stream << "realize " << op->func->func_name() << "("; + p->stream << "producer_realize " << op->producer->GetNameHint() << "("; for (size_t i = 0; i < op->bounds.size(); ++i) { p->stream << "["; p->Print(op->bounds[i]->min); @@ -472,9 +458,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) if (i < op->bounds.size() - 1) p->stream << ", "; } p->stream << ")"; - if (op->func->num_outputs() != 1) { - p->stream << ".value[" << op->value_index << "]"; - } if (!is_one(op->condition)) { p->stream << " if "; p->Print(op->condition); @@ -580,10 +563,10 @@ TVM_REGISTER_NODE_TYPE(LetStmtNode); TVM_REGISTER_NODE_TYPE(AssertStmtNode); TVM_REGISTER_NODE_TYPE(ForNode); TVM_REGISTER_NODE_TYPE(StoreNode); -TVM_REGISTER_NODE_TYPE(ProvideNode); +TVM_REGISTER_NODE_TYPE(ProducerStoreNode); TVM_REGISTER_NODE_TYPE(AllocateNode); TVM_REGISTER_NODE_TYPE(FreeNode); -TVM_REGISTER_NODE_TYPE(RealizeNode); +TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); TVM_REGISTER_NODE_TYPE(SeqStmtNode); TVM_REGISTER_NODE_TYPE(IfThenElseNode); TVM_REGISTER_NODE_TYPE(EvaluateNode); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 13d0b098dd4a..6d0a60fed13a 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -87,12 +87,12 @@ void StmtVisitor::VisitStmt_(const AssertStmtNode* op) { this->VisitStmt(op->body); } -void StmtVisitor::VisitStmt_(const ProvideNode* op) { - VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); }); +void StmtVisitor::VisitStmt_(const ProducerStoreNode* op) { + VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitExpr(op->value); } -void StmtVisitor::VisitStmt_(const RealizeNode* op) { +void StmtVisitor::VisitStmt_(const ProducerRealizeNode* op) { VisitArray(op->bounds, [this](const Range& r) { this->VisitExpr(r->min); this->VisitExpr(r->extent); @@ -261,20 +261,20 @@ Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { } } -Stmt StmtMutator::VisitStmt_(const ProvideNode* op) { - Array args = Internal::Mutate(this, op->args); +Stmt StmtMutator::VisitStmt_(const ProducerStoreNode* op) { + Array indices = Internal::Mutate(this, op->indices); PrimExpr value = this->VisitExpr(op->value); - if (args.same_as(op->args) && value.same_as(op->value)) { + if (indices.same_as(op->indices) && value.same_as(op->value)) { return GetRef(op); } else { auto n = CopyOnWrite(op); - n->args = std::move(args); + n->indices = std::move(indices); n->value = std::move(value); return Stmt(n); } } -Stmt StmtMutator::VisitStmt_(const RealizeNode* op) { +Stmt StmtMutator::VisitStmt_(const ProducerRealizeNode* op) { Region bounds = Internal::Mutate(this, op->bounds); Stmt body = this->VisitStmt(op->body); PrimExpr condition = this->VisitExpr(op->condition); diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index e2a027d2f1f4..6528e97c3071 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -462,7 +462,7 @@ class VirtualThreadInjector : public StmtMutator { } } - Stmt VisitStmt_(const ProvideNode* op) final { + Stmt VisitStmt_(const ProducerStoreNode* op) final { LOG(FATAL) << "Need to call StorageFlatten first"; return GetRef(op); } diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index 15a7e8638e5c..0463d448df86 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -84,9 +84,9 @@ class NoOpRemover : public StmtMutator { return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt; } - Stmt VisitStmt_(const RealizeNode* op) final { + Stmt VisitStmt_(const ProducerRealizeNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); return is_no_op(op->body) ? op->body : stmt; } Stmt VisitStmt_(const EvaluateNode* op) final { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 447a1e3c1fb1..8f2a95b1f4c0 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -340,13 +340,13 @@ class StorageFlattener : public StmtExprMutator { return PrimExpr(); } - Stmt VisitStmt_(const ProvideNode* op) final { + Stmt VisitStmt_(const ProducerStoreNode* op) final { LOG(FATAL) << "Cannot handle Provide " << " please run SchedulePostProcToPrimFunc first"; return Stmt(); } - Stmt VisitStmt_(const RealizeNode* op) final { + Stmt VisitStmt_(const ProducerRealizeNode* op) final { LOG(FATAL) << "Cannot handle Realize " << " please run SchedulePostProcToPrimFunc first"; return Stmt(); diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 61ec5724d49e..c0a546d745fd 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -278,17 +278,9 @@ class Vectorizer : public StmtExprMutator { } } } - // Provide - Stmt VisitStmt_(const ProvideNode* op) final { - PrimExpr new_value = this->VisitExpr(op->value); - int lane = new_value.dtype().lanes(); - Array new_args = MutateArray(op->args, &lane); - if (op->args.same_as(new_args) && op->value.same_as(new_value)) { - return GetRef(op); - } else { - new_value = BroadcastTo(new_value, lane); - return ProvideNode::make(op->func, op->value_index, new_value, new_args); - } + Stmt VisitStmt_(const ProducerStoreNode* op) final { + LOG(FATAL) << "ProducerProvide is cannot appear in a TIR PrimFunc"; + return Stmt(); } // Store Stmt VisitStmt_(const StoreNode* op) final { diff --git a/tests/lint/git-clang-format.sh b/tests/lint/git-clang-format.sh index dc4450ab6c69..5d2c4f3c468d 100755 --- a/tests/lint/git-clang-format.sh +++ b/tests/lint/git-clang-format.sh @@ -19,13 +19,6 @@ set -e set -u set -o pipefail -if [[ "$1" == "-i" ]]; then - INPLACE_FORMAT=1 - shift 1 -else - INPLACE_FORMAT=0 -fi - if [[ "$#" -lt 1 ]]; then echo "Usage: tests/lint/git-clang-format.sh [-i] " echo "" @@ -37,6 +30,13 @@ if [[ "$#" -lt 1 ]]; then exit 1 fi +if [[ "$1" == "-i" ]]; then + INPLACE_FORMAT=1 + shift 1 +else + INPLACE_FORMAT=0 +fi + cleanup() { rm -rf /tmp/$$.clang-format.txt diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index c6f28ad9a4b5..8ab65f129cc5 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -131,11 +131,11 @@ def test_outer_product(): assert isinstance(jbody.message, tvm.tir.StringImm) assert jbody.message.value == "index out of range!" jbody = jblock[1] - assert isinstance(jbody, tvm.tir.Provide) - assert jbody.func.name == 'c' - assert len(jbody.args) == 2 - assert jbody.args[0].name == 'i' - assert jbody.args[1].name == 'j' + assert isinstance(jbody, tvm.tir.ProducerStore) + assert jbody.producer.op.name == 'c' + assert len(jbody.indices) == 2 + assert jbody.indices[0].name == 'i' + assert jbody.indices[1].name == 'j' assert isinstance(jbody.value, tvm.tir.Mul) mul = jbody.value assert isinstance(mul.a, tvm.tir.ProducerLoad) @@ -187,26 +187,26 @@ def fanout(n, a): ibody = ir.body assert isinstance(ibody, tvm.tir.AttrStmt) abody = ibody.body - assert isinstance(abody, tvm.tir.Realize) + assert isinstance(abody, tvm.tir.ProducerRealize) assert abody.bounds[0].min.value == 0 assert abody.bounds[0].extent.value == 1 - assert abody.func.name == 'sigma' + assert abody.producer.op.name == 'sigma' #Check i loop body rbody = abody.body - assert isinstance(rbody[0], tvm.tir.Provide) - assert rbody[0].func.name == 'sigma' - assert len(rbody[0].args) == 1 - assert rbody[0].args[0].value == 0 + assert isinstance(rbody[0], tvm.tir.ProducerStore) + assert rbody[0].producer.op.name == 'sigma' + assert len(rbody[0].indices) == 1 + assert rbody[0].indices[0].value == 0 #Check fanout loop jloop = rbody[1] assert jloop.loop_var.name == 'j' assert jloop.min.value == 0 assert jloop.extent.value == 3 jbody = jloop.body - assert isinstance(jbody, tvm.tir.Provide) - assert len(jbody.args) == 1 - assert jbody.args[0].value == 0 - assert jbody.func.name == 'sigma' + assert isinstance(jbody, tvm.tir.ProducerStore) + assert len(jbody.indices) == 1 + assert jbody.indices[0].value == 0 + assert jbody.producer.op.name == 'sigma' assert isinstance(jbody.value, tvm.tir.Add) value = jbody.value assert isinstance(value.a, tvm.tir.ProducerLoad) @@ -217,9 +217,9 @@ def fanout(n, a): assert len(value.b.indices) == 1 assert tvm.ir.structural_equal(value.b.indices[0], ir.loop_var + jloop.loop_var) divide= rbody[2] - assert isinstance(divide, tvm.tir.Provide) - assert len(divide.args) == 1 - assert divide.args[0].value == 0 + assert isinstance(divide, tvm.tir.ProducerStore) + assert len(divide.indices) == 1 + assert divide.indices[0].value == 0 value = divide.value assert isinstance(value, tvm.tir.Mul) assert value.a.producer.name == 'sigma' @@ -227,8 +227,8 @@ def fanout(n, a): assert value.a.indices[0].value == 0 assert abs(value.b.value - (1 / 3.0)) < 1e-5 write = rbody[3] - assert isinstance(write, tvm.tir.Provide) - assert write.func.name == 'b' + assert isinstance(write, tvm.tir.ProducerStore) + assert write.producer.op.name == 'b' assert write.value.producer.name == 'sigma' assert len(write.value.indices) == 1 assert write.value.indices[0].value == 0 diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index a8ab3cfda25a..8d737c9f629b 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -261,8 +261,8 @@ def test_tuple_with_different_deps(): stmt = tvm.te.schedule.ScheduleOps(sch, bounds) def get_B1_realize(x): - if isinstance(x, tvm.tir.Realize) and \ - x.func == B1.op and x.value_index == 1: + if isinstance(x, tvm.tir.ProducerRealize) and \ + x.producer.op == B1.op and x.producer.value_index == 1: ret.append(x) ret = [] tvm.tir.stmt_functor.post_order_visit(stmt, get_B1_realize) diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 86f87348ec53..8f03d1028bc6 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -158,12 +158,6 @@ def test_stmt_constructor(): assert x.index.value == 10 assert x.value.value == 1 - tensor = te.placeholder((), dtype="float32") - x = tvm.tir.Provide(tensor.op, 0, 10, []) - assert isinstance(x, tvm.tir.Provide) - assert x.value_index == 0 - assert x.value.value == 10 - x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop) assert isinstance(x, tvm.tir.Allocate) @@ -181,10 +175,6 @@ def test_stmt_constructor(): assert isinstance(x, tvm.tir.Free) assert x.buffer_var == buffer_var - x = tvm.tir.Realize(None, 0, "float", [], tvm.tir.const(1, "uint1"), nop) - assert isinstance(x, tvm.tir.Realize) - assert x.body == nop - x = tvm.tir.IfThenElse(tvm.tir.const(1, "uint1"), tvm.tir.Evaluate(11), nop) diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index 51da6ea9ad18..5801200c15da 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -69,7 +69,7 @@ def test_lower_warp_memory_correct_indices(): ir = tvm.te.schedule.ScheduleOps(s, bounds) inner_func = ir.body.body.body.body store_A_warp = inner_func.body.seq[0].body.body - indices = list(store_A_warp.args) + indices = list(store_A_warp.indices) # A.warp is actually many buffers, one for each warp, although they are all called A.warp # 1. If we are accessing from different threads within a same warp (different