Skip to content

Commit

Permalink
Optimize when the set is large
Browse files Browse the repository at this point in the history
  • Loading branch information
xumingkuan committed Sep 22, 2020
1 parent 691f706 commit 3c78343
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 12 deletions.
45 changes: 34 additions & 11 deletions taichi/program/state_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ bool StateFlowGraph::fuse() {
fusion_set.insert(fusion_set.end(), i);
}
}
const int kLargeFusionSetThreshold = std::max(n / 32, 4);

std::unordered_set<int> indices_to_delete;

Expand Down Expand Up @@ -353,21 +354,43 @@ bool StateFlowGraph::fuse() {
}
}
}
// The case without an edge: O(sum(size^2)) = O(n^2)
// The case without an edge: O(sum(size * min(size, n / 64))) = O(n^2 / 64)
for (auto &fusion_map : task_fusion_map) {
// TODO: optimize from O(size^2) to O(size * n / 64) when
// fusion_map.second.size() is large
std::vector<int> indices(fusion_map.second.begin(),
fusion_map.second.end());
for (int i = 0; i < (int)indices.size(); i++) {
const int a = indices[i];
if (!fused[a]) {
for (int j = i + 1; j < (int)indices.size(); j++) {
const int b = indices[j];
if (!fused[b] && !has_path[a][b] && !has_path[b][a]) {
do_fuse(std::min(a, b), std::max(a, b));
TI_ASSERT(std::is_sorted(indices.begin(), indices.end()));
if (fusion_map.second.size() >= kLargeFusionSetThreshold) {
// O(size * n / 64)
Bitset mask(n);
for (int a : indices) {
mask[a] = true;
}
for (int a : indices) {
if (!fused[a]) {
int b = (mask & ~(has_path[a] | has_path_reverse[a]))
.lower_bound(a + 1);
if (b == -1) {
mask[a] = false; // a can't be fused in this iteration
} else {
do_fuse(a, b);
mask[a] = false;
mask[b] = false;
updated = true;
break;
}
}
}
} else {
// O(size^2)
for (int i = 0; i < (int)indices.size(); i++) {
const int a = indices[i];
if (!fused[a]) {
for (int j = i + 1; j < (int)indices.size(); j++) {
const int b = indices[j];
if (!fused[b] && !has_path[a][b] && !has_path[b][a]) {
do_fuse(a, b);
updated = true;
break;
}
}
}
}
Expand Down
22 changes: 21 additions & 1 deletion taichi/util/bit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ Bitset &Bitset::operator|=(const Bitset &other) {
return *this;
}

Bitset Bitset::operator|(const Bitset &other) const {
Bitset result = *this;
result |= other;
return result;
}

Bitset &Bitset::operator^=(const Bitset &other) {
const int len = vec_.size();
TI_ASSERT(len == other.vec_.size());
Expand All @@ -83,13 +89,27 @@ Bitset &Bitset::operator^=(const Bitset &other) {
return *this;
}

Bitset Bitset::operator~() const {
Bitset result(size());
const int len = vec_.size();
for (int i = 0; i < len; i++) {
result.vec_[i] = ~vec_[i];
}
return result;
}

int Bitset::find_first_bit() const {
return lower_bound(0);
}

int Bitset::lower_bound(int x) const {
const int len = vec_.size();
TI_ASSERT(x >= 0 && x < len * kBits);
if (x >= len * kBits) {
return -1;
}
if (x < 0) {
x = 0;
}
int i = x / kBits;
if (x % kBits != 0) {
if (auto test = vec_[i] & (kMask ^ ((1ULL << (x % kBits)) - 1))) {
Expand Down
2 changes: 2 additions & 0 deletions taichi/util/bit.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ class Bitset {
Bitset &operator&=(const Bitset &other);
Bitset operator&(const Bitset &other) const;
Bitset &operator|=(const Bitset &other);
Bitset operator|(const Bitset &other) const;
Bitset &operator^=(const Bitset &other);
Bitset operator~() const;

// Find the place of the first bit, or return -1 if it doesn't exist.
int find_first_bit() const;
Expand Down

0 comments on commit 3c78343

Please sign in to comment.