Skip to content

Commit

Permalink
make sure we rethrow exceptions in async tasks (#355)
Browse files Browse the repository at this point in the history
* make sure we do not crash due to uncaught exceptions when we called folly::Future::wait but not trying to get the values; use folly::collect to simplify code

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>

* fix some possible memory leaks when thread pool tasks throw exceptions

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>

---------

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
  • Loading branch information
zhengbuqian committed Jan 23, 2024
1 parent 170e88e commit 042d20d
Show file tree
Hide file tree
Showing 12 changed files with 190 additions and 197 deletions.
20 changes: 20 additions & 0 deletions include/knowhere/comp/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include "folly/executors/CPUThreadPoolExecutor.h"
#include "folly/futures/Future.h"
#include "knowhere/expected.h"
#include "knowhere/log.h"

namespace knowhere {
Expand Down Expand Up @@ -211,4 +212,23 @@ class ThreadPool {

constexpr static size_t kTaskQueueFactor = 16;
};

// T is either folly::Unit or Status
template <typename T>
inline Status
WaitAllSuccess(std::vector<folly::Future<T>>& futures) {
static_assert(std::is_same<T, folly::Unit>::value || std::is_same<T, Status>::value,
"WaitAllSuccess can only be used with folly::Unit or knowhere::Status");
auto allFuts = folly::collectAll(futures.begin(), futures.end()).get();
for (const auto& result : allFuts) {
result.throwUnlessValue();
if constexpr (!std::is_same_v<T, folly::Unit>) {
if (result.value() != Status::success) {
return result.value();
}
}
}
return Status::success;
}

} // namespace knowhere
36 changes: 13 additions & 23 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,17 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
bool is_cosine = IsMetricType(metric_str, metric::COSINE);

int topk = cfg.k.value();
auto labels = new int64_t[nq * topk];
auto distances = new float[nq * topk];
auto labels = std::make_unique<int64_t[]>(nq * topk);
auto distances = std::make_unique<float[]>(nq * topk);

auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i] {
futs.emplace_back(pool->push([&, index = i, labels_ptr = labels.get(), distances_ptr = distances.get()] {
ThreadPool::ScopedOmpSetter setter(1);
auto cur_labels = labels + topk * index;
auto cur_distances = distances + topk * index;
auto cur_labels = labels_ptr + topk * index;
auto cur_distances = distances_ptr + topk * index;

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;
Expand Down Expand Up @@ -128,14 +128,11 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
return GenResultDataSet(nq, cfg.k.value(), labels, distances);
return GenResultDataSet(nq, cfg.k.value(), labels.release(), distances.release());
}

template <typename DataType>
Expand Down Expand Up @@ -233,11 +230,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
RETURN_IF_ERROR(ret);
}
RETURN_IF_ERROR(WaitAllSuccess(futs));
return Status::success;
}

Expand Down Expand Up @@ -348,12 +341,9 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}

int64_t* ids = nullptr;
Expand Down
19 changes: 3 additions & 16 deletions src/common/thread/thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <utility>

#include "knowhere/comp/thread_pool.h"

namespace knowhere {

void
Expand All @@ -33,14 +34,7 @@ ExecOverSearchThreadPool(std::vector<std::function<void()>>& tasks) {
}));
}
std::this_thread::yield();
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& f : futures) {
f.wait();
}
for (auto& f : futures) {
f.result().value();
}
WaitAllSuccess(futures);
}

void
Expand All @@ -55,14 +49,7 @@ ExecOverBuildThreadPool(std::vector<std::function<void()>>& tasks) {
}));
}
std::this_thread::yield();
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& f : futures) {
f.wait();
}
for (auto& f : futures) {
f.result().value();
}
WaitAllSuccess(futures);
}

