Skip to content

Commit

Permalink
use make_unique to create buffers for sparse search results
Browse files Browse the repository at this point in the history
Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
  • Loading branch information
zhengbuqian committed Jan 19, 2024
1 parent 3ea2cec commit 31236d8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
17 changes: 8 additions & 9 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr
std::fill(labels, labels + nq * topk, -1);

auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<Status>> futs;
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(nq);
for (int64_t i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i] {
Expand All @@ -406,7 +406,7 @@ BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr

const auto& row = xq[index];
if (row.size() == 0) {
return Status::success;
return;
}
sparse::MaxMinHeap<float> heap(topk);
for (int64_t j = 0; j < rows; ++j) {
Expand All @@ -424,12 +424,10 @@ BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr
cur_distances[j] = heap.top().val;
heap.pop();
}
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
RETURN_IF_ERROR(fut.result().value());
fut.get();
}
return Status::success;
}
Expand All @@ -445,10 +443,11 @@ BruteForce::SearchSparse(const DataSetPtr base_dataset, const DataSetPtr query_d
return expected<DataSetPtr>::Err(status, msg);
}
int topk = cfg.k.value();
auto labels = new sparse::label_t[nq * topk];
auto distances = new float[nq * topk];
SearchSparseWithBuf(base_dataset, query_dataset, labels, distances, config, bitset);
return GenResultDataSet(nq, topk, labels, distances);
auto labels = std::make_unique<sparse::label_t[]>(nq * topk);
auto distances = std::make_unique<float[]>(nq * topk);

SearchSparseWithBuf(base_dataset, query_dataset, labels.get(), distances.get(), config, bitset);
return GenResultDataSet(nq, topk, labels.release(), distances.release());
}

} // namespace knowhere
Expand Down
10 changes: 5 additions & 5 deletions src/index/sparse/sparse_index_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,21 @@ class SparseInvertedIndexNode : public IndexNode {
auto refine_factor = cfg.refine_factor.value_or(10);
auto drop_ratio_search = cfg.drop_ratio_search.value_or(0.0f);

auto p_id = new sparse::label_t[nq * k];
auto p_dist = new float[nq * k];
auto p_id = std::make_unique<sparse::label_t[]>(nq * k);
auto p_dist = std::make_unique<float[]>(nq * k);

std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(nq);
for (int64_t idx = 0; idx < nq; ++idx) {
futs.emplace_back(search_pool_->push([&, idx = idx]() {
futs.emplace_back(search_pool_->push([&, idx = idx, p_id = p_id.get(), p_dist = p_dist.get()]() {
index_->Search(queries[idx], k, drop_ratio_search, p_dist + idx * k, p_id + idx * k, refine_factor,
bitset);
}));
}
for (auto& fut : futs) {
fut.wait();
fut.get();
}
return GenResultDataSet(nq, k, p_id, p_dist);
return GenResultDataSet(nq, k, p_id.release(), p_dist.release());
}

[[nodiscard]] expected<std::vector<std::shared_ptr<IndexNode::iterator>>>
Expand Down

0 comments on commit 31236d8

Please sign in to comment.