Skip to content

Commit bc765ba

Browse files
authored
Merge pull request #445 from tensor-compiler/array_algebra_iter_const
Array algebra iter const
2 parents 7154fd0 + 2b8bb8b commit bc765ba

File tree

1 file changed

+116
-18
lines changed

1 file changed

+116
-18
lines changed

src/lower/merge_lattice.cpp

Lines changed: 116 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -541,11 +541,53 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
541541
bool locateLeft = locateFromLeft(left, right);
542542

543543
// Append all combinations of a and b merge points
544-
for (auto& leftPoint : left.points()) {
545-
for (auto& rightPoint : right.points()) {
546-
points.push_back(intersectPoints(leftPoint, rightPoint, locateLeft));
544+
struct pointSort {
545+
bool operator()(const MergePoint& a, const MergePoint& b) {
546+
size_t left_size = a.iterators().size() + a.locators().size();
547+
size_t right_size = b.iterators().size() + b.locators().size();
548+
return left_size > right_size;
549+
}
550+
} pointSorter;
551+
552+
// Append all combinations of the merge points of a and b
553+
auto sorted_apoint = left.points();
554+
auto sorted_bpoint = right.points();
555+
std::sort(sorted_apoint.begin(), sorted_apoint.end(), pointSorter);
556+
std::sort(sorted_bpoint.begin(), sorted_bpoint.end(), pointSorter);
557+
558+
set<Iterator> apoint_root_set;
559+
if (!sorted_apoint.empty())
560+
apoint_root_set = sorted_apoint.begin()->tensorRegion();
561+
562+
set<Iterator>bpoint_root_set;
563+
if (!sorted_bpoint.empty())
564+
bpoint_root_set = sorted_bpoint.begin()->tensorRegion();
565+
566+
567+
for (auto& apoint : sorted_apoint) {
568+
for (auto& bpoint : sorted_bpoint) {
569+
bool hasIntersection = true;
570+
571+
auto apoint_set = apoint.tensorRegion();
572+
auto bpoint_set = bpoint.tensorRegion();
573+
574+
for (auto& it : apoint_set) {
575+
if (!std::count(bpoint_set.begin(), bpoint_set.end(), it) &&
576+
std::count(bpoint_root_set.begin(), bpoint_root_set.end(), it)) {
577+
hasIntersection = false;
578+
}
579+
}
580+
for (auto& it : bpoint_set) {
581+
if (!std::count(apoint_set.begin(), apoint_set.end(), it) &&
582+
std::count(apoint_root_set.begin(), apoint_root_set.end(), it)) {
583+
hasIntersection = false;
584+
}
585+
}
586+
if (hasIntersection)
587+
points.push_back(intersectPoints(apoint, bpoint, locateLeft));
547588
}
548589
}
590+
std::sort(points.begin(), points.end(), pointSorter);
549591

550592
// Correctness: ensures that points produced on BOTH the left and the
551593
// right lattices are produced in the final intersection.
@@ -557,7 +599,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
557599
// points and resolves conflicts arising between omitters and
558600
// producers
559601
points = removeDuplicatedTensorRegions(points, true);
560-
602+
561603
// Optimization: Removed a subLattice of points if the entire subLattice is
562604
// made of only omitters
563605
// points = removeUnnecessaryOmitterPoints(points);
@@ -577,10 +619,49 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
577619
{
578620
vector<MergePoint> points;
579621

622+
struct pointSort {
623+
bool operator()(const MergePoint& a, const MergePoint& b) {
624+
size_t left_size = a.iterators().size() + a.locators().size();
625+
size_t right_size = b.iterators().size() + b.locators().size();
626+
return left_size > right_size;
627+
}
628+
} pointSorter;
629+
580630
// Append all combinations of the merge points of a and b
581-
for (auto& apoint : left.points()) {
582-
for (auto& bpoint : right.points()) {
583-
points.push_back(unionPoints(apoint, bpoint));
631+
auto sorted_apoint = left.points();
632+
auto sorted_bpoint = right.points();
633+
std::sort(sorted_apoint.begin(), sorted_apoint.end(), pointSorter);
634+
std::sort(sorted_bpoint.begin(), sorted_bpoint.end(), pointSorter);
635+
636+
set<Iterator> apoint_root_set;
637+
if (!sorted_apoint.empty())
638+
apoint_root_set = sorted_apoint.begin()->tensorRegion();
639+
640+
set<Iterator>bpoint_root_set;
641+
if (!sorted_bpoint.empty())
642+
bpoint_root_set = sorted_bpoint.begin()->tensorRegion();
643+
644+
for (auto& apoint : sorted_apoint) {
645+
for (auto& bpoint : sorted_bpoint) {
646+
bool hasIntersection = true;
647+
648+
auto apoint_set = apoint.tensorRegion();
649+
auto bpoint_set = bpoint.tensorRegion();
650+
651+
for (auto& it : apoint_set) {
652+
if (!std::count(bpoint_set.begin(), bpoint_set.end(), it) &&
653+
std::count(bpoint_root_set.begin(), bpoint_root_set.end(), it)) {
654+
hasIntersection = false;
655+
}
656+
}
657+
for (auto& it : bpoint_set) {
658+
if (!std::count(apoint_set.begin(), apoint_set.end(), it) &&
659+
std::count(apoint_root_set.begin(), apoint_root_set.end(), it)) {
660+
hasIntersection = false;
661+
}
662+
}
663+
if (hasIntersection)
664+
points.push_back(unionPoints(apoint, bpoint));
584665
}
585666
}
586667

@@ -590,22 +671,13 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
590671
// Append the merge points of b
591672
util::append(points, right.points());
592673

593-
struct pointSort {
594-
bool operator()(const MergePoint& a, const MergePoint& b) {
595-
size_t left_size = a.iterators().size() + a.locators().size();
596-
size_t right_size = b.iterators().size() + b.locators().size();
597-
return left_size > right_size;
598-
}
599-
} pointSorter;
600-
601674
std::sort(points.begin(), points.end(), pointSorter);
602675

603676
// Correctness: This ensures that points omitted on BOTH the left and the
604677
// right lattices are omitted in the Union. Needed since some
605678
// subpoints may produce leading to erroneous producer regions
606679
points = correctPointTypesAfterUnion(left.points(), right.points(), points);
607680

608-
609681
// Correctness: Deduplicate regions that are described by multiple lattice
610682
// points and resolves conflicts arising between omitters and
611683
// producers
@@ -671,6 +743,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
671743
*/
672744
static MergePoint unionPoints(MergePoint left, MergePoint right)
673745
{
746+
674747
vector<Iterator> iterators= combine(left.iterators(),right.iterators());
675748
vector<Iterator> locaters = combine(left.locators(), right.locators());
676749
vector<Iterator> results = combine(left.results(), right.results());
@@ -784,6 +857,21 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
784857
return deduplicates;
785858
}
786859

860+
static vector<Iterator>
861+
removeDimensionIterators(const vector<Iterator>& iterators)
862+
{
863+
vector<Iterator> result;
864+
865+
// Remove all but one of the dense iterators, which are all the same.
866+
for (auto& iterator : iterators) {
867+
if (!iterator.isDimensionIterator()) {
868+
result.push_back(iterator);
869+
}
870+
}
871+
return result;
872+
}
873+
874+
787875
static vector<MergePoint>
788876
flipPoints(const vector<MergePoint>& points) {
789877
vector<MergePoint> flippedPoints;
@@ -1119,8 +1207,18 @@ ostream& operator<<(ostream& os, const MergeLattice& ml) {
11191207
}
11201208

11211209
bool operator==(const MergeLattice& a, const MergeLattice& b) {
1122-
auto& apoints = a.points();
1123-
auto& bpoints = b.points();
1210+
auto apoints = a.points();
1211+
auto bpoints = b.points();
1212+
struct pointSort {
1213+
bool operator()(const MergePoint& a, const MergePoint& b) {
1214+
size_t left_size = a.iterators().size() + a.locators().size();
1215+
size_t right_size = b.iterators().size() + b.locators().size();
1216+
return left_size > right_size;
1217+
}
1218+
} pointSorter;
1219+
1220+
std::sort(apoints.begin(), apoints.end(), pointSorter);
1221+
std::sort(bpoints.begin(), bpoints.end(), pointSorter);
11241222
if (apoints.size() != bpoints.size()) {
11251223
return false;
11261224
}

0 commit comments

Comments
 (0)