Skip to content

Commit

Permalink
refactor: move parallel code into a separate function
Browse files Browse the repository at this point in the history
  • Loading branch information
william-silversmith committed Dec 30, 2023
1 parent 1eb7d0a commit 4715659
Showing 1 changed file with 51 additions and 51 deletions.
102 changes: 51 additions & 51 deletions fastmorph/fastmorph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,47 @@
#include <vector>
#include <cstdlib>
#include <cmath>
#include <functional>
#include "threadpool.h"

namespace fastmorph {

void parallelize_blocks(
const std::function<void(
const uint64_t, const uint64_t,
const uint64_t, const uint64_t,
const uint64_t, const uint64_t
)> &process_block,
const uint64_t sx, const uint64_t sy, const uint64_t sz,
const uint64_t threads, const uint64_t offset
) {
const uint64_t block_size = 64;

const uint64_t grid_x = std::max(static_cast<uint64_t>((sx + block_size/2) / block_size), static_cast<uint64_t>(1));
const uint64_t grid_y = std::max(static_cast<uint64_t>((sy + block_size/2) / block_size), static_cast<uint64_t>(1));
const uint64_t grid_z = std::max(static_cast<uint64_t>((sz + block_size/2) / block_size), static_cast<uint64_t>(1));

const int real_threads = std::max(std::min(threads, grid_x * grid_y * grid_z), static_cast<uint64_t>(0));

ThreadPool pool(real_threads);

for (uint64_t gz = 0; gz < grid_z; gz++) {
for (uint64_t gy = 0; gy < grid_y; gy++) {
for (uint64_t gx = 0; gx < grid_x; gx++) {
pool.enqueue([=]() {
process_block(
std::max(offset, gx * block_size), std::min((gx+1) * block_size, sx - offset),
std::max(offset, gy * block_size), std::min((gy+1) * block_size, sy - offset),
std::max(offset, gz * block_size), std::min((gz+1) * block_size, sz - offset)
);
});
}
}
}

pool.join();
}


template <typename LABEL>
void multilabel_dilate(
Expand Down Expand Up @@ -251,31 +288,13 @@ void multilabel_dilate(
}
};

const uint64_t block_size = 64;

const uint64_t grid_x = std::max(static_cast<uint64_t>((sx + block_size/2) / block_size), static_cast<uint64_t>(1));
const uint64_t grid_y = std::max(static_cast<uint64_t>((sy + block_size/2) / block_size), static_cast<uint64_t>(1));
const uint64_t grid_z = std::max(static_cast<uint64_t>((sz + block_size/2) / block_size), static_cast<uint64_t>(1));

const int real_threads = std::max(std::min(threads, grid_x * grid_y * grid_z), static_cast<uint64_t>(0));

ThreadPool pool(real_threads);

for (uint64_t gz = 0; gz < grid_z; gz++) {
for (uint64_t gy = 0; gy < grid_y; gy++) {
for (uint64_t gx = 0; gx < grid_x; gx++) {
pool.enqueue([=]() {
process_block(
gx * block_size, std::min((gx+1) * block_size, sx),
gy * block_size, std::min((gy+1) * block_size, sy),
gz * block_size, std::min((gz+1) * block_size, sz)
);
});
}
}
}

pool.join();
parallelize_blocks(
std::function<void(
const uint64_t,const uint64_t,const uint64_t,
const uint64_t,const uint64_t,const uint64_t
)>(process_block),
sx, sy, sz, threads, /*offset=*/0
);
}

template <typename LABEL>
Expand Down Expand Up @@ -423,32 +442,13 @@ void multilabel_erode(

#undef FILL_STENCIL

const uint64_t block_size = 64;

const uint64_t grid_x = std::max(static_cast<uint64_t>((sx + block_size/2) / block_size), static_cast<uint64_t>(1));
const uint64_t grid_y = std::max(static_cast<uint64_t>((sy + block_size/2) / block_size), static_cast<uint64_t>(1));
const uint64_t grid_z = std::max(static_cast<uint64_t>((sz + block_size/2) / block_size), static_cast<uint64_t>(1));

const int real_threads = std::max(std::min(threads, grid_x * grid_y * grid_z), static_cast<uint64_t>(0));

ThreadPool pool(real_threads);

for (uint64_t gz = 0; gz < grid_z; gz++) {
for (uint64_t gy = 0; gy < grid_y; gy++) {
for (uint64_t gx = 0; gx < grid_x; gx++) {
pool.enqueue([=]() {
const uint64_t one = 1;
process_block(
std::max(one, gx * block_size), std::min((gx+1) * block_size, sx - 1),
std::max(one, gy * block_size), std::min((gy+1) * block_size, sy - 1),
std::max(one, gz * block_size), std::min((gz+1) * block_size, sz - 1)
);
});
}
}
}

pool.join();
parallelize_blocks(
std::function<void(
const uint64_t,const uint64_t,const uint64_t,
const uint64_t,const uint64_t,const uint64_t
)>(process_block),
sx, sy, sz, threads, /*offset=*/1
);
}

};
Expand Down

0 comments on commit 4715659

Please sign in to comment.