void
Expand Down
42 changes: 13 additions & 29 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,6 @@ DiskANNIndexNode<DataType>::Deserialize(const BinarySet& binset, const Config& c
std::vector<int64_t> warmup_result_ids_64(warmup_num, 0);
std::vector<DistType> warmup_result_dists(warmup_num, 0);

bool all_searches_are_good = true;

std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(warmup_num);
for (_s64 i = 0; i < (int64_t)warmup_num; ++i) {
Expand All @@ -490,16 +488,14 @@ DiskANNIndexNode<DataType>::Deserialize(const BinarySet& binset, const Config& c
warmup_result_dists.data() + (index * 1), 4);
}));
}
for (auto& future : futures) {
if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) {
all_searches_are_good = false;
}
}

bool failed = TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success;

if (warmup != nullptr) {
diskann::aligned_free(warmup);
}

if (!all_searches_are_good) {
if (failed) {
LOG_KNOWHERE_ERROR_ << "Failed to do search on warmup file for DiskANN.";
return Status::diskann_inner_error;
}
Expand Down Expand Up @@ -542,34 +538,28 @@ DiskANNIndexNode<DataType>::Search(const DataSet& dataset, const Config& cfg, co
search_conf.search_list_size.value());
}

auto p_id = new int64_t[k * nq];
auto p_dist = new DistType[k * nq];
auto p_id = std::make_unique<int64_t[]>(k * nq);
auto p_dist = std::make_unique<DistType[]>(k * nq);

bool all_searches_are_good = true;
std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(nq);
for (int64_t row = 0; row < nq; ++row) {
futures.emplace_back(search_pool_->push([&, index = row]() {
futures.emplace_back(search_pool_->push([&, index = row, p_id_ptr = p_id.get(), p_dist_ptr = p_dist.get()]() {
diskann::QueryStats stats;
pq_flash_index_->cached_beam_search(xq + (index * dim), k, lsearch, p_id + (index * k),
p_dist + (index * k), beamwidth, false, &stats, feder_result, bitset,
filter_ratio, for_tuning);
pq_flash_index_->cached_beam_search(xq + (index * dim), k, lsearch, p_id_ptr + (index * k),
p_dist_ptr + (index * k), beamwidth, false, &stats, feder_result,
bitset, filter_ratio, for_tuning);
#ifdef NOT_COMPILE_FOR_SWIG
knowhere_diskann_search_hops.Observe(stats.n_hops);
#endif
}));
}
for (auto& future : futures) {
if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) {
all_searches_are_good = false;
}
}

if (!all_searches_are_good) {
if (TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success) {
return expected<DataSetPtr>::Err(Status::diskann_inner_error, "some search failed");
}

auto res = GenResultDataSet(nq, k, p_id, p_dist);
auto res = GenResultDataSet(nq, k, p_id.release(), p_dist.release());

// set visit_info json string into result dataset
if (feder_result != nullptr) {
Expand Down Expand Up @@ -621,7 +611,6 @@ DiskANNIndexNode<DataType>::RangeSearch(const DataSet& dataset, const Config& cf

std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(nq);
bool all_searches_are_good = true;
for (int64_t row = 0; row < nq; ++row) {
futures.emplace_back(search_pool_->push([&, index = row]() {
std::vector<int64_t> indices;
Expand All @@ -639,12 +628,7 @@ DiskANNIndexNode<DataType>::RangeSearch(const DataSet& dataset, const Config& cf
}
}));
}
for (auto& future : futures) {
if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) {
all_searches_are_good = false;
}
}
if (!all_searches_are_good) {
if (TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success) {
return expected<DataSetPtr>::Err(Status::diskann_inner_error, "some search failed");
}

Expand Down
18 changes: 2 additions & 16 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,7 @@ class FlatIndexNode : public IndexNode {
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
} catch (const std::exception& e) {
std::unique_ptr<int64_t[]> auto_delete_ids(ids);
std::unique_ptr<float[]> auto_delete_dis(distances);
Expand Down Expand Up @@ -216,14 +209,7 @@ class FlatIndexNode : public IndexNode {
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids,
lims);
} catch (const std::exception& e) {
Expand Down
48 changes: 15 additions & 33 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,7 @@ class HnswIndexNode : public IndexNode {
}
}));
}
for (auto& future : futures) {
future.wait();
}
// check for exceptions
for (auto& future : futures) {
future.result().value();
}
WaitAllSuccess(futures);
futures.clear();
}

Expand All @@ -146,13 +140,7 @@ class HnswIndexNode : public IndexNode {
futures.emplace_back(
build_pool->push([&, idx = i]() { index_->repairGraphConnectivity(unreached[idx]); }));
}
for (auto& future : futures) {
future.wait();
}
// check for exceptions
for (auto& future : futures) {
future.result().value();
}
WaitAllSuccess(futures);
}
build_time.RecordSection("graph repair");
LOG_KNOWHERE_INFO_ << "HNSW built with #points num:" << index_->max_elements_ << " #M:" << index_->M_
Expand Down Expand Up @@ -186,8 +174,8 @@ class HnswIndexNode : public IndexNode {
feder_result = std::make_unique<feder::hnsw::FederResult>();
}

