Skip to content

Commit

Permalink
Ensure topk results for IVF_FLAT_CC (#353)
Browse files Browse the repository at this point in the history
Signed-off-by: chasingegg <chao.gao@zilliz.com>
  • Loading branch information
chasingegg committed Jan 19, 2024
1 parent 4a28aaa commit 433de97
Show file tree
Hide file tree
Showing 13 changed files with 87 additions and 24 deletions.
1 change: 1 addition & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ constexpr const char* M = "m"; // PQ param for IVFPQ
constexpr const char* SSIZE = "ssize";
constexpr const char* REORDER_K = "reorder_k";
constexpr const char* WITH_RAW_DATA = "with_raw_data";
constexpr const char* ENSURE_TOPK_FULL = "ensure_topk_full";
// RAFT Params
constexpr const char* REFINE_RATIO = "refine_ratio";
constexpr const char* CACHE_DATASET_ON_DEVICE = "cache_dataset_on_device";
Expand Down
15 changes: 12 additions & 3 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -541,17 +541,26 @@ IvfIndexNode<DataType, IndexType>::Search(const DataSet& dataset, const Config&
distances[i + offset] = static_cast<float>(i_distances[i + offset]);
}
}
} else if constexpr (std::is_same<IndexType, faiss::IndexIVFFlat>::value) {
} else if constexpr (std::is_same<IndexType, faiss::IndexIVFFlatCC>::value) {
auto cur_query = (const float*)data + index * dim;
if (is_cosine) {
copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
cur_query = copied_query.get();
}

faiss::IVFSearchParameters ivf_search_params;
ivf_search_params.nprobe = nprobe;
ivf_search_params.max_codes = 0;

ivf_search_params.sel = id_selector;
ivf_search_params.ensure_topk_full = ivf_cfg.ensure_topk_full.value();
if (ivf_search_params.ensure_topk_full) {
ivf_search_params.nprobe = index_->nlist;
// use max_codes to early termination
ivf_search_params.max_codes =
(nprobe * 1.0 / index_->nlist) * (index_->ntotal - bitset.count());
} else {
ivf_search_params.nprobe = nprobe;
ivf_search_params.max_codes = 0;
}

index_->search(1, cur_query, k, distances + offset, ids + offset, &ivf_search_params);
} else if constexpr (std::is_same<IndexType, faiss::IndexScaNN>::value) {
Expand Down
5 changes: 5 additions & 0 deletions src/index/ivf/ivf_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class IvfConfig : public BaseConfig {
CFG_INT nlist;
CFG_INT nprobe;
CFG_BOOL use_elkan;
CFG_BOOL ensure_topk_full;
KNOHWERE_DECLARE_CONFIG(IvfConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(nlist)
.set_default(128)
Expand All @@ -36,6 +37,10 @@ class IvfConfig : public BaseConfig {
.set_default(true)
.description("whether to use elkan algorithm")
.for_train();
KNOWHERE_CONFIG_DECLARE_FIELD(ensure_topk_full)
.set_default(true)
.description("whether to make sure topk results full")
.for_search();
}
};

Expand Down
1 change: 1 addition & 0 deletions tests/ut/test_ivfflat_cc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") {
knowhere::Json json = base_gen();
json[knowhere::indexparam::NLIST] = 128;
json[knowhere::indexparam::NPROBE] = 16;
json[knowhere::indexparam::ENSURE_TOPK_FULL] = false;
return json;
};

Expand Down
26 changes: 26 additions & 0 deletions tests/ut/test_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,32 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") {
REQUIRE(recall > kBruteForceRecallThreshold);
}

SECTION("Test Search with IVFFLATCC ensure topk full") {
using std::make_tuple;
auto ivfflatcc_gen_ = [base_gen, nb]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::NLIST] = 16;
json[knowhere::indexparam::NPROBE] = 1;
json[knowhere::indexparam::SSIZE] = 48;
json[knowhere::meta::TOPK] = nb;
return json;
};
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen_),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success);

auto results = idx.Search(*query_ds, json, nullptr);
auto gt = knowhere::BruteForce::Search<knowhere::fp32>(train_ds, query_ds, json, nullptr);
float recall = GetKNNRecall(*gt.value(), *results.value());
REQUIRE(recall > kBruteForceRecallThreshold);
}

