diff --git a/cinn/common/cinn_value.cc b/cinn/common/cinn_value.cc index e5b2c56279452..e87327bf95a34 100644 --- a/cinn/common/cinn_value.cc +++ b/cinn/common/cinn_value.cc @@ -117,7 +117,7 @@ void CINNValuePack::AddValue(const CINNValue &value) { values_.push_back(value); } void CINNValuePack::Clear() { values_.clear(); } -const char *CINNValuePack::type_info() const { return "CINNValuePack"; } +const char *CINNValuePack::type_info() const { return __type_info__; } CINNValue &CINNValue::operator=(int32_t value) { *this = CINNValue(value); diff --git a/cinn/common/cinn_value.h b/cinn/common/cinn_value.h index 6984cd6f7178b..4ba4c5a07d294 100644 --- a/cinn/common/cinn_value.h +++ b/cinn/common/cinn_value.h @@ -60,10 +60,11 @@ struct CINNValuePack : public common::Object { private: CINNValuePack() = default; std::vector values_; + static constexpr char* __type_info__ = "CINNValuePack"; }; struct CINNValuePackShared : public Shared { - CINNValuePackShared(CINNValuePack* ptr) : Shared(ptr) {} + explicit CINNValuePackShared(CINNValuePack* ptr) : Shared(ptr) {} CINNValue& operator[](int offset) { return (*operator->())[offset]; } const CINNValue& operator[](int offset) const { return (*operator->())[offset]; } diff --git a/cinn/common/graph_utils.h b/cinn/common/graph_utils.h index 28dabcde01cff..de4b3a71f4643 100644 --- a/cinn/common/graph_utils.h +++ b/cinn/common/graph_utils.h @@ -35,13 +35,14 @@ class GraphEdge : public Object { GraphNode* source() const { return source_; } GraphNode* sink() const { return sink_; } - const char* type_info() const override { return "graph_edge"; } + const char* type_info() const override { return __type_info__; } private: //! Source of this edge. GraphNode* source_{}; //! End of this edge. GraphNode* sink_{}; + static constexpr char* __type_info__ = "graph_edge"; }; struct GraphEdgeCompare { diff --git a/cinn/frontend/syntax.h b/cinn/frontend/syntax.h index b171d068d9013..98a4ebb41098f 100644 --- a/cinn/frontend/syntax.h +++ b/cinn/frontend/syntax.h @@ -68,7 +68,7 @@ struct _Instruction_ : public common::Object { const char* type_info() const override { return __type_info__; } - static constexpr const char* __type_info__ = "cinn_frontend_instruction"; + static constexpr char* __type_info__ = "cinn_frontend_instruction"; }; struct Instruction : public common::Shared<_Instruction_> { diff --git a/cinn/hlir/framework/op_strategy.h b/cinn/hlir/framework/op_strategy.h index 403a41bb6325c..08a2c9c3fb5c0 100644 --- a/cinn/hlir/framework/op_strategy.h +++ b/cinn/hlir/framework/op_strategy.h @@ -52,17 +52,17 @@ class OpImpl : public common::Object { return nullptr; } - const char* type_info() const override { return _type_key; } + const char* type_info() const override { return __type_info__; } private: - static constexpr char* _type_key = "OpImplementation"; + static constexpr char* __type_info__ = "OpImplementation"; }; //! Specialized implementations for operators under certain conditions. class OpSpec : public common::Object { public: //! List of implementations. - std::vector implementations; + std::vector> implementations; /** \brief Condition to enable the specialization. * Could be undefined to represent generic case. @@ -71,10 +71,10 @@ class OpSpec : public common::Object { */ std::string condition; - const char* type_info() const override { return _type_key; } + const char* type_info() const override { return __type_info__; } void AddImpl(CINNCompute fcompute, CINNSchedule fschedule, std::string name, int plevel) { - auto n = make_shared(); + auto n = std::make_shared(); n->fcompute = fcompute; n->fschedule = fschedule; n->name = std::move(name); @@ -83,15 +83,15 @@ class OpSpec : public common::Object { } private: - static constexpr char* _type_key = "OpSpecialization"; + static constexpr char* __type_info__ = "OpSpecialization"; }; //! Operator strategy class. class OpStrategy : public common::Object { public: - const char* type_info() const override { return "CINNOpStrategy"; } + const char* type_info() const override { return __type_info__; } //! List of operator specializations. - std::vector specializations; + std::vector> specializations; /** * \brief Add an implementation. @@ -103,21 +103,40 @@ class OpStrategy : public common::Object { void AddImpl(CINNCompute fcompute, CINNSchedule fschedule, std::string name, int plevel) { //! TODO(haozech) : here curr_cond should get the condition from outside. //! Expected : auto curr_cond = SpecializedCondition::Current(); - std::string curr_cond = "current_condition"; - OpSpec* op_spec; - for (OpSpec* op_spec : specializations) { - if (op_spec->condition == curr_cond) { + std::string curr_condition = "default"; + for (auto op_spec : specializations) { + if (op_spec->condition == curr_condition) { op_spec->AddImpl(fcompute, fschedule, std::move(name), plevel); return; } } - OpSpec* n = make_shared(); - n->condition = curr_cond; + std::shared_ptr n = std::make_shared(); + n->condition = curr_condition; n->AddImpl(fcompute, fschedule, std::move(name), plevel); this->specializations.push_back(n); } + + private: + static constexpr char* __type_info__ = "OpStrategy"; }; +std::shared_ptr SelectImpl(std::shared_ptr strategy) { + //! should get the host info from global environment. + std::string curr_condition = "default"; + std::shared_ptr res = nullptr; + for (auto spec : strategy->specializations) { + if (spec->condition == "default") { + for (auto i : spec->implementations) { + if (!res || res->plevel < i->plevel) { + res = i; + } + } + } + } + CHECK(res) << "There is no available strategy implementation! SelectImpl failed!"; + return res; +} + } // namespace framework } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/framework/op_test.cc b/cinn/hlir/framework/op_test.cc index 2cdf427cdfde0..d1684260c9192 100644 --- a/cinn/hlir/framework/op_test.cc +++ b/cinn/hlir/framework/op_test.cc @@ -2,15 +2,23 @@ #include +#include + #include #include "cinn/cinn.h" + +#include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op_strategy.h" #include "cinn/hlir/pe/broadcast.h" namespace cinn { namespace hlir { namespace framework { +using StrategyFunction = std::function( + const NodeAttr, const std::vector, common::Type, const common::Target)>; + +using CCompute = std::function(const std::vector)>; CINN_REGISTER_OP(add) .describe("test of op Add") @@ -19,30 +27,68 @@ CINN_REGISTER_OP(add) .set_attr("nick_name", "plus") .set_support_level(4); -common::Shared GetStrategyTest() { - ir::PackedFunc::body_t body = [](ir::Args args, ir::RetValue* ret) { - Expr a = args[0]; - Expr b = args[1]; - (*ret) = Expr(pe::Add(a.as_tensor_ref(), b.as_tensor_ref(), "C").get()); +std::shared_ptr StrategyTest(const NodeAttr &attr, + const std::vector &inputs, + common::Type out_type, + const common::Target &target) { + ir::PackedFunc::body_t compute_body = [](ir::Args args, ir::RetValue *ret) { + common::CINNValuePackShared a = args[0]; + ir::Expr A = a[0]; + ir::Expr B = a[1]; + *ret = common::CINNValuePack::Make( + {common::CINNValue(ir::Expr(pe::Add(A.as_tensor_ref(), B.as_tensor_ref(), "C").get()))}); }; - ir::PackedFunc fcompute(body); - // TODO(haozech): fschedule should be an instance of pe::schedule... - ir::PackedFunc fschedule; - common::Shared strategy(make_shared()); - //! To build more complex strategy, we can add more than 1 - //! implementations to one Opstrategy, with different plevel. - strategy->AddImpl(fcompute, fschedule, "test.strategy", 10); + ir::PackedFunc fcompute(compute_body); + + ir::PackedFunc::body_t schedule_body = [](ir::Args args, ir::RetValue *ret) { + common::CINNValuePackShared a = args[0]; + ir::Expr A = a[0]; + A.as_tensor_ref()->stage()->Vectorize(1, 16); + A.as_tensor_ref()->stage()->Unroll(1); + *ret = common::CINNValuePack::Make({common::CINNValue(A)}); + }; + ir::PackedFunc fschedule(schedule_body); + + std::shared_ptr strategy = std::make_shared(); + + if (target.arch == common::Target::Arch ::X86) { + strategy->AddImpl(fcompute, fschedule, "test.strategy.x86", 10); + } else { + strategy->AddImpl(fcompute, fschedule, "test.strategy.else", 10); + } + strategy->AddImpl(fcompute, fschedule, "test.strategy.lowPlevel.x86", 5); return strategy; } TEST(Operator, GetAttr) { - auto add = Operator::Get("add"); - auto test_strategy = GetStrategyTest(); - Operator temp = *add; - temp.set_attr>("CINNStrategy", test_strategy); + auto add = Operator::Get("add"); + Operator temp = *add; + temp.set_attr("CINNStrategy", StrategyTest); auto nick = Operator::GetAttr("nick_name"); - auto strategy = Operator::GetAttr>("CINNStrategy"); - ASSERT_EQ(strategy[add]->specializations[0]->implementations[0]->name, "test.strategy"); + auto strategy = Operator::GetAttr("CINNStrategy"); + + Expr M(100), N(32); + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + + NodeAttr attr; + std::vector inputs{A, B}; + common::Type type; + common::Target target; + target.arch = common::Target::Arch::X86; + auto impl = SelectImpl(strategy[add](attr, inputs, type, target)); + + common::CINNValuePackShared cinn_input = common::CINNValuePack::Make({common::CINNValue(A), common::CINNValue(B)}); + common::CINNValuePackShared C = impl->fcompute(cinn_input); + C = impl->fschedule(C); + for (int i = 0; i < C.get()->size(); i++) { + ir::Expr temp = C[i]; + inputs.push_back(temp.as_tensor_ref()); + } + auto func = Lower("add1", inputs); + LOG(INFO) << "Test Strategy Codegen:\n" << func; + + ASSERT_EQ(impl->name, "test.strategy.x86"); ASSERT_EQ(add->description, "test of op Add"); ASSERT_EQ(nick[add], "plus"); } diff --git a/cinn/hlir/framework/schedule.h b/cinn/hlir/framework/schedule.h index ca675bf156c2c..f3e0085dc33cf 100644 --- a/cinn/hlir/framework/schedule.h +++ b/cinn/hlir/framework/schedule.h @@ -14,7 +14,7 @@ namespace framework { */ class Schedule : public common::Object { public: - const char* type_info() const override { return "CINNSchedule"; } + const char* type_info() const override { return __type_info__; } /** * \brief Get the stage corresponds to the op @@ -36,6 +36,9 @@ class Schedule : public common::Object { //! map of original operation to the stages std::unordered_map stage_map; + + private: + static constexpr char* __type_info__ = "CINNSchedule"; }; } // namespace framework } // namespace hlir diff --git a/cinn/ir/node.h b/cinn/ir/node.h index 5110095302e49..ac6a6f4cd7bce 100644 --- a/cinn/ir/node.h +++ b/cinn/ir/node.h @@ -141,7 +141,7 @@ class IrNode : public common::Object { const char* type_info() const override { return __type_info__; } protected: - constexpr static const char* __type_info__ = "IRNode"; + static constexpr char* __type_info__ = "IRNode"; Type type_; }; diff --git a/cinn/poly/stage.h b/cinn/poly/stage.h index 86966c04f0304..1e473ee1c97b1 100644 --- a/cinn/poly/stage.h +++ b/cinn/poly/stage.h @@ -219,7 +219,7 @@ class Stage : public Object { //! Get the statements. std::vector input_statements() const; - virtual const char* type_info() const { return "Status"; } + virtual const char* type_info() const { return __type_info__; } inline const ir::VectorizeInfo& vectorize_info() const { return vectorize_info_; } inline const std::set& unroll_info() const { return unroll_info_; } @@ -280,6 +280,8 @@ class Stage : public Object { std::set locked_axis_; + static constexpr char* __type_info__ = "Status"; + friend isl_map* __isl_give GatherAccesses(Stage* stage, const std::string& tensor_name); friend class PolyGroupScheduler; };