Skip to content

Commit

Permalink
The final candidate of the version that uses Faiss 1.7.4 (#181)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
  • Loading branch information
alexanderguzhva committed Nov 9, 2023
1 parent 6b2fe56 commit f7fd047
Show file tree
Hide file tree
Showing 395 changed files with 47,637 additions and 12,354 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ut.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
ut:
name: ut on ubuntu-20.04
runs-on: ubuntu-20.04
timeout-minutes: 60
timeout-minutes: 90
strategy:
fail-fast: false
steps:
Expand Down
12 changes: 12 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ knowhere_option(WITH_BENCHMARK "Build with benchmark" OFF)
knowhere_option(WITH_COVERAGE "Build with coverage" OFF)
knowhere_option(WITH_CCACHE "Build with ccache" ON)
knowhere_option(WITH_PROFILER "Build with profiler" OFF)
knowhere_option(WITH_FAISS_TESTS "Build with Faiss unit tests" OFF)

# this is needed for clang on ubuntu:20.04, otherwise
# the linked fails with 'undefined reference' error.
# fmt v9 was used by the time the error was encountered.
# clang on ubuntu:22.04 seems to be unaffected.
# gcc seems to be unaffected.
add_definitions(-DFMT_HEADER_ONLY)

# this is needed for clang on ubuntu:20.04, otherwise
# the linked fails with 'undefined reference' error.
Expand Down Expand Up @@ -156,6 +164,10 @@ if(WITH_BENCHMARK)
add_subdirectory(benchmark)
endif()

if(WITH_FAISS_TESTS)
add_subdirectory(tests/faiss)
endif()

install(TARGETS knowhere
DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
install(DIRECTORY "${PROJECT_SOURCE_DIR}/include/knowhere"
Expand Down
4 changes: 4 additions & 0 deletions cmake/libs/libfaiss.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ knowhere_file_glob(GLOB FAISS_AVX2_SRCS

list(REMOVE_ITEM FAISS_SRCS ${FAISS_AVX512_SRCS})

# disable RHNSW
knowhere_file_glob(GLOB FAISS_RHNSW_SRCS thirdparty/faiss/faiss/impl/RHNSW.cpp)
list(REMOVE_ITEM FAISS_SRCS ${FAISS_RHNSW_SRCS})

if(__X86_64)
set(UTILS_SRC src/simd/distances_ref.cc src/simd/hook.cc)
set(UTILS_SSE_SRC src/simd/distances_sse.cc)
Expand Down
5 changes: 5 additions & 0 deletions conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class KnowhereConan(ConanFile):
"with_ut": [True, False],
"with_benchmark": [True, False],
"with_coverage": [True, False],
"with_faiss_tests": [True, False],
}
default_options = {
"shared": True,
Expand All @@ -46,6 +47,7 @@ class KnowhereConan(ConanFile):
"with_coverage": False,
"boost:without_test": True,
"fmt:header_only": True,
"with_faiss_tests": False,
}

exports_sources = (
Expand Down Expand Up @@ -95,6 +97,8 @@ def requirements(self):
if self.options.with_benchmark:
self.requires("gtest/1.13.0")
self.requires("hdf5/1.14.0")
if self.options.with_faiss_tests:
self.requires("gtest/1.13.0")

@property
def _required_boost_components(self):
Expand Down Expand Up @@ -155,6 +159,7 @@ def generate(self):
tc.variables["WITH_UT"] = self.options.with_ut
tc.variables["WITH_BENCHMARK"] = self.options.with_benchmark
tc.variables["WITH_COVERAGE"] = self.options.with_coverage
tc.variables["WITH_FAISS_TESTS"] = self.options.with_faiss_tests
tc.generate()

deps = CMakeDeps(self)
Expand Down
33 changes: 33 additions & 0 deletions include/knowhere/bitsetview_idselector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (C) 2019-2023 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.

#pragma once

#include <faiss/impl/IDSelector.h>

#include "knowhere/bitsetview.h"

namespace knowhere {

struct BitsetViewIDSelector : faiss::IDSelector {
BitsetView bitset_view;

inline BitsetViewIDSelector(BitsetView bitset_view) : bitset_view{bitset_view} {
}

inline bool
is_member(faiss::idx_t id) const override final {
// it is by design that bitset_view.empty() is not tested here
return (!bitset_view.test(id));
}
};

} // namespace knowhere
1 change: 1 addition & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ namespace indexparam {
// IVF Params
constexpr const char* NPROBE = "nprobe";
constexpr const char* NLIST = "nlist";
constexpr const char* USE_ELKAN = "use_elkan";
constexpr const char* NBITS = "nbits"; // PQ/SQ
constexpr const char* M = "m"; // PQ param for IVFPQ
constexpr const char* SSIZE = "ssize";
Expand Down
50 changes: 32 additions & 18 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "faiss/MetricType.h"
#include "faiss/utils/binary_distances.h"
#include "faiss/utils/distances.h"
#include "knowhere/bitsetview_idselector.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/config.h"
#include "knowhere/expected.h"
Expand Down Expand Up @@ -67,36 +68,40 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
ThreadPool::ScopedOmpSetter setter(1);
auto cur_labels = labels + topk * index;
auto cur_distances = distances + topk * index;

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
case faiss::METRIC_L2: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, bitset);
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector);
break;
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (is_cosine) {
auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, bitset);
faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, id_selector);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector);
}
break;
}
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, bitset);
binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, id_selector);
break;
}
case faiss::METRIC_Hamming: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
std::vector<int32_t> int_distances(topk);
faiss::int_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, int_distances.data()};
binary_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)cur_query, (const uint8_t*)xb, nb,
dim / 8, bitset);
dim / 8, id_selector);
for (int i = 0; i < topk; ++i) {
cur_distances[i] = int_distances[i];
}
Expand All @@ -107,7 +112,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
// only matched ids will be chosen, not to use heap
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
binary_knn_mc(faiss_metric_type, cur_query, (const uint8_t*)xb, 1, nb, topk, dim / 8, cur_distances,
cur_labels, bitset);
cur_labels, id_selector);
break;
}
default: {
Expand Down Expand Up @@ -161,36 +166,40 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
ThreadPool::ScopedOmpSetter setter(1);
auto cur_labels = labels + topk * index;
auto cur_distances = distances + topk * index;

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
case faiss::METRIC_L2: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, bitset);
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector);
break;
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (is_cosine) {
auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, bitset);
faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, id_selector);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector);
}
break;
}
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, bitset);
binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, id_selector);
break;
}
case faiss::METRIC_Hamming: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
std::vector<int32_t> int_distances(topk);
faiss::int_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, int_distances.data()};
binary_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)cur_query, (const uint8_t*)xb, nb,
dim / 8, bitset);
dim / 8, id_selector);
for (int i = 0; i < topk; ++i) {
cur_distances[i] = int_distances[i];
}
Expand All @@ -201,7 +210,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
// only matched ids will be chosen, not to use heap
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
binary_knn_mc(faiss_metric_type, cur_query, (const uint8_t*)xb, 1, nb, topk, dim / 8, cur_distances,
cur_labels, bitset);
cur_labels, id_selector);
break;
}
default: {
Expand Down Expand Up @@ -263,10 +272,14 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
futs.emplace_back(pool->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
faiss::RangeSearchResult res(1);

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
case faiss::METRIC_L2: {
auto cur_query = (const float*)xq + dim * index;
faiss::range_search_L2sqr(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
faiss::range_search_L2sqr(cur_query, (const float*)xb, dim, 1, nb, radius, &res, id_selector);
break;
}
case faiss::METRIC_INNER_PRODUCT: {
Expand All @@ -275,24 +288,25 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
if (is_cosine) {
auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
faiss::range_search_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, radius,
&res, bitset);
&res, id_selector);
} else {
faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res,
bitset);
id_selector);
}
break;
}
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(
faiss::METRIC_Jaccard, cur_query, (const uint8_t*)xb, 1, nb, radius, dim / 8, &res, bitset);
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(faiss::METRIC_Jaccard, cur_query,
(const uint8_t*)xb, 1, nb, radius,
dim / 8, &res, id_selector);
break;
}
case faiss::METRIC_Hamming: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::binary_range_search<faiss::CMin<int, int64_t>, int>(faiss::METRIC_Hamming, cur_query,
(const uint8_t*)xb, 1, nb, (int)radius,
dim / 8, &res, bitset);
dim / 8, &res, id_selector);
break;
}
default: {
Expand Down
1 change: 1 addition & 0 deletions src/common/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ static const std::unordered_set<std::string> ext_legal_json_keys = {"metric_type
"dim",
"nlist", // IVF param
"nprobe", // IVF param
"use_elkan", // IVF param
"ssize", // IVF_FLAT_CC param
"nbits", // IVF_PQ param
"m", // IVF_PQ param
Expand Down
33 changes: 29 additions & 4 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "faiss/index_io.h"
#include "index/flat/flat_config.h"
#include "io/memory_io.h"
#include "knowhere/bitsetview_idselector.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/factory.h"
#include "knowhere/log.h"
Expand Down Expand Up @@ -93,18 +94,31 @@ class FlatIndexNode : public IndexNode {
ThreadPool::ScopedOmpSetter setter(1);
auto cur_ids = ids + k * index;
auto cur_dis = distances + k * index;

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

if constexpr (std::is_same<T, faiss::IndexFlat>::value) {
auto cur_query = (const float*)x + dim * index;
std::unique_ptr<float[]> copied_query = nullptr;
if (is_cosine) {
copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
cur_query = copied_query.get();
}
index_->search(1, cur_query, k, cur_dis, cur_ids, bitset);

faiss::SearchParameters search_params;
search_params.sel = id_selector;

index_->search(1, cur_query, k, cur_dis, cur_ids, &search_params);
}
if constexpr (std::is_same<T, faiss::IndexBinaryFlat>::value) {
auto cur_i_dis = reinterpret_cast<int32_t*>(cur_dis);
index_->search(1, (const uint8_t*)x + index * dim / 8, k, cur_i_dis, cur_ids, bitset);

faiss::SearchParameters search_params;
search_params.sel = id_selector;

index_->search(1, (const uint8_t*)x + index * dim / 8, k, cur_i_dis, cur_ids, &search_params);

if (index_->metric_type == faiss::METRIC_Hamming) {
for (int64_t j = 0; j < k; j++) {
cur_dis[j] = static_cast<float>(cur_i_dis[j]);
Expand Down Expand Up @@ -166,17 +180,28 @@ class FlatIndexNode : public IndexNode {
futs.emplace_back(search_pool_->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
faiss::RangeSearchResult res(1);

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

if constexpr (std::is_same<T, faiss::IndexFlat>::value) {
auto cur_query = (const float*)xq + dim * index;
std::unique_ptr<float[]> copied_query = nullptr;
if (is_cosine) {
copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
cur_query = copied_query.get();
}
index_->range_search(1, cur_query, radius, &res, bitset);

faiss::SearchParameters search_params;
search_params.sel = id_selector;

index_->range_search(1, cur_query, radius, &res, &search_params);
}
if constexpr (std::is_same<T, faiss::IndexBinaryFlat>::value) {
index_->range_search(1, (const uint8_t*)xq + index * dim / 8, radius, &res, bitset);
faiss::SearchParameters search_params;
search_params.sel = id_selector;

index_->range_search(1, (const uint8_t*)xq + index * dim / 8, radius, &res, &search_params);
}
auto elem_cnt = res.lims[1];
result_dist_array[index].resize(elem_cnt);
Expand Down
Loading

0 comments on commit f7fd047

Please sign in to comment.