Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make sure we rethrow exceptions in async tasks #355

Merged
merged 2 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/knowhere/comp/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include "folly/executors/CPUThreadPoolExecutor.h"
#include "folly/futures/Future.h"
#include "knowhere/expected.h"
#include "knowhere/log.h"

namespace knowhere {
Expand Down Expand Up @@ -211,4 +212,23 @@ class ThreadPool {

constexpr static size_t kTaskQueueFactor = 16;
};

// T is either folly::Unit or Status
template <typename T>
inline Status
WaitAllSuccess(std::vector<folly::Future<T>>& futures) {
static_assert(std::is_same<T, folly::Unit>::value || std::is_same<T, Status>::value,
"WaitAllSuccess can only be used with folly::Unit or knowhere::Status");
auto allFuts = folly::collectAll(futures.begin(), futures.end()).get();
for (const auto& result : allFuts) {
result.throwUnlessValue();
if constexpr (!std::is_same_v<T, folly::Unit>) {
if (result.value() != Status::success) {
return result.value();
}
}
}
return Status::success;
}

} // namespace knowhere
24 changes: 7 additions & 17 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,9 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
return GenResultDataSet(nq, cfg.k.value(), labels, distances);
}
Expand Down Expand Up @@ -233,11 +230,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
RETURN_IF_ERROR(ret);
}
RETURN_IF_ERROR(WaitAllSuccess(futs));
return Status::success;
}

Expand Down Expand Up @@ -348,12 +341,9 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}

int64_t* ids = nullptr;
Expand Down
19 changes: 3 additions & 16 deletions src/common/thread/thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <utility>

#include "knowhere/comp/thread_pool.h"

namespace knowhere {

void
Expand All @@ -33,14 +34,7 @@ ExecOverSearchThreadPool(std::vector<std::function<void()>>& tasks) {
}));
}
std::this_thread::yield();
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& f : futures) {
f.wait();
}
for (auto& f : futures) {
f.result().value();
}
WaitAllSuccess(futures);
}

void
Expand All @@ -55,14 +49,7 @@ ExecOverBuildThreadPool(std::vector<std::function<void()>>& tasks) {
}));
}
std::this_thread::yield();
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& f : futures) {
f.wait();
}
for (auto& f : futures) {
f.result().value();
}
WaitAllSuccess(futures);
}

void
Expand Down
28 changes: 6 additions & 22 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,6 @@ DiskANNIndexNode<DataType>::Deserialize(const BinarySet& binset, const Config& c
std::vector<int64_t> warmup_result_ids_64(warmup_num, 0);
std::vector<DistType> warmup_result_dists(warmup_num, 0);

bool all_searches_are_good = true;

std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(warmup_num);
for (_s64 i = 0; i < (int64_t)warmup_num; ++i) {
Expand All @@ -490,16 +488,14 @@ DiskANNIndexNode<DataType>::Deserialize(const BinarySet& binset, const Config& c
warmup_result_dists.data() + (index * 1), 4);
}));
}
for (auto& future : futures) {
if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) {
all_searches_are_good = false;
}
}

bool failed = TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success;

if (warmup != nullptr) {
diskann::aligned_free(warmup);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a memory leak in case of exception

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TryDiskANNCall catches all exceptions and returns non success Status, thus memory leak won't happen here.

}

if (!all_searches_are_good) {
if (failed) {
LOG_KNOWHERE_ERROR_ << "Failed to do search on warmup file for DiskANN.";
return Status::diskann_inner_error;
}
Expand Down Expand Up @@ -545,7 +541,6 @@ DiskANNIndexNode<DataType>::Search(const DataSet& dataset, const Config& cfg, co
auto p_id = new int64_t[k * nq];
auto p_dist = new DistType[k * nq];

bool all_searches_are_good = true;
std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(nq);
for (int64_t row = 0; row < nq; ++row) {
Expand All @@ -559,13 +554,8 @@ DiskANNIndexNode<DataType>::Search(const DataSet& dataset, const Config& cfg, co
#endif
}));
}
for (auto& future : futures) {
if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) {
all_searches_are_good = false;
}
}

if (!all_searches_are_good) {
if (TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success) {
return expected<DataSetPtr>::Err(Status::diskann_inner_error, "some search failed");
}

Expand Down Expand Up @@ -621,7 +611,6 @@ DiskANNIndexNode<DataType>::RangeSearch(const DataSet& dataset, const Config& cf

std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(nq);
bool all_searches_are_good = true;
for (int64_t row = 0; row < nq; ++row) {
futures.emplace_back(search_pool_->push([&, index = row]() {
std::vector<int64_t> indices;
Expand All @@ -639,12 +628,7 @@ DiskANNIndexNode<DataType>::RangeSearch(const DataSet& dataset, const Config& cf
}
}));
}
for (auto& future : futures) {
if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) {
all_searches_are_good = false;
}
}
if (!all_searches_are_good) {
if (TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success) {
return expected<DataSetPtr>::Err(Status::diskann_inner_error, "some search failed");
}

Expand Down
18 changes: 2 additions & 16 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,7 @@ class FlatIndexNode : public IndexNode {
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
} catch (const std::exception& e) {
std::unique_ptr<int64_t[]> auto_delete_ids(ids);
std::unique_ptr<float[]> auto_delete_dis(distances);
Expand Down Expand Up @@ -216,14 +209,7 @@ class FlatIndexNode : public IndexNode {
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids,
lims);
} catch (const std::exception& e) {
Expand Down
28 changes: 5 additions & 23 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,7 @@ class HnswIndexNode : public IndexNode {
}
}));
}
for (auto& future : futures) {
future.wait();
}
// check for exceptions
for (auto& future : futures) {
future.result().value();
}
WaitAllSuccess(futures);
futures.clear();
}

Expand All @@ -146,13 +140,7 @@ class HnswIndexNode : public IndexNode {
futures.emplace_back(
build_pool->push([&, idx = i]() { index_->repairGraphConnectivity(unreached[idx]); }));
}
for (auto& future : futures) {
future.wait();
}
// check for exceptions
for (auto& future : futures) {
future.result().value();
}
WaitAllSuccess(futures);
}
build_time.RecordSection("graph repair");
LOG_KNOWHERE_INFO_ << "HNSW built with #points num:" << index_->max_elements_ << " #M:" << index_->M_
Expand Down Expand Up @@ -213,9 +201,7 @@ class HnswIndexNode : public IndexNode {
}
}));
}
for (auto& fut : futs) {
fut.wait();
}
WaitAllSuccess(futs);

auto res = GenResultDataSet(nq, k, p_id, p_dist);

Expand Down Expand Up @@ -300,9 +286,7 @@ class HnswIndexNode : public IndexNode {
}));
}
// wait for initial search(in top layers and search for seed_ef in base layer) to finish
for (auto& fut : futs) {
fut.wait();
}
WaitAllSuccess(futs);

return vec;
}
Expand Down Expand Up @@ -365,9 +349,7 @@ class HnswIndexNode : public IndexNode {
}
}));
}
for (auto& fut : futs) {
fut.wait();
}
WaitAllSuccess(futs);

// filter range search result
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius_for_filter, range_filter, dis, ids,
Expand Down
18 changes: 2 additions & 16 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -589,14 +589,7 @@ IvfIndexNode<DataType, IndexType>::Search(const DataSet& dataset, const Config&
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
} catch (const std::exception& e) {
delete[] ids;
delete[] distances;
Expand Down Expand Up @@ -718,14 +711,7 @@ IvfIndexNode<DataType, IndexType>::RangeSearch(const DataSet& dataset, const Con
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims);
} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
Expand Down
Loading
Loading