Skip to content

Commit

Permalink
make sure we do not crash due to uncaught exceptions when we called f…
Browse files Browse the repository at this point in the history
…olly::Future::wait but not trying to get the values; use folly::collect to simplify code

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
  • Loading branch information
zhengbuqian committed Jan 20, 2024
1 parent 4a28aaa commit cbaf51b
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 151 deletions.
20 changes: 20 additions & 0 deletions include/knowhere/comp/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include "folly/executors/CPUThreadPoolExecutor.h"
#include "folly/futures/Future.h"
#include "knowhere/expected.h"
#include "knowhere/log.h"

namespace knowhere {
Expand Down Expand Up @@ -211,4 +212,23 @@ class ThreadPool {

constexpr static size_t kTaskQueueFactor = 16;
};

// T is either folly::Unit or Status
template <typename T>
inline Status
WaitAllSuccess(std::vector<folly::Future<T>>& futures) {
static_assert(std::is_same<T, folly::Unit>::value || std::is_same<T, Status>::value,
"WaitAllSuccess can only be used with folly::Unit or knowhere::Status");
auto allFuts = folly::collectAll(futures.begin(), futures.end()).get();
for (const auto& result : allFuts) {
result.throwUnlessValue();
if constexpr (!std::is_same_v<T, folly::Unit>) {
if (result.value() != Status::success) {
return result.value();
}
}
}
return Status::success;
}

} // namespace knowhere
24 changes: 7 additions & 17 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,9 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
return GenResultDataSet(nq, cfg.k.value(), labels, distances);
}
Expand Down Expand Up @@ -233,11 +230,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
RETURN_IF_ERROR(ret);
}
RETURN_IF_ERROR(WaitAllSuccess(futs));
return Status::success;
}

Expand Down Expand Up @@ -348,12 +341,9 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}

int64_t* ids = nullptr;
Expand Down
19 changes: 3 additions & 16 deletions src/common/thread/thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <utility>

#include "knowhere/comp/thread_pool.h"

namespace knowhere {

void
Expand All @@ -33,14 +34,7 @@ ExecOverSearchThreadPool(std::vector<std::function<void()>>& tasks) {
}));
}
std::this_thread::yield();
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& f : futures) {
f.wait();
}
for (auto& f : futures) {
f.result().value();
}
WaitAllSuccess(futures);
}

void
Expand All @@ -55,14 +49,7 @@ ExecOverBuildThreadPool(std::vector<std::function<void()>>& tasks) {
}));
}
std::this_thread::yield();
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& f : futures) {
f.wait();
}
for (auto& f : futures) {
f.result().value();
}
WaitAllSuccess(futures);
}

void
Expand Down
28 changes: 6 additions & 22 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,6 @@ DiskANNIndexNode<DataType>::Deserialize(const BinarySet& binset, const Config& c
std::vector<int64_t> warmup_result_ids_64(warmup_num, 0);
std::vector<DistType> warmup_result_dists(warmup_num, 0);

bool all_searches_are_good = true;

std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(warmup_num);
for (_s64 i = 0; i < (int64_t)warmup_num; ++i) {
Expand All @@ -490,16 +488,14 @@ DiskANNIndexNode<DataType>::Deserialize(const BinarySet& binset, const Config& c
warmup_result_dists.data() + (index * 1), 4);
}));
}
for (auto& future : futures) {
if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) {
all_searches_are_good = false;
}
}

bool failed = TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success;

if (warmup != nullptr) {
diskann::aligned_free(warmup);
}

if (!all_searches_are_good) {
if (failed) {
LOG_KNOWHERE_ERROR_ << "Failed to do search on warmup file for DiskANN.";
return Status::diskann_inner_error;
}
Expand Down Expand Up @@ -545,7 +541,6 @@ DiskANNIndexNode<DataType>::Search(const DataSet& dataset, const Config& cfg, co
auto p_id = new int64_t[k * nq];
auto p_dist = new DistType[k * nq];

bool all_searches_are_good = true;
std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(nq);
for (int64_t row = 0; row < nq; ++row) {
Expand All @@ -559,13 +554,8 @@ DiskANNIndexNode<DataType>::Search(const DataSet& dataset, const Config& cfg, co
#endif
}));
}
for (auto& future : futures) {
if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) {
all_searches_are_good = false;
}
}

