From 3ec8a183fb8e4faef42b1b81c15ce3b96a3abb00 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Thu, 6 Feb 2020 20:41:52 +0800 Subject: [PATCH] add tile transform --- cinn/common/CMakeLists.txt | 2 +- cinn/common/object.h | 1 + cinn/common/pod_value.cc | 31 ++++++++----------------------- cinn/ir/ir.cc | 28 ++++++++++++++++++++++++++++ cinn/ir/operation.cc | 24 +----------------------- cinn/ir/operation.h | 17 ++--------------- cinn/ir/tensor.cc | 4 ++-- cinn/poly/element.cc | 8 ++++---- cinn/poly/element_test.cc | 30 ++++++++++++++++++++++++++++-- cinn/poly/map.h | 3 +++ cinn/utils/functional.cc | 4 ++++ 11 files changed, 82 insertions(+), 70 deletions(-) diff --git a/cinn/common/CMakeLists.txt b/cinn/common/CMakeLists.txt index c9648e974f1db..a087e0211e209 100644 --- a/cinn/common/CMakeLists.txt +++ b/cinn/common/CMakeLists.txt @@ -4,6 +4,6 @@ cc_library(common graph_utils.cc DEPS boost utils) -cc_test(test_pod_value SRCS pod_value_test.cc DEPS common) +cc_test(test_pod_value SRCS pod_value_test.cc DEPS common ir) cc_test(test_shared SRCS shared_test.cc DEPS common) cc_test(test_graph_utils SRCS graph_utils_test.cc DEPS common) diff --git a/cinn/common/object.h b/cinn/common/object.h index 387d99c58a4c9..90cf38b9df430 100644 --- a/cinn/common/object.h +++ b/cinn/common/object.h @@ -1,3 +1,4 @@ +#pragma once #include "cinn/common/shared.h" namespace cinn { diff --git a/cinn/common/pod_value.cc b/cinn/common/pod_value.cc index 49d66eab5f1d9..cf99fcdffcfbe 100644 --- a/cinn/common/pod_value.cc +++ b/cinn/common/pod_value.cc @@ -1,8 +1,15 @@ #include "cinn/common/pod_value.h" -#include "cinn/ir/ir.h" #include "cinn/ir/node.h" namespace cinn { + +namespace ir { + +class Expr; +class Var; + +} // namespace ir + namespace common { //! Implement the type_code for all the supported types. @@ -91,16 +98,6 @@ void PODValue::Set(char const *v) { type_code_ = TypeCode(); value_.v_str = const_cast(v); } -template <> -void PODValue::Set(ir::Var v) { - type_code_ = TypeCode(); - value_.v_handle = v.ptr(); -} -template <> -void PODValue::Set(ir::Expr v) { - type_code_ = TypeCode(); - value_.v_handle = v.ptr(); -} // @} //! Implement ToValue. @@ -141,18 +138,6 @@ Value ToValue(char const *v) { val.v_str = const_cast(v); return val; } -template <> -Value ToValue(ir::Expr v) { - Value val; - val.v_handle = v.ptr(); - return val; -} -template <> -Value ToValue(ir::Var v) { - Value val; - val.v_handle = v.ptr(); - return val; -} // @} } // namespace common diff --git a/cinn/ir/ir.cc b/cinn/ir/ir.cc index 39c7c065b7cb5..9be7df8153014 100644 --- a/cinn/ir/ir.cc +++ b/cinn/ir/ir.cc @@ -220,4 +220,32 @@ Expr Call::Make(Type type, return Expr(node); } } // namespace ir + +namespace common { + +template <> +void PODValue::Set(ir::Var v) { + type_code_ = TypeCode(); + value_.v_handle = v.ptr(); +} +template <> +void PODValue::Set(ir::Expr v) { + type_code_ = TypeCode(); + value_.v_handle = v.ptr(); +} +template <> +Value ToValue(ir::Expr v) { + Value val; + val.v_handle = v.ptr(); + return val; +} +template <> +Value ToValue(ir::Var v) { + Value val; + val.v_handle = v.ptr(); + return val; +} + +} // namespace common + } // namespace cinn diff --git a/cinn/ir/operation.cc b/cinn/ir/operation.cc index 1381fc77b5385..cfafd9ad17ae0 100644 --- a/cinn/ir/operation.cc +++ b/cinn/ir/operation.cc @@ -1,27 +1,5 @@ #include "cinn/ir/operation.h" namespace cinn { -namespace ir { - -Operation ExternOp::Make(std::string name, - std::string tag, - std::map attrs, - std::vector inputs, - std::vector input_placeholders, - std::vector output_placeholders, - Stmt body) { - auto n = common::make_shared(); - n->name = std::move(name); - n->tag = std::move(tag); - n->attrs = std::move(attrs); - CHECK_EQ(inputs.size(), input_placeholders.size()); - - n->inputs = std::move(inputs); - n->input_placeholders = std::move(input_placeholders); - n->output_placeholders = std::move(output_placeholders); - n->body = std::move(body); - return Operation(n); -} - -} // namespace ir +namespace ir {} // namespace ir } // namespace cinn \ No newline at end of file diff --git a/cinn/ir/operation.h b/cinn/ir/operation.h index 7da8e6f3848a3..8b384a9c82ee5 100644 --- a/cinn/ir/operation.h +++ b/cinn/ir/operation.h @@ -42,13 +42,7 @@ struct PlaceholderOp : public _Operation_ { //! The data type of the input. Type dtype; - static Operation Make(std::string name, std::vector shape, Type dtype) { - auto n = common::make_shared(); - n->name = name; - n->shape = shape; - n->dtype = dtype; - return Operation(n); - } + static Operation Make(std::string name, std::vector shape, Type dtype); }; /** @@ -68,14 +62,7 @@ struct ComputeOp : public _Operation_ { std::string tag, std::map attrs, std::vector axis, - std::vector body) { - auto n = common::make_shared(); - n->name = std::move(name); - n->tag = std::move(tag); - n->attrs = std::move(attrs); - n->body = std::move(body); - return Operation(n); - } + std::vector body); }; } // namespace ir diff --git a/cinn/ir/tensor.cc b/cinn/ir/tensor.cc index e5a8939ffcab1..b8e1cc59b840c 100644 --- a/cinn/ir/tensor.cc +++ b/cinn/ir/tensor.cc @@ -18,12 +18,12 @@ Tensor::Tensor(const std::vector &shape, Type type) : IrNodeRef(common::ma } const _Tensor_ *Tensor::operator->() const { - auto *p = As<_Tensor_>(); + auto *p = Object::As<_Tensor_>(); CHECK(p) << "type not match"; return p; } _Tensor_ *Tensor::operator->() { - auto *p = As<_Tensor_>(); + auto *p = Object::As<_Tensor_>(); CHECK(p) << "type not match"; return p; } diff --git a/cinn/poly/element.cc b/cinn/poly/element.cc index 0820bc76017cb..c8a3663a5a63d 100644 --- a/cinn/poly/element.cc +++ b/cinn/poly/element.cc @@ -86,11 +86,11 @@ std::tuple Element::Tile(const Iterator const Iterator &level1, int factor0, int factor1) { - Iterator level0_inner(InnerName(level0)); - Iterator level0_outer(OuterName(level0)); - Iterator level1_inner(InnerName(level1)); - Iterator level1_outer(OuterName(level1)); + Iterator level0_inner, level0_outer; + Iterator level1_inner, level1_outer; + std::tie(level0_outer, level0_inner) = Split(level0, factor0); + std::tie(level1_outer, level1_inner) = Split(level1, factor1); return std::make_tuple(level0_outer, level0_inner, level1_outer, level1_inner); } diff --git a/cinn/poly/element_test.cc b/cinn/poly/element_test.cc index 82f8e54ce30ca..6d6a65fa05d17 100644 --- a/cinn/poly/element_test.cc +++ b/cinn/poly/element_test.cc @@ -5,12 +5,38 @@ namespace cinn { namespace poly { -TEST(Element, basic) { +TEST(Element, split) { isl::ctx ctx(isl_ctx_alloc()); isl::set domain(ctx, "{ S[i,j]: 0<=i,j<=100 }"); Element ele(domain); - ele.Split(Iterator("i"), 4); + Iterator outer, inner; + std::tie(outer, inner) = ele.Split(Iterator("i"), 4); + LOG(INFO) << ele.schedule(); + EXPECT_EQ(utils::GetStreamCnt(ele.schedule()), + "{ S[i, j] -> S[i_outer, i_inner, j' = j] : (-i + i_inner) mod 4 = 0 and -3 + i <= 4i_outer <= i and 0 <= " + "i_inner <= 3 }"); + + EXPECT_EQ(outer.id, "i_outer"); + EXPECT_EQ(inner.id, "i_inner"); +} + +TEST(Element, tile) { + isl::ctx ctx(isl_ctx_alloc()); + isl::set domain(ctx, "{ S[i,j,k]: 0<=i,j,k<=100 }"); + Element ele(domain); + + Iterator outer0, inner0, outer1, inner1; + std::tie(outer0, inner0, outer1, inner1) = ele.Tile(Iterator("i"), Iterator("j"), 4, 6); + LOG(INFO) << ele.schedule(); + EXPECT_EQ(outer0.id, "i_outer"); + EXPECT_EQ(outer1.id, "j_outer"); + EXPECT_EQ(inner0.id, "i_inner"); + EXPECT_EQ(outer1.id, "j_outer"); + EXPECT_EQ( + utils::GetStreamCnt(ele.schedule()), + "{ S[i, j, k] -> S[i_outer, i_inner, j_outer, j_inner, k' = k] : (-i + i_inner) mod 4 = 0 and (-j + j_inner) mod " + "6 = 0 and -3 + i <= 4i_outer <= i and 0 <= i_inner <= 3 and -5 + j <= 6j_outer <= j and 0 <= j_inner <= 5 }"); } } // namespace poly diff --git a/cinn/poly/map.h b/cinn/poly/map.h index 0d557c6f7dceb..9a993be701f21 100644 --- a/cinn/poly/map.h +++ b/cinn/poly/map.h @@ -16,10 +16,13 @@ namespace poly { struct Iterator { std::string id; + Iterator() = default; explicit Iterator(const std::string& id) : id(id) {} explicit Iterator(const Iterator& x) : id(x.id) {} explicit Iterator(Iterator&& x) : id(std::move(x.id)) {} + Iterator& operator=(const Iterator& other) { id = other.id; } + friend std::ostream& operator<<(std::ostream& os, const Iterator& x); }; diff --git a/cinn/utils/functional.cc b/cinn/utils/functional.cc index 8b137891791fe..6929a834bccc9 100644 --- a/cinn/utils/functional.cc +++ b/cinn/utils/functional.cc @@ -1 +1,5 @@ +#include "cinn/utils/functional.h" +namespace cinn { +namespace utils {} // namespace utils +} // namespace cinn