Skip to content

Commit

Permalink
add add operator (PaddlePaddle#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn committed Aug 19, 2020
1 parent e57aad1 commit 5aef829
Show file tree
Hide file tree
Showing 17 changed files with 175 additions and 93 deletions.
1 change: 1 addition & 0 deletions cinn/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using common::make_shared;
using common::Object;
using common::ref_count;
using common::Shared;
using common::UniqName;

// Type related.
using common::Bool;
Expand Down
2 changes: 2 additions & 0 deletions cinn/common/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,7 @@ class Context {
mutable std::string runtime_include_dir_;
};

static std::string UniqName(const std::string& prefix) { return Context::Global().NewName(prefix); }

} // namespace common
} // namespace cinn
14 changes: 8 additions & 6 deletions cinn/common/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,18 @@ struct Object {
//! Type safe cast.
template <typename T>
T* safe_as() {
CHECK(std::strcmp(type_info(), T::__type_info__) == 0)
<< "type mismatch, this is a " << type_info() << ", but want a " << T::__type_info__;
return static_cast<T*>(this);
if (std::strcmp(type_info(), T::__type_info__) == 0) {
return static_cast<T*>(this);
}
return nullptr;
}
//! Type safe cast.
template <typename T>
const T* safe_as() const {
CHECK(std::strcmp(type_info(), T::__type_info__) == 0)
<< "type mismatch, this is a " << type_info() << ", but want a " << T::__type_info__;
return static_cast<const T*>(this);
if (std::strcmp(type_info(), T::__type_info__) == 0) {
return static_cast<const T*>(this);
}
return nullptr;
}

//! Check if the type is right.
Expand Down
3 changes: 3 additions & 0 deletions cinn/hlir/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ set(srcs
memory.cc
instruction.cc
graph_compiler.cc
graph.cc
node.cc
pass.cc
)

if(WITH_CUDA)
Expand Down
40 changes: 40 additions & 0 deletions cinn/hlir/framework/graph.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include "cinn/hlir/framework/graph.h"

namespace cinn {
namespace hlir {
namespace framework {

Graph::Graph(frontend::Program prog) {
std::unordered_map<std::string, std::vector<int>> res;
int counter = 0;
for (size_t i = 0; i < prog.size(); i++) {
auto temp = prog[i];
Node* node_tmp =
new Node(Operator::Get(temp->op_type), temp->op_type, temp->op_type + "_" + std::to_string(counter++));
std::shared_ptr<Node> node_ptr(node_tmp);
node_tmp->attrs.attr_store = temp->attrs;
for (frontend::Variable j : temp->inputs) {
NodeData* input_data = this->RetriveNode(j->id)->as<NodeData>();
if (!input_data) {
res[j->id] = j->shape;
input_data = new NodeData(nullptr, 0, 0, j->id);
input_data->LinkTo(node_tmp);
this->RegisterNode(j->id, input_data);
} else {
input_data->LinkTo(node_tmp);
}
}
for (frontend::Variable j : temp->outputs) {
int out_idx = 0;
NodeData* output_data = new NodeData(node_ptr, out_idx++, 0, j->id);
node_tmp->LinkTo(output_data);
this->RegisterNode(j->id, output_data);
}
this->RegisterNode(node_tmp->id(), node_tmp);
}
this->attrs["infer_shape"] = std::make_shared<std::any>(res);
}

} // namespace framework
} // namespace hlir
} // namespace cinn
34 changes: 4 additions & 30 deletions cinn/hlir/framework/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,7 @@ namespace framework {
*/
class Graph : public cinn::common::Graph {
public:
explicit Graph(frontend::Program prog) {
std::unordered_map<std::string, std::vector<int>> res;
int counter = 0;
for (size_t i = 0; i < prog.size(); i++) {
auto temp = prog[i];
Node* node_tmp =
new Node(Operator::Get(temp->op_type), temp->op_type, temp->op_type + "_" + std::to_string(counter++));
std::shared_ptr<Node> node_ptr(node_tmp);
node_tmp->attrs.attr_store = temp->attrs;
for (frontend::Variable j : temp->inputs) {
NodeData* input_data = this->RetriveNode(j->id)->as<NodeData>();
if (!input_data) {
res[j->id] = j->shape;
input_data = new NodeData(nullptr, 0, 0, j->id);
input_data->LinkTo(node_tmp);
this->RegisterNode(j->id, input_data);
} else {
input_data->LinkTo(node_tmp);
}
}
for (frontend::Variable j : temp->outputs) {
int out_idx = 0;
NodeData* output_data = new NodeData(node_ptr, out_idx++, 0, j->id);
node_tmp->LinkTo(output_data);
this->RegisterNode(j->id, output_data);
}
this->RegisterNode(node_tmp->id(), node_tmp);
}
this->attrs["infer_shape"] = std::make_shared<std::any>(res);
}
explicit Graph(frontend::Program prog);

/** \brief outputs of the computation graph. */
std::vector<NodeData*> outputs;
Expand Down Expand Up @@ -89,6 +60,9 @@ class Graph : public cinn::common::Graph {
auto it = attrs.find(attr_name);
return it != attrs.end();
}

private:
CINN_DISALLOW_COPY_AND_ASSIGN(Graph);
};

} // namespace framework
Expand Down
2 changes: 0 additions & 2 deletions cinn/hlir/framework/graph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
namespace cinn {
namespace hlir {
namespace framework {
using StrategyFunction = std::function<std::shared_ptr<OpStrategy>(
const NodeAttr, const std::vector<ir::Tensor>, common::Type, const common::Target)>;

std::unique_ptr<Program> GraphCompiler::Build() {
auto [nodes, edges] = graph_->topological_order();
Expand Down
6 changes: 2 additions & 4 deletions cinn/hlir/framework/graph_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <string>
#include <utility>
#include <vector>

#include "cinn/backends/compiler.h"
#include "cinn/common/macros.h"
#include "cinn/hlir/framework/graph.h"
Expand Down Expand Up @@ -37,10 +38,7 @@ class Program {
class GraphCompiler final {
public:
GraphCompiler(Target target, const std::shared_ptr<Scope>& scope, Graph* const graph)
: target_(std::move(target)),
scope_(scope),
graph_(graph),
m_builder_(Context::Global().NewName("module"), target) {}
: target_(std::move(target)), scope_(scope), graph_(graph), m_builder_(UniqName("module"), target) {}

std::unique_ptr<Program> Build();

Expand Down
8 changes: 4 additions & 4 deletions cinn/hlir/framework/infershape_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@ CINN_REGISTER_PASS(InferShape)

TEST(Operator, GetAttr) {
frontend::Program prog;
// TODO(Superjomn) Replace with Placeholder here.
frontend::Variable a("a");
frontend::Variable b("b");
a->shape = {100, 32};
b->shape = {100, 32};
auto c = prog.add(a, b);
auto d = prog.add(c, b);
auto e = prog.add(c, d);
ASSERT_EQ(prog.size(), 3);
Graph* g = new Graph(prog);
ApplyPass(g, "InferShape");
ASSERT_EQ(prog.size(), 3UL);
std::unique_ptr<Graph> g(new Graph(prog));
ApplyPass(g.get(), "InferShape");
auto s = g->GetAttr<std::unordered_map<std::string, std::vector<int>>>("infer_shape");
for (auto i : s) {
LOG(INFO) << "Var id is: " << i.first << " and Var shape is: ";
std::vector<int> correct_shape{100, 32};
for (auto j : i.second) {
LOG(INFO) << j << " ";
}
Expand Down
17 changes: 17 additions & 0 deletions cinn/hlir/framework/node.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "cinn/hlir/framework/node.h"

namespace cinn {
namespace hlir {
namespace framework {

std::tuple<common::GraphEdge *, common::GraphEdge *> Node::LinkTo(NodeData *other) {
return this->common::GraphNode::LinkTo(other->as<common::GraphNode>());
}

std::tuple<common::GraphEdge *, common::GraphEdge *> NodeData::LinkTo(Node *other) {
return this->common::GraphNode::LinkTo(other->as<common::GraphNode>());
}

} // namespace framework
} // namespace hlir
} // namespace cinn
10 changes: 1 addition & 9 deletions cinn/hlir/framework/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
#include <tuple>
#include <unordered_map>
#include <utility>
#include <variant>
#include <vector>

#include "cinn/common/graph_utils.h"
#include "cinn/hlir/framework/op.h"

Expand Down Expand Up @@ -149,14 +149,6 @@ class NodeData : public common::GraphNode {
std::string id_;
};

std::tuple<common::GraphEdge *, common::GraphEdge *> Node::LinkTo(NodeData *other) {
return this->common::GraphNode::LinkTo(other->as<common::GraphNode>());
}

std::tuple<common::GraphEdge *, common::GraphEdge *> NodeData::LinkTo(Node *other) {
return this->common::GraphNode::LinkTo(other->as<common::GraphNode>());
}

} // namespace framework
} // namespace hlir
} // namespace cinn
11 changes: 8 additions & 3 deletions cinn/hlir/framework/op_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
#include "cinn/hlir/framework/schedule.h"
#include "cinn/ir/packed_func.h"

using CINNCompute = cinn::ir::PackedFunc;
using CINNSchedule = cinn::ir::PackedFunc;

namespace cinn {
namespace hlir {
namespace framework {

using CINNCompute = ir::PackedFunc;
using CINNSchedule = ir::PackedFunc;

struct OpStrategy;

using StrategyFunction = std::function<std::shared_ptr<OpStrategy>(
const NodeAttr, const std::vector<ir::Tensor>, common::Type, const common::Target)>;

//! Operator implementation that includes compute and schedule function.
class OpImpl : public common::Object {
public:
Expand Down
6 changes: 1 addition & 5 deletions cinn/hlir/framework/op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,16 @@
#include <gtest/gtest.h>

#include <functional>

#include <string>

#include "cinn/cinn.h"

#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op_strategy.h"
#include "cinn/hlir/pe/broadcast.h"

namespace cinn {
namespace hlir {
namespace framework {
using StrategyFunction = std::function<std::shared_ptr<OpStrategy>(
const NodeAttr, const std::vector<ir::Tensor>, common::Type, const common::Target)>;

using CCompute = std::function<std::shared_ptr<ir::Tensor>(const std::vector<ir::Tensor>)>;

Expand Down Expand Up @@ -49,7 +45,7 @@ std::shared_ptr<OpStrategy> StrategyTest(const NodeAttr &attr,
};
ir::PackedFunc fschedule(schedule_body);

std::shared_ptr<OpStrategy> strategy = std::make_shared<OpStrategy>();
auto strategy = std::make_shared<OpStrategy>();

if (target.arch == common::Target::Arch ::X86) {
strategy->AddImpl(fcompute, fschedule, "test.strategy.x86", 10);
Expand Down
38 changes: 38 additions & 0 deletions cinn/hlir/framework/pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "cinn/hlir/framework/pass.h"

namespace cinn {
namespace hlir {
namespace framework {

void ApplyPasses(Graph* g, const std::vector<std::string>& passes) {
std::vector<const PassFunctionRegister*> fpass;
for (auto& name : passes) {
auto* reg = Registry<PassFunctionRegister>::Global()->Find(name);
CHECK(reg) << "Cannot find pass " << name << " in the registry";
fpass.push_back(reg);
}
for (auto* r : fpass) {
for (auto& dep : r->graph_attr_dependency) {
CHECK_NE(g->attrs.count(dep), 0) << "To apply pass [" << r->name << "], Graph's attribute [" << dep
<< "] is required, but it is not available.";
if (g->attrs.count(dep) == 0) {
auto* pass_dep = FindPassDep(dep);
CHECK(!pass_dep) << "And the attribute is provided by pass [" << pass_dep->name << "].";
}
}
r->body(g);
}
}

const PassFunctionRegister* FindPassDep(const std::string& attr_name) {
for (auto* r : Registry<PassFunctionRegister>::Global()->List()) {
for (auto& s : r->graph_attr_targets) {
if (s == attr_name) return r;
}
}
return nullptr;
}

} // namespace framework
} // namespace hlir
} // namespace cinn
29 changes: 2 additions & 27 deletions cinn/hlir/framework/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,40 +77,15 @@ class PassFunctionRegister : public FunctionRegEntryBase<PassFunctionRegister, P
}
};

const PassFunctionRegister* FindPassDep(const std::string& attr_name) {
for (auto* r : Registry<PassFunctionRegister>::Global()->List()) {
for (auto& s : r->graph_attr_targets) {
if (s == attr_name) return r;
}
}
return nullptr;
}
const PassFunctionRegister* FindPassDep(const std::string& attr_name);

/**
* \brief Apply a sequence of passes on a graph.
* @param g The input graph to apply passes on.
* @param passes The sequence of pass.
* @return The graph after being modified by the passes.
*/
void ApplyPasses(Graph* g, const std::vector<std::string>& passes) {
std::vector<const PassFunctionRegister*> fpass;
for (auto& name : passes) {
auto* reg = Registry<PassFunctionRegister>::Global()->Find(name);
CHECK(reg) << "Cannot find pass " << name << " in the registry";
fpass.push_back(reg);
}
for (auto* r : fpass) {
for (auto& dep : r->graph_attr_dependency) {
CHECK_NE(g->attrs.count(dep), 0) << "To apply pass [" << r->name << "], Graph's attribute [" << dep
<< "] is required, but it is not available.";
if (g->attrs.count(dep) == 0) {
auto* pass_dep = FindPassDep(dep);
CHECK(!pass_dep) << "And the attribute is provided by pass [" << pass_dep->name << "].";
}
}
r->body(g);
}
}
void ApplyPasses(Graph* g, const std::vector<std::string>& passes);

// Apply a single pass on a graph.
inline void ApplyPass(Graph* g, const std::string& pass) { return ApplyPasses(g, {pass}); }
Expand Down
Loading

0 comments on commit 5aef829

Please sign in to comment.