Skip to content

Commit

Permalink
Add graph_compiler functions and infershape pass (PaddlePaddle#175)
Browse files Browse the repository at this point in the history
* Add graph_compiler functions and infershape pass

* improve infershape_pass test

* add check_type() for object class
  • Loading branch information
haozech committed Aug 18, 2020
1 parent 63ebbb8 commit e57aad1
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 47 deletions.
9 changes: 9 additions & 0 deletions cinn/common/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ struct Object {
return static_cast<const T*>(this);
}

//! Check if the type is right.
template <typename T>
bool check_type() const {
if (std::strcmp(type_info(), T::__type_info__) == 0) {
return true;
}
return false;
}

//! The reference count, which make all the derived type able to share.
mutable RefCount __ref_count__;
};
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ cc_test(test_hlir_framework_scope SRCS scope_test.cc DEPS core)
cc_test(test_hlir_framework_instruction SRCS instruction_test.cc DEPS core)
cc_test(test_hlir_framework_op SRCS op_test.cc DEPS core)
cc_test(test_hlir_framework_print_graph_pass SRCS print_graph_pass_test.cc DEPS core)

cc_test(test_hlir_framework_infershape_pass SRCS infershape_pass_test.cc DEPS core)

foreach(cpp ${srcs})
set(core_src
Expand Down
44 changes: 40 additions & 4 deletions cinn/hlir/framework/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <vector>

#include "cinn/common/graph_utils.h"
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/node.h"

namespace cinn {
Expand All @@ -18,17 +19,52 @@ namespace framework {
*/
class Graph : public cinn::common::Graph {
public:
explicit Graph(frontend::Program prog) {
std::unordered_map<std::string, std::vector<int>> res;
int counter = 0;
for (size_t i = 0; i < prog.size(); i++) {
auto temp = prog[i];
Node* node_tmp =
new Node(Operator::Get(temp->op_type), temp->op_type, temp->op_type + "_" + std::to_string(counter++));
std::shared_ptr<Node> node_ptr(node_tmp);
node_tmp->attrs.attr_store = temp->attrs;
for (frontend::Variable j : temp->inputs) {
NodeData* input_data = this->RetriveNode(j->id)->as<NodeData>();
if (!input_data) {
res[j->id] = j->shape;
input_data = new NodeData(nullptr, 0, 0, j->id);
input_data->LinkTo(node_tmp);
this->RegisterNode(j->id, input_data);
} else {
input_data->LinkTo(node_tmp);
}
}
for (frontend::Variable j : temp->outputs) {
int out_idx = 0;
NodeData* output_data = new NodeData(node_ptr, out_idx++, 0, j->id);
node_tmp->LinkTo(output_data);
this->RegisterNode(j->id, output_data);
}
this->RegisterNode(node_tmp->id(), node_tmp);
}
this->attrs["infer_shape"] = std::make_shared<std::any>(res);
}

/** \brief outputs of the computation graph. */
std::vector<NodeData*> outputs;

/** \brief attributes of a graph */
std::unordered_map<std::string, std::shared_ptr<std::any>> attrs;

void RegisterNode(size_t key, Node* node) {
this->cinn::common::Graph::RegisterNode(key, node->as<cinn::common::GraphNode>());
}
void RegisterNode(size_t key, Node* node) { this->common::Graph::RegisterNode(key, node->as<common::GraphNode>()); }
void RegisterNode(size_t key, NodeData* node) {
this->cinn::common::Graph::RegisterNode(key, node->as<cinn::common::GraphNode>());
this->common::Graph::RegisterNode(key, node->as<common::GraphNode>());
}
void RegisterNode(const std::string& key, Node* node) {
this->common::Graph::RegisterNode(key, node->as<common::GraphNode>());
}
void RegisterNode(const std::string& key, NodeData* node) {
this->common::Graph::RegisterNode(key, node->as<common::GraphNode>());
}

/**
Expand Down
49 changes: 46 additions & 3 deletions cinn/hlir/framework/graph_compiler.cc
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
#include "cinn/hlir/framework/graph_compiler.h"

#include <unordered_map>
#include "cinn/hlir/framework/instruction.h"

namespace cinn {
namespace hlir {
namespace framework {
using StrategyFunction = std::function<std::shared_ptr<OpStrategy>(
const NodeAttr, const std::vector<ir::Tensor>, common::Type, const common::Target)>;

std::unique_ptr<Program> GraphCompiler::Build() {
auto [nodes, edges] = graph_->topological_order();
for (auto n : nodes) {
auto* node = n->safe_as<Node>();
if (node) {
auto lowered_func = GetOpFunc(node->op());
auto lowered_func = GetOpFunc(node);
m_builder_.AddFunction(lowered_func);
}
}
Expand All @@ -32,8 +35,8 @@ std::vector<std::unique_ptr<Instruction>> GraphCompiler::BuildInstructions() {
auto* node = n->safe_as<Node>();
if (node) {
auto instr = std::unique_ptr<Instruction>(
new Instruction(target_, scope_.get(), OpGetInputNames(node->op()), OpGetOutputNames(node->op())));
auto* fn = compiler_->Lookup(GenOpFuncName(node->op()));
new Instruction(target_, scope_.get(), OpGetInputNames(node), OpGetOutputNames(node)));
auto* fn = compiler_->Lookup(GenOpFuncName(node));
CHECK(fn);
instr->SetLoweredFunc(fn);
instructions.push_back(std::move(instr));
Expand All @@ -42,6 +45,46 @@ std::vector<std::unique_ptr<Instruction>> GraphCompiler::BuildInstructions() {
return instructions;
}

ir::LoweredFunc GraphCompiler::GetOpFunc(const Node* node) {
auto strategy = Operator::GetAttr<StrategyFunction>("CINNStrategy");
auto res = graph_->GetAttr<std::unordered_map<std::string, std::vector<int>>>("infer_shape");
std::vector<ir::Tensor> inputs;
std::vector<common::CINNValue> cinn_inputs;
for (auto i : node->inlinks()) {
std::string input_id = i->source()->as<NodeData>()->id();
std::vector<int> in_shape = res[input_id];
lang::Placeholder<float> temp(input_id, in_shape);
inputs.push_back(temp);
cinn_inputs.push_back(common::CINNValue(temp));
}
common::Type type;
auto impl = SelectImpl(strategy[node->op()](node->attrs, inputs, type, target_));

common::CINNValuePackShared C = impl->fcompute(common::CINNValuePack::Make(cinn_inputs));
C = impl->fschedule(C);
for (int i = 0; i < C.get()->size(); i++) {
ir::Expr temp = C[i];
inputs.push_back(temp.as_tensor_ref());
}
auto func = Lower(node->id(), inputs);
}

std::vector<std::string> GraphCompiler::OpGetInputNames(const Node* node) const {
std::vector<std::string> res;
for (auto i : node->inlinks()) {
res.push_back(i->source()->as<NodeData>()->id());
}
return res;
}

std::vector<std::string> GraphCompiler::OpGetOutputNames(const Node* node) const {
std::vector<std::string> res;
for (auto i : node->outlinks()) {
res.push_back(i->sink()->as<NodeData>()->id());
}
return res;
}

} // namespace framework
} // namespace hlir
} // namespace cinn
11 changes: 5 additions & 6 deletions cinn/hlir/framework/graph_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,16 @@ class GraphCompiler final {
std::unique_ptr<Program> Build();

private:
// TODO(haozech) add implementation
ir::LoweredFunc GetOpFunc(const Operator* op) { CINN_NOT_IMPLEMENTED; }
ir::LoweredFunc GetOpFunc(const Node* node);

std::string GenOpFuncName(const Operator* op) const {
return "fn_" + op->name + "_" + std::to_string(op->get_index());
std::string GenOpFuncName(const Node* node) const {
return "fn_" + node->op()->name + "_" + std::to_string(node->op()->get_index());
}

// TODO(haozech) add implementation
std::vector<std::string> OpGetInputNames(const Operator* op) const { CINN_NOT_IMPLEMENTED; }
std::vector<std::string> OpGetInputNames(const Node* node) const;
// TODO(haozech) add implementation
std::vector<std::string> OpGetOutputNames(const Operator* op) const { CINN_NOT_IMPLEMENTED; }
std::vector<std::string> OpGetOutputNames(const Node* node) const;

std::vector<std::unique_ptr<Instruction>> BuildInstructions();

Expand Down
87 changes: 87 additions & 0 deletions cinn/hlir/framework/infershape_pass_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#include <gtest/gtest.h>

#include <any>
#include <string>

#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
#include "cinn/hlir/framework/pass.h"
#include "cinn/ir/packed_func.h"

namespace cinn {
namespace hlir {
namespace framework {

std::vector<std::vector<int>> AddInferShape(std::vector<std::vector<int>> inputs_shape) {
CHECK(inputs_shape.size() && inputs_shape[0].size()) << "The input's shape size is 0! Please check again.";
std::vector<std::vector<int>> res{inputs_shape[0]};
return res;
}

CINN_REGISTER_OP(add)
.describe("test of op Add")
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<std::string>("nick_name", "plus")
.set_attr<std::function<std::vector<std::vector<int>>(std::vector<std::vector<int>>)>>("infer_shape", AddInferShape)
.set_support_level(4);

void InferShapePass(Graph* src) {
auto res = src->GetAttr<std::unordered_map<std::string, std::vector<int>>>("infer_shape");
auto store_node = std::get<0>(src->topological_order());
auto op_infershape =
Operator::GetAttr<std::function<std::vector<std::vector<int>>(std::vector<std::vector<int>>)>>("infer_shape");
for (auto i : store_node) {
if (i->check_type<Node>()) {
std::vector<std::vector<int>> inputs_shape;
for (auto j : i->inlinks()) {
inputs_shape.push_back(res[j->source()->safe_as<NodeData>()->id()]);
}
auto out_shape = op_infershape[i->safe_as<Node>()->op()](inputs_shape);
int counter = 0;
CHECK_EQ(i->outlinks().size(), out_shape.size())
<< "The output number of node " << i->id() << " is " << i->outlinks().size()
<< " , which is different with the output shape size " << out_shape.size() << " . And the op type is "
<< i->safe_as<Node>()->op()->name;
for (auto j : i->outlinks()) {
res[j->sink()->safe_as<NodeData>()->id()] = out_shape[counter++];
}
}
}
src->attrs["infer_shape"] = std::make_shared<std::any>(res);
}

CINN_REGISTER_PASS(InferShape)
.describe("This pass infer the shape of tensor and save to g.attrs[\"infer_shape\"].")
.set_change_structure(false)
.provide_graph_attr("infer_shape")
.set_body(InferShapePass);

TEST(Operator, GetAttr) {
frontend::Program prog;
frontend::Variable a("a");
frontend::Variable b("b");
a->shape = {100, 32};
b->shape = {100, 32};
auto c = prog.add(a, b);
auto d = prog.add(c, b);
auto e = prog.add(c, d);
ASSERT_EQ(prog.size(), 3);
Graph* g = new Graph(prog);
ApplyPass(g, "InferShape");
auto s = g->GetAttr<std::unordered_map<std::string, std::vector<int>>>("infer_shape");
for (auto i : s) {
LOG(INFO) << "Var id is: " << i.first << " and Var shape is: ";
std::vector<int> correct_shape{100, 32};
for (auto j : i.second) {
LOG(INFO) << j << " ";
}
CHECK_EQ(i.second[0], 100) << "The infered shape is wrong.";
CHECK_EQ(i.second[1], 32) << "The infered shape is wrong.";
}
}

} // namespace framework
} // namespace hlir
} // namespace cinn
15 changes: 10 additions & 5 deletions cinn/hlir/framework/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ using NodePtr = std::shared_ptr<Node>;
* and other parameters like axis.
*/
struct NodeAttr {
using attr_t = std::variant<int, float, std::string, std::vector<int>, std::vector<float>, std::vector<std::string>>;

/**
* \brief The operator this node uses.
*/
Expand All @@ -36,7 +38,7 @@ struct NodeAttr {
/**
* \brief The attributes stored as string in dictionary.
*/
std::unordered_map<std::string, std::string> attr_store;
std::unordered_map<std::string, attr_t> attr_store;
};

/**
Expand All @@ -50,7 +52,7 @@ class Node : public common::GraphNode {
this->attrs.node_name = name;
this->id_ = std::move(id);
}

const char *type_info() const override { return __type_info__; }
std::tuple<common::GraphEdge *, common::GraphEdge *> LinkTo(NodeData *other);
/**
* \brief Get the unique id of this NodeData.
Expand Down Expand Up @@ -88,6 +90,8 @@ class Node : public common::GraphNode {
* \brief NodeData represents the output data from an operator.
*/
class NodeData : public common::GraphNode {
using attr_t = std::variant<int, float, std::string, std::vector<int>, std::vector<float>, std::vector<std::string>>;

public:
NodeData(NodePtr node, uint32_t index, uint32_t version, std::string id)
: source_node(std::move(node)), output_index(index), version(version), id_(std::move(id)) {}
Expand All @@ -99,8 +103,8 @@ class NodeData : public common::GraphNode {
const char *op_name,
std::string node_name,
std::vector<NodeData> inputs,
std::string id = nullptr,
std::unordered_map<std::string, std::string> attrs = std::unordered_map<std::string, std::string>()) {
std::string id = nullptr,
std::unordered_map<std::string, attr_t> attrs = std::unordered_map<std::string, attr_t>()) {
auto res = std::make_shared<NodeData>();
res->id_ = std::move(id);
res->source_node = Node::Create();
Expand All @@ -110,6 +114,7 @@ class NodeData : public common::GraphNode {
return res;
}

const char *type_info() const override { return __type_info__; }
/**
* \brief Get the unique id of this NodeData.
*/
Expand All @@ -135,7 +140,7 @@ class NodeData : public common::GraphNode {
*/
uint32_t version;

static constexpr char *__type_info__ = "hlir_framework_node";
static constexpr char *__type_info__ = "hlir_framework_nodedata";

private:
/**
Expand Down
Loading

0 comments on commit e57aad1

Please sign in to comment.