SECTION("Test Search with Bitset") {
using std::make_tuple;
auto [name, gen, threshold] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>, float>({
Expand Down
16 changes: 11 additions & 5 deletions thirdparty/faiss/faiss/IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ void IndexIVF::search_preassigned(

const idx_t unlimited_list_size = std::numeric_limits<idx_t>::max();
idx_t max_codes = params ? params->max_codes : this->max_codes;
bool ensure_topk_full = params ? params->ensure_topk_full : false;
IDSelector* sel = params ? params->sel : nullptr;
const IDSelectorRange* selr = dynamic_cast<const IDSelectorRange*>(sel);
if (selr) {
Expand Down Expand Up @@ -545,7 +546,7 @@ void IndexIVF::search_preassigned(

return list_size;
} else {
size_t scan_cnt = 0;
size_t scan_cnt = 0; // only record valid cnt

size_t segment_num = invlists->get_segment_num(key);
for (size_t segment_idx = 0; segment_idx < segment_num; segment_idx++) {
Expand All @@ -570,8 +571,8 @@ void IndexIVF::search_preassigned(
ids,
simi,
idxi,
k);
scan_cnt += segment_size;
k,
scan_cnt);
}

return scan_cnt;
Expand Down Expand Up @@ -613,7 +614,9 @@ void IndexIVF::search_preassigned(
simi,
idxi,
max_codes - nscan);
if (nscan >= max_codes) {

// if ensure_topk_full enabled, also make sure nscan >= k, then stop search further
if (nscan >= max_codes && (!ensure_topk_full || nscan >= k)) {
break;
}
}
Expand Down Expand Up @@ -1306,13 +1309,15 @@ size_t InvertedListScanner::scan_codes(
const idx_t* ids,
float* simi,
idx_t* idxi,
size_t k) const {
size_t k,
size_t& scan_cnt) const {
size_t nup = 0;

if (!keep_max) {
for (size_t j = 0; j < list_size; j++) {
// // todo aguzhva: use int64_t id instead of j ?
if (!sel || sel->is_member(j)) {
scan_cnt++;
float dis = distance_to_code(codes);
if (code_norms) {
dis /= code_norms[j];
Expand All @@ -1329,6 +1334,7 @@ size_t InvertedListScanner::scan_codes(
for (size_t j = 0; j < list_size; j++) {
// // todo aguzhva: use int64_t id instead of j ?
if (!sel || sel->is_member(j)) {
scan_cnt++;
float dis = distance_to_code(codes);
if (code_norms) {
dis /= code_norms[j];
Expand Down
4 changes: 3 additions & 1 deletion thirdparty/faiss/faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ struct Level1Quantizer {
struct SearchParametersIVF : SearchParameters {
size_t nprobe = 1; ///< number of probes at query time
size_t max_codes = 0; ///< max nb of codes to visit to do a query
bool ensure_topk_full = false; ///< indicate whether we make sure topk result is full
SearchParameters* quantizer_params = nullptr;

virtual ~SearchParametersIVF() {}
Expand Down Expand Up @@ -493,7 +494,8 @@ struct InvertedListScanner {
const idx_t* ids,
float* distances,
idx_t* labels,
size_t k) const;
size_t k,
size_t& scan_cnt) const;

// same as scan_codes, using an iterator
virtual size_t iterate_codes(
Expand Down
12 changes: 8 additions & 4 deletions thirdparty/faiss/faiss/IndexIVFFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,18 +286,20 @@ struct IVFFlatScanner : InvertedListScanner {
const idx_t* ids,
float* simi,
idx_t* idxi,
size_t k) const override {
size_t k,
size_t& scan_cnt) const override {
const float* list_vecs = (const float*)codes;
size_t nup = 0;

// the lambda that filters acceptable elements.
auto filter =
[&](const size_t j) { return (!use_sel || sel->is_member(ids[j])); };

// the lambda that applies a filtered element.
// the lambda that applies a valid element.
auto apply =
[&](const float dis_in, const size_t j) {
const float dis = (code_norms == nullptr) ? dis_in : (dis_in / code_norms[j]);
scan_cnt++;
if (C::cmp(simi[0], dis)) {
const int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
heap_replace_top<C>(k, simi, idxi, dis, id);
Expand Down Expand Up @@ -389,18 +391,20 @@ struct IVFFlatBitsetViewScanner : InvertedListScanner {
const idx_t* __restrict ids,
float* __restrict simi,
idx_t* __restrict idxi,
size_t k) const override {
size_t k,
size_t& scan_cnt) const override {
const float* list_vecs = (const float*)codes;
size_t nup = 0;

// the lambda that filters acceptable elements.
auto filter =
[&](const size_t j) { return (!use_sel || !bitset.test(ids[j])); };

// the lambda that applies a filtered element.
// the lambda that applies a valid element.
auto apply =
[&](const float dis_in, const size_t j) {
const float dis = (code_norms == nullptr) ? dis_in : (dis_in / code_norms[j]);
scan_cnt++;
if (C::cmp(simi[0], dis)) {
const int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
heap_replace_top<C>(k, simi, idxi, dis, id);
Expand Down
3 changes: 2 additions & 1 deletion thirdparty/faiss/faiss/IndexIVFPQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1228,7 +1228,8 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
const idx_t* ids,
float* heap_sim,
idx_t* heap_ids,
size_t k) const override {
size_t k,
size_t& scan_cnt) const override {
KnnSearchResults<C, use_sel> res = {
/* key */ this->key,
/* ids */ this->store_pairs ? nullptr : ids,
Expand Down
5 changes: 3 additions & 2 deletions thirdparty/faiss/faiss/IndexIVFSpectralHash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,13 @@ struct IVFScanner : InvertedListScanner {
const idx_t* ids,
float* simi,
idx_t* idxi,
size_t k) const override {
size_t k,
size_t& scan_cnt) const override {
size_t nup = 0;
for (size_t j = 0; j < list_size; j++) {
if (!sel || sel->is_member(ids[j])) {
float dis = hc.compute(codes);

scan_cnt++;
if (dis < simi[0]) {
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
maxheap_replace_top(k, simi, idxi, dis, id);
Expand Down
3 changes: 2 additions & 1 deletion thirdparty/faiss/faiss/IndexScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ void IndexScalarQuantizer::search(
minheap_heapify(k, D, I);
}
scanner->set_query(x + i * d);
scanner->scan_codes(ntotal, codes.data(), nullptr, nullptr, D, I, k);
size_t scan_cnt = 0;
scanner->scan_codes(ntotal, codes.data(), nullptr, nullptr, D, I, k, scan_cnt);

// re-order heap
if (metric_type == METRIC_L2) {
Expand Down
6 changes: 4 additions & 2 deletions thirdparty/faiss/faiss/impl/ScalarQuantizerScanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ struct IVFSQScannerIP : InvertedListScanner {
const idx_t* ids,
float* simi,
idx_t* idxi,
size_t k) const override {
size_t k,
size_t& scan_cnt) const override {
size_t nup = 0;

for (size_t j = 0; j < list_size; j++, codes += code_size) {
Expand Down Expand Up @@ -215,7 +216,8 @@ struct IVFSQScannerL2 : InvertedListScanner {
const idx_t* ids,
float* simi,
idx_t* idxi,
size_t k) const override {
size_t k,
size_t& scan_cnt) const override {
size_t nup = 0;

// // baseline
Expand Down
14 changes: 9 additions & 5 deletions thirdparty/faiss/tests/test_lowlevel_ivf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,15 @@ void test_lowlevel_access(const char* index_key, MetricType metric) {

// here we get the inverted lists from the InvertedLists
// object but they could come from anywhere

size_t scan_cnt = 0;
scanner->scan_codes(
il->list_size(list_no),
InvertedLists::ScopedCodes(il, list_no).get(),
InvertedLists::ScopedIds(il, list_no).get(),
D.data(),
I.data(),
k);
k,
scan_cnt);

if (j == 0) {
// all results so far come from list_no, so let's check if
Expand Down Expand Up @@ -338,14 +339,15 @@ void test_lowlevel_access_binary(const char* index_key) {

// here we get the inverted lists from the InvertedLists
// object but they could come from anywhere

size_t scan_cnt = 0;
scanner->scan_codes(
il->list_size(list_no),
InvertedLists::ScopedCodes(il, list_no).get(),
InvertedLists::ScopedIds(il, list_no).get(),
D.data(),
I.data(),
k);
k,
scan_cnt);

if (j == 0) {
// all results so far come from list_no, so let's check if
Expand Down Expand Up @@ -500,13 +502,15 @@ void test_threaded_search(const char* index_key, MetricType metric) {
continue;
scanner->set_list(list_no, q_dis[i * nprobe + j]);

size_t scan_cnt = 0;
scanner->scan_codes(
il->list_size(list_no),
InvertedLists::ScopedCodes(il, list_no).get(),
InvertedLists::ScopedIds(il, list_no).get(),
local_D,
local_I,
k);
k,
scan_cnt);
}
};

Expand Down

0 comments on commit 433de97

Please sign in to comment.