Skip to content

Commit

Permalink
use IdVal instead of std::pair
Browse files Browse the repository at this point in the history
Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
  • Loading branch information
zhengbuqian committed Apr 17, 2024
1 parent fc5fe59 commit bdf16b5
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 29 deletions.
21 changes: 9 additions & 12 deletions include/knowhere/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,7 @@ class IndexNode : public Object {
class IndexIterator : public IndexNode::iterator {
public:
IndexIterator(bool larger_is_closer, float refine_ratio = 0.0f)
: refine_ratio_(refine_ratio),
refine_(refine_ratio != 0.0f),
// std::priotity_queue is by default a max heap, so we need to reverse the comparator
res_(PairComparator(!larger_is_closer)),
refined_res_(PairComparator(!larger_is_closer)) {
: refine_ratio_(refine_ratio), refine_(refine_ratio != 0.0f), sign_(larger_is_closer ? -1 : 1) {
}

std::pair<int64_t, float>
Expand All @@ -150,7 +146,7 @@ class IndexIterator : public IndexNode::iterator {
auto ret = q.top();
q.pop();
UpdateNext();
return std::make_pair(ret.second, ret.first);
return std::make_pair(ret.id, ret.val * sign_);
}

[[nodiscard]] bool
Expand All @@ -174,12 +170,12 @@ class IndexIterator : public IndexNode::iterator {
throw std::runtime_error("raw_distance not implemented");
}

const bool refine_ratio_;
const float refine_ratio_;
const bool refine_;

std::priority_queue<std::pair<float, int64_t>, std::vector<std::pair<float, int64_t>>, PairComparator> res_;
std::priority_queue<DistId, std::vector<DistId>, std::greater<DistId>> res_;
// unused if refine_ is false
std::priority_queue<std::pair<float, int64_t>, std::vector<std::pair<float, int64_t>>, PairComparator> refined_res_;
std::priority_queue<DistId, std::vector<DistId>, std::greater<DistId>> refined_res_;

private:
inline size_t
Expand All @@ -194,21 +190,22 @@ class IndexIterator : public IndexNode::iterator {
if (batch.empty()) {
return;
}
for (const auto& pair : batch) {
res_.emplace(pair);
for (const auto& dist_id : batch) {
res_.emplace(dist_id.id, dist_id.val * sign_);
}
if (refine_) {
while (!res_.empty() && (refined_res_.empty() || refined_res_.size() < min_refine_size())) {
auto pair = res_.top();
res_.pop();
refined_res_.emplace(raw_distance(pair.id), pair.id);
refined_res_.emplace(pair.id, raw_distance(pair.id) * sign_);
}
}
};
next_batch(batch_handler);
}

bool initialized_ = false;
const int64_t sign_;
};

// An iterator implementation that accepts a list of distances and ids and returns them in order.
Expand Down
4 changes: 2 additions & 2 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,11 @@ class HnswIndexNode : public IndexNode {

protected:
void
next_batch(std::function<void(const std::vector<std::pair<float, int64_t>>&)> batch_handler) override {
next_batch(std::function<void(const std::vector<DistId>&)> batch_handler) override {
index_->getIteratorNextBatch(workspace_.get());
if (transform_) {
for (auto& p : workspace_->dists) {
p.first = -p.first;
p.val = -p.val;
}
}
batch_handler(workspace_->dists);
Expand Down
7 changes: 4 additions & 3 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,12 @@ class IvfIndexNode : public IndexNode {
Status
TrainInternal(const DataSet& dataset, const Config& cfg);

static bool
static constexpr bool
IsQuantized() {
return std::is_same_v<IndexType, faiss::IndexIVFPQ> ||
std::is_same_v<IndexType, faiss::IndexIVFScalarQuantizer> ||
std::is_same_v<IndexType, faiss::IndexIVFScalarQuantizerCC>;
std::is_same_v<IndexType, faiss::IndexIVFScalarQuantizerCC> ||
std::is_same_v<IndexType, faiss::IndexScaNN>;
}

private:
Expand Down Expand Up @@ -282,7 +283,7 @@ class IvfIndexNode : public IndexNode {

protected:
void
next_batch(std::function<void(const std::vector<std::pair<float, int64_t>>&)> batch_handler) override {
next_batch(std::function<void(const std::vector<DistId>&)> batch_handler) override {
index_->getIteratorNextBatch(workspace_.get(), res_.size());
batch_handler(workspace_->dists);
workspace_->dists.clear();
Expand Down
4 changes: 3 additions & 1 deletion thirdparty/faiss/faiss/IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>

#include "knowhere/object.h"

namespace faiss {

using ScopedIds = InvertedLists::ScopedIds;
Expand Down Expand Up @@ -1369,7 +1371,7 @@ void InvertedListScanner::scan_codes_and_return(
const uint8_t* codes,
const float* code_norms,
const idx_t* ids,
std::vector<std::pair<float, int64_t>>& out) const {
std::vector<knowhere::DistId>& out) const {
FAISS_THROW_MSG("Not implemented.");
}

Expand Down
4 changes: 3 additions & 1 deletion thirdparty/faiss/faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include <faiss/invlists/InvertedLists.h>
#include <faiss/utils/Heap.h>

#include "knowhere/object.h"

namespace faiss {

/** Encapsulates a quantizer object for the IndexIVF
Expand Down Expand Up @@ -521,7 +523,7 @@ struct InvertedListScanner {
const uint8_t* codes,
const float* code_norms,
const idx_t* ids,
std::vector<std::pair<float, int64_t>>& out) const;
std::vector<knowhere::DistId>& out) const;

// same as scan_codes, using an iterator
virtual size_t iterate_codes(
Expand Down
9 changes: 5 additions & 4 deletions thirdparty/faiss/faiss/IndexIVFFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <cinttypes>
#include <cstdio>

#include "knowhere/object.h"
#include "knowhere/utils.h"
#include "knowhere/bitsetview_idselector.h"

Expand Down Expand Up @@ -324,7 +325,7 @@ struct IVFFlatScanner : InvertedListScanner {
const uint8_t* codes,
const float* code_norms,
const idx_t* ids,
std::vector<std::pair<float, int64_t>>& out) const override {
std::vector<knowhere::DistId>& out) const override {
const float* list_vecs = (const float*)codes;

// the lambda that filters acceptable elements.
Expand All @@ -335,7 +336,7 @@ struct IVFFlatScanner : InvertedListScanner {
auto apply = [&](const float dis_in, const size_t j) {
const float dis =
(code_norms == nullptr) ? dis_in : (dis_in / code_norms[j]);
out.emplace_back(dis, ids[j]);
out.emplace_back(ids[j], dis);
};
if constexpr (metric == METRIC_INNER_PRODUCT) {
fvec_inner_products_ny_if(
Expand Down Expand Up @@ -455,7 +456,7 @@ struct IVFFlatBitsetViewScanner : InvertedListScanner {
const uint8_t* codes,
const float* code_norms,
const idx_t* ids,
std::vector<std::pair<float, int64_t>>& out) const override {
std::vector<knowhere::DistId>& out) const override {
const float* list_vecs = (const float*)codes;

// the lambda that filters acceptable elements.
Expand All @@ -466,7 +467,7 @@ struct IVFFlatBitsetViewScanner : InvertedListScanner {
auto apply = [&](const float dis_in, const size_t j) {
const float dis =
(code_norms == nullptr) ? dis_in : (dis_in / code_norms[j]);
out.emplace_back(dis, ids[j]);
out.emplace_back(ids[j], dis);
};
if constexpr (metric == METRIC_INNER_PRODUCT) {
fvec_inner_products_ny_if(
Expand Down
4 changes: 3 additions & 1 deletion thirdparty/faiss/faiss/IndexIVFFlat.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include <faiss/IndexIVF.h>

#include "knowhere/object.h"

namespace faiss {
struct IVFFlatIteratorWorkspace {
IVFFlatIteratorWorkspace(
Expand All @@ -27,7 +29,7 @@ struct IVFFlatIteratorWorkspace {
const IVFSearchParameters* search_params = nullptr;
size_t nprobe = 0;
size_t backup_count_threshold = 0; // count * nprobe / nlist
std::vector<std::pair<float, int64_t>> dists; // should be cleared after each use
std::vector<knowhere::DistId> dists; // should be cleared after each use
size_t next_visit_coarse_list_idx = 0;
std::unique_ptr<float[]> coarse_dis = nullptr; // backup coarse centroids distances (heap)
std::unique_ptr<idx_t[]> coarse_idx = nullptr; // backup coarse centroids ids (heap)
Expand Down
4 changes: 2 additions & 2 deletions thirdparty/hnswlib/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -1468,7 +1468,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}
workspace->dists.reserve(retset.size());
for (int i = 0; i < retset.size(); i++) {
workspace->dists.emplace_back(retset[i].distance, retset[i].id);
workspace->dists.emplace_back(retset[i].id, retset[i].distance);
}
workspace->initial_search_done = true;
return;
Expand All @@ -1493,7 +1493,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
add_search_candidate, feder_result);
}
if (!has_deletions || !workspace->bitset.test((int64_t)top.id)) {
workspace->dists.emplace_back(top.distance, top.id);
workspace->dists.emplace_back(top.id, top.distance);
return;
}
}
Expand Down
8 changes: 5 additions & 3 deletions thirdparty/hnswlib/hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ AVX512Capable() {
}
#endif

#include <knowhere/bitsetview.h>
#include <knowhere/feder/HNSW.h>
#include <string.h>

#include <fstream>
Expand All @@ -136,6 +134,10 @@ AVX512Capable() {
#include "io/memory_io.h"
#include "neighbor.h"

#include "knowhere/bitsetview.h"
#include "knowhere/feder/HNSW.h"
#include "knowhere/object.h"

namespace hnswlib {
typedef int64_t labeltype;

Expand Down Expand Up @@ -205,7 +207,7 @@ struct IteratorWorkspace {
// iteration request, we cannot use the visited list in the shared visited list pool,
// thus creating a new visited list for every new iteration request.
std::vector<bool> visited;
std::vector<std::pair<float, int64_t>> dists;
std::vector<knowhere::DistId> dists;
const size_t seed_ef;
std::unique_ptr<SearchParam> param;
// though named raw_query_vector, it is normalized for cosine metric. used
Expand Down

0 comments on commit bdf16b5

Please sign in to comment.