Skip to content

Commit

Permalink
Add AMX support to speed up Faiss Inner-Product
Browse files Browse the repository at this point in the history
Signed-off-by: Fangzheng Zhang <fangzheng.zhang@intel.com>
  • Loading branch information
mellonyou committed Apr 28, 2024
1 parent b0f0e5c commit 7b6f49a
Show file tree
Hide file tree
Showing 7 changed files with 375 additions and 79 deletions.
7 changes: 7 additions & 0 deletions cmake/libs/libfaiss.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ if(__X86_64)
faiss PUBLIC OpenMP::OpenMP_CXX ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES}
faiss_avx2 faiss_avx512 knowhere_utils)
target_compile_definitions(faiss PRIVATE FINTEGER=int)
if(WITH_DNNL)
find_package(DNNL REQUIRED)
find_library(RT_LIB rt)
find_library(DNNL_LIB dnnl)
target_link_libraries(faiss PRIVATE ${RT_LIB} ${DNNL_LIB})
add_definitions(-DFAISS_WITH_DNNL)
endif()
endif()

if(__AARCH64)
Expand Down
3 changes: 3 additions & 0 deletions conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class KnowhereConan(ConanFile):
"with_benchmark": [True, False],
"with_coverage": [True, False],
"with_faiss_tests": [True, False],
"with_dnnl": [True, False],
}
default_options = {
"shared": True,
Expand All @@ -50,6 +51,7 @@ class KnowhereConan(ConanFile):
"boost:without_test": True,
"fmt:header_only": True,
"with_faiss_tests": False,
"with_dnnl": False,
}

exports_sources = (
Expand Down Expand Up @@ -164,6 +166,7 @@ def generate(self):
tc.variables["WITH_BENCHMARK"] = self.options.with_benchmark
tc.variables["WITH_COVERAGE"] = self.options.with_coverage
tc.variables["WITH_FAISS_TESTS"] = self.options.with_faiss_tests
tc.variables["WITH_DNNL"] = self.options.with_dnnl
tc.generate()

deps = CMakeDeps(self)
Expand Down
11 changes: 11 additions & 0 deletions scripts/install_deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
set -e

UNAME="$(uname -s)"
CURRENT_DIR=$(pwd)

case "${UNAME}" in
Linux*) MACHINE=Linux;;
Expand Down Expand Up @@ -77,6 +78,16 @@ if [[ "${MACHINE}" == "Linux" ]]; then
sudo ln -s /usr/local/lib/libopenblasp-r0.3.21.so /usr/local/lib/libblas.so.3 && \
sudo ln -s /usr/local/lib/pkgconfig/openblas.pc /usr/local/lib/pkgconfig/blas.pc
fi

# Install Intel oneDNN
cd ${CURRENT_DIR} && \
rm -rf oneDNN-3.5-pc && \
wget https://github.com/oneapi-src/oneDNN/archive/refs/tags/v3.5-pc.tar.gz && \
tar zxvf v3.5-pc.tar.gz && cd oneDNN-3.5-pc && \
mkdir -p build && cd build && \
export CC=icx && export CXX=icpx && \
cmake .. && make -j8 && \
sudo cmake --build . --target install
elif [[ -x "$(command -v yum)" ]]; then
# for CentOS 7
sudo yum install -y epel-release centos-release-scl-rh wget && \
Expand Down
142 changes: 79 additions & 63 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#include "faiss/MetricType.h"
#include "faiss/utils/binary_distances.h"
#include "faiss/utils/distances.h"
#ifdef FAISS_WITH_DNNL
#include "faiss/utils/onednn_utils.h"
#endif
#include "knowhere/bitsetview_idselector.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/config.h"
Expand Down Expand Up @@ -86,73 +89,86 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
auto labels = std::make_unique<int64_t[]>(nq * topk);
auto distances = std::make_unique<float[]>(nq * topk);

auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i, labels_ptr = labels.get(), distances_ptr = distances.get()] {
ThreadPool::ScopedOmpSetter setter(1);
auto cur_labels = labels_ptr + topk * index;
auto cur_distances = distances_ptr + topk * index;

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
case faiss::METRIC_L2: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector);
break;
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (is_cosine) {
auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, id_selector);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector);
#ifdef FAISS_WITH_DNNL
if (faiss::is_dnnl_enabled() && (faiss_metric_type == faiss::METRIC_INNER_PRODUCT) && (is_cosine == false)) {
BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;
faiss::float_minheap_array_t buf{(size_t)nq, (size_t)topk, labels.get(), distances.get()};
faiss::knn_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, &buf, id_selector);
} else {
#endif
auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i, labels_ptr = labels.get(), distances_ptr = distances.get()] {
ThreadPool::ScopedOmpSetter setter(1);
auto cur_labels = labels_ptr + topk * index;
auto cur_distances = distances_ptr + topk * index;

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
case faiss::METRIC_L2: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector);
break;
}
break;
}
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, id_selector);
break;
}
case faiss::METRIC_Hamming: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
std::vector<int32_t> int_distances(topk);
faiss::int_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, int_distances.data()};
binary_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)cur_query, (const uint8_t*)xb, nb,
dim / 8, id_selector);
for (int i = 0; i < topk; ++i) {
cur_distances[i] = int_distances[i];
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (is_cosine) {
auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf,
id_selector);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector);
}
break;
}
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8,
id_selector);
break;
}
case faiss::METRIC_Hamming: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
std::vector<int32_t> int_distances(topk);
faiss::int_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, int_distances.data()};
binary_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)cur_query, (const uint8_t*)xb, nb,
dim / 8, id_selector);
for (int i = 0; i < topk; ++i) {
cur_distances[i] = int_distances[i];
}
break;
}
case faiss::METRIC_Substructure:
case faiss::METRIC_Superstructure: {
// only matched ids will be chosen, not to use heap
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
binary_knn_mc(faiss_metric_type, cur_query, (const uint8_t*)xb, 1, nb, topk, dim / 8,
cur_distances, cur_labels, id_selector);
break;
}
default: {
LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << cfg.metric_type.value();
return Status::invalid_metric_type;
}
break;
}
case faiss::METRIC_Substructure:
case faiss::METRIC_Superstructure: {
// only matched ids will be chosen, not to use heap
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
binary_knn_mc(faiss_metric_type, cur_query, (const uint8_t*)xb, 1, nb, topk, dim / 8, cur_distances,
cur_labels, id_selector);
break;
}
default: {
LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << cfg.metric_type.value();
return Status::invalid_metric_type;
}
}
return Status::success;
}));
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
return Status::success;
}));
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
#ifdef FAISS_WITH_DNNL
}
#endif
auto res = GenResultDataSet(nq, cfg.k.value(), labels.release(), distances.release());

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
Expand Down
64 changes: 48 additions & 16 deletions thirdparty/faiss/faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#include <faiss/impl/IDSelector.h>
#include <faiss/impl/ResultHandler.h>
#include <faiss/utils/utils.h>
#ifdef FAISS_WITH_DNNL
#include <faiss/utils/onednn_utils.h>
#endif

#ifndef FINTEGER
#define FINTEGER long
Expand Down Expand Up @@ -211,30 +214,59 @@ void exhaustive_inner_product_seq_impl(
using SingleResultHandler = typename BlockResultHandler::SingleResultHandler;
int nt = std::min(int(nx), omp_get_max_threads());

#ifdef FAISS_WITH_DNNL
if (is_dnnl_enabled()) {
float *res_arr = NULL;

comput_f32bf16f32_inner_product(nx, d, ny, d, const_cast<float*>(x), const_cast<float*>(y), &res_arr);
if (res_arr == NULL) {
printf("res_arr = NULL\n");
fflush(stderr);
exit(1);
}

#pragma omp parallel num_threads(nt)
{
SingleResultHandler resi(res);
{
SingleResultHandler resi(res);
#pragma omp for
for (int64_t i = 0; i < nx; i++) {
const float* x_i = x + i * d;
resi.begin(i);
for (size_t i = 0; i < nx; i++) {
resi.begin(i);
for (size_t j = 0; j < ny; j++) {
float ip = res_arr[i*ny + j];
resi.add_result(ip , j);
}
resi.end();
}
}
} else {
#endif
#pragma omp parallel num_threads(nt)
{
SingleResultHandler resi(res);
#pragma omp for
for (int64_t i = 0; i < nx; i++) {
const float* x_i = x + i * d;
resi.begin(i);

// the lambda that filters acceptable elements.
auto filter = [&selector](const size_t j) {
return selector.is_member(j);
};
// the lambda that filters acceptable elements.
auto filter = [&selector](const size_t j) {
return selector.is_member(j);
};

// the lambda that applies a filtered element.
auto apply = [&resi](const float ip, const idx_t j) {
resi.add_result(ip, j);
};
// the lambda that applies a filtered element.
auto apply = [&resi](const float ip, const idx_t j) {
resi.add_result(ip, j);
};

// compute distances
fvec_inner_products_ny_if(x_i, y, d, ny, filter, apply);
// compute distances
fvec_inner_products_ny_if(x_i, y, d, ny, filter, apply);

resi.end();
resi.end();
}
}
#ifdef FAISS_WITH_DNNL
}
#endif
}

template <class BlockResultHandler>
Expand Down
11 changes: 11 additions & 0 deletions thirdparty/faiss/faiss/utils/onednn_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#ifdef FAISS_WITH_DNNL
#include "onednn_utils.h"

thread_local faiss::inner_product_desc inner_product_desc_t;

void faiss::comput_f32bf16f32_inner_product(uint32_t xrow, uint32_t xcol, uint32_t yrow, uint32_t ycol,
float* in_f32_1, float* in_f32_2, float** out_f32) {
inner_product_desc_t.init(xrow, xcol, yrow, ycol, in_f32_1, in_f32_2);
inner_product_desc_t.execut(out_f32);
}
#endif
Loading

0 comments on commit 7b6f49a

Please sign in to comment.