@@ -541,11 +541,53 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
541
541
bool locateLeft = locateFromLeft (left, right);
542
542
543
543
// Append all combinations of a and b merge points
544
- for (auto & leftPoint : left.points ()) {
545
- for (auto & rightPoint : right.points ()) {
546
- points.push_back (intersectPoints (leftPoint, rightPoint, locateLeft));
544
+ struct pointSort {
545
+ bool operator ()(const MergePoint& a, const MergePoint& b) {
546
+ size_t left_size = a.iterators ().size () + a.locators ().size ();
547
+ size_t right_size = b.iterators ().size () + b.locators ().size ();
548
+ return left_size > right_size;
549
+ }
550
+ } pointSorter;
551
+
552
+ // Append all combinations of the merge points of a and b
553
+ auto sorted_apoint = left.points ();
554
+ auto sorted_bpoint = right.points ();
555
+ std::sort (sorted_apoint.begin (), sorted_apoint.end (), pointSorter);
556
+ std::sort (sorted_bpoint.begin (), sorted_bpoint.end (), pointSorter);
557
+
558
+ set<Iterator> apoint_root_set;
559
+ if (!sorted_apoint.empty ())
560
+ apoint_root_set = sorted_apoint.begin ()->tensorRegion ();
561
+
562
+ set<Iterator>bpoint_root_set;
563
+ if (!sorted_bpoint.empty ())
564
+ bpoint_root_set = sorted_bpoint.begin ()->tensorRegion ();
565
+
566
+
567
+ for (auto & apoint : sorted_apoint) {
568
+ for (auto & bpoint : sorted_bpoint) {
569
+ bool hasIntersection = true ;
570
+
571
+ auto apoint_set = apoint.tensorRegion ();
572
+ auto bpoint_set = bpoint.tensorRegion ();
573
+
574
+ for (auto & it : apoint_set) {
575
+ if (!std::count (bpoint_set.begin (), bpoint_set.end (), it) &&
576
+ std::count (bpoint_root_set.begin (), bpoint_root_set.end (), it)) {
577
+ hasIntersection = false ;
578
+ }
579
+ }
580
+ for (auto & it : bpoint_set) {
581
+ if (!std::count (apoint_set.begin (), apoint_set.end (), it) &&
582
+ std::count (apoint_root_set.begin (), apoint_root_set.end (), it)) {
583
+ hasIntersection = false ;
584
+ }
585
+ }
586
+ if (hasIntersection)
587
+ points.push_back (intersectPoints (apoint, bpoint, locateLeft));
547
588
}
548
589
}
590
+ std::sort (points.begin (), points.end (), pointSorter);
549
591
550
592
// Correctness: ensures that points produced on BOTH the left and the
551
593
// right lattices are produced in the final intersection.
@@ -557,7 +599,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
557
599
// points and resolves conflicts arising between omitters and
558
600
// producers
559
601
points = removeDuplicatedTensorRegions (points, true );
560
-
602
+
561
603
// Optimization: Removed a subLattice of points if the entire subLattice is
562
604
// made of only omitters
563
605
// points = removeUnnecessaryOmitterPoints(points);
@@ -577,10 +619,49 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
577
619
{
578
620
vector<MergePoint> points;
579
621
622
+ struct pointSort {
623
+ bool operator ()(const MergePoint& a, const MergePoint& b) {
624
+ size_t left_size = a.iterators ().size () + a.locators ().size ();
625
+ size_t right_size = b.iterators ().size () + b.locators ().size ();
626
+ return left_size > right_size;
627
+ }
628
+ } pointSorter;
629
+
580
630
// Append all combinations of the merge points of a and b
581
- for (auto & apoint : left.points ()) {
582
- for (auto & bpoint : right.points ()) {
583
- points.push_back (unionPoints (apoint, bpoint));
631
+ auto sorted_apoint = left.points ();
632
+ auto sorted_bpoint = right.points ();
633
+ std::sort (sorted_apoint.begin (), sorted_apoint.end (), pointSorter);
634
+ std::sort (sorted_bpoint.begin (), sorted_bpoint.end (), pointSorter);
635
+
636
+ set<Iterator> apoint_root_set;
637
+ if (!sorted_apoint.empty ())
638
+ apoint_root_set = sorted_apoint.begin ()->tensorRegion ();
639
+
640
+ set<Iterator>bpoint_root_set;
641
+ if (!sorted_bpoint.empty ())
642
+ bpoint_root_set = sorted_bpoint.begin ()->tensorRegion ();
643
+
644
+ for (auto & apoint : sorted_apoint) {
645
+ for (auto & bpoint : sorted_bpoint) {
646
+ bool hasIntersection = true ;
647
+
648
+ auto apoint_set = apoint.tensorRegion ();
649
+ auto bpoint_set = bpoint.tensorRegion ();
650
+
651
+ for (auto & it : apoint_set) {
652
+ if (!std::count (bpoint_set.begin (), bpoint_set.end (), it) &&
653
+ std::count (bpoint_root_set.begin (), bpoint_root_set.end (), it)) {
654
+ hasIntersection = false ;
655
+ }
656
+ }
657
+ for (auto & it : bpoint_set) {
658
+ if (!std::count (apoint_set.begin (), apoint_set.end (), it) &&
659
+ std::count (apoint_root_set.begin (), apoint_root_set.end (), it)) {
660
+ hasIntersection = false ;
661
+ }
662
+ }
663
+ if (hasIntersection)
664
+ points.push_back (unionPoints (apoint, bpoint));
584
665
}
585
666
}
586
667
@@ -590,22 +671,13 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
590
671
// Append the merge points of b
591
672
util::append (points, right.points ());
592
673
593
- struct pointSort {
594
- bool operator ()(const MergePoint& a, const MergePoint& b) {
595
- size_t left_size = a.iterators ().size () + a.locators ().size ();
596
- size_t right_size = b.iterators ().size () + b.locators ().size ();
597
- return left_size > right_size;
598
- }
599
- } pointSorter;
600
-
601
674
std::sort (points.begin (), points.end (), pointSorter);
602
675
603
676
// Correctness: This ensures that points omitted on BOTH the left and the
604
677
// right lattices are omitted in the Union. Needed since some
605
678
// subpoints may produce leading to erroneous producer regions
606
679
points = correctPointTypesAfterUnion (left.points (), right.points (), points);
607
680
608
-
609
681
// Correctness: Deduplicate regions that are described by multiple lattice
610
682
// points and resolves conflicts arising between omitters and
611
683
// producers
@@ -671,6 +743,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
671
743
*/
672
744
static MergePoint unionPoints (MergePoint left, MergePoint right)
673
745
{
746
+
674
747
vector<Iterator> iterators= combine (left.iterators (),right.iterators ());
675
748
vector<Iterator> locaters = combine (left.locators (), right.locators ());
676
749
vector<Iterator> results = combine (left.results (), right.results ());
@@ -784,6 +857,21 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
784
857
return deduplicates;
785
858
}
786
859
860
+ static vector<Iterator>
861
+ removeDimensionIterators (const vector<Iterator>& iterators)
862
+ {
863
+ vector<Iterator> result;
864
+
865
+ // Remove all but one of the dense iterators, which are all the same.
866
+ for (auto & iterator : iterators) {
867
+ if (!iterator.isDimensionIterator ()) {
868
+ result.push_back (iterator);
869
+ }
870
+ }
871
+ return result;
872
+ }
873
+
874
+
787
875
static vector<MergePoint>
788
876
flipPoints (const vector<MergePoint>& points) {
789
877
vector<MergePoint> flippedPoints;
@@ -1119,8 +1207,18 @@ ostream& operator<<(ostream& os, const MergeLattice& ml) {
1119
1207
}
1120
1208
1121
1209
bool operator ==(const MergeLattice& a, const MergeLattice& b) {
1122
- auto & apoints = a.points ();
1123
- auto & bpoints = b.points ();
1210
+ auto apoints = a.points ();
1211
+ auto bpoints = b.points ();
1212
+ struct pointSort {
1213
+ bool operator ()(const MergePoint& a, const MergePoint& b) {
1214
+ size_t left_size = a.iterators ().size () + a.locators ().size ();
1215
+ size_t right_size = b.iterators ().size () + b.locators ().size ();
1216
+ return left_size > right_size;
1217
+ }
1218
+ } pointSorter;
1219
+
1220
+ std::sort (apoints.begin (), apoints.end (), pointSorter);
1221
+ std::sort (bpoints.begin (), bpoints.end (), pointSorter);
1124
1222
if (apoints.size () != bpoints.size ()) {
1125
1223
return false ;
1126
1224
}
0 commit comments