Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AMX support to speed up Faiss Inner-Product #535

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
a28d310
Add AMX support to speed up Faiss Inner-Product
mellonyou Apr 25, 2024
5d2afb3
Merge branch 'zilliztech:main' into main
mellonyou Apr 28, 2024
7b6f49a
Add AMX support to speed up Faiss Inner-Product
mellonyou Apr 28, 2024
64a8804
Port the onednn code to knowhere, and modify it to follow dynamic hoo…
mellonyou May 15, 2024
71fd0cf
Merge branch 'zilliztech:main' into main
mellonyou May 15, 2024
420b8c2
Merge branch 'zilliztech:main' into amx_ip
mellonyou May 15, 2024
64bd1b9
Merge branch 'main' into amx_ip
mellonyou May 15, 2024
47b8f42
Merge branch 'main' into amx_ip
mellonyou May 15, 2024
2dbb422
Merge branch 'amx_ip' of https://github.com/mellonyou/knowhere into a…
mellonyou May 15, 2024
1d281ae
Merge branch 'amx_ip' of https://github.com/mellonyou/knowhere into a…
mellonyou May 15, 2024
9b3c1b1
Merge branch 'amx_ip' of https://github.com/mellonyou/knowhere into a…
mellonyou May 15, 2024
961e7cf
Merge branch 'amx_ip' of https://github.com/mellonyou/knowhere into a…
mellonyou May 15, 2024
5c17682
Merge branch 'amx_ip' of https://github.com/mellonyou/knowhere into a…
mellonyou May 15, 2024
1169f95
Merge branch 'zilliztech:main' into amx_ip
mellonyou May 27, 2024
62eba85
Merge branch 'zilliztech:main' into amx_ip
mellonyou Jun 5, 2024
417601b
Add searchwithbuf and rangesearch interface implementation with AMX o…
mellonyou Jun 5, 2024
76cc32d
Add result filter after AMX Inner Product.
mellonyou Jul 1, 2024
ff5c7cd
Merge branch 'amx_ip' of https://github.com/mellonyou/knowhere into a…
mellonyou Jul 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cmake/libs/libfaiss.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,36 @@ if(__X86_64)
set(UTILS_SSE_SRC src/simd/distances_sse.cc)
set(UTILS_AVX_SRC src/simd/distances_avx.cc)
set(UTILS_AVX512_SRC src/simd/distances_avx512.cc)
set(UTILS_ONEDNN_SRC src/simd/distances_onednn.cc)

add_library(utils_sse OBJECT ${UTILS_SSE_SRC})
add_library(utils_avx OBJECT ${UTILS_AVX_SRC})
add_library(utils_avx512 OBJECT ${UTILS_AVX512_SRC})
add_library(utils_onednn OBJECT ${UTILS_ONEDNN_SRC})

target_compile_options(utils_sse PRIVATE -msse4.2 -mpopcnt)
target_compile_options(utils_avx PRIVATE -mfma -mf16c -mavx2 -mpopcnt)
target_compile_options(utils_avx512 PRIVATE -mfma -mf16c -mavx512f -mavx512dq
-mavx512bw -mpopcnt)

if(WITH_DNNL)
add_library(
knowhere_utils STATIC
${UTILS_SRC} $<TARGET_OBJECTS:utils_sse> $<TARGET_OBJECTS:utils_avx>
$<TARGET_OBJECTS:utils_avx512> $<TARGET_OBJECTS:utils_onednn>)

find_package(DNNL REQUIRED)
find_library(RT_LIB rt)
find_library(DNNL_LIB dnnl)
target_link_libraries(knowhere_utils PUBLIC ${RT_LIB} ${DNNL_LIB})

add_definitions(-DKNOWHERE_WITH_DNNL)
else()
add_library(
knowhere_utils STATIC
${UTILS_SRC} $<TARGET_OBJECTS:utils_sse> $<TARGET_OBJECTS:utils_avx>
$<TARGET_OBJECTS:utils_avx512>)
endif()
target_link_libraries(knowhere_utils PUBLIC glog::glog)
endif()

Expand Down Expand Up @@ -111,6 +127,13 @@ if(__X86_64)
faiss PUBLIC OpenMP::OpenMP_CXX ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES}
faiss_avx2 faiss_avx512 knowhere_utils)
target_compile_definitions(faiss PRIVATE FINTEGER=int)
if(WITH_DNNL)
find_package(DNNL REQUIRED)
find_library(RT_LIB rt)
find_library(DNNL_LIB dnnl)
target_link_libraries(faiss PRIVATE ${RT_LIB} ${DNNL_LIB})
add_definitions(-DFAISS_WITH_DNNL)
endif()
endif()

