Skip to content

Commit

Permalink
add a common subclass for IndexNode::iterator, so that different inde…
Browse files Browse the repository at this point in the history
…xes only need to implement next_batch, add impl for hnsw and ivf

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
  • Loading branch information
zhengbuqian committed Apr 19, 2024
1 parent ad6f71f commit 7e667ff
Show file tree
Hide file tree
Showing 11 changed files with 378 additions and 349 deletions.
6 changes: 6 additions & 0 deletions include/knowhere/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ class BaseConfig : public Config {
CFG_INT trace_flags;
CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE materialized_view_search_info;
CFG_STRING opt_fields_path;
CFG_FLOAT iterator_refine_ratio;
KNOHWERE_DECLARE_CONFIG(BaseConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(metric_type)
.set_default("L2")
Expand Down Expand Up @@ -717,6 +718,11 @@ class BaseConfig : public Config {
.description("materialized view optional fields path")
.allow_empty_without_default()
.for_train();
KNOWHERE_CONFIG_DECLARE_FIELD(iterator_refine_ratio)
.set_default(0.5)
.description("refine ratio for iterator")
.for_iterator()
.for_range_search();
}
};
} // namespace knowhere
Expand Down
101 changes: 99 additions & 2 deletions include/knowhere/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
#ifndef INDEX_NODE_H
#define INDEX_NODE_H

#include <functional>
#include <queue>
#include <utility>
#include <vector>

#include "knowhere/binaryset.h"
#include "knowhere/bitsetview.h"
#include "knowhere/config.h"
Expand All @@ -22,6 +27,7 @@
#include "knowhere/version.h"

