Skip to content

Commit

Permalink
[Arith] Fix iter_affine_map for non-constant extent (apache#280)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored and Hzfengsy committed Jul 4, 2021
1 parent 42cc496 commit 9d7a7ee
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 15 deletions.
55 changes: 40 additions & 15 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,8 @@ class IterMapRewriter : public ExprMutator {
return analyzer_->CanProveEqual(lhs, rhs) || analyzer_->CanProve(floormod(lhs, rhs) == 0);
}

PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs);
PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs);
PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig);
PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig);

static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
tir::ExprDeepEqual equal;
Expand Down Expand Up @@ -897,7 +897,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
if (a->IsInstance<IterMapExprNode>() && b->IsInstance<IterMapExprNode>()) {
// cannot multiply two iterators, mark as unresolved.
++unresolved_count_;
return Mul(a, b);
return GetRef<PrimExpr>(op);
}

if (!a->IsInstance<IterMapExprNode>()) {
Expand All @@ -916,7 +916,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
}
}

PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig) {
// floordiv(x*scale, rhs)
if (is_one(rhs)) return std::move(lhs);
if (!is_one(lhs->scale)) {
Expand All @@ -932,7 +932,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
} else {
// mark as unresolved.
++unresolved_count_;
return floordiv(lhs, rhs);
return orig;
}
}
}
Expand All @@ -954,7 +954,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
} else {
// mark as unresolved.
++unresolved_count_;
return floordiv(lhs, rhs);
return orig;
}
}

Expand Down Expand Up @@ -982,25 +982,40 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) {
if (b->IsInstance<IterMapExprNode>()) {
// cannot divide an iterator, mark as unresolved.
++unresolved_count_;
return FloorDiv(a, b);
return GetRef<PrimExpr>(op);
}

if (a->IsInstance<IterSumExprNode>()) {
IterSumExpr ret = Downcast<IterSumExpr>(a);
<<<<<<< HEAD
if (Optional<IterSplitExpr> opt = TryFuseIters(ret)) {
return SplitFloorDivConst(opt.value(), b, GetRef<PrimExpr>(op));
=======
PrimExpr base = ret->base;
if (!CanProveDivisible(base, b)) {
++unresolved_count_;
return GetRef<PrimExpr>(op);
}
ret.CopyOnWrite()->base = 0;
if (auto opt = TryFuseIters(ret)) {
auto res = SplitFloorDivConst(opt.value(), b, GetRef<PrimExpr>(op));
auto res_op = res.as<IterSplitExprNode>();
return res_op && !is_zero(base)
? IterSumExpr({GetRef<IterSplitExpr>(res_op)}, analyzer_->Simplify(base / b))
: res + analyzer_->Simplify(base / b);
>>>>>>> [Arith] Fix iter_affine_map for non-constant extent (#280)
} else {
++unresolved_count_;
return FloorDiv(a, b);
return GetRef<PrimExpr>(op);
}
} else {
CHECK(a->IsInstance<IterSplitExprNode>());
IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a));
return SplitFloorDivConst(ret, b);
return SplitFloorDivConst(ret, b, GetRef<PrimExpr>(op));
}
}

PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig) {
// floormod(x*scale, rhs)
if (is_one(rhs)) return make_zero(lhs->dtype);
if (!is_one(lhs->scale)) {
Expand All @@ -1014,7 +1029,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
} else {
// mark as unresolved.
++unresolved_count_;
return floormod(lhs, rhs);
return orig;
}
}
}
Expand All @@ -1028,7 +1043,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
} else {
// mark as unresolved.
++unresolved_count_;
return floormod(lhs, rhs);
return orig;
}
}

Expand Down Expand Up @@ -1056,21 +1071,31 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
if (b->IsInstance<IterMapExprNode>()) {
// cannot mod an iterator, mark as unresolved.
++unresolved_count_;
return FloorMod(a, b);
return GetRef<PrimExpr>(op);
}

if (a->IsInstance<IterSumExprNode>()) {
IterSumExpr ret = Downcast<IterSumExpr>(a);
<<<<<<< HEAD
if (Optional<IterSplitExpr> opt = TryFuseIters(ret)) {
=======
PrimExpr base = ret->base;
if (!CanProveDivisible(base, b)) {
++unresolved_count_;
return GetRef<PrimExpr>(op);
}
ret.CopyOnWrite()->base = 0;
if (auto opt = TryFuseIters(ret)) {
>>>>>>> [Arith] Fix iter_affine_map for non-constant extent (#280)
return SplitFloorModConst(opt.value(), b, GetRef<PrimExpr>(op));
} else {
++unresolved_count_;
return FloorMod(a, b);
return GetRef<PrimExpr>(op);
}
} else {
CHECK(a->IsInstance<IterSplitExprNode>());
IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a));
return SplitFloorModConst(ret, b);
return SplitFloorModConst(ret, b, GetRef<PrimExpr>(op));
}
}

Expand Down
4 changes: 4 additions & 0 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ def test_split():
res = tvm.arith.detect_iter_map([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)]))
assert len(res) == 0

res = tvm.arith.detect_iter_map(iter_var_par([(0, 4)]), [fld(x, flm(flm(y, 8), 6))],
var_dom([(x, 24), (y, 8)]))
assert len(res) == 0


def test_compound():
x = tvm.tir.Var("x", "int32"), 10
Expand Down

0 comments on commit 9d7a7ee

Please sign in to comment.