diff --git a/include/knowhere/comp/thread_pool.h b/include/knowhere/comp/thread_pool.h index 13f75cf4..f62f2b24 100644 --- a/include/knowhere/comp/thread_pool.h +++ b/include/knowhere/comp/thread_pool.h @@ -22,6 +22,7 @@ #include "folly/executors/CPUThreadPoolExecutor.h" #include "folly/futures/Future.h" +#include "knowhere/expected.h" #include "knowhere/log.h" namespace knowhere { @@ -211,4 +212,23 @@ class ThreadPool { constexpr static size_t kTaskQueueFactor = 16; }; + +// T is either folly::Unit or Status +template +inline Status +WaitAllSuccess(std::vector>& futures) { + static_assert(std::is_same::value || std::is_same::value, + "WaitAllSuccess can only be used with folly::Unit or knowhere::Status"); + auto allFuts = folly::collectAll(futures.begin(), futures.end()).get(); + for (const auto& result : allFuts) { + result.throwUnlessValue(); + if constexpr (!std::is_same_v) { + if (result.value() != Status::success) { + return result.value(); + } + } + } + return Status::success; +} + } // namespace knowhere diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc index f54d9795..47bd6adc 100644 --- a/src/common/comp/brute_force.cc +++ b/src/common/comp/brute_force.cc @@ -62,17 +62,17 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset bool is_cosine = IsMetricType(metric_str, metric::COSINE); int topk = cfg.k.value(); - auto labels = new int64_t[nq * topk]; - auto distances = new float[nq * topk]; + auto labels = std::make_unique(nq * topk); + auto distances = std::make_unique(nq * topk); auto pool = ThreadPool::GetGlobalSearchThreadPool(); std::vector> futs; futs.reserve(nq); for (int i = 0; i < nq; ++i) { - futs.emplace_back(pool->push([&, index = i] { + futs.emplace_back(pool->push([&, index = i, labels_ptr = labels.get(), distances_ptr = distances.get()] { ThreadPool::ScopedOmpSetter setter(1); - auto cur_labels = labels + topk * index; - auto cur_distances = distances + topk * index; + auto cur_labels = labels_ptr + topk * index; + auto cur_distances = distances_ptr + topk * index; BitsetViewIDSelector bw_idselector(bitset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; @@ -128,14 +128,11 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset return Status::success; })); } - for (auto& fut : futs) { - fut.wait(); - auto ret = fut.result().value(); - if (ret != Status::success) { - return expected::Err(ret, "failed to brute force search"); - } + auto ret = WaitAllSuccess(futs); + if (ret != Status::success) { + return expected::Err(ret, "failed to brute force search"); } - return GenResultDataSet(nq, cfg.k.value(), labels, distances); + return GenResultDataSet(nq, cfg.k.value(), labels.release(), distances.release()); } template @@ -233,11 +230,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ return Status::success; })); } - for (auto& fut : futs) { - fut.wait(); - auto ret = fut.result().value(); - RETURN_IF_ERROR(ret); - } + RETURN_IF_ERROR(WaitAllSuccess(futs)); return Status::success; } @@ -348,12 +341,9 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da return Status::success; })); } - for (auto& fut : futs) { - fut.wait(); - auto ret = fut.result().value(); - if (ret != Status::success) { - return expected::Err(ret, "failed to brute force search"); - } + auto ret = WaitAllSuccess(futs); + if (ret != Status::success) { + return expected::Err(ret, "failed to brute force search"); } int64_t* ids = nullptr; diff --git a/src/common/thread/thread.cc b/src/common/thread/thread.cc index 9d21692b..0e41e039 100644 --- a/src/common/thread/thread.cc +++ b/src/common/thread/thread.cc @@ -19,6 +19,7 @@ #include #include "knowhere/comp/thread_pool.h" + namespace knowhere { void @@ -33,14 +34,7 @@ ExecOverSearchThreadPool(std::vector>& tasks) { })); } std::this_thread::yield(); - // check for exceptions. value() is {}, so either - // a call does nothing, or it throws an inner exception. - for (auto& f : futures) { - f.wait(); - } - for (auto& f : futures) { - f.result().value(); - } + WaitAllSuccess(futures); } void @@ -55,14 +49,7 @@ ExecOverBuildThreadPool(std::vector>& tasks) { })); } std::this_thread::yield(); - // check for exceptions. value() is {}, so either - // a call does nothing, or it throws an inner exception. - for (auto& f : futures) { - f.wait(); - } - for (auto& f : futures) { - f.result().value(); - } + WaitAllSuccess(futures); } void diff --git a/src/index/diskann/diskann.cc b/src/index/diskann/diskann.cc index 2b13a94d..171038b3 100644 --- a/src/index/diskann/diskann.cc +++ b/src/index/diskann/diskann.cc @@ -479,8 +479,6 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& c std::vector warmup_result_ids_64(warmup_num, 0); std::vector warmup_result_dists(warmup_num, 0); - bool all_searches_are_good = true; - std::vector> futures; futures.reserve(warmup_num); for (_s64 i = 0; i < (int64_t)warmup_num; ++i) { @@ -490,16 +488,14 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& c warmup_result_dists.data() + (index * 1), 4); })); } - for (auto& future : futures) { - if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) { - all_searches_are_good = false; - } - } + + bool failed = TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success; + if (warmup != nullptr) { diskann::aligned_free(warmup); } - if (!all_searches_are_good) { + if (failed) { LOG_KNOWHERE_ERROR_ << "Failed to do search on warmup file for DiskANN."; return Status::diskann_inner_error; } @@ -542,34 +538,28 @@ DiskANNIndexNode::Search(const DataSet& dataset, const Config& cfg, co search_conf.search_list_size.value()); } - auto p_id = new int64_t[k * nq]; - auto p_dist = new DistType[k * nq]; + auto p_id = std::make_unique(k * nq); + auto p_dist = std::make_unique(k * nq); - bool all_searches_are_good = true; std::vector> futures; futures.reserve(nq); for (int64_t row = 0; row < nq; ++row) { - futures.emplace_back(search_pool_->push([&, index = row]() { + futures.emplace_back(search_pool_->push([&, index = row, p_id_ptr = p_id.get(), p_dist_ptr = p_dist.get()]() { diskann::QueryStats stats; - pq_flash_index_->cached_beam_search(xq + (index * dim), k, lsearch, p_id + (index * k), - p_dist + (index * k), beamwidth, false, &stats, feder_result, bitset, - filter_ratio, for_tuning); + pq_flash_index_->cached_beam_search(xq + (index * dim), k, lsearch, p_id_ptr + (index * k), + p_dist_ptr + (index * k), beamwidth, false, &stats, feder_result, + bitset, filter_ratio, for_tuning); #ifdef NOT_COMPILE_FOR_SWIG knowhere_diskann_search_hops.Observe(stats.n_hops); #endif })); } - for (auto& future : futures) { - if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) { - all_searches_are_good = false; - } - } - if (!all_searches_are_good) { + if (TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success) { return expected::Err(Status::diskann_inner_error, "some search failed"); } - auto res = GenResultDataSet(nq, k, p_id, p_dist); + auto res = GenResultDataSet(nq, k, p_id.release(), p_dist.release()); // set visit_info json string into result dataset if (feder_result != nullptr) { @@ -621,7 +611,6 @@ DiskANNIndexNode::RangeSearch(const DataSet& dataset, const Config& cf std::vector> futures; futures.reserve(nq); - bool all_searches_are_good = true; for (int64_t row = 0; row < nq; ++row) { futures.emplace_back(search_pool_->push([&, index = row]() { std::vector indices; @@ -639,12 +628,7 @@ DiskANNIndexNode::RangeSearch(const DataSet& dataset, const Config& cf } })); } - for (auto& future : futures) { - if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) { - all_searches_are_good = false; - } - } - if (!all_searches_are_good) { + if (TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success) { return expected::Err(Status::diskann_inner_error, "some search failed"); } diff --git a/src/index/flat/flat.cc b/src/index/flat/flat.cc index dc086fda..1f353d85 100644 --- a/src/index/flat/flat.cc +++ b/src/index/flat/flat.cc @@ -127,14 +127,7 @@ class FlatIndexNode : public IndexNode { })); } // wait for the completion - for (auto& fut : futs) { - fut.wait(); - } - // check for exceptions. value() is {}, so either - // a call does nothing, or it throws an inner exception. - for (auto& fut : futs) { - fut.result().value(); - } + WaitAllSuccess(futs); } catch (const std::exception& e) { std::unique_ptr auto_delete_ids(ids); std::unique_ptr auto_delete_dis(distances); @@ -216,14 +209,7 @@ class FlatIndexNode : public IndexNode { })); } // wait for the completion - for (auto& fut : futs) { - fut.wait(); - } - // check for exceptions. value() is {}, so either - // a call does nothing, or it throws an inner exception. - for (auto& fut : futs) { - fut.result().value(); - } + WaitAllSuccess(futs); GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims); } catch (const std::exception& e) { diff --git a/src/index/hnsw/hnsw.cc b/src/index/hnsw/hnsw.cc index f5c53e47..cb3f8ac5 100644 --- a/src/index/hnsw/hnsw.cc +++ b/src/index/hnsw/hnsw.cc @@ -126,13 +126,7 @@ class HnswIndexNode : public IndexNode { } })); } - for (auto& future : futures) { - future.wait(); - } - // check for exceptions - for (auto& future : futures) { - future.result().value(); - } + WaitAllSuccess(futures); futures.clear(); } @@ -146,13 +140,7 @@ class HnswIndexNode : public IndexNode { futures.emplace_back( build_pool->push([&, idx = i]() { index_->repairGraphConnectivity(unreached[idx]); })); } - for (auto& future : futures) { - future.wait(); - } - // check for exceptions - for (auto& future : futures) { - future.result().value(); - } + WaitAllSuccess(futures); } build_time.RecordSection("graph repair"); LOG_KNOWHERE_INFO_ << "HNSW built with #points num:" << index_->max_elements_ << " #M:" << index_->M_ @@ -186,8 +174,8 @@ class HnswIndexNode : public IndexNode { feder_result = std::make_unique(); } - auto p_id = new int64_t[k * nq]; - auto p_dist = new DistType[k * nq]; + auto p_id = std::make_unique(k * nq); + auto p_dist = std::make_unique(k * nq); hnswlib::SearchParam param{(size_t)hnsw_cfg.ef.value(), hnsw_cfg.for_tuning.value()}; bool transform = @@ -196,12 +184,12 @@ class HnswIndexNode : public IndexNode { std::vector> futs; futs.reserve(nq); for (int i = 0; i < nq; ++i) { - futs.emplace_back(search_pool_->push([&, idx = i]() { + futs.emplace_back(search_pool_->push([&, idx = i, p_id_ptr = p_id.get(), p_dist_ptr = p_dist.get()]() { auto single_query = (const char*)xq + idx * index_->data_size_; auto rst = index_->searchKnn(single_query, k, bitset, ¶m, feder_result); size_t rst_size = rst.size(); - auto p_single_dis = p_dist + idx * k; - auto p_single_id = p_id + idx * k; + auto p_single_dis = p_dist_ptr + idx * k; + auto p_single_id = p_id_ptr + idx * k; for (size_t idx = 0; idx < rst_size; ++idx) { const auto& [dist, id] = rst[idx]; p_single_dis[idx] = transform ? (-dist) : dist; @@ -213,11 +201,9 @@ class HnswIndexNode : public IndexNode { } })); } - for (auto& fut : futs) { - fut.wait(); - } + WaitAllSuccess(futs); - auto res = GenResultDataSet(nq, k, p_id, p_dist); + auto res = GenResultDataSet(nq, k, p_id.release(), p_dist.release()); // set visit_info json string into result dataset if (feder_result != nullptr) { @@ -300,9 +286,7 @@ class HnswIndexNode : public IndexNode { })); } // wait for initial search(in top layers and search for seed_ef in base layer) to finish - for (auto& fut : futs) { - fut.wait(); - } + WaitAllSuccess(futs); return vec; } @@ -335,10 +319,6 @@ class HnswIndexNode : public IndexNode { hnswlib::SearchParam param{(size_t)hnsw_cfg.ef.value()}; - int64_t* ids = nullptr; - DistType* dis = nullptr; - size_t* lims = nullptr; - std::vector> result_id_array(nq); std::vector> result_dist_array(nq); std::vector result_size(nq); @@ -365,9 +345,11 @@ class HnswIndexNode : public IndexNode { } })); } - for (auto& fut : futs) { - fut.wait(); - } + WaitAllSuccess(futs); + + int64_t* ids = nullptr; + DistType* dis = nullptr; + size_t* lims = nullptr; // filter range search result GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius_for_filter, range_filter, dis, ids, diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index c0155080..589a46c4 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -589,14 +589,7 @@ IvfIndexNode::Search(const DataSet& dataset, const Config& })); } // wait for the completion - for (auto& fut : futs) { - fut.wait(); - } - // check for exceptions. value() is {}, so either - // a call does nothing, or it throws an inner exception. - for (auto& fut : futs) { - fut.result().value(); - } + WaitAllSuccess(futs); } catch (const std::exception& e) { delete[] ids; delete[] distances; @@ -718,14 +711,7 @@ IvfIndexNode::RangeSearch(const DataSet& dataset, const Con })); } // wait for the completion - for (auto& fut : futs) { - fut.wait(); - } - // check for exceptions. value() is {}, so either - // a call does nothing, or it throws an inner exception. - for (auto& fut : futs) { - fut.result().value(); - } + WaitAllSuccess(futs); GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims); } catch (const std::exception& e) { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); diff --git a/tests/ut/test_utils.cc b/tests/ut/test_utils.cc index a899c3fd..61dcaec1 100644 --- a/tests/ut/test_utils.cc +++ b/tests/ut/test_utils.cc @@ -9,11 +9,15 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License. +#include +#include #include #include "catch2/catch_approx.hpp" #include "catch2/catch_test_macros.hpp" +#include "knowhere/comp/thread_pool.h" #include "knowhere/comp/time_recorder.h" +#include "knowhere/expected.h" #include "knowhere/heap.h" #include "knowhere/utils.h" #include "knowhere/version.h" @@ -122,3 +126,83 @@ TEST_CASE("Test DiskLoad") { REQUIRE(!knowhere::UseDiskLoad(knowhere::IndexEnum::INDEX_HNSW, knowhere::Version::GetCurrentVersion().VersionNumber())); } + +TEST_CASE("Test WaitAllSuccess with folly::Unit futures") { + auto pool = knowhere::ThreadPool::GetGlobalSearchThreadPool(); + std::vector> futures; + + SECTION("All futures succeed") { + for (size_t i = 0; i < 10; ++i) { + futures.emplace_back(pool->push([]() { return folly::Unit(); })); + } + REQUIRE(knowhere::WaitAllSuccess(futures) == knowhere::Status::success); + } + + SECTION("One future throws an exception") { + for (size_t i = 0; i < 10; ++i) { + futures.emplace_back(pool->push([i]() { + if (i == 5) { + throw std::runtime_error("Task failed"); + } + return folly::Unit(); + })); + } + REQUIRE_THROWS_AS(knowhere::WaitAllSuccess(futures), std::runtime_error); + } + + SECTION("WaitAllSuccess should wait until all tasks finish even if any throws exception") { + std::atomic externalValue{0}; + + futures.emplace_back(pool->push([&]() { + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + REQUIRE(externalValue.load() == 1); + externalValue.store(2); + return folly::Unit(); + })); + + futures.emplace_back(pool->push([&]() { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + externalValue.store(1); + throw std::runtime_error("Task failed"); + })); + + REQUIRE_THROWS_AS(knowhere::WaitAllSuccess(futures), std::runtime_error); + REQUIRE(externalValue.load() == 2); + } +} + +TEST_CASE("Test WaitAllSuccess with knowhere::Status futures") { + auto pool = knowhere::ThreadPool::GetGlobalSearchThreadPool(); + std::vector> futures; + + SECTION("All futures succeed with Status::success") { + for (size_t i = 0; i < 10; ++i) { + futures.emplace_back(pool->push([]() { return knowhere::Status::success; })); + } + REQUIRE(knowhere::WaitAllSuccess(futures) == knowhere::Status::success); + } + + SECTION("One future returns Status::invalid_args") { + for (size_t i = 0; i < 10; ++i) { + futures.emplace_back(pool->push([i]() { + if (i == 5) { + return knowhere::Status::invalid_args; + } + return knowhere::Status::success; + })); + } + REQUIRE(knowhere::WaitAllSuccess(futures) == knowhere::Status::invalid_args); + } + + SECTION("One future throws an exception") { + for (size_t i = 0; i < 10; ++i) { + futures.emplace_back(pool->push([i]() { + if (i == 5) { + throw std::runtime_error("Task failed"); + } + return knowhere::Status::success; + })); + } + REQUIRE_THROWS_AS(knowhere::WaitAllSuccess(futures), std::runtime_error); + } +} diff --git a/thirdparty/DiskANN/src/aux_utils.cpp b/thirdparty/DiskANN/src/aux_utils.cpp index 8f409925..04488e56 100644 --- a/thirdparty/DiskANN/src/aux_utils.cpp +++ b/thirdparty/DiskANN/src/aux_utils.cpp @@ -24,6 +24,7 @@ #include "diskann/partition_and_pq.h" #include "diskann/percentile_stats.h" #include "diskann/pq_flash_index.h" +#include "knowhere/comp/thread_pool.h" #include "tsl/robin_set.h" #include "diskann/utils.h" @@ -496,7 +497,7 @@ namespace diskann { paras.Set("saturate_graph", 1); paras.Set("save_path", mem_index_path); paras.Set("accelerate_build", accelerate_build); - paras.Set("shuffle_build", shuffle_build); + paras.Set("shuffle_build", shuffle_build); std::unique_ptr> _pvamanaIndex = std::unique_ptr>(new diskann::Index( @@ -600,7 +601,7 @@ namespace diskann { } _u64 sample_num, sample_dim; - T *samples = nullptr; + std::unique_ptr samples; if (file_exists(sample_file)) { diskann::load_bin(sample_file, samples, sample_num, sample_dim); } else { @@ -620,7 +621,7 @@ namespace diskann { num_nodes_to_cache = points_num; } - uint8_t *pq_code = nullptr; + std::unique_ptr pq_code; diskann::FixedChunkPQTable pq_table; uint64_t pq_chunks, pq_npts = 0; if (file_exists(pq_pivots_path) && file_exists(pq_compressed_code_path)) { @@ -684,7 +685,8 @@ namespace diskann { auto compute_dists = [&, scratch_ids, pq_table_dists]( const unsigned *ids, const _u64 n_ids, float *dists_out) { - aggregate_coords(ids, n_ids, pq_code, pq_chunks, scratch_ids.get()); + aggregate_coords(ids, n_ids, pq_code.get(), pq_chunks, + scratch_ids.get()); pq_dist_lookup(scratch_ids.get(), n_ids, pq_chunks, pq_table_dists.get(), dists_out); }; @@ -745,9 +747,7 @@ namespace diskann { })); } - for (auto &future : futures) { - future.wait(); - } + knowhere::WaitAllSuccess(futures); std::sort(node_count_list.begin(), node_count_list.end(), [](std::pair<_u32, _u32> &a, std::pair<_u32, _u32> &b) { @@ -761,10 +761,6 @@ namespace diskann { save_bin(cache_file, node_list.data(), num_nodes_to_cache, 1); - if (samples != nullptr) - delete[] samples; - if (pq_code != nullptr) - delete[] pq_code; } // General purpose support for DiskANN interface @@ -787,6 +783,7 @@ namespace diskann { std::vector tuning_sample_result_ids_64(tuning_sample_num, 0); std::vector tuning_sample_result_dists(tuning_sample_num, 0); diskann::QueryStats *stats = new diskann::QueryStats[tuning_sample_num]; + std::unique_ptr stats_deleter(stats); std::vector> futures; futures.reserve(tuning_sample_num); @@ -800,9 +797,7 @@ namespace diskann { stats + index); })); } - for (auto &future : futures) { - future.wait(); - } + knowhere::WaitAllSuccess(futures); auto e = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = e - s; double qps = @@ -825,8 +820,6 @@ namespace diskann { } if (cur_bw > 64) stop_flag = true; - - delete[] stats; } return best_bw; } diff --git a/thirdparty/DiskANN/src/index.cpp b/thirdparty/DiskANN/src/index.cpp index e631bbfc..dbae53e9 100644 --- a/thirdparty/DiskANN/src/index.cpp +++ b/thirdparty/DiskANN/src/index.cpp @@ -14,6 +14,7 @@ #include #include #include "knowhere/log.h" +#include "knowhere/comp/thread_pool.h" #include "tsl/robin_set.h" #include "tsl/robin_map.h" #include @@ -363,7 +364,7 @@ namespace diskann { return 0; } size_t tag_bytes_written; - TagT *tag_data = new TagT[_nd + _num_frozen_pts]; + auto tag_data = std::make_unique(_nd + _num_frozen_pts); for (_u32 i = 0; i < _nd; i++) { if (_location_to_tag.find(i) != _location_to_tag.end()) { tag_data[i] = _location_to_tag[i]; @@ -377,11 +378,10 @@ namespace diskann { } try { tag_bytes_written = - save_bin(tags_file, tag_data, _nd + _num_frozen_pts, 1); + save_bin(tags_file, tag_data.get(), _nd + _num_frozen_pts, 1); } catch (std::system_error &e) { throw FileException(tags_file, e, __FUNCSIG__, __FILE__, __LINE__); } - delete[] tag_data; return tag_bytes_written; } @@ -776,7 +776,7 @@ namespace diskann { template unsigned Index::calculate_entry_point() { // allocate and init centroid - float *center = new float[_aligned_dim](); + auto center = std::make_unique(_aligned_dim); for (size_t j = 0; j < _aligned_dim; j++) center[j] = 0; @@ -788,37 +788,37 @@ namespace diskann { center[j] /= (float) _nd; // compute all to one distance - float *distances = new float[_nd](); - auto l2_distance_fun = get_distance_function(diskann::Metric::L2); + auto distances = std::make_unique(_nd); + auto l2_distance_fun = get_distance_function(diskann::Metric::L2); - auto num_threads = _build_thread_pool->size(); + auto num_threads = _build_thread_pool->size(); std::vector> futures; futures.reserve(num_threads); auto future_task_size = DIV_ROUND_UP(_nd, num_threads); for (_s64 i = 0; i < (_s64) _nd; i += future_task_size) { futures.emplace_back(_build_thread_pool->push( - [&, beg_id = i, end_id = std::min(_nd, i + future_task_size)]() { + [&, beg_id = i, end_id = std::min(_nd, i + future_task_size), + center_ptr = center.get(), distances_ptr = distances.get()]() { for (auto node_id = beg_id; node_id < (_s64) end_id; node_id++) { // extract point and distance reference - float &dist = distances[node_id]; + float &dist = distances_ptr[node_id]; const T *cur_vec = _data + (node_id * (size_t) _aligned_dim); if constexpr (std::is_same::value) { - dist = l2_distance_fun(center, cur_vec, (size_t) _aligned_dim); + dist = + l2_distance_fun(center_ptr, cur_vec, (size_t) _aligned_dim); } else { dist = 0; float diff = 0; for (size_t j = 0; j < _aligned_dim; j++) { - diff = (center[j] - (float) cur_vec[j]) * - (center[j] - (float) cur_vec[j]); + diff = (center_ptr[j] - (float) cur_vec[j]) * + (center_ptr[j] - (float) cur_vec[j]); dist += diff; } } } })); } - for (auto &future : futures) { - future.wait(); - } + knowhere::WaitAllSuccess(futures); // find imin unsigned min_idx = 0; float min_dist = distances[0]; @@ -829,8 +829,6 @@ namespace diskann { } } - delete[] distances; - delete[] center; return min_idx; } @@ -1343,7 +1341,7 @@ namespace diskann { _indexingRange = parameters.Get("R"); _indexingMaxC = parameters.Get("C"); const bool accelerate_build = parameters.Get("accelerate_build"); - const bool shuffle_build = parameters.Get("shuffle_build"); + const bool shuffle_build = parameters.Get("shuffle_build"); const float last_round_alpha = parameters.Get("alpha"); unsigned L = _indexingQueueSize; @@ -1483,9 +1481,7 @@ namespace diskann { prune_neighbors(node, pool, pruned_list); })); } - for (auto &future : futures) { - future.wait(); - } + knowhere::WaitAllSuccess(futures); diff = std::chrono::high_resolution_clock::now() - s; sync_time += diff.count(); @@ -1509,9 +1505,7 @@ namespace diskann { } })); } - for (auto &future : futures) { - future.wait(); - } + knowhere::WaitAllSuccess(futures); s = std::chrono::high_resolution_clock::now(); futures.clear(); @@ -1532,9 +1526,7 @@ namespace diskann { } })); } - for (auto &future : futures) { - future.wait(); - } + knowhere::WaitAllSuccess(futures); futures.clear(); for (_s64 node_ctr = 0; node_ctr < (_s64) (visit_order.size()); @@ -1567,9 +1559,7 @@ namespace diskann { })); } } - for (auto &future : futures) { - future.wait(); - } + knowhere::WaitAllSuccess(futures); diff = std::chrono::high_resolution_clock::now() - s; inter_time += diff.count(); @@ -1638,9 +1628,7 @@ namespace diskann { })); } } - for (auto &future : futures) { - future.wait(); - } + knowhere::WaitAllSuccess(futures); if (_nd > 0) { LOG_KNOWHERE_DEBUG_ << "final cleanup done. Link time: " << ((double) link_timer.elapsed() / (double) 1000000) @@ -1683,9 +1671,7 @@ namespace diskann { } } } - for (auto &future : futures) { - future.wait(); - } + knowhere::WaitAllSuccess(futures); diskann::cout << "Prune time : " << timer.elapsed() / 1000 << "ms" << std::endl; @@ -2385,9 +2371,7 @@ namespace diskann { })); } - for (auto &future : futures) { - future.wait(); - } + knowhere::WaitAllSuccess(futures); if (_support_eager_delete) update_in_graph(); diff --git a/thirdparty/DiskANN/src/partition_and_pq.cpp b/thirdparty/DiskANN/src/partition_and_pq.cpp index c8078a22..30792de9 100644 --- a/thirdparty/DiskANN/src/partition_and_pq.cpp +++ b/thirdparty/DiskANN/src/partition_and_pq.cpp @@ -32,6 +32,7 @@ #include "diskann/parameters.h" #include "tsl/robin_set.h" #include "diskann/utils.h" +#include "knowhere/comp/thread_pool.h" #include #include @@ -355,9 +356,7 @@ int generate_pq_pivots(const float *passed_train_data, size_t num_train, } })); } - for (auto &future : futures) { - future.wait(); - } + knowhere::WaitAllSuccess(futures); diskann::save_bin(pq_pivots_path.c_str(), full_pivot_data.get(), (size_t) num_centers, dim); @@ -574,9 +573,7 @@ int generate_pq_data_from_pivots(const std::string data_file, } })); } - for (auto &future : futures) { - future.wait(); - } + knowhere::WaitAllSuccess(futures); if (num_centers > 256) { compressed_file_writer.write( diff --git a/thirdparty/DiskANN/src/utils.cpp b/thirdparty/DiskANN/src/utils.cpp index 989b8ca8..ff86be9a 100644 --- a/thirdparty/DiskANN/src/utils.cpp +++ b/thirdparty/DiskANN/src/utils.cpp @@ -1,4 +1,6 @@ #include "diskann/utils.h" +#include "knowhere/comp/thread_pool.h" + #include namespace diskann { @@ -23,9 +25,7 @@ namespace diskann { } })); } - for (auto& future : futures) { - future.wait(); - } + knowhere::WaitAllSuccess(futures); writr.write((char*) read_buf, npts * ndims * sizeof(float)); }