namespace knowhere {

class IndexNode : public Object {
public:
IndexNode(const int32_t ver) : version_(ver) {
Expand Down Expand Up @@ -57,7 +63,7 @@ class IndexNode : public Object {
virtual std::pair<int64_t, float>
Next() = 0;
[[nodiscard]] virtual bool
HasNext() const = 0;
HasNext() = 0;
virtual ~iterator() {
}
};
Expand Down Expand Up @@ -115,6 +121,97 @@ class IndexNode : public Object {
Version version_;
};

// Common superclass for iterators that expand search range as needed. Subclasses need
// to override `next_batch` which will add expanded vectors to the results. For indexes
// with quantization, override `raw_distance`.
class IndexIterator : public IndexNode::iterator {
public:
IndexIterator(bool larger_is_closer, float refine_ratio = 0.0f)
: refine_ratio_(refine_ratio), refine_(refine_ratio != 0.0f), sign_(larger_is_closer ? -1 : 1) {
}

std::pair<int64_t, float>
Next() override {
if (!initialized_) {
throw std::runtime_error("Next should not be called before initialization");
}
auto& q = refined_res_.empty() ? res_ : refined_res_;
if (q.empty()) {
throw std::runtime_error("No more elements");
}
auto ret = q.top();
q.pop();
UpdateNext();
return std::make_pair(ret.id, ret.val * sign_);
}

[[nodiscard]] bool
HasNext() override {
if (!initialized_) {
throw std::runtime_error("HasNext should not be called before initialization");
}
return !res_.empty() || !refined_res_.empty();
}

virtual void
initialize() {
if (initialized_) {
throw std::runtime_error("initialize should not be called twice");
}
UpdateNext();
initialized_ = true;
}

protected:
virtual void
next_batch(std::function<void(const std::vector<DistId>&)> batch_handler) = 0;
// will be called only if refine_ratio_ is not 0.
virtual float
raw_distance(int64_t id) {
if (!refine_) {
throw std::runtime_error("raw_distance should not be called for indexes without quantization");
}
throw std::runtime_error("raw_distance not implemented");
}

const float refine_ratio_;
const bool refine_;

std::priority_queue<DistId, std::vector<DistId>, std::greater<DistId>> res_;
// unused if refine_ is false
std::priority_queue<DistId, std::vector<DistId>, std::greater<DistId>> refined_res_;

private:
inline size_t
min_refine_size() const {
// TODO: maybe make this configurable
return std::max((size_t)20, (size_t)(res_.size() * refine_ratio_));
}

void
UpdateNext() {
auto batch_handler = [this](const std::vector<DistId>& batch) {
if (batch.empty()) {
return;
}
for (const auto& dist_id : batch) {
res_.emplace(dist_id.id, dist_id.val * sign_);
}
if (refine_) {
while (!res_.empty() && (refined_res_.empty() || refined_res_.size() < min_refine_size())) {
auto pair = res_.top();
res_.pop();
refined_res_.emplace(pair.id, raw_distance(pair.id) * sign_);
}
}
};
next_batch(batch_handler);
}

bool initialized_ = false;
const int64_t sign_;
};

// An iterator implementation that accepts a list of distances and ids and returns them in order.
class PrecomputedDistanceIterator : public IndexNode::iterator {
public:
Expand Down Expand Up @@ -147,7 +244,7 @@ class PrecomputedDistanceIterator : public IndexNode::iterator {
}

[[nodiscard]] bool
HasNext() const override {
HasNext() override {
return next_ < results_.size() && results_[next_].id != -1;
}

Expand Down
62 changes: 31 additions & 31 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,47 +223,45 @@ class HnswIndexNode : public IndexNode {
}

private:
class iterator : public IndexNode::iterator {
class iterator : public IndexIterator {
public:
iterator(const hnswlib::HierarchicalNSW<DistType, quant_type>* index, const char* query, const bool transform,
const BitsetView& bitset, const bool for_tuning = false, const size_t seed_ef = kIteratorSeedEf)
: index_(index),
const BitsetView& bitset, const bool for_tuning = false, const size_t seed_ef = kIteratorSeedEf,
const float refine_ratio = 0.5f)
: IndexIterator(transform, (hnswlib::HierarchicalNSW<DistType, quant_type>::sq_enabled &&
hnswlib::HierarchicalNSW<DistType, quant_type>::has_raw_data)
? refine_ratio
: 0.0f),
index_(index),
transform_(transform),
workspace_(index_->getIteratorWorkspace(query, seed_ef, for_tuning, bitset)) {
UpdateNext();
}

std::pair<int64_t, DistType>
Next() override {
auto ret = std::make_pair(next_id_, next_dist_);
UpdateNext();
return ret;
}

[[nodiscard]] bool
HasNext() const override {
return has_next_;
}

private:
protected:
void
UpdateNext() {
auto next = index_->getIteratorNext(workspace_.get());
if (next.has_value()) {
auto [dist, id] = next.value();
next_dist_ = transform_ ? (-dist) : dist;
next_id_ = id;
has_next_ = true;
} else {
has_next_ = false;
next_batch(std::function<void(const std::vector<DistId>&)> batch_handler) override {
index_->getIteratorNextBatch(workspace_.get());
if (transform_) {
for (auto& p : workspace_->dists) {
p.val = -p.val;
}
}
batch_handler(workspace_->dists);
workspace_->dists.clear();
}
float
raw_distance(int64_t id) override {
if constexpr (hnswlib::HierarchicalNSW<DistType, quant_type>::sq_enabled &&
hnswlib::HierarchicalNSW<DistType, quant_type>::has_raw_data) {
return (transform_ ? -1 : 1) * index_->calcRefineDistance(workspace_->raw_query_data.get(), id);
}
throw std::runtime_error("raw_distance not supported: index does not have raw data or sq is not enabled");
}

private:
const hnswlib::HierarchicalNSW<DistType, quant_type>* index_;
const bool transform_;
std::unique_ptr<hnswlib::IteratorWorkspace> workspace_;
bool has_next_;
DistType next_dist_;
int64_t next_id_;
};

public:
Expand All @@ -287,8 +285,10 @@ class HnswIndexNode : public IndexNode {
for (int i = 0; i < nq; ++i) {
futs.emplace_back(search_pool_->push([&, i]() {
auto single_query = (const char*)xq + i * index_->data_size_;
vec[i].reset(new iterator(this->index_, single_query, transform, bitset, hnsw_cfg.for_tuning.value(),
hnsw_cfg.seed_ef.value()));
auto it = new iterator(this->index_, single_query, transform, bitset, hnsw_cfg.for_tuning.value(),
hnsw_cfg.seed_ef.value(), hnsw_cfg.iterator_refine_ratio.value());
it->initialize();
vec[i].reset(it);
}));
}
// wait for initial search(in top layers and search for seed_ef in base layer) to finish
Expand Down
58 changes: 25 additions & 33 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,25 @@ class IvfIndexNode : public IndexNode {
Status
TrainInternal(const DataSet& dataset, const Config& cfg);

static constexpr bool
IsQuantized() {
return std::is_same_v<IndexType, faiss::IndexIVFPQ> ||
std::is_same_v<IndexType, faiss::IndexIVFScalarQuantizer> ||
std::is_same_v<IndexType, faiss::IndexIVFScalarQuantizerCC> ||
std::is_same_v<IndexType, faiss::IndexScaNN>;
}

private:
// only support IVFFlat and IVFFlatCC
// iterator will own the copied_norm_query
class iterator : public IndexNode::iterator {
// TODO: iterator should copy and own query data.
class iterator : public IndexIterator {
public:
iterator(const IndexType* index, const float* query_data, std::unique_ptr<float[]>&& copied_norm_query,
const BitsetView& bitset, size_t nprobe)
: index_(index), copied_norm_query_(std::move(copied_norm_query)) {
const BitsetView& bitset, size_t nprobe, bool larger_is_closer, const float refine_ratio = 0.5f)
: IndexIterator(larger_is_closer, IsQuantized() ? refine_ratio : 0.0f),
index_(index),
copied_norm_query_(std::move(copied_norm_query)) {
if (copied_norm_query_ != nullptr) {
query_data = copied_norm_query_.get();
}
Expand All @@ -268,44 +279,22 @@ class IvfIndexNode : public IndexNode {
ivf_search_params_.max_codes = 0;

workspace_ = index_->getIteratorWorkspace(query_data, &ivf_search_params_);
UpdateNext();
}

std::pair<int64_t, float>
Next() override {
auto ret = std::make_pair(next_id_, next_dist_);
UpdateNext();
return ret;
}

[[nodiscard]] bool
HasNext() const override {
return has_next_;
}

private:
protected:
void
UpdateNext() {
auto next = index_->getIteratorNext(workspace_.get());
if (next.has_value()) {
auto [dist, id] = next.value();
next_dist_ = dist;
next_id_ = id;
has_next_ = true;
} else {
has_next_ = false;
}
next_batch(std::function<void(const std::vector<DistId>&)> batch_handler) override {
index_->getIteratorNextBatch(workspace_.get(), res_.size());
batch_handler(workspace_->dists);
workspace_->dists.clear();
}

private:
const IndexType* index_ = nullptr;
std::unique_ptr<faiss::IVFFlatIteratorWorkspace> workspace_ = nullptr;
std::unique_ptr<float[]> copied_norm_query_ = nullptr;
std::unique_ptr<BitsetViewIDSelector> bw_idselector_ = nullptr;
faiss::IVFSearchParameters ivf_search_params_;

bool has_next_ = false;
float next_dist_ = 0.0f;
int64_t next_id_ = -1;
};

std::unique_ptr<IndexType> index_;
Expand Down Expand Up @@ -906,6 +895,7 @@ IvfIndexNode<DataType, IndexType>::AnnIterator(const DataSet& dataset, const Con

const IvfConfig& ivf_cfg = static_cast<const IvfConfig&>(cfg);
bool is_cosine = IsMetricType(ivf_cfg.metric_type.value(), knowhere::metric::COSINE);
auto larger_is_closer = IsMetricType(ivf_cfg.metric_type.value(), knowhere::metric::IP) || is_cosine;

size_t nprobe = ivf_cfg.nprobe.value();

Expand All @@ -921,8 +911,10 @@ IvfIndexNode<DataType, IndexType>::AnnIterator(const DataSet& dataset, const Con
}

// the iterator only own the copied_norm_query.
vec[index].reset(
new iterator(index_.get(), cur_query, std::move(copied_norm_query), bitset, nprobe));
auto it = new iterator(index_.get(), cur_query, std::move(copied_norm_query), bitset, nprobe,
larger_is_closer, ivf_cfg.iterator_refine_ratio.value());
it->initialize();
vec[index].reset(it);
}));
}

Expand Down
8 changes: 7 additions & 1 deletion tests/ut/test_iterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
return json;
};

auto rand = GENERATE(1, 2, 3, 5);
auto rand = GENERATE(1, 2);

const auto train_ds = GenDataSet(nb, dim, rand);
const auto query_ds = GenDataSet(nq, dim, rand + 777);
Expand All @@ -138,6 +138,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>(
{make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen)}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
Expand Down Expand Up @@ -169,6 +171,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>(
{make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen)}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
Expand Down Expand Up @@ -203,6 +207,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>(
{make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen)}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
Expand Down
Loading

0 comments on commit 7e667ff

Please sign in to comment.