Skip to content

Commit

Permalink
add iterator for ivf sq
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
  • Loading branch information
cqy123456 committed Feb 27, 2024
1 parent 60a5c9c commit ca8f376
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 230 deletions.
8 changes: 5 additions & 3 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class IvfIndexNode : public IndexNode {
}

const IndexType* index_ = nullptr;
std::unique_ptr<faiss::IVFFlatIteratorWorkspace> workspace_ = nullptr;
std::unique_ptr<faiss::IVFIteratorWorkspace> workspace_ = nullptr;
std::unique_ptr<float[]> copied_norm_query_ = nullptr;
std::unique_ptr<BitsetViewIDSelector> bw_idselector_ = nullptr;
faiss::IVFSearchParameters ivf_search_params_;
Expand Down Expand Up @@ -829,8 +829,10 @@ IvfIndexNode<DataType, IndexType>::AnnIterator(const DataSet& dataset, const Con
}
// only support IVFFlat and IVFFlatCC;
if constexpr (!std::is_same<faiss::IndexIVFFlatCC, IndexType>::value &&
!std::is_same<faiss::IndexIVFFlat, IndexType>::value) {
LOG_KNOWHERE_WARNING_ << "Current index_type: " << Type() << ", only IVFFlat and IVFFlatCC support Iterator.";
!std::is_same<faiss::IndexIVFFlat, IndexType>::value &&
!std::is_same<faiss::IndexIVFScalarQuantizer, IndexType>::value) {
LOG_KNOWHERE_WARNING_ << "Current index_type: " << Type()
<< ", only IVFFlat, IVFFlatCC and IVFSQ8 support Iterator.";
return expected<std::vector<std::shared_ptr<IndexNode::iterator>>>::Err(Status::not_implemented,
"index not supported");
} else {
Expand Down
9 changes: 6 additions & 3 deletions tests/ut/test_iterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>(
{make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen)}));
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivfflat_gen)}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
Expand Down Expand Up @@ -170,7 +171,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>(
{make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen)}));
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivfflat_gen)}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
Expand Down Expand Up @@ -204,7 +206,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>(
{make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen)}));
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivfflat_gen)}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
Expand Down
178 changes: 176 additions & 2 deletions thirdparty/faiss/faiss/IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,6 @@ void IndexIVF::range_search_preassigned(
bool store_pairs,
const IVFSearchParameters* params,
IndexIVFStats* stats) const {

// Knowhere-specific code:
// only "parallel_mode == 0" branch is supported.

Expand Down Expand Up @@ -933,6 +932,181 @@ void IndexIVF::range_search_preassigned(
}
}

std::unique_ptr<IVFIteratorWorkspace> IndexIVF::getIteratorWorkspace(
const float* query_data,
const IVFSearchParameters* ivfsearchParams) const {
return std::make_unique<IVFIteratorWorkspace>(query_data, ivfsearchParams);
}

