Skip to content

Commit

Permalink
add iterator for ivf sq
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
  • Loading branch information
cqy123456 committed Mar 27, 2024
1 parent e141e8b commit 47db8bd
Show file tree
Hide file tree
Showing 7 changed files with 474 additions and 369 deletions.
11 changes: 7 additions & 4 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ class IvfIndexNode : public IndexNode {
}

const IndexType* index_ = nullptr;
std::unique_ptr<faiss::IVFFlatIteratorWorkspace> workspace_ = nullptr;
std::unique_ptr<faiss::IVFIteratorWorkspace> workspace_ = nullptr;
std::unique_ptr<float[]> copied_norm_query_ = nullptr;
std::unique_ptr<BitsetViewIDSelector> bw_idselector_ = nullptr;
faiss::IVFSearchParameters ivf_search_params_;
Expand Down Expand Up @@ -891,10 +891,13 @@ IvfIndexNode<DataType, IndexType>::AnnIterator(const DataSet& dataset, const Con
return expected<std::vector<std::shared_ptr<IndexNode::iterator>>>::Err(Status::index_not_trained,
"index not trained");
}
// only support IVFFlat and IVFFlatCC;
// only support IVFFlat, IVFFlatCC, IndexIVFScalarQuantizer and IndexIVFScalarQuantizerCC;
if constexpr (!std::is_same<faiss::IndexIVFFlatCC, IndexType>::value &&
!std::is_same<faiss::IndexIVFFlat, IndexType>::value) {
LOG_KNOWHERE_WARNING_ << "Current index_type: " << Type() << ", only IVFFlat and IVFFlatCC support Iterator.";
!std::is_same<faiss::IndexIVFFlat, IndexType>::value &&
!std::is_same<faiss::IndexIVFScalarQuantizer, IndexType>::value &&
!std::is_same<faiss::IndexIVFScalarQuantizerCC, IndexType>::value) {
LOG_KNOWHERE_WARNING_ << "Current index_type: " << Type()
<< ", only IVFFlat, IVFFlatCC and IVFSQ8 support Iterator.";
return expected<std::vector<std::shared_ptr<IndexNode::iterator>>>::Err(Status::not_implemented,
"index not supported");
} else {
Expand Down
38 changes: 30 additions & 8 deletions tests/ut/test_iterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
return json;
};

auto ivfflat_gen = [&base_gen]() {
auto ivf_base_gen = [&base_gen]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::NPROBE] = 16;
json[knowhere::indexparam::NLIST] = 24;
Expand All @@ -129,6 +129,15 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
return json;
};

auto ivf_sq_cc_gen = [&base_gen]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::NPROBE] = 18;
json[knowhere::indexparam::NLIST] = 24;
json[knowhere::indexparam::SSIZE] = 32;
json[knowhere::indexparam::CODE_SIZE] = 8;
return json;
};

auto rand = GENERATE(1, 2, 3, 5);

const auto train_ds = GenDataSet(nb, dim, rand);
Expand All @@ -138,8 +147,10 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>(
{make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen)}));
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC, ivf_sq_cc_gen)}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
Expand Down Expand Up @@ -169,8 +180,10 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>(
{make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen)}));
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC, ivf_sq_cc_gen)}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
Expand Down Expand Up @@ -203,8 +216,10 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>(
{make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen)}));
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC, ivf_sq_cc_gen)}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
Expand Down Expand Up @@ -263,6 +278,12 @@ TEST_CASE("Test Iterator IVFFlatCC With Newly Insert Vectors", "[float metrics]
return json;
};

auto ivf_sq_cc_gen = [&ivfflatcc_gen]() {
knowhere::Json json = ivfflatcc_gen();
json[knowhere::indexparam::CODE_SIZE] = 8;
return json;
};

auto rand = GENERATE(1, 3, 5);

const auto train_ds = GenDataSet(nb, dim, rand);
Expand All @@ -272,7 +293,8 @@ TEST_CASE("Test Iterator IVFFlatCC With Newly Insert Vectors", "[float metrics]
SECTION("Test Search using iterator with newly inserted vectors") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>(
{make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen)}));
{make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC, ivf_sq_cc_gen)}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json, nb_new);
Expand Down
Loading

0 comments on commit 47db8bd

Please sign in to comment.