diff --git a/cinn/common/common.h b/cinn/common/common.h index c8acbf460e3f0..a352808081f13 100644 --- a/cinn/common/common.h +++ b/cinn/common/common.h @@ -17,6 +17,7 @@ using common::make_shared; using common::Object; using common::ref_count; using common::Shared; +using common::UniqName; // Type related. using common::Bool; diff --git a/cinn/common/context.h b/cinn/common/context.h index a4b39c9cf5255..4a1638ca50bbd 100644 --- a/cinn/common/context.h +++ b/cinn/common/context.h @@ -73,5 +73,7 @@ class Context { mutable std::string runtime_include_dir_; }; +static std::string UniqName(const std::string& prefix) { return Context::Global().NewName(prefix); } + } // namespace common } // namespace cinn diff --git a/cinn/common/object.h b/cinn/common/object.h index 1ac4fb09043eb..3512848a439fc 100644 --- a/cinn/common/object.h +++ b/cinn/common/object.h @@ -30,16 +30,18 @@ struct Object { //! Type safe cast. template T* safe_as() { - CHECK(std::strcmp(type_info(), T::__type_info__) == 0) - << "type mismatch, this is a " << type_info() << ", but want a " << T::__type_info__; - return static_cast(this); + if (std::strcmp(type_info(), T::__type_info__) == 0) { + return static_cast(this); + } + return nullptr; } //! Type safe cast. template const T* safe_as() const { - CHECK(std::strcmp(type_info(), T::__type_info__) == 0) - << "type mismatch, this is a " << type_info() << ", but want a " << T::__type_info__; - return static_cast(this); + if (std::strcmp(type_info(), T::__type_info__) == 0) { + return static_cast(this); + } + return nullptr; } //! Check if the type is right. diff --git a/cinn/hlir/framework/CMakeLists.txt b/cinn/hlir/framework/CMakeLists.txt index 657165d3ef088..3d7cfb2039cee 100644 --- a/cinn/hlir/framework/CMakeLists.txt +++ b/cinn/hlir/framework/CMakeLists.txt @@ -6,6 +6,9 @@ set(srcs memory.cc instruction.cc graph_compiler.cc + graph.cc + node.cc + pass.cc ) if(WITH_CUDA) diff --git a/cinn/hlir/framework/graph.cc b/cinn/hlir/framework/graph.cc new file mode 100644 index 0000000000000..a99247b41895a --- /dev/null +++ b/cinn/hlir/framework/graph.cc @@ -0,0 +1,40 @@ +#include "cinn/hlir/framework/graph.h" + +namespace cinn { +namespace hlir { +namespace framework { + +Graph::Graph(frontend::Program prog) { + std::unordered_map> res; + int counter = 0; + for (size_t i = 0; i < prog.size(); i++) { + auto temp = prog[i]; + Node* node_tmp = + new Node(Operator::Get(temp->op_type), temp->op_type, temp->op_type + "_" + std::to_string(counter++)); + std::shared_ptr node_ptr(node_tmp); + node_tmp->attrs.attr_store = temp->attrs; + for (frontend::Variable j : temp->inputs) { + NodeData* input_data = this->RetriveNode(j->id)->as(); + if (!input_data) { + res[j->id] = j->shape; + input_data = new NodeData(nullptr, 0, 0, j->id); + input_data->LinkTo(node_tmp); + this->RegisterNode(j->id, input_data); + } else { + input_data->LinkTo(node_tmp); + } + } + for (frontend::Variable j : temp->outputs) { + int out_idx = 0; + NodeData* output_data = new NodeData(node_ptr, out_idx++, 0, j->id); + node_tmp->LinkTo(output_data); + this->RegisterNode(j->id, output_data); + } + this->RegisterNode(node_tmp->id(), node_tmp); + } + this->attrs["infer_shape"] = std::make_shared(res); +} + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index 8573dc896384f..2ede9bd77fbcb 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -19,36 +19,7 @@ namespace framework { */ class Graph : public cinn::common::Graph { public: - explicit Graph(frontend::Program prog) { - std::unordered_map> res; - int counter = 0; - for (size_t i = 0; i < prog.size(); i++) { - auto temp = prog[i]; - Node* node_tmp = - new Node(Operator::Get(temp->op_type), temp->op_type, temp->op_type + "_" + std::to_string(counter++)); - std::shared_ptr node_ptr(node_tmp); - node_tmp->attrs.attr_store = temp->attrs; - for (frontend::Variable j : temp->inputs) { - NodeData* input_data = this->RetriveNode(j->id)->as(); - if (!input_data) { - res[j->id] = j->shape; - input_data = new NodeData(nullptr, 0, 0, j->id); - input_data->LinkTo(node_tmp); - this->RegisterNode(j->id, input_data); - } else { - input_data->LinkTo(node_tmp); - } - } - for (frontend::Variable j : temp->outputs) { - int out_idx = 0; - NodeData* output_data = new NodeData(node_ptr, out_idx++, 0, j->id); - node_tmp->LinkTo(output_data); - this->RegisterNode(j->id, output_data); - } - this->RegisterNode(node_tmp->id(), node_tmp); - } - this->attrs["infer_shape"] = std::make_shared(res); - } + explicit Graph(frontend::Program prog); /** \brief outputs of the computation graph. */ std::vector outputs; @@ -89,6 +60,9 @@ class Graph : public cinn::common::Graph { auto it = attrs.find(attr_name); return it != attrs.end(); } + + private: + CINN_DISALLOW_COPY_AND_ASSIGN(Graph); }; } // namespace framework diff --git a/cinn/hlir/framework/graph_compiler.cc b/cinn/hlir/framework/graph_compiler.cc index bbd2dd67ad1a5..a5e4b796e8ccb 100644 --- a/cinn/hlir/framework/graph_compiler.cc +++ b/cinn/hlir/framework/graph_compiler.cc @@ -6,8 +6,6 @@ namespace cinn { namespace hlir { namespace framework { -using StrategyFunction = std::function( - const NodeAttr, const std::vector, common::Type, const common::Target)>; std::unique_ptr GraphCompiler::Build() { auto [nodes, edges] = graph_->topological_order(); diff --git a/cinn/hlir/framework/graph_compiler.h b/cinn/hlir/framework/graph_compiler.h index 5f5284a4f1362..d5372f1e1cc46 100644 --- a/cinn/hlir/framework/graph_compiler.h +++ b/cinn/hlir/framework/graph_compiler.h @@ -4,6 +4,7 @@ #include #include #include + #include "cinn/backends/compiler.h" #include "cinn/common/macros.h" #include "cinn/hlir/framework/graph.h" @@ -37,10 +38,7 @@ class Program { class GraphCompiler final { public: GraphCompiler(Target target, const std::shared_ptr& scope, Graph* const graph) - : target_(std::move(target)), - scope_(scope), - graph_(graph), - m_builder_(Context::Global().NewName("module"), target) {} + : target_(std::move(target)), scope_(scope), graph_(graph), m_builder_(UniqName("module"), target) {} std::unique_ptr Build(); diff --git a/cinn/hlir/framework/infershape_pass_test.cc b/cinn/hlir/framework/infershape_pass_test.cc index 0797b4ac4e5e2..78cb53b877d49 100644 --- a/cinn/hlir/framework/infershape_pass_test.cc +++ b/cinn/hlir/framework/infershape_pass_test.cc @@ -60,6 +60,7 @@ CINN_REGISTER_PASS(InferShape) TEST(Operator, GetAttr) { frontend::Program prog; + // TODO(Superjomn) Replace with Placeholder here. frontend::Variable a("a"); frontend::Variable b("b"); a->shape = {100, 32}; @@ -67,13 +68,12 @@ TEST(Operator, GetAttr) { auto c = prog.add(a, b); auto d = prog.add(c, b); auto e = prog.add(c, d); - ASSERT_EQ(prog.size(), 3); - Graph* g = new Graph(prog); - ApplyPass(g, "InferShape"); + ASSERT_EQ(prog.size(), 3UL); + std::unique_ptr g(new Graph(prog)); + ApplyPass(g.get(), "InferShape"); auto s = g->GetAttr>>("infer_shape"); for (auto i : s) { LOG(INFO) << "Var id is: " << i.first << " and Var shape is: "; - std::vector correct_shape{100, 32}; for (auto j : i.second) { LOG(INFO) << j << " "; } diff --git a/cinn/hlir/framework/node.cc b/cinn/hlir/framework/node.cc new file mode 100644 index 0000000000000..74878f7453de2 --- /dev/null +++ b/cinn/hlir/framework/node.cc @@ -0,0 +1,17 @@ +#include "cinn/hlir/framework/node.h" + +namespace cinn { +namespace hlir { +namespace framework { + +std::tuple Node::LinkTo(NodeData *other) { + return this->common::GraphNode::LinkTo(other->as()); +} + +std::tuple NodeData::LinkTo(Node *other) { + return this->common::GraphNode::LinkTo(other->as()); +} + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/node.h b/cinn/hlir/framework/node.h index 7ee1a3c27aaf0..8461eeed53a0c 100644 --- a/cinn/hlir/framework/node.h +++ b/cinn/hlir/framework/node.h @@ -4,8 +4,8 @@ #include #include #include +#include #include - #include "cinn/common/graph_utils.h" #include "cinn/hlir/framework/op.h" @@ -149,14 +149,6 @@ class NodeData : public common::GraphNode { std::string id_; }; -std::tuple Node::LinkTo(NodeData *other) { - return this->common::GraphNode::LinkTo(other->as()); -} - -std::tuple NodeData::LinkTo(Node *other) { - return this->common::GraphNode::LinkTo(other->as()); -} - } // namespace framework } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/framework/op_strategy.h b/cinn/hlir/framework/op_strategy.h index 08a2c9c3fb5c0..d64ade43abb7f 100644 --- a/cinn/hlir/framework/op_strategy.h +++ b/cinn/hlir/framework/op_strategy.h @@ -6,13 +6,18 @@ #include "cinn/hlir/framework/schedule.h" #include "cinn/ir/packed_func.h" -using CINNCompute = cinn::ir::PackedFunc; -using CINNSchedule = cinn::ir::PackedFunc; - namespace cinn { namespace hlir { namespace framework { +using CINNCompute = ir::PackedFunc; +using CINNSchedule = ir::PackedFunc; + +struct OpStrategy; + +using StrategyFunction = std::function( + const NodeAttr, const std::vector, common::Type, const common::Target)>; + //! Operator implementation that includes compute and schedule function. class OpImpl : public common::Object { public: diff --git a/cinn/hlir/framework/op_test.cc b/cinn/hlir/framework/op_test.cc index d1684260c9192..dbd0c28db490f 100644 --- a/cinn/hlir/framework/op_test.cc +++ b/cinn/hlir/framework/op_test.cc @@ -3,11 +3,9 @@ #include #include - #include #include "cinn/cinn.h" - #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op_strategy.h" #include "cinn/hlir/pe/broadcast.h" @@ -15,8 +13,6 @@ namespace cinn { namespace hlir { namespace framework { -using StrategyFunction = std::function( - const NodeAttr, const std::vector, common::Type, const common::Target)>; using CCompute = std::function(const std::vector)>; @@ -49,7 +45,7 @@ std::shared_ptr StrategyTest(const NodeAttr &attr, }; ir::PackedFunc fschedule(schedule_body); - std::shared_ptr strategy = std::make_shared(); + auto strategy = std::make_shared(); if (target.arch == common::Target::Arch ::X86) { strategy->AddImpl(fcompute, fschedule, "test.strategy.x86", 10); diff --git a/cinn/hlir/framework/pass.cc b/cinn/hlir/framework/pass.cc new file mode 100644 index 0000000000000..9427a4793de04 --- /dev/null +++ b/cinn/hlir/framework/pass.cc @@ -0,0 +1,38 @@ +#include "cinn/hlir/framework/pass.h" + +namespace cinn { +namespace hlir { +namespace framework { + +void ApplyPasses(Graph* g, const std::vector& passes) { + std::vector fpass; + for (auto& name : passes) { + auto* reg = Registry::Global()->Find(name); + CHECK(reg) << "Cannot find pass " << name << " in the registry"; + fpass.push_back(reg); + } + for (auto* r : fpass) { + for (auto& dep : r->graph_attr_dependency) { + CHECK_NE(g->attrs.count(dep), 0) << "To apply pass [" << r->name << "], Graph's attribute [" << dep + << "] is required, but it is not available."; + if (g->attrs.count(dep) == 0) { + auto* pass_dep = FindPassDep(dep); + CHECK(!pass_dep) << "And the attribute is provided by pass [" << pass_dep->name << "]."; + } + } + r->body(g); + } +} + +const PassFunctionRegister* FindPassDep(const std::string& attr_name) { + for (auto* r : Registry::Global()->List()) { + for (auto& s : r->graph_attr_targets) { + if (s == attr_name) return r; + } + } + return nullptr; +} + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/pass.h b/cinn/hlir/framework/pass.h index f019743a6899e..b32248768d8b2 100644 --- a/cinn/hlir/framework/pass.h +++ b/cinn/hlir/framework/pass.h @@ -77,14 +77,7 @@ class PassFunctionRegister : public FunctionRegEntryBase::Global()->List()) { - for (auto& s : r->graph_attr_targets) { - if (s == attr_name) return r; - } - } - return nullptr; -} +const PassFunctionRegister* FindPassDep(const std::string& attr_name); /** * \brief Apply a sequence of passes on a graph. @@ -92,25 +85,7 @@ const PassFunctionRegister* FindPassDep(const std::string& attr_name) { * @param passes The sequence of pass. * @return The graph after being modified by the passes. */ -void ApplyPasses(Graph* g, const std::vector& passes) { - std::vector fpass; - for (auto& name : passes) { - auto* reg = Registry::Global()->Find(name); - CHECK(reg) << "Cannot find pass " << name << " in the registry"; - fpass.push_back(reg); - } - for (auto* r : fpass) { - for (auto& dep : r->graph_attr_dependency) { - CHECK_NE(g->attrs.count(dep), 0) << "To apply pass [" << r->name << "], Graph's attribute [" << dep - << "] is required, but it is not available."; - if (g->attrs.count(dep) == 0) { - auto* pass_dep = FindPassDep(dep); - CHECK(!pass_dep) << "And the attribute is provided by pass [" << pass_dep->name << "]."; - } - } - r->body(g); - } -} +void ApplyPasses(Graph* g, const std::vector& passes); // Apply a single pass on a graph. inline void ApplyPass(Graph* g, const std::string& pass) { return ApplyPasses(g, {pass}); } diff --git a/cinn/hlir/op/nn.cc b/cinn/hlir/op/nn.cc index 3ea643a3f905e..c3bfad34e5e40 100644 --- a/cinn/hlir/op/nn.cc +++ b/cinn/hlir/op/nn.cc @@ -1,11 +1,52 @@ +#include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" #include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/pe/broadcast.h" + +namespace cinn { +namespace hlir { +namespace op { +using common::CINNValue; +using common::CINNValuePack; +using common::CINNValuePackShared; +using framework::OpStrategy; +using framework::StrategyFunction; + +std::shared_ptr StrategyForAdd(const framework::NodeAttr &attr, + const std::vector &inputs, + Type out_type, + const Target &target) { + framework::CINNCompute add_compute([](ir::Args args, ir::RetValue *ret) { + CINNValuePackShared a = args[0]; + ir::Expr A = a[0]; + ir::Expr B = a[1]; + CHECK(A.as_tensor()); + CHECK(B.as_tensor()); + *ret = + CINNValuePack::Make({CINNValue(ir::Expr(pe::Add(A.as_tensor_ref(), B.as_tensor_ref(), UniqName("C")).get()))}); + }); + + framework::CINNSchedule add_schedule([](ir::Args args, ir::RetValue *ret) { + CINNValuePackShared a = args[0]; + ir::Expr A = a[0]; + *ret = CINNValuePack::Make({CINNValue(A)}); + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(add_compute, add_schedule, "strategy.add.x86", 1); + + return strategy; +} + +} // namespace op +} // namespace hlir +} // namespace cinn CINN_REGISTER_HELPER(nn_ops) { CINN_REGISTER_OP(add) - .describe("Add") + .describe("Add two tensors") .set_num_inputs(2) .set_num_outputs(1) - .set_attr("add", "add") + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForAdd) .set_support_level(4); } diff --git a/cmake/core.cmake b/cmake/core.cmake index dac3832522a4f..4b60d44c7f22a 100644 --- a/cmake/core.cmake +++ b/cmake/core.cmake @@ -1,4 +1,4 @@ -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -mavx") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -mavx -Wno-write-strings") function(cc_library TARGET_NAME) set(options STATIC static SHARED shared)