if(__AARCH64)
Expand Down
3 changes: 3 additions & 0 deletions conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class KnowhereConan(ConanFile):
"with_benchmark": [True, False],
"with_coverage": [True, False],
"with_faiss_tests": [True, False],
"with_dnnl": [True, False],
}
default_options = {
"shared": True,
Expand All @@ -50,6 +51,7 @@ class KnowhereConan(ConanFile):
"boost:without_test": True,
"fmt:header_only": True,
"with_faiss_tests": False,
"with_dnnl": False,
}

exports_sources = (
Expand Down Expand Up @@ -164,6 +166,7 @@ def generate(self):
tc.variables["WITH_BENCHMARK"] = self.options.with_benchmark
tc.variables["WITH_COVERAGE"] = self.options.with_coverage
tc.variables["WITH_FAISS_TESTS"] = self.options.with_faiss_tests
tc.variables["WITH_DNNL"] = self.options.with_dnnl
tc.generate()

deps = CMakeDeps(self)
Expand Down
11 changes: 11 additions & 0 deletions scripts/install_deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
set -e

UNAME="$(uname -s)"
CURRENT_DIR=$(pwd)

case "${UNAME}" in
Linux*) MACHINE=Linux;;
Expand Down Expand Up @@ -77,6 +78,16 @@ if [[ "${MACHINE}" == "Linux" ]]; then
sudo ln -s /usr/local/lib/libopenblasp-r0.3.21.so /usr/local/lib/libblas.so.3 && \
sudo ln -s /usr/local/lib/pkgconfig/openblas.pc /usr/local/lib/pkgconfig/blas.pc
fi

# Install Intel oneDNN
cd ${CURRENT_DIR} && \
rm -rf oneDNN-3.5-pc && \
wget https://github.com/oneapi-src/oneDNN/archive/refs/tags/v3.5-pc.tar.gz && \
tar zxvf v3.5-pc.tar.gz && cd oneDNN-3.5-pc && \
mkdir -p build && cd build && \
export CC=icx && export CXX=icpx && \
cmake .. && make -j8 && \
sudo cmake --build . --target install
elif [[ -x "$(command -v yum)" ]]; then
# for CentOS 7
sudo yum install -y epel-release centos-release-scl-rh wget && \
Expand Down
205 changes: 143 additions & 62 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
#include "knowhere/sparse_utils.h"
#include "knowhere/utils.h"

#ifdef KNOWHERE_WITH_DNNL
#include "simd/distances_onednn.h"
#endif

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
#include "knowhere/tracer.h"
#endif
Expand Down Expand Up @@ -83,73 +87,87 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
auto labels = std::make_unique<int64_t[]>(nq * topk);
auto distances = std::make_unique<float[]>(nq * topk);

auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i, labels_ptr = labels.get(), distances_ptr = distances.get()] {
ThreadPool::ScopedOmpSetter setter(1);
auto cur_labels = labels_ptr + topk * index;
auto cur_distances = distances_ptr + topk * index;
#ifdef KNOWHERE_WITH_DNNL
if (faiss::is_dnnl_enabled() && (faiss_metric_type == faiss::METRIC_INNER_PRODUCT) && (is_cosine == false)) {
BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
case faiss::METRIC_L2: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector);
break;
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (is_cosine) {
auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, id_selector);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector);
faiss::float_minheap_array_t buf{(size_t)nq, (size_t)topk, labels.get(), distances.get()};
faiss::knn_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, &buf, id_selector);
} else {
#endif
auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i, labels_ptr = labels.get(), distances_ptr = distances.get()] {
ThreadPool::ScopedOmpSetter setter(1);
auto cur_labels = labels_ptr + topk * index;
auto cur_distances = distances_ptr + topk * index;

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
case faiss::METRIC_L2: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector);
break;
}
break;
}
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, id_selector);
break;
}
case faiss::METRIC_Hamming: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
std::vector<int32_t> int_distances(topk);
faiss::int_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, int_distances.data()};
binary_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)cur_query, (const uint8_t*)xb, nb,
dim / 8, id_selector);
for (int i = 0; i < topk; ++i) {
cur_distances[i] = int_distances[i];
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (is_cosine) {
auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf,
id_selector);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector);
}
break;
}
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8,
id_selector);
break;
}
case faiss::METRIC_Hamming: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
std::vector<int32_t> int_distances(topk);
faiss::int_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, int_distances.data()};
binary_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)cur_query, (const uint8_t*)xb, nb,
dim / 8, id_selector);
for (int i = 0; i < topk; ++i) {
cur_distances[i] = int_distances[i];
}
break;
}
case faiss::METRIC_Substructure:
case faiss::METRIC_Superstructure: {
// only matched ids will be chosen, not to use heap
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
binary_knn_mc(faiss_metric_type, cur_query, (const uint8_t*)xb, 1, nb, topk, dim / 8,
cur_distances, cur_labels, id_selector);
break;
}
default: {
LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << cfg.metric_type.value();
return Status::invalid_metric_type;
}
break;
}
case faiss::METRIC_Substructure:
case faiss::METRIC_Superstructure: {
// only matched ids will be chosen, not to use heap
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
binary_knn_mc(faiss_metric_type, cur_query, (const uint8_t*)xb, 1, nb, topk, dim / 8, cur_distances,
cur_labels, id_selector);
break;
}
default: {
LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << cfg.metric_type.value();
return Status::invalid_metric_type;
}
}
return Status::success;
}));
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
return Status::success;
}));
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
#ifdef KNOWHERE_WITH_DNNL
}
#endif
auto res = GenResultDataSet(nq, cfg.k.value(), std::move(labels), std::move(distances));

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
Expand Down Expand Up @@ -203,6 +221,15 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
auto labels = ids;
auto distances = dis;

