Skip to content

Commit

Permalink
Merge pull request #50 from Dobatymo/fix-openmp-single
Browse files Browse the repository at this point in the history
Only use openmp if batch size > 1
  • Loading branch information
Dobatymo committed Sep 1, 2023
2 parents 4007905 + 2a780ca commit c09d884
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pynear/include/BKTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ template <typename key_t, typename distance_t, typename metric> class BKTree {
std::vector<std::vector<key_t>> keys_out(keys.size());

#if (ENABLE_OMP_PARALLEL)
#pragma omp parallel for schedule(static, 1)
#pragma omp parallel for schedule(static, 1) if (keys.size() > 1)
#endif
// i should be size_t, however msvc requires signed integral loop variables (except with -openmp:llvm)
for (int i = 0; i < static_cast<int>(keys.size()); ++i) {
Expand Down
8 changes: 4 additions & 4 deletions pynear/include/VPTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,10 @@ template <typename T, typename distance_type, distance_type (*distance)(const T
results.resize(queries.size());

#if (ENABLE_OMP_PARALLEL)
#pragma omp parallel for schedule(static, 1)
#pragma omp parallel for schedule(static, 1) if (queries.size() > 1)
#endif
// i should be size_t, however msvc requires signed integral loop variables (except with -openmp:llvm)
for (int i = 0; i < queries.size(); ++i) {
for (int i = 0; i < static_cast<int>(queries.size()); ++i) {
const T &query = queries[i];
std::priority_queue<VPTreeSearchElement> knnQueue;
searchKNN(_rootPartition, query, k, knnQueue);
Expand All @@ -212,10 +212,10 @@ template <typename T, typename distance_type, distance_type (*distance)(const T
distances.resize(queries.size());

#if (ENABLE_OMP_PARALLEL)
#pragma omp parallel for schedule(static, 1)
#pragma omp parallel for schedule(static, 1) if (queries.size() > 1)
#endif
// i should be size_t, see above
for (int i = 0; i < queries.size(); ++i) {
for (int i = 0; i < static_cast<int>(queries.size()); ++i) {
const T &query = queries[i];
distance_type dist = 0;
int64_t index = -1;
Expand Down

0 comments on commit c09d884

Please sign in to comment.