Skip to content

Commit

Permalink
Improve test of strategy in op_test (PaddlePaddle#172)
Browse files Browse the repository at this point in the history
* add primitive layer. add function primitive::add

* optimize the primitive::add function

* change the name and struct of add function

* change opfunction to pe.

* fix codestyle and debug

* fix codestyle

* add op module. test ci

* fix bug

* fix bug

* fix CMakelist

* add class graph and class node. early version

* fix bug and improve codestyle

* fix codestyle

* fix codestyle

* fix bug

* adjust codestyle for comment

* fix override function

* add pass module

* add example pass and unittest

* fix codestyle

* fix namespace

* add strategy module

* fix function name problem

* fix codestyle

* fix bug

* add strategy test

* using vector for input and output in compute and schedule

* fix

* fix type_info() function
  • Loading branch information
haozech committed Aug 18, 2020
1 parent 4ace494 commit 63ebbb8
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 39 deletions.
2 changes: 1 addition & 1 deletion cinn/common/cinn_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ void CINNValuePack::AddValue(const CINNValue &value) {
values_.push_back(value);
}
void CINNValuePack::Clear() { values_.clear(); }
const char *CINNValuePack::type_info() const { return "CINNValuePack"; }
const char *CINNValuePack::type_info() const { return __type_info__; }

CINNValue &CINNValue::operator=(int32_t value) {
*this = CINNValue(value);
Expand Down
3 changes: 2 additions & 1 deletion cinn/common/cinn_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ struct CINNValuePack : public common::Object {
private:
CINNValuePack() = default;
std::vector<CINNValue> values_;
static constexpr char* __type_info__ = "CINNValuePack";
};

struct CINNValuePackShared : public Shared<CINNValuePack> {
CINNValuePackShared(CINNValuePack* ptr) : Shared<CINNValuePack>(ptr) {}
explicit CINNValuePackShared(CINNValuePack* ptr) : Shared<CINNValuePack>(ptr) {}

CINNValue& operator[](int offset) { return (*operator->())[offset]; }
const CINNValue& operator[](int offset) const { return (*operator->())[offset]; }
Expand Down
3 changes: 2 additions & 1 deletion cinn/common/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ class GraphEdge : public Object {

GraphNode* source() const { return source_; }
GraphNode* sink() const { return sink_; }
const char* type_info() const override { return "graph_edge"; }
const char* type_info() const override { return __type_info__; }

private:
//! Source of this edge.
GraphNode* source_{};
//! End of this edge.
GraphNode* sink_{};
static constexpr char* __type_info__ = "graph_edge";
};

struct GraphEdgeCompare {
Expand Down
2 changes: 1 addition & 1 deletion cinn/frontend/syntax.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ struct _Instruction_ : public common::Object {

const char* type_info() const override { return __type_info__; }

static constexpr const char* __type_info__ = "cinn_frontend_instruction";
static constexpr char* __type_info__ = "cinn_frontend_instruction";
};

struct Instruction : public common::Shared<_Instruction_> {
Expand Down
47 changes: 33 additions & 14 deletions cinn/hlir/framework/op_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,17 @@ class OpImpl : public common::Object {
return nullptr;
}

const char* type_info() const override { return _type_key; }
const char* type_info() const override { return __type_info__; }

private:
static constexpr char* _type_key = "OpImplementation";
static constexpr char* __type_info__ = "OpImplementation";
};

//! Specialized implementations for operators under certain conditions.
class OpSpec : public common::Object {
public:
//! List of implementations.
std::vector<OpImpl*> implementations;
std::vector<std::shared_ptr<OpImpl>> implementations;

/** \brief Condition to enable the specialization.
* Could be undefined to represent generic case.
Expand All @@ -71,10 +71,10 @@ class OpSpec : public common::Object {
*/
std::string condition;

const char* type_info() const override { return _type_key; }
const char* type_info() const override { return __type_info__; }

void AddImpl(CINNCompute fcompute, CINNSchedule fschedule, std::string name, int plevel) {
auto n = make_shared<OpImpl>();
auto n = std::make_shared<OpImpl>();
n->fcompute = fcompute;
n->fschedule = fschedule;
n->name = std::move(name);
Expand All @@ -83,15 +83,15 @@ class OpSpec : public common::Object {
}

private:
static constexpr char* _type_key = "OpSpecialization";
static constexpr char* __type_info__ = "OpSpecialization";
};

//! Operator strategy class.
class OpStrategy : public common::Object {
public:
const char* type_info() const override { return "CINNOpStrategy"; }
const char* type_info() const override { return __type_info__; }
//! List of operator specializations.
std::vector<OpSpec*> specializations;
std::vector<std::shared_ptr<OpSpec>> specializations;

/**
* \brief Add an implementation.
Expand All @@ -103,21 +103,40 @@ class OpStrategy : public common::Object {
void AddImpl(CINNCompute fcompute, CINNSchedule fschedule, std::string name, int plevel) {
//! TODO(haozech) : here curr_cond should get the condition from outside.
//! Expected : auto curr_cond = SpecializedCondition::Current();
std::string curr_cond = "current_condition";
OpSpec* op_spec;
for (OpSpec* op_spec : specializations) {
if (op_spec->condition == curr_cond) {
std::string curr_condition = "default";
for (auto op_spec : specializations) {
if (op_spec->condition == curr_condition) {
op_spec->AddImpl(fcompute, fschedule, std::move(name), plevel);
return;
}
}
OpSpec* n = make_shared<OpSpec>();
n->condition = curr_cond;
std::shared_ptr<OpSpec> n = std::make_shared<OpSpec>();
n->condition = curr_condition;
n->AddImpl(fcompute, fschedule, std::move(name), plevel);
this->specializations.push_back(n);
}

private:
static constexpr char* __type_info__ = "OpStrategy";
};

std::shared_ptr<OpImpl> SelectImpl(std::shared_ptr<OpStrategy> strategy) {
//! should get the host info from global environment.
std::string curr_condition = "default";
std::shared_ptr<OpImpl> res = nullptr;
for (auto spec : strategy->specializations) {
if (spec->condition == "default") {
for (auto i : spec->implementations) {
if (!res || res->plevel < i->plevel) {
res = i;
}
}
}
}
CHECK(res) << "There is no available strategy implementation! SelectImpl failed!";
return res;
}

} // namespace framework
} // namespace hlir
} // namespace cinn
82 changes: 64 additions & 18 deletions cinn/hlir/framework/op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@

#include <gtest/gtest.h>

#include <functional>

#include <string>

#include "cinn/cinn.h"

#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op_strategy.h"
#include "cinn/hlir/pe/broadcast.h"

namespace cinn {
namespace hlir {
namespace framework {
using StrategyFunction = std::function<std::shared_ptr<OpStrategy>(
const NodeAttr, const std::vector<ir::Tensor>, common::Type, const common::Target)>;

using CCompute = std::function<std::shared_ptr<ir::Tensor>(const std::vector<ir::Tensor>)>;

CINN_REGISTER_OP(add)
.describe("test of op Add")
Expand All @@ -19,30 +27,68 @@ CINN_REGISTER_OP(add)
.set_attr<std::string>("nick_name", "plus")
.set_support_level(4);

common::Shared<OpStrategy> GetStrategyTest() {
ir::PackedFunc::body_t body = [](ir::Args args, ir::RetValue* ret) {
Expr a = args[0];
Expr b = args[1];
(*ret) = Expr(pe::Add(a.as_tensor_ref(), b.as_tensor_ref(), "C").get());
std::shared_ptr<OpStrategy> StrategyTest(const NodeAttr &attr,
const std::vector<ir::Tensor> &inputs,
common::Type out_type,
const common::Target &target) {
ir::PackedFunc::body_t compute_body = [](ir::Args args, ir::RetValue *ret) {
common::CINNValuePackShared a = args[0];
ir::Expr A = a[0];
ir::Expr B = a[1];
*ret = common::CINNValuePack::Make(
{common::CINNValue(ir::Expr(pe::Add(A.as_tensor_ref(), B.as_tensor_ref(), "C").get()))});
};
ir::PackedFunc fcompute(body);
// TODO(haozech): fschedule should be an instance of pe::schedule...
ir::PackedFunc fschedule;
common::Shared<OpStrategy> strategy(make_shared<OpStrategy>());
//! To build more complex strategy, we can add more than 1
//! implementations to one Opstrategy, with different plevel.
strategy->AddImpl(fcompute, fschedule, "test.strategy", 10);
ir::PackedFunc fcompute(compute_body);

ir::PackedFunc::body_t schedule_body = [](ir::Args args, ir::RetValue *ret) {
common::CINNValuePackShared a = args[0];
ir::Expr A = a[0];
A.as_tensor_ref()->stage()->Vectorize(1, 16);
A.as_tensor_ref()->stage()->Unroll(1);
*ret = common::CINNValuePack::Make({common::CINNValue(A)});
};
ir::PackedFunc fschedule(schedule_body);

std::shared_ptr<OpStrategy> strategy = std::make_shared<OpStrategy>();

if (target.arch == common::Target::Arch ::X86) {
strategy->AddImpl(fcompute, fschedule, "test.strategy.x86", 10);
} else {
strategy->AddImpl(fcompute, fschedule, "test.strategy.else", 10);
}
strategy->AddImpl(fcompute, fschedule, "test.strategy.lowPlevel.x86", 5);
return strategy;
}

TEST(Operator, GetAttr) {
auto add = Operator::Get("add");
auto test_strategy = GetStrategyTest();
Operator temp = *add;
temp.set_attr<common::Shared<OpStrategy>>("CINNStrategy", test_strategy);
auto add = Operator::Get("add");
Operator temp = *add;
temp.set_attr<StrategyFunction>("CINNStrategy", StrategyTest);
auto nick = Operator::GetAttr<std::string>("nick_name");
auto strategy = Operator::GetAttr<common::Shared<OpStrategy>>("CINNStrategy");
ASSERT_EQ(strategy[add]->specializations[0]->implementations[0]->name, "test.strategy");
auto strategy = Operator::GetAttr<StrategyFunction>("CINNStrategy");

Expr M(100), N(32);
Placeholder<float> A("A", {M, N});
Placeholder<float> B("B", {M, N});

NodeAttr attr;
std::vector<ir::Tensor> inputs{A, B};
common::Type type;
common::Target target;
target.arch = common::Target::Arch::X86;
auto impl = SelectImpl(strategy[add](attr, inputs, type, target));

common::CINNValuePackShared cinn_input = common::CINNValuePack::Make({common::CINNValue(A), common::CINNValue(B)});
common::CINNValuePackShared C = impl->fcompute(cinn_input);
C = impl->fschedule(C);
for (int i = 0; i < C.get()->size(); i++) {
ir::Expr temp = C[i];
inputs.push_back(temp.as_tensor_ref());
}
auto func = Lower("add1", inputs);
LOG(INFO) << "Test Strategy Codegen:\n" << func;

ASSERT_EQ(impl->name, "test.strategy.x86");
ASSERT_EQ(add->description, "test of op Add");
ASSERT_EQ(nick[add], "plus");
}
Expand Down
5 changes: 4 additions & 1 deletion cinn/hlir/framework/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace framework {
*/
class Schedule : public common::Object {
public:
const char* type_info() const override { return "CINNSchedule"; }
const char* type_info() const override { return __type_info__; }

/**
* \brief Get the stage corresponds to the op
Expand All @@ -36,6 +36,9 @@ class Schedule : public common::Object {

//! map of original operation to the stages
std::unordered_map<std::string, ir::Tensor> stage_map;

private:
static constexpr char* __type_info__ = "CINNSchedule";
};
} // namespace framework
} // namespace hlir
Expand Down
2 changes: 1 addition & 1 deletion cinn/ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class IrNode : public common::Object {
const char* type_info() const override { return __type_info__; }

protected:
constexpr static const char* __type_info__ = "IRNode";
static constexpr char* __type_info__ = "IRNode";
Type type_;
};

Expand Down
4 changes: 3 additions & 1 deletion cinn/poly/stage.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class Stage : public Object {
//! Get the statements.
std::vector<std::string> input_statements() const;

virtual const char* type_info() const { return "Status"; }
virtual const char* type_info() const { return __type_info__; }

inline const ir::VectorizeInfo& vectorize_info() const { return vectorize_info_; }
inline const std::set<int>& unroll_info() const { return unroll_info_; }
Expand Down Expand Up @@ -280,6 +280,8 @@ class Stage : public Object {

std::set<int> locked_axis_;

static constexpr char* __type_info__ = "Status";

friend isl_map* __isl_give GatherAccesses(Stage* stage, const std::string& tensor_name);
friend class PolyGroupScheduler;
};
Expand Down

0 comments on commit 63ebbb8

Please sign in to comment.