std::optional<std::pair<float, idx_t>> IndexIVF::getIteratorNext(
IVFIteratorWorkspace* workspace) const {
auto scan_one_list_then_add_to_backup =
[&](idx_t list_no,
float coarse_list_centroid_dist, // no use, dist for residual.
float* distances,
idx_t* labels,
size_t& counter_back,
size_t max_codes) {
if (list_no < 0) {
// not enough centroids for multiprobe
return (size_t)0;
}
FAISS_THROW_IF_NOT_FMT(
list_no < (idx_t)nlist,
"Invalid list_no=%" PRId64 " nlist=%zd\n",
list_no,
nlist);

// don't waste time on empty lists
if (invlists->is_empty(list_no)) {
return (size_t)0;
}

// get scanner
IDSelector* sel = workspace->search_params
? workspace->search_params->sel
: nullptr;
InvertedListScanner* scanner =
get_InvertedListScanner(false, sel);
scanner->set_query(workspace->query_data);
scanner->set_list(list_no, coarse_list_centroid_dist);
ScopeDeleter1<InvertedListScanner> del(scanner);

size_t segment_num = invlists->get_segment_num(list_no);
size_t scan_cnt = 0;
for (size_t segment_idx = 0; segment_idx < segment_num;
segment_idx++) {
size_t segment_size =
invlists->get_segment_size(list_no, segment_idx);
size_t should_scan_size =
std::min(segment_size, max_codes - scan_cnt);
scan_cnt += should_scan_size;
if (should_scan_size <= 0) {
break;
}
size_t segment_offset =
invlists->get_segment_offset(list_no, segment_idx);
InvertedLists::ScopedCodes scodes(
invlists, list_no, segment_offset);
InvertedLists::ScopedCodeNorms scode_norms(
invlists, list_no, segment_offset);
InvertedLists::ScopedIds sids(
invlists, list_no, segment_offset);

scanner->scan_codes_and_push_back(
should_scan_size,
scodes.get(),
scode_norms.get(),
sids.get(),
distances,
labels,
counter_back);
}

return max_codes;
};

if (!workspace->initial_search_done) {
// snapshot of list_sizes;
auto coarse_list_sizes = std::make_unique<size_t[]>(nlist);
size_t count = 0;
for (size_t list_no = 0; list_no < nlist; ++list_no) {
auto list_size = invlists->list_size(list_no);
coarse_list_sizes[list_no] = list_size;
count += list_size;
if (list_size > workspace->max_coarse_list_size) {
workspace->max_coarse_list_size = list_size;
}
}

// compute backup_count_threshold - (nprobe / nlist) * count
size_t nprobe = workspace->search_params->nprobe
? workspace->search_params->nprobe
: this->nprobe;
nprobe = std::min(nlist, nprobe);
workspace->backup_count_threshold = count * nprobe / nlist;
workspace->max_backup_count = workspace->max_coarse_list_size +
workspace->backup_count_threshold;

// compute distances of all centroids
auto coarse_idx = std::make_unique<idx_t[]>(nlist);
auto coarse_dis = std::make_unique<float[]>(nlist);
quantizer->search(
1,
workspace->query_data,
nlist,
coarse_dis.get(),
coarse_idx.get(),
workspace->search_params
? workspace->search_params->quantizer_params
: nullptr);

// init backup_nodes until more than threshold
invlists->prefetch_lists(coarse_idx.get(), nprobe);
auto labels = std::make_unique<idx_t[]>(workspace->max_backup_count);
auto distances = std::make_unique<float[]>(workspace->max_backup_count);
size_t backup_count = 0;
size_t next_visit_coarse_list_idx = 0;
while (next_visit_coarse_list_idx < nlist &&
backup_count < workspace->backup_count_threshold) {
scan_one_list_then_add_to_backup(
coarse_idx[next_visit_coarse_list_idx],
coarse_dis[next_visit_coarse_list_idx],
distances.get(),
labels.get(),
backup_count,
coarse_list_sizes[coarse_idx[next_visit_coarse_list_idx]]);
next_visit_coarse_list_idx++;
}
workspace->backup_count = backup_count;
workspace->next_visit_coarse_list_idx = next_visit_coarse_list_idx;

workspace->labels = std::move(labels);
workspace->distances = std::move(distances);
workspace->coarse_idx = std::move(coarse_idx);
workspace->coarse_dis = std::move(coarse_dis);
workspace->coarse_list_sizes = std::move(coarse_list_sizes);

workspace->initial_search_done = true;
}

// terminate when no backup nodes.
if (workspace->backup_count == 0 &&
workspace->next_visit_coarse_list_idx >= nlist) {
return std::nullopt;
}
while (workspace->backup_count < workspace->backup_count_threshold &&
workspace->next_visit_coarse_list_idx < nlist) {
auto next_list_idx = workspace->next_visit_coarse_list_idx;
scan_one_list_then_add_to_backup(
workspace->coarse_idx[next_list_idx],
workspace->coarse_dis[next_list_idx],
workspace->distances.get(),
workspace->labels.get(),
workspace->backup_count,
workspace->coarse_list_sizes
[workspace->coarse_idx[next_list_idx]]);
workspace->next_visit_coarse_list_idx++;
}

auto next_dis = workspace->distances[0];
auto next_id = workspace->labels[0];
if (metric_type == METRIC_INNER_PRODUCT) {
heap_pop<CMax<float, int64_t>>(
workspace->backup_count,
workspace->distances.get(),
workspace->labels.get());
} else {
heap_pop<CMin<float, int64_t>>(
workspace->backup_count,
workspace->distances.get(),
workspace->labels.get());
}
workspace->backup_count--;

return std::make_optional(std::make_pair(next_dis, next_id));
}