#ifdef KNOWHERE_WITH_DNNL
if (faiss::is_dnnl_enabled() && (faiss_metric_type == faiss::METRIC_INNER_PRODUCT) && (is_cosine == false)) {
BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

faiss::float_minheap_array_t buf{(size_t)nq, (size_t)topk, labels, distances};
faiss::knn_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, &buf, id_selector);
} else {
#endif
auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
Expand Down Expand Up @@ -267,6 +294,9 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
}));
}
RETURN_IF_ERROR(WaitAllSuccess(futs));
#ifdef KNOWHERE_WITH_DNNL
}
#endif

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
if (cfg.trace_id.has_value()) {
Expand Down Expand Up @@ -341,6 +371,54 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
std::vector<std::vector<int64_t>> result_id_array(nq);
std::vector<std::vector<float>> result_dist_array(nq);

#ifdef KNOWHERE_WITH_DNNL
if (faiss::is_dnnl_enabled() && (faiss_metric_type == faiss::METRIC_INNER_PRODUCT) && (is_cosine == false)) {
if (is_sparse) {
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i] {
auto cur_query = (const sparse::SparseRow<float>*)xq + index;
auto xb_sparse = (const sparse::SparseRow<float>*)xb;
for (int j = 0; j < nb; ++j) {
if (!bitset.empty() && bitset.test(j)) {
continue;
}
auto dist = cur_query->dot(xb_sparse[j]);
if (dist > radius && dist <= range_filter) {
result_id_array[index].push_back(j);
result_dist_array[index].push_back(dist);
}
}
return Status::success;
}));
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
} else {
faiss::RangeSearchResult res(nq);

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

faiss::range_search_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, radius, &res, id_selector);
for (int i = 0; i < nq; ++i) {
auto elem_cnt = res.lims[nq];
result_dist_array[i].resize(elem_cnt);
result_id_array[i].resize(elem_cnt);
for (size_t j = 0; j < elem_cnt; j++) {
result_dist_array[i][j] = res.distances[j];
result_id_array[i][j] = res.labels[j];
}
if (cfg.range_filter.value() != defaultRangeFilter) {
FilterRangeSearchResultForOneNq(result_dist_array[i], result_id_array[i], is_ip, radius, range_filter);
}
}
}
} else {
#endif
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
Expand Down Expand Up @@ -423,6 +501,9 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
#ifdef KNOWHERE_WITH_DNNL
}
#endif

auto range_search_result =
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter);
Expand Down
14 changes: 14 additions & 0 deletions src/simd/distances_onednn.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifdef KNOWHERE_WITH_DNNL
#include "distances_onednn.h"

namespace faiss {

thread_local faiss::inner_product_desc inner_product_desc_t;

void fvec_f32bf16f32_inner_product_onednn(uint32_t xrow, uint32_t xcol, uint32_t yrow, uint32_t ycol,
float* in_f32_1, float* in_f32_2, float** out_f32) {
inner_product_desc_t.init(xrow, xcol, yrow, ycol, in_f32_1, in_f32_2);
inner_product_desc_t.execute(out_f32);
}
}
#endif
Loading