if (!all_searches_are_good) {
if (TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success) {
return expected<DataSetPtr>::Err(Status::diskann_inner_error, "some search failed");
}

Expand Down Expand Up @@ -621,7 +611,6 @@ DiskANNIndexNode<DataType>::RangeSearch(const DataSet& dataset, const Config& cf

std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(nq);
bool all_searches_are_good = true;
for (int64_t row = 0; row < nq; ++row) {
futures.emplace_back(search_pool_->push([&, index = row]() {
std::vector<int64_t> indices;
Expand All @@ -639,12 +628,7 @@ DiskANNIndexNode<DataType>::RangeSearch(const DataSet& dataset, const Config& cf
}
}));
}
for (auto& future : futures) {
if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) {
all_searches_are_good = false;
}
}
if (!all_searches_are_good) {
if (TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success) {
return expected<DataSetPtr>::Err(Status::diskann_inner_error, "some search failed");
}

Expand Down
18 changes: 2 additions & 16 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,7 @@ class FlatIndexNode : public IndexNode {
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
} catch (const std::exception& e) {
std::unique_ptr<int64_t[]> auto_delete_ids(ids);
std::unique_ptr<float[]> auto_delete_dis(distances);
Expand Down Expand Up @@ -216,14 +209,7 @@ class FlatIndexNode : public IndexNode {
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids,
lims);
} catch (const std::exception& e) {
Expand Down
28 changes: 5 additions & 23 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,7 @@ class HnswIndexNode : public IndexNode {
}
}));
}
for (auto& future : futures) {
future.wait();
}
// check for exceptions
for (auto& future : futures) {
future.result().value();
}
WaitAllSuccess(futures);
futures.clear();
}

Expand All @@ -146,13 +140,7 @@ class HnswIndexNode : public IndexNode {
futures.emplace_back(
build_pool->push([&, idx = i]() { index_->repairGraphConnectivity(unreached[idx]); }));
}
for (auto& future : futures) {
future.wait();
}
// check for exceptions
for (auto& future : futures) {
future.result().value();
}
WaitAllSuccess(futures);
}
build_time.RecordSection("graph repair");
LOG_KNOWHERE_INFO_ << "HNSW built with #points num:" << index_->max_elements_ << " #M:" << index_->M_
Expand Down Expand Up @@ -213,9 +201,7 @@ class HnswIndexNode : public IndexNode {
}
}));
}
for (auto& fut : futs) {
fut.wait();
}
WaitAllSuccess(futs);

auto res = GenResultDataSet(nq, k, p_id, p_dist);

Expand Down Expand Up @@ -300,9 +286,7 @@ class HnswIndexNode : public IndexNode {
}));
}
// wait for initial search(in top layers and search for seed_ef in base layer) to finish
for (auto& fut : futs) {
fut.wait();
}
WaitAllSuccess(futs);

return vec;
}
Expand Down Expand Up @@ -365,9 +349,7 @@ class HnswIndexNode : public IndexNode {
}
}));
}
for (auto& fut : futs) {
fut.wait();
}
WaitAllSuccess(futs);

// filter range search result
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius_for_filter, range_filter, dis, ids,
Expand Down
18 changes: 2 additions & 16 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -589,14 +589,7 @@ IvfIndexNode<DataType, IndexType>::Search(const DataSet& dataset, const Config&
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
} catch (const std::exception& e) {
delete[] ids;
delete[] distances;
Expand Down Expand Up @@ -718,14 +711,7 @@ IvfIndexNode<DataType, IndexType>::RangeSearch(const DataSet& dataset, const Con
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims);
} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
Expand Down
Loading

0 comments on commit cbaf51b

Please sign in to comment.