Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add sparse index support to knowhere #199

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/knowhere/comp/brute_force.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ class BruteForce {
static expected<DataSetPtr>
RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
const BitsetView& bitset);

// Perform row oriented sparse vector brute force search.
// For unit test purpose only, assumes that the tensor is csr matrix with types:
// shape: int64
// indptr: int64
// indices: int32
// data: float
static expected<DataSetPtr>
SearchSparse(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
const BitsetView& bitset);
};

} // namespace knowhere
Expand Down
6 changes: 6 additions & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ constexpr const char* INDEX_RAFT_CAGRA = "GPU_RAFT_CAGRA";
constexpr const char* INDEX_HNSW = "HNSW";
constexpr const char* INDEX_DISKANN = "DISKANN";

constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX";
constexpr const char* INDEX_SPARSE_WAND = "SPARSE_WAND";
} // namespace IndexEnum

namespace meta {
Expand Down Expand Up @@ -88,6 +90,10 @@ constexpr const char* HNSW_M = "M";
constexpr const char* EF = "ef";
constexpr const char* SEED_EF = "seed_ef";
constexpr const char* OVERVIEW_LEVELS = "overview_levels";

// Sparse Params
constexpr const char* DROP_RATIO_BUILD = "drop_ratio_build";
constexpr const char* DROP_RATIO_SEARCH = "drop_ratio_search";
} // namespace indexparam

using MetricType = std::string;
Expand Down
5 changes: 5 additions & 0 deletions include/knowhere/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class DataSet {
this->data_[meta::IDS] = Var(std::in_place_index<2>, ids);
}

/**
* For dense float vector, tensor is a rows * dim float array
* For sparse float vector, tensor is a CSR matrix. See namespace sparse in utils.h for details.
* rows and dim should be set for both dense/sparse dataset.
*/
void
SetTensor(const void* tensor) {
std::unique_lock lock(mutex_);
Expand Down
153 changes: 153 additions & 0 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <strings.h>

#include <algorithm>
#include <cstddef>
#include <functional>
#include <vector>

#include "knowhere/binaryset.h"
Expand Down Expand Up @@ -92,4 +94,155 @@ readBinaryPOD(R& in, T& podRef) {
in.read((char*)&podRef, sizeof(T));
}

// utilities for sparse index
namespace sparse {

// type used to represent the id of a vector in the index interface.
// this is same as other index types.
using label_t = int64_t;
// type used to represent the id of a vector inside the index.
using table_t = uint32_t;

/**
CSR(Compressed Sparse Row) Matrix Format:

+------------+-----------+------------+------------+--------------+-------------+-------------+
| | rows | cols | nnz | indptr | indices | data |
+------------+-----------+------------+------------+--------------+-------------+-------------+
| Type | ShapeT | ShapeT | ShapeT | IntPtrT | IndicesT | ValueT |
+------------+-----------+------------+------------+--------------+-------------+-------------+
| elem count | 1 | 1 | 1 | rows + 1 | nnz | nnz |
+------------+-----------+------------+------------+--------------+-------------+-------------+

*/

// indptr, indices and data references the original data, so they should not be freed by the caller.
// csr_matrix must outlive them.
template <typename ValueT, typename IndPtrT = int64_t, typename IndicesT = int32_t, typename ShapeT = int64_t>
void
parse_csr_matrix(const void* csr_matrix, size_t& rows, size_t& cols, size_t& nnz, const IndPtrT*& indptr,
const IndicesT*& indices, const ValueT*& data) {
const ShapeT* header = static_cast<const ShapeT*>(csr_matrix);
rows = header[0];
cols = header[1];
nnz = header[2];

std::size_t offset = 3 * sizeof(ShapeT);

indptr = reinterpret_cast<const IndPtrT*>(static_cast<const char*>(csr_matrix) + offset);
offset += (rows + 1) * sizeof(IndPtrT);

indices = reinterpret_cast<const IndicesT*>(static_cast<const char*>(csr_matrix) + offset);
offset += nnz * sizeof(IndicesT);

data = reinterpret_cast<const ValueT*>(static_cast<const char*>(csr_matrix) + offset);
}

// indices and data references the original data, so they should not be freed by the caller.
// csr_matrix must outlive them.
template <typename ValueT, typename IndPtrT = int64_t, typename IndicesT = int32_t, typename ShapeT = int64_t>
void
get_row(const void* csr_matrix, size_t idx, size_t& len, const IndicesT*& indices, const ValueT*& data) {
const ShapeT* header = reinterpret_cast<const ShapeT*>(csr_matrix);
size_t n_rows = header[0];
if (idx >= n_rows) {
len = 0;
indices = nullptr;
data = nullptr;
return;
}
const IndPtrT* indptr = reinterpret_cast<const IndPtrT*>(header + 3);
const IndicesT* csr_indices = reinterpret_cast<const IndicesT*>(indptr + n_rows + 1);
const ValueT* csr_data = reinterpret_cast<const ValueT*>(csr_indices + header[2]);

len = static_cast<size_t>(indptr[idx + 1] - indptr[idx]);
indices = const_cast<IndicesT*>(&csr_indices[indptr[idx]]);
data = const_cast<ValueT*>(&csr_data[indptr[idx]]);
}

template <typename dist_t = float>
struct Neighbor {
table_t id;
dist_t distance;

Neighbor() = default;
Neighbor(table_t id, dist_t distance) : id(id), distance(distance) {
}

inline friend bool
operator<(const Neighbor& lhs, const Neighbor& rhs) {
return lhs.distance < rhs.distance || (lhs.distance == rhs.distance && lhs.id < rhs.id);
}
inline friend bool
operator>(const Neighbor& lhs, const Neighbor& rhs) {
return !(lhs < rhs);
}
};

// when pushing new elements into a MinMaxHeap, only the `capacity` smallest elements are kept.
// pop()/top() returns the largest element out of those `capacity` smallest elements.
template <typename T = float>
class MinMaxHeap {
public:
explicit MinMaxHeap(int capacity) : capacity_(capacity), pool_(capacity) {
}
void
push(table_t id, T dist) {
if (size_ < capacity_) {
pool_[size_] = {id, dist};
std::push_heap(pool_.begin(), pool_.begin() + ++size_);
} else if (dist < pool_[0].distance) {
sift_down(id, dist);
}
}
table_t
pop() {
std::pop_heap(pool_.begin(), pool_.begin() + size_--);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please

std::pop_heap(pool_.begin(), pool_.begin() + size_);
size_ -= 1;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return pool_[size_].id;
}
[[nodiscard]] size_t
size() const {
return size_;
}
[[nodiscard]] bool
empty() const {
return size() == 0;
}
Neighbor<T>
top() const {
return pool_[0];
}
[[nodiscard]] bool
full() const {
return size_ == capacity_;
}

private:
void
sift_down(table_t id, T dist) {
size_t i = 0;
for (; 2 * i + 1 < size_;) {
size_t j = i;
size_t l = 2 * i + 1, r = 2 * i + 2;
if (pool_[l].distance > dist) {
j = l;
}
if (r < size_ && pool_[r].distance > std::max(pool_[l].distance, dist)) {
j = r;
}
if (i == j) {
break;
}
pool_[i] = pool_[j];
i = j;
}
pool_[i] = {id, dist};
}

size_t size_ = 0, capacity_;
std::vector<Neighbor<T>> pool_;
}; // class MinMaxHeap

} // namespace sparse

} // namespace knowhere
92 changes: 92 additions & 0 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,4 +343,96 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims);
return GenResultDataSet(nq, ids, distances, lims);
}