auto p_id = new int64_t[k * nq];
auto p_dist = new DistType[k * nq];
auto p_id = std::make_unique<int64_t[]>(k * nq);
auto p_dist = std::make_unique<DistType[]>(k * nq);

hnswlib::SearchParam param{(size_t)hnsw_cfg.ef.value(), hnsw_cfg.for_tuning.value()};
bool transform =
Expand All @@ -196,12 +184,12 @@ class HnswIndexNode : public IndexNode {
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(search_pool_->push([&, idx = i]() {
futs.emplace_back(search_pool_->push([&, idx = i, p_id_ptr = p_id.get(), p_dist_ptr = p_dist.get()]() {
auto single_query = (const char*)xq + idx * index_->data_size_;
auto rst = index_->searchKnn(single_query, k, bitset, &param, feder_result);
size_t rst_size = rst.size();
auto p_single_dis = p_dist + idx * k;
auto p_single_id = p_id + idx * k;
auto p_single_dis = p_dist_ptr + idx * k;
auto p_single_id = p_id_ptr + idx * k;
for (size_t idx = 0; idx < rst_size; ++idx) {
const auto& [dist, id] = rst[idx];
p_single_dis[idx] = transform ? (-dist) : dist;
Expand All @@ -213,11 +201,9 @@ class HnswIndexNode : public IndexNode {
}
}));
}
for (auto& fut : futs) {
fut.wait();
}
WaitAllSuccess(futs);

auto res = GenResultDataSet(nq, k, p_id, p_dist);
auto res = GenResultDataSet(nq, k, p_id.release(), p_dist.release());

// set visit_info json string into result dataset
if (feder_result != nullptr) {
Expand Down Expand Up @@ -300,9 +286,7 @@ class HnswIndexNode : public IndexNode {
}));
}
// wait for initial search(in top layers and search for seed_ef in base layer) to finish
for (auto& fut : futs) {
fut.wait();
}
WaitAllSuccess(futs);

return vec;
}
Expand Down Expand Up @@ -335,10 +319,6 @@ class HnswIndexNode : public IndexNode {

hnswlib::SearchParam param{(size_t)hnsw_cfg.ef.value()};

int64_t* ids = nullptr;
DistType* dis = nullptr;
size_t* lims = nullptr;

std::vector<std::vector<int64_t>> result_id_array(nq);
std::vector<std::vector<DistType>> result_dist_array(nq);
std::vector<size_t> result_size(nq);
Expand All @@ -365,9 +345,11 @@ class HnswIndexNode : public IndexNode {
}
}));
}
for (auto& fut : futs) {
fut.wait();
}
WaitAllSuccess(futs);

int64_t* ids = nullptr;
DistType* dis = nullptr;
size_t* lims = nullptr;

// filter range search result
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius_for_filter, range_filter, dis, ids,
Expand Down
Loading

0 comments on commit 042d20d

Please sign in to comment.