InvertedListScanner* IndexIVF::get_InvertedListScanner(
bool /*store_pairs*/,
const IDSelector* /* sel */) const {
Expand Down Expand Up @@ -1351,7 +1525,7 @@ size_t InvertedListScanner::scan_codes(
return nup;
}

size_t InvertedListScanner::scan_codes_and_push_back(
void InvertedListScanner::scan_codes_and_push_back(
size_t list_size,
const uint8_t* codes,
const float* code_norms,
Expand Down
40 changes: 39 additions & 1 deletion thirdparty/faiss/faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <stdint.h>
#include <memory>
#include <optional>
#include <unordered_map>
#include <vector>

Expand Down Expand Up @@ -89,6 +90,30 @@ struct InvertedListScanner;
struct IndexIVFStats;
struct CodePacker;

struct IVFIteratorWorkspace {
IVFIteratorWorkspace(
const float* query_data,
const IVFSearchParameters* search_params)
: query_data(query_data), search_params(search_params) {}
const float* query_data = nullptr; // single query
const IVFSearchParameters* search_params = nullptr;
bool initial_search_done = false;
std::unique_ptr<float[]> distances = nullptr; // backup distances (heap)
std::unique_ptr<idx_t[]> labels = nullptr; // backup ids (heap)
// scan a new coarse-list when less than backup_count_threshold
size_t backup_count = 0;
size_t max_backup_count = 0;
size_t backup_count_threshold = 0; // count * nprobe / nlist
size_t next_visit_coarse_list_idx = 0;
// backup coarse centroids distances (heap)
std::unique_ptr<float[]> coarse_dis = nullptr;
// backup coarse centroids ids (heap)
std::unique_ptr<idx_t[]> coarse_idx = nullptr;
// snapshot of the list_size
std::unique_ptr<size_t[]> coarse_list_sizes = nullptr;
size_t max_coarse_list_size = 0;
};

struct IndexIVFInterface : Level1Quantizer {
size_t nprobe = 1; ///< number of probes at query time
size_t max_codes = 0; ///< max nb of codes to visit to do a query
Expand Down Expand Up @@ -313,6 +338,19 @@ struct IndexIVF : Index, IndexIVFInterface {
RangeSearchResult* result,
const SearchParameters* params = nullptr) const override;

std::unique_ptr<IVFIteratorWorkspace> getIteratorWorkspace(
const float* query_data,
const IVFSearchParameters* ivfsearchParams) const;

// Unlike regular knn-search, the iterator does not know the size `k` of the
// returned result.
// The workspace will maintain a heap of at least (nprobe/nlist) nodes for
// iterator `Next()` operation.
// When there are not enough nodes in the heap, iterator will scan the
// next coarse list.
std::optional<std::pair<float, idx_t>> getIteratorNext(
IVFIteratorWorkspace* workspace) const;

/** Get a scanner for this index (store_pairs means ignore labels)
*
* The default search implementation uses this to compute the distances
Expand Down Expand Up @@ -516,7 +554,7 @@ struct InvertedListScanner {
* @param counter_back heap size (will increase)
* @return number of heap pushes performed
*/
virtual size_t scan_codes_and_push_back(
virtual void scan_codes_and_push_back(
size_t list_size,
const uint8_t* codes,
const float* code_norms,
Expand Down
Loading

0 comments on commit ca8f376

Please sign in to comment.