Skip to content

Commit

Permalink
Add searchwithbuf and rangesearch interface implementation with AMX o…
Browse files Browse the repository at this point in the history
…nednn.

Signed-off-by: Eric Zhang <fangzheng.zhang@intel.com>
  • Loading branch information
mellonyou committed Jun 5, 2024
1 parent 62eba85 commit 417601b
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 11 deletions.
63 changes: 63 additions & 0 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,15 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
auto labels = ids;
auto distances = dis;

#ifdef KNOWHERE_WITH_DNNL
if (faiss::is_dnnl_enabled() && (faiss_metric_type == faiss::METRIC_INNER_PRODUCT) && (is_cosine == false)) {
BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

faiss::float_minheap_array_t buf{(size_t)nq, (size_t)topk, labels, distances};
faiss::knn_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, &buf, id_selector);
} else {
#endif
auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
Expand Down Expand Up @@ -291,6 +300,9 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
}));
}
RETURN_IF_ERROR(WaitAllSuccess(futs));
#ifdef KNOWHERE_WITH_DNNL
}
#endif

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
if (cfg.trace_id.has_value()) {
Expand Down Expand Up @@ -367,6 +379,54 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
std::vector<std::vector<int64_t>> result_id_array(nq);
std::vector<std::vector<float>> result_dist_array(nq);

#ifdef KNOWHERE_WITH_DNNL
if (faiss::is_dnnl_enabled() && (faiss_metric_type == faiss::METRIC_INNER_PRODUCT) && (is_cosine == false)) {
if (is_sparse) {
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i] {
auto cur_query = (const sparse::SparseRow<float>*)xq + index;
auto xb_sparse = (const sparse::SparseRow<float>*)xb;
for (int j = 0; j < nb; ++j) {
if (!bitset.empty() && bitset.test(j)) {
continue;
}
auto dist = cur_query->dot(xb_sparse[j]);
if (dist > radius && dist <= range_filter) {
result_id_array[index].push_back(j);
result_dist_array[index].push_back(dist);
}
}
return Status::success;
}));
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
} else {
faiss::RangeSearchResult res(nq);

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

faiss::range_search_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, radius, &res, id_selector);
for (int i = 0; i < nq; ++i) {
auto elem_cnt = res.lims[nq];
result_dist_array[i].resize(elem_cnt);
result_id_array[i].resize(elem_cnt);
for (size_t j = 0; j < elem_cnt; j++) {
result_dist_array[i][j] = res.distances[j];
result_id_array[i][j] = res.labels[j];
}
if (cfg.range_filter.value() != defaultRangeFilter) {
FilterRangeSearchResultForOneNq(result_dist_array[i], result_id_array[i], is_ip, radius, range_filter);
}
}
}
} else {
#endif
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
Expand Down Expand Up @@ -449,6 +509,9 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
#ifdef KNOWHERE_WITH_DNNL
}
#endif

int64_t* ids = nullptr;
float* distances = nullptr;
Expand Down
2 changes: 1 addition & 1 deletion src/simd/distances_onednn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ thread_local faiss::inner_product_desc inner_product_desc_t;
void fvec_f32bf16f32_inner_product_onednn(uint32_t xrow, uint32_t xcol, uint32_t yrow, uint32_t ycol,
float* in_f32_1, float* in_f32_2, float** out_f32) {
inner_product_desc_t.init(xrow, xcol, yrow, ycol, in_f32_1, in_f32_2);
inner_product_desc_t.execut(out_f32);
inner_product_desc_t.execute(out_f32);
}
}
#endif
19 changes: 10 additions & 9 deletions src/simd/distances_onednn.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,6 @@ struct inner_product_desc {
return;
}

this->xrow = xrow;
this->xcol = xcol;
this->yrow = yrow;
this->ycol = ycol;
this->in_f32_1 = in_f32_1;
this->in_f32_2 = in_f32_2;

f32_md1 = dnnl::memory::desc({xrow, xcol}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::ab);
f32_md2 = dnnl::memory::desc({yrow, ycol}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::ab);
f32_dst_md2 = dnnl::memory::desc({xrow, yrow}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::ab);
Expand All @@ -176,11 +169,19 @@ struct inner_product_desc {
bf16_mem1 = dnnl::memory(inner_product_pd.src_desc(), cpu_engine);

// update state for new base data
if (this->in_f32_2 != in_f32_2)
if (this->in_f32_2 != in_f32_2) {
BaseData::getState().store(BASE_DATA_STATE::MODIFIED);
}

this->xrow = xrow;
this->xcol = xcol;
this->yrow = yrow;
this->ycol = ycol;
this->in_f32_1 = in_f32_1;
this->in_f32_2 = in_f32_2;
}

void execut(float** out_f32) {
void execute(float** out_f32) {
dnnl::reorder(f32_mem1, bf16_mem1).execute(engine_stream, f32_mem1, bf16_mem1);
BASE_DATA_STATE expected = BASE_DATA_STATE::MODIFIED;

Expand Down
2 changes: 1 addition & 1 deletion thirdparty/faiss/faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ void exhaustive_inner_product_seq_impl(
resi.begin(i);
for (size_t j = 0; j < ny; j++) {
float ip = res_arr[i*ny + j];
resi.add_result(ip , j);
resi.add_result(ip, j);
}
resi.end();
}
Expand Down

0 comments on commit 417601b

Please sign in to comment.