Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#17 from Superjomn/fea/add-scheduler
Browse files Browse the repository at this point in the history
make scheduler works
  • Loading branch information
Superjomn committed Feb 6, 2020
2 parents fc24b37 + a0a1112 commit 7f64f9c
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 7 deletions.
15 changes: 12 additions & 3 deletions cinn/poly/element.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "cinn/poly/element.h"
#include "cinn/poly/isl_utils.h"
#include "cinn/utils/functional.h"

namespace cinn {
namespace poly {
Expand Down Expand Up @@ -38,7 +39,7 @@ Element::Element(isl::set domain) : domain_(domain) {
std::tuple<Iterator, Iterator> Element::Split(const Iterator &level, int factor) {
int offset = isl_set_find_dim_by_name(domain_.get(), isl_dim_set, level.id.c_str());
CHECK_GE(offset, 0) << "iterator " << level << " not in " << domain_;
auto dim_names = GetDimNames(domain_);
auto dim_names = GetDimNames(schedule_, isl_dim_out);

VLOG(2) << "domain: " << domain_;
VLOG(2) << "schedule: " << schedule_;
Expand All @@ -65,14 +66,18 @@ std::tuple<Iterator, Iterator> Element::Split(const Iterator &level, int factor)
}
}

Map transform(domain_.ctx(), "", from_iters, to_iters, conds, "");
Map transform(domain_.ctx(), id(), from_iters, to_iters, conds, id());
VLOG(3) << "transform: " << transform.__str__();
schedule_ = schedule_.apply_range(transform.to_isl());
auto range_dims =
utils::Map<std::vector<Iterator>, std::vector<std::string>>(to_iters, [](const Iterator &x) { return x.id; });
SetDimNames(&schedule_, isl_dim_out, range_dims);

VLOG(3) << "transform " << transform.to_isl();
VLOG(3) << "schedule after transform: " << schedule_;

std::make_tuple(outer_iter, inner_iter);
VLOG(3) << "iterators: " << outer_iter << " " << inner_iter;
return std::make_tuple(outer_iter, inner_iter);
}

void Element::Reorder(const std::vector<Iterator> &order) {}
Expand Down Expand Up @@ -107,5 +112,9 @@ std::string OuterName(const Iterator &iterator) { return OuterName(iterator.id);

const char *Element::id() const { return isl_set_get_tuple_name(domain_.get()); }

std::tuple<Iterator, Iterator> Element::Split(const std::string &level, int factor) {
return std::move(Split(Iterator(level), factor));
}

} // namespace poly
} // namespace cinn
2 changes: 2 additions & 0 deletions cinn/poly/element.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class Element {
*/
std::tuple<Iterator, Iterator> //
Split(const Iterator& level, int factor);
std::tuple<Iterator, Iterator> //
Split(const std::string& level, int factor);

/**
* Reorder the iterators.
Expand Down
19 changes: 19 additions & 0 deletions cinn/poly/isl_utils.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "cinn/poly/isl_utils.h"
#include <glog/logging.h>
#include <isl/cpp.h>

namespace cinn {
Expand All @@ -20,5 +21,23 @@ std::vector<std::string> GetDimNames(const isl::map &x, isl_dim_type dim_type) {
return res;
}

void SetDimNames(isl::map *map, isl_dim_type dim_type, const std::vector<std::string> &names) {
const int dim = isl_map_dim(map->get(), dim_type);
CHECK_EQ(dim, names.size());

for (int i = 0; i < dim; i++) {
*map = isl::manage(isl_map_set_dim_name(map->release(), dim_type, i, names[i].c_str()));
}
}

void SetDimNames(isl::set *set, const std::vector<std::string> &names) {
int dim = isl_set_dim(set->get(), isl_dim_set);
CHECK_EQ(dim, names.size());

for (int i = 0; i < dim; i++) {
*set = isl::manage(isl_set_set_dim_name(set->release(), isl_dim_set, i, names[i].c_str()));
}
}

} // namespace poly
} // namespace cinn
7 changes: 5 additions & 2 deletions cinn/poly/isl_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ namespace poly {

//! Get dimension names from isl containers.
// @{
std::vector<std::string> GetDimNames(const isl::set &x);
std::vector<std::string> GetDimNames(const isl::map &x, isl_dim_type dim_type);
std::vector<std::string> GetDimNames(const isl::set& x);
std::vector<std::string> GetDimNames(const isl::map& x, isl_dim_type dim_type);
// @}

void SetDimNames(isl::set* set, const std::vector<std::string>& names);
void SetDimNames(isl::map* map, isl_dim_type dim_type, const std::vector<std::string>& names);

} // namespace poly
} // namespace cinn
6 changes: 4 additions & 2 deletions cinn/poly/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ namespace poly {
struct Iterator {
std::string id;

explicit Iterator(std::string id) : id(std::move(id)) {}
explicit Iterator(const std::string& id) : id(id) {}
explicit Iterator(const Iterator& x) : id(x.id) {}
explicit Iterator(Iterator&& x) : id(std::move(x.id)) {}

friend std::ostream& operator<<(std::ostream& os, const Iterator& x);
};
Expand All @@ -25,7 +27,7 @@ struct Condition {
Iterator iterator;
std::string cond;

Condition(Iterator iterator, std::string cond) : iterator(std::move(iterator)), cond(std::move(cond)) {}
Condition(const Iterator& iterator, std::string cond) : iterator(iterator), cond(std::move(cond)) {}

friend std::ostream& operator<<(std::ostream& os, const Condition& x) {
os << x.__str__();
Expand Down
1 change: 1 addition & 0 deletions cinn/poly/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ void Scheduler::RegisterElement(const Element &x) {
CHECK(!registration_finalized_) << "element registration has been finalized.";
space_size_ = std::max(space_size_, isl_map_dim(x.schedule().get(), isl_dim_out));
VLOG(3) << "space_size: " << space_size_;
VLOG(3) << "schedule: " << x.schedule();

// Use the dimensions from element's schedule's range as the new domain dimensions because in Element, the schedule is
// like '{ S0[i,j] -> S0[i_outer, i_inner, j] }', the scheduler should schedule base on the range.
Expand Down
27 changes: 27 additions & 0 deletions cinn/poly/schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,37 @@ TEST(Schedule, basic) {

auto schedule = scheduler.BuildSchedule();

EXPECT_EQ(utils::GetStreamCnt(schedule["A"]), "{ A[i, j] -> [t0 = 0, d0 = i, t1 = 0, d1 = j] }");
EXPECT_EQ(utils::GetStreamCnt(schedule["B"]), "{ B[i, j] -> [t0 = 0, d0 = i, t1 = 1, d1 = j] }");

for (auto item : schedule) {
LOG(INFO) << item.first << " " << item.second;
}
}

TEST(Schedule, basic_with_transform) {
isl::ctx ctx(isl_ctx_alloc());
Element A(isl::set(ctx, "[]->{ A[i,j]: 0<i,j<100 }"));
Element B(isl::set(ctx, "[]->{ B[i,j]: 0<i,j<100 }"));
auto x = A.Split("i", 4);
LOG(INFO) << A.schedule();
B.Split(Iterator("j"), 6);
LOG(INFO) << B.schedule();

Scheduler scheduler;
scheduler.RegisterElement(A);
scheduler.RegisterElement(B);
scheduler.After(A, B, 1);
auto schedule = scheduler.BuildSchedule();
for (auto item : schedule) {
LOG(INFO) << item.first << " " << item.second;
}

EXPECT_EQ(utils::GetStreamCnt(schedule["A"]),
"{ A[i_outer, i_inner, j] -> [t0 = 0, d0 = i_outer, t1 = 0, d1 = i_inner, t2 = 0, d2 = j] }");
EXPECT_EQ(utils::GetStreamCnt(schedule["B"]),
"{ B[i, j_outer, j_inner] -> [t0 = 0, d0 = i, t1 = 1, d1 = j_outer, t2 = 0, d2 = j_inner] }");
}

} // namespace poly
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
cc_library(utils SRCS string.cc
target.cc
functional.cc
)
1 change: 1 addition & 0 deletions cinn/utils/functional.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

17 changes: 17 additions & 0 deletions cinn/utils/functional.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include <functional>

namespace cinn {
namespace utils {

template <typename InT, typename OutT>
OutT Map(const InT& in, std::function<typename OutT::value_type(const typename InT::value_type&)> fn) {
OutT res;
std::transform(
in.begin(), in.end(), std::back_inserter(res), [&](const typename InT::value_type& x) { return fn(x); });
return res;
}

} // namespace utils
} // namespace cinn

0 comments on commit 7f64f9c

Please sign in to comment.