expected<DataSetPtr>
BruteForce::SearchSparse(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
const BitsetView& bitset) {
auto base_csr = base_dataset->GetTensor();
size_t rows, cols, nnz;
const int64_t* indptr;
const int32_t* indices;
const float* data;
sparse::parse_csr_matrix(base_csr, rows, cols, nnz, indptr, indices, data);

auto xq = query_dataset->GetTensor();
auto nq = query_dataset->GetRows();

BruteForceConfig cfg;
std::string msg;
auto status = Config::Load(cfg, config, knowhere::SEARCH, &msg);
if (status != Status::success) {
return expected<DataSetPtr>::Err(status, std::move(msg));
}

std::string metric_str = cfg.metric_type.value();
auto result = Str2FaissMetricType(metric_str);
if (result.error() != Status::success) {
return expected<DataSetPtr>::Err(result.error(), result.what());
}
if (!IsMetricType(metric_str, metric::IP)) {
return expected<DataSetPtr>::Err(Status::invalid_metric_type,
"Only IP metric type is supported for sparse vector");
}

int topk = cfg.k.value();
auto labels = new sparse::label_t[nq * topk];
auto distances = new float[nq * topk];

auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
auto cur_labels = labels + topk * index;
auto cur_distances = distances + topk * index;
std::fill(cur_labels, cur_labels + topk, -1);
std::fill(cur_distances, cur_distances + topk, std::numeric_limits<float>::quiet_NaN());

size_t len;
const int32_t* cur_indices;
const float* cur_data;
sparse::get_row(xq, index, len, cur_indices, cur_data);
if (len == 0) {
return Status::success;
}
std::unordered_map<int64_t, float> query;
for (size_t j = 0; j < len; ++j) {
query[cur_indices[j]] = cur_data[j];
}
sparse::MinMaxHeap<float> heap(topk);
for (size_t j = 0; j < rows; ++j) {
if (!bitset.empty() && bitset.test(j)) {
continue;
}
float dist = 0.0f;
for (int64_t k = indptr[j]; k < indptr[j + 1]; ++k) {
auto it = query.find(indices[k]);
if (it != query.end()) {
dist += it->second * data[k];
}
}
if (dist > 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code does not distinguish two very different situations:

  1. dist=0 because no item was found in query
  2. dist=0 because the match is exact

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only IP is supported thus dist=0 doesn't mean exact match?

heap.push(j, -dist);
}
}
int result_size = heap.size();
for (int64_t j = result_size - 1; j >= 0; --j) {
cur_labels[j] = heap.top().id;
cur_distances[j] = -heap.top().distance;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

working with negative distances twice... I'd rather change MinMaxHeap implementation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

heap.pop();
}
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
}
return GenResultDataSet(nq, cfg.k.value(), labels, distances);
}

} // namespace knowhere
Loading
Loading