Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#20 from Superjomn/fea/add-ast-gen
Browse files Browse the repository at this point in the history
add ast build
  • Loading branch information
Superjomn committed Feb 6, 2020
2 parents ea6cdb5 + bebbacb commit b20f50f
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 0 deletions.
1 change: 1 addition & 0 deletions cinn/poly/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ cc_library(poly SRCS

cc_test(test_poly_element SRCS element_test.cc DEPS poly)
cc_test(test_schedule SRCS schedule_test.cc DEPS poly)
cc_test(test_ast_gen SRCS ast_gen_test.cc DEPS poly)
45 changes: 45 additions & 0 deletions cinn/poly/ast_gen.cc
Original file line number Diff line number Diff line change
@@ -1 +1,46 @@
#include "cinn/poly/ast_gen.h"

namespace cinn {
namespace poly {

isl::ast_node AstGen::operator()(const std::vector<Element> &elements, const Scheduler &scheduler) {
// Collect domains.
auto sets = utils::Map<std::vector<Element>, isl::set>(elements, [](const Element &e) { return e.domain(); });
isl::union_set domain = SetsToUnionSet(sets);

isl::ctx ctx = elements.front().domain().ctx();

// Collect schedule from scheduler.
auto schedules = scheduler.BuildSchedule();
std::vector<isl::map> maps;
for (auto &ele : elements) {
auto it = schedules.find(ele.id());
CHECK(it != std::end(schedules));
maps.push_back(it->second);
}
auto schedule = MapsToUnionMap(maps);

// Build it.
auto build = isl::ast_build::from_context(context_);
// Set iterators.
if (!iterator_names_.empty()) {
auto iterator_names = scheduler.WrapIteratorNames(iterator_names_);
isl::id_list ids = isl::manage(isl_id_list_alloc(ctx.get(), iterator_names.size()));
for (int i = 0; i < iterator_names.size(); i++) {
ids = isl::manage(isl_id_list_add(ids.release(), isl_id_alloc(ctx.get(), iterator_names[i].c_str(), nullptr)));
}
build = isl::manage(isl_ast_build_set_iterators(build.release(), ids.release()));
}

auto ast = build.node_from_schedule_map(schedule.intersect_domain(domain));
VLOG(2) << "\n" << isl_ast_node_to_C_str(ast.get());
return ast;
}

AstGen &AstGen::SetIteratorNames(const std::vector<std::string> &names) {
iterator_names_ = names;
return *this;
}

} // namespace poly
} // namespace cinn
29 changes: 29 additions & 0 deletions cinn/poly/ast_gen.h
Original file line number Diff line number Diff line change
@@ -1 +1,30 @@
#pragma once
#include <isl/cpp.h>
#include "cinn/poly/element.h"
#include "cinn/poly/isl_utils.h"
#include "cinn/poly/schedule.h"
#include "cinn/utils/functional.h"

namespace cinn {
namespace poly {

class AstGen {
public:
AstGen(const isl::set& context) : context_(context) {}

/**
* Set forloop iterator names.
* @param names
* @return AstGen itself.
*/
AstGen& SetIteratorNames(const std::vector<std::string>& names);

isl::ast_node operator()(const std::vector<Element>& elements, const Scheduler& scheduler);

private:
isl::set context_;
std::vector<std::string> iterator_names_;
};

} // namespace poly
} // namespace cinn
23 changes: 23 additions & 0 deletions cinn/poly/ast_gen_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "cinn/poly/ast_gen.h"
#include <gtest/gtest.h>

namespace cinn {
namespace poly {

TEST(ast_gen, basic) {
isl::ctx ctx(isl_ctx_alloc());
Element A(isl::set(ctx, "{ A[i,j,k]: 0<i,j,k<100 }"));
Element B(isl::set(ctx, "{ B[i,j,k]: 0<i,j,k<100 }"));

Scheduler scheduler;
scheduler.RegisterElement(A);
scheduler.RegisterElement(B);
scheduler.After(A, B, 2);

AstGen gen(isl::set(ctx, "{:}"));
gen.SetIteratorNames({"i", "j", "k"});
gen({A, B}, scheduler);
}

} // namespace poly
} // namespace cinn
18 changes: 18 additions & 0 deletions cinn/poly/isl_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,23 @@ void SetDimNames(isl::set *set, const std::vector<std::string> &names) {
}
}

isl::union_map MapsToUnionMap(const std::vector<isl::map> &maps) {
CHECK(!maps.empty());
isl::union_map umap = isl::manage(isl_union_map_from_map(maps.front().copy()));
for (int i = 1; i < maps.size(); i++) {
umap = isl::manage(isl_union_map_add_map(umap.release(), maps[i].copy()));
}
return umap;
}

isl::union_set SetsToUnionSet(const std::vector<isl::set> &sets) {
CHECK(!sets.empty());
isl::union_set uset = isl::manage(isl_union_set_from_set(sets.front().copy()));
for (int i = 1; i < sets.size(); i++) {
uset = isl::manage(isl_union_set_add_set(uset.release(), sets[i].copy()));
}
return uset;
}

} // namespace poly
} // namespace cinn
4 changes: 4 additions & 0 deletions cinn/poly/isl_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,9 @@ std::vector<std::string> GetDimNames(const isl::map& x, isl_dim_type dim_type);
void SetDimNames(isl::set* set, const std::vector<std::string>& names);
void SetDimNames(isl::map* map, isl_dim_type dim_type, const std::vector<std::string>& names);

//! Convert a list of isl::map to isl::union_map
isl::union_map MapsToUnionMap(const std::vector<isl::map>& maps);
isl::union_set SetsToUnionSet(const std::vector<isl::set>& sets);

} // namespace poly
} // namespace cinn
10 changes: 10 additions & 0 deletions cinn/poly/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,15 @@ std::map<std::string, isl::map> Scheduler::BuildSchedule() const {
return res;
}

std::vector<std::string> Scheduler::WrapIteratorNames(const std::vector<std::string> &names) const {
CHECK_EQ(names.size(), space_size());
std::vector<std::string> res;
for (int i = 0; i < space_size(); i++) {
res.push_back(""); // fake name for time space.
res.push_back(names[i]); // name for the corresponding iterator.
}
return res;
}

} // namespace poly
} // namespace cinn
14 changes: 14 additions & 0 deletions cinn/poly/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ class Scheduler {
*/
void FinalizeRegistration();

/**
* Tell whether the registration is finalized.
*/
bool finalized() const { return registration_finalized_; }

/**
* Mark this should schedule after another.
*
Expand All @@ -108,6 +113,15 @@ class Scheduler {
*/
std::map<std::string, isl::map> BuildSchedule() const;

/**
* Wrap the iterator names with time space.
* @param names the original iterator names.
* @return the iterator names with time space included.
*/
std::vector<std::string> WrapIteratorNames(const std::vector<std::string> &names) const;

int space_size() const { return space_size_; }

private:
/**
* The polyhedral schedule, any schedule is performed on it.
Expand Down

0 comments on commit b20f50f

Please sign in to comment.