From c0f90d36b9a37c1e12a4b712b8effcd6871eea16 Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Wed, 17 Jan 2024 15:33:40 +0800 Subject: [PATCH] Add sparse index support to knowhere with linscan and wand algorithm; added related unit test Signed-off-by: Buqian Zheng --- include/knowhere/comp/brute_force.h | 9 + include/knowhere/comp/index_param.h | 6 + include/knowhere/dataset.h | 19 +- include/knowhere/operands.h | 2 + include/knowhere/sparse_utils.h | 261 ++++++++++ src/common/comp/brute_force.cc | 86 ++++ src/common/factory.cc | 2 +- src/index/sparse/sparse_index_node.cc | 217 ++++++++ src/index/sparse/sparse_inverted_index.h | 464 ++++++++++++++++++ .../sparse/sparse_inverted_index_config.h | 47 ++ tests/ut/test_sparse.cc | 260 ++++++++++ tests/ut/utils.h | 45 ++ 12 files changed, 1416 insertions(+), 2 deletions(-) create mode 100644 include/knowhere/sparse_utils.h create mode 100644 src/index/sparse/sparse_index_node.cc create mode 100644 src/index/sparse/sparse_inverted_index.h create mode 100644 src/index/sparse/sparse_inverted_index_config.h create mode 100644 tests/ut/test_sparse.cc diff --git a/include/knowhere/comp/brute_force.h b/include/knowhere/comp/brute_force.h index e66217d6..00548ad9 100644 --- a/include/knowhere/comp/brute_force.h +++ b/include/knowhere/comp/brute_force.h @@ -33,6 +33,15 @@ class BruteForce { static expected RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset); + + // Perform row oriented sparse vector brute force search. + static expected + SearchSparse(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, + const BitsetView& bitset); + + static Status + SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, sparse::label_t* ids, float* dis, + const Json& config, const BitsetView& bitset); }; } // namespace knowhere diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 1b922fba..06e51ac9 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -45,6 +45,8 @@ constexpr const char* INDEX_RAFT_CAGRA = "GPU_RAFT_CAGRA"; constexpr const char* INDEX_HNSW = "HNSW"; constexpr const char* INDEX_DISKANN = "DISKANN"; +constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX"; +constexpr const char* INDEX_SPARSE_WAND = "SPARSE_WAND"; } // namespace IndexEnum namespace meta { @@ -123,6 +125,10 @@ constexpr const char* HNSW_M = "M"; constexpr const char* EF = "ef"; constexpr const char* SEED_EF = "seed_ef"; constexpr const char* OVERVIEW_LEVELS = "overview_levels"; + +// Sparse Params +constexpr const char* DROP_RATIO_BUILD = "drop_ratio_build"; +constexpr const char* DROP_RATIO_SEARCH = "drop_ratio_search"; } // namespace indexparam using MetricType = std::string; diff --git a/include/knowhere/dataset.h b/include/knowhere/dataset.h index 64324ade..cb377d14 100644 --- a/include/knowhere/dataset.h +++ b/include/knowhere/dataset.h @@ -21,6 +21,7 @@ #include #include "comp/index_param.h" +#include "knowhere/sparse_utils.h" namespace knowhere { @@ -54,7 +55,11 @@ class DataSet : public std::enable_shared_from_this { { auto ptr = std::get_if<3>(&x.second); if (ptr != nullptr) { - delete[](char*)(*ptr); + if (is_sparse) { + delete[](sparse::SparseRow*)(*ptr); + } else { + delete[](char*)(*ptr); + } } } } @@ -78,6 +83,11 @@ class DataSet : public std::enable_shared_from_this { this->data_[meta::IDS] = Var(std::in_place_index<2>, ids); } + /** + * For dense float vector, tensor is a rows * dim float array + * For sparse float vector, tensor is pointer to sparse::Sparse* + * and values in each row should be sorted by column id. + */ void SetTensor(const void* tensor) { std::unique_lock lock(mutex_); @@ -202,6 +212,12 @@ class DataSet : public std::enable_shared_from_this { this->is_owner = is_owner; } + void + SetIsSparse(bool is_sparse) { + std::unique_lock lock(mutex_); + this->is_sparse = is_sparse; + } + // deprecated API template void @@ -225,6 +241,7 @@ class DataSet : public std::enable_shared_from_this { mutable std::shared_mutex mutex_; std::map data_; bool is_owner = true; + bool is_sparse = false; }; using DataSetPtr = std::shared_ptr; inline DataSetPtr diff --git a/include/knowhere/operands.h b/include/knowhere/operands.h index eb115fc9..e8ad14ed 100644 --- a/include/knowhere/operands.h +++ b/include/knowhere/operands.h @@ -16,6 +16,8 @@ #define OPERANDS_H #include +#include + namespace { union fp32_bits { uint32_t as_bits; diff --git a/include/knowhere/sparse_utils.h b/include/knowhere/sparse_utils.h new file mode 100644 index 00000000..b0ff96a4 --- /dev/null +++ b/include/knowhere/sparse_utils.h @@ -0,0 +1,261 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// valributed under the License is valributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "knowhere/operands.h" + +namespace knowhere::sparse { + +// integer type in SparseRow +using table_t = uint32_t; +// type used to represent the id of a vector in the index interface. +// this is same as other index types. +using label_t = int64_t; + +template +struct IdVal { + table_t id; + T val; + + IdVal() = default; + IdVal(table_t id, T val) : id(id), val(val) { + } + + inline friend bool + operator<(const IdVal& lhs, const IdVal& rhs) { + return lhs.val < rhs.val || (lhs.val == rhs.val && lhs.id < rhs.id); + } + inline friend bool + operator>(const IdVal& lhs, const IdVal& rhs) { + return !(lhs < rhs); + } + + inline friend bool + operator==(const IdVal& lhs, const IdVal& rhs) { + return lhs.id == rhs.id && lhs.val == rhs.val; + } +}; + +template +class SparseRow { + static_assert(std::is_same_v, "SparseRow supports float only"); + + public: + // construct an SparseRow with memory allocated to hold `count` elements. + SparseRow(size_t count = 0) + : data_(count ? new uint8_t[count * element_size()] : nullptr), count_(count), own_data_(true) { + } + + SparseRow(size_t count, uint8_t* data, bool own_data) : data_(data), count_(count), own_data_(own_data) { + } + + // copy constructor and copy assignment operator perform deep copy + SparseRow(const SparseRow& other) : SparseRow(other.count_) { + std::copy(other.data_, other.data_ + count_ * element_size(), data_); + } + + SparseRow(SparseRow&& other) noexcept : SparseRow() { + swap(*this, other); + } + + SparseRow& + operator=(const SparseRow& other) { + if (this != &other) { + SparseRow tmp(other); + swap(*this, tmp); + } + return *this; + } + + SparseRow& + operator=(SparseRow&& other) noexcept { + swap(*this, other); + return *this; + } + + ~SparseRow() { + if (own_data_ && data_ != nullptr) { + delete[] data_; + data_ = nullptr; + } + } + + size_t + size() const { + return count_; + } + + size_t + memory_usage() const { + return count_ * element_size() + sizeof(*this); + } + + void* + data() { + return data_; + } + + const void* + data() const { + return data_; + } + + // dim of a sparse vector is the max index + 1, or 0 for an empty vector. + int64_t + dim() const { + if (count_ == 0) { + return 0; + } + auto* elem = reinterpret_cast(data_) + count_ - 1; + return elem->index + 1; + } + + IdVal + operator[](size_t i) const { + auto* elem = reinterpret_cast(data_) + i; + return {elem->index, elem->value}; + } + + void + set_at(size_t i, table_t index, T value) { + auto* elem = reinterpret_cast(data_) + i; + elem->index = index; + elem->value = value; + } + + float + dot(const SparseRow& other) const { + float product_sum = 0.0f; + size_t i = 0, j = 0; + while (i < count_ && j < other.count_) { + auto* left = reinterpret_cast(data_) + i; + auto* right = reinterpret_cast(other.data_) + j; + + if (left->index < right->index) { + ++i; + } else if (left->index > right->index) { + ++j; + } else { + product_sum += left->value * right->value; + ++i; + ++j; + } + } + return product_sum; + } + + friend void + swap(SparseRow& left, SparseRow& right) { + using std::swap; + swap(left.count_, right.count_); + swap(left.data_, right.data_); + swap(left.own_data_, right.own_data_); + } + + static inline size_t + element_size() { + return sizeof(table_t) + sizeof(T); + } + + private: + // ElementProxy is used to access elements in the data_ array and should + // never be actually constructed. + struct __attribute__((packed)) ElementProxy { + table_t index; + T value; + ElementProxy() = delete; + ElementProxy(const ElementProxy&) = delete; + }; + // data_ must be sorted by column id. use raw pointer for easy mmap and zero + // copy. + uint8_t* data_; + size_t count_; + bool own_data_; +}; + +// When pushing new elements into a MaxMinHeap, only `capacity` elements with the +// largest val are kept. pop()/top() returns the smallest element out of them. +template +class MaxMinHeap { + public: + explicit MaxMinHeap(int capacity) : capacity_(capacity), pool_(capacity) { + } + void + push(table_t id, T val) { + if (size_ < capacity_) { + pool_[size_] = {id, val}; + size_ += 1; + std::push_heap(pool_.begin(), pool_.begin() + size_, std::greater>()); + } else if (val > pool_[0].val) { + sift_down(id, val); + } + } + table_t + pop() { + std::pop_heap(pool_.begin(), pool_.begin() + size_, std::greater>()); + size_ -= 1; + return pool_[size_].id; + } + [[nodiscard]] size_t + size() const { + return size_; + } + [[nodiscard]] bool + empty() const { + return size() == 0; + } + IdVal + top() const { + return pool_[0]; + } + [[nodiscard]] bool + full() const { + return size_ == capacity_; + } + + private: + void + sift_down(table_t id, T val) { + size_t i = 0; + for (; 2 * i + 1 < size_;) { + size_t j = i; + size_t l = 2 * i + 1, r = 2 * i + 2; + if (pool_[l].val < val) { + j = l; + } + if (r < size_ && pool_[r].val < std::min(pool_[l].val, val)) { + j = r; + } + if (i == j) { + break; + } + pool_[i] = pool_[j]; + i = j; + } + pool_[i] = {id, val}; + } + + size_t size_ = 0, capacity_; + std::vector> pool_; +}; // class MaxMinHeap + +} // namespace knowhere::sparse diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc index 47bd6adc..6bf2ca9d 100644 --- a/src/common/comp/brute_force.cc +++ b/src/common/comp/brute_force.cc @@ -23,6 +23,7 @@ #include "knowhere/config.h" #include "knowhere/expected.h" #include "knowhere/log.h" +#include "knowhere/sparse_utils.h" #include "knowhere/utils.h" namespace knowhere { @@ -352,6 +353,91 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims); return GenResultDataSet(nq, ids, distances, lims); } + +Status +BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, sparse::label_t* labels, + float* distances, const Json& config, const BitsetView& bitset) { + auto base = static_cast*>(base_dataset->GetTensor()); + auto rows = base_dataset->GetRows(); + + auto xq = static_cast*>(query_dataset->GetTensor()); + auto nq = query_dataset->GetRows(); + + BruteForceConfig cfg; + std::string msg; + auto status = Config::Load(cfg, config, knowhere::SEARCH, &msg); + if (status != Status::success) { + LOG_KNOWHERE_ERROR_ << "Failed to load config, msg is: " << msg; + return status; + } + + std::string metric_str = cfg.metric_type.value(); + auto result = Str2FaissMetricType(metric_str); + if (result.error() != Status::success) { + LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << cfg.metric_type.value(); + return result.error(); + } + if (!IsMetricType(metric_str, metric::IP)) { + LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << cfg.metric_type.value(); + return Status::invalid_metric_type; + } + + int topk = cfg.k.value(); + std::fill(distances, distances + nq * topk, std::numeric_limits::quiet_NaN()); + std::fill(labels, labels + nq * topk, -1); + + auto pool = ThreadPool::GetGlobalSearchThreadPool(); + std::vector> futs; + futs.reserve(nq); + for (int64_t i = 0; i < nq; ++i) { + futs.emplace_back(pool->push([&, index = i] { + auto cur_labels = labels + topk * index; + auto cur_distances = distances + topk * index; + + const auto& row = xq[index]; + if (row.size() == 0) { + return; + } + sparse::MaxMinHeap heap(topk); + for (int64_t j = 0; j < rows; ++j) { + if (!bitset.empty() && bitset.test(j)) { + continue; + } + float dist = row.dot(base[j]); + if (dist > 0) { + heap.push(j, dist); + } + } + int result_size = heap.size(); + for (int j = result_size - 1; j >= 0; --j) { + cur_labels[j] = heap.top().id; + cur_distances[j] = heap.top().val; + heap.pop(); + } + })); + } + WaitAllSuccess(futs); + return Status::success; +} + +expected +BruteForce::SearchSparse(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, + const BitsetView& bitset) { + auto nq = query_dataset->GetRows(); + BruteForceConfig cfg; + std::string msg; + auto status = Config::Load(cfg, config, knowhere::SEARCH, &msg); + if (status != Status::success) { + return expected::Err(status, msg); + } + int topk = cfg.k.value(); + auto labels = std::make_unique(nq * topk); + auto distances = std::make_unique(nq * topk); + + SearchSparseWithBuf(base_dataset, query_dataset, labels.get(), distances.get(), config, bitset); + return GenResultDataSet(nq, topk, labels.release(), distances.release()); +} + } // namespace knowhere template knowhere::expected knowhere::BruteForce::Search(const knowhere::DataSetPtr base_dataset, diff --git a/src/common/factory.cc b/src/common/factory.cc index 40b0c452..239c0b7c 100644 --- a/src/common/factory.cc +++ b/src/common/factory.cc @@ -20,7 +20,7 @@ IndexFactory::Create(const std::string& name, const int32_t& version, const Obje auto& func_mapping_ = MapInstance(); auto key = GetIndexKey(name); assert(func_mapping_.find(key) != func_mapping_.end()); - LOG_KNOWHERE_INFO_ << "use key" << key << " to create knowhere index " << name << " with version " << version; + LOG_KNOWHERE_INFO_ << "use key " << key << " to create knowhere index " << name << " with version " << version; auto fun_map_v = (FunMapValue>*)(func_mapping_[key].get()); return fun_map_v->fun_value(version, object); } diff --git a/src/index/sparse/sparse_index_node.cc b/src/index/sparse/sparse_index_node.cc new file mode 100644 index 00000000..b00408b9 --- /dev/null +++ b/src/index/sparse/sparse_index_node.cc @@ -0,0 +1,217 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "index/hnsw/hnsw_config.h" +#include "index/sparse/sparse_inverted_index.h" +#include "index/sparse/sparse_inverted_index_config.h" +#include "io/memory_io.h" +#include "knowhere/comp/thread_pool.h" +#include "knowhere/config.h" +#include "knowhere/dataset.h" +#include "knowhere/expected.h" +#include "knowhere/factory.h" +#include "knowhere/index_node.h" +#include "knowhere/log.h" +#include "knowhere/sparse_utils.h" +#include "knowhere/utils.h" + +namespace knowhere { + +template +class SparseInvertedIndexNode : public IndexNode { + static_assert(std::is_same_v, "SparseInvertedIndexNode only support float"); + + public: + explicit SparseInvertedIndexNode(const int32_t& /*version*/, const Object& /*object*/) + : search_pool_(ThreadPool::GetGlobalSearchThreadPool()) { + } + + ~SparseInvertedIndexNode() override { + if (index_ != nullptr) { + delete index_; + index_ = nullptr; + } + } + + Status + Train(const DataSet& dataset, const Config& config) override { + auto cfg = static_cast(config); + if (!IsMetricType(cfg.metric_type.value(), metric::IP)) { + LOG_KNOWHERE_ERROR_ << Type() << " only support metric_type: IP"; + return Status::invalid_metric_type; + } + auto drop_ratio_build = cfg.drop_ratio_build.value_or(0.0f); + auto index = new sparse::InvertedIndex(); + index->SetUseWand(use_wand); + index->Train(static_cast*>(dataset.GetTensor()), dataset.GetRows(), + drop_ratio_build); + if (index_ != nullptr) { + LOG_KNOWHERE_WARNING_ << Type() << " deleting old index during train"; + delete index_; + } + index_ = index; + return Status::success; + } + + Status + Add(const DataSet& dataset, const Config& config) override { + if (!index_) { + LOG_KNOWHERE_ERROR_ << "Could not add data to empty " << Type(); + return Status::empty_index; + } + return index_->Add(static_cast*>(dataset.GetTensor()), dataset.GetRows(), + dataset.GetDim()); + } + + [[nodiscard]] expected + Search(const DataSet& dataset, const Config& config, const BitsetView& bitset) const override { + if (!index_) { + LOG_KNOWHERE_ERROR_ << "Could not search empty " << Type(); + return expected::Err(Status::empty_index, "index not loaded"); + } + auto cfg = static_cast(config); + auto nq = dataset.GetRows(); + auto queries = static_cast*>(dataset.GetTensor()); + auto k = cfg.k.value(); + auto refine_factor = cfg.refine_factor.value_or(10); + auto drop_ratio_search = cfg.drop_ratio_search.value_or(0.0f); + + auto p_id = std::make_unique(nq * k); + auto p_dist = std::make_unique(nq * k); + + std::vector> futs; + futs.reserve(nq); + for (int64_t idx = 0; idx < nq; ++idx) { + futs.emplace_back(search_pool_->push([&, idx = idx, p_id = p_id.get(), p_dist = p_dist.get()]() { + index_->Search(queries[idx], k, drop_ratio_search, p_dist + idx * k, p_id + idx * k, refine_factor, + bitset); + })); + } + WaitAllSuccess(futs); + return GenResultDataSet(nq, k, p_id.release(), p_dist.release()); + } + + [[nodiscard]] expected>> + AnnIterator(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override { + throw std::runtime_error("annIterator not supported for current index type"); + } + + [[nodiscard]] expected + RangeSearch(const DataSet& dataset, const Config& config, const BitsetView& bitset) const override { + throw std::runtime_error("RangeSearch not supported for current index type"); + } + + [[nodiscard]] expected + GetVectorByIds(const DataSet& dataset) const override { + if (!index_) { + return expected::Err(Status::empty_index, "index not loaded"); + } + + auto rows = dataset.GetRows(); + auto ids = dataset.GetIds(); + + auto data = std::make_unique[]>(rows); + int64_t dim = 0; + try { + for (int64_t i = 0; i < rows; ++i) { + auto& target = data[i]; + index_->GetVectorById(ids[i], target); + dim = std::max(dim, target.dim()); + } + } catch (std::exception& e) { + return expected::Err(Status::invalid_args, "GetVectorByIds failed"); + } + auto res = GenResultDataSet(rows, dim, data.release()); + res->SetIsSparse(true); + return res; + } + + [[nodiscard]] bool + HasRawData(const std::string& metric_type) const override { + return true; + } + + [[nodiscard]] expected + GetIndexMeta(const Config& cfg) const override { + throw std::runtime_error("GetIndexMeta not supported for current index type"); + } + + Status + Serialize(BinarySet& binset) const override { + if (!index_) { + LOG_KNOWHERE_ERROR_ << "Could not serialize empty " << Type(); + return Status::empty_index; + } + MemoryIOWriter writer; + RETURN_IF_ERROR(index_->Save(writer)); + std::shared_ptr data(writer.data()); + binset.Append(Type(), data, writer.tellg()); + return Status::success; + } + + Status + Deserialize(const BinarySet& binset, const Config& config) override { + if (index_) { + LOG_KNOWHERE_WARNING_ << Type() << " has already been created, deleting old"; + delete index_; + index_ = nullptr; + } + auto binary = binset.GetByName(Type()); + if (binary == nullptr) { + LOG_KNOWHERE_ERROR_ << "Invalid BinarySet."; + return Status::invalid_binary_set; + } + MemoryIOReader reader(binary->data.get(), binary->size); + index_ = new sparse::InvertedIndex(); + // no need to set use_wand_ of index_, since it will be set in Load() + return index_->Load(reader); + } + + Status + DeserializeFromFile(const std::string& filename, const Config& config) override { + throw std::runtime_error("DeserializeFromFile not supported for current index type"); + } + + [[nodiscard]] std::unique_ptr + CreateConfig() const override { + return std::make_unique(); + } + + // note that the Dim of a sparse vector index may change as new vectors are added + [[nodiscard]] int64_t + Dim() const override { + return index_ ? index_->n_cols() : 0; + } + + [[nodiscard]] int64_t + Size() const override { + return index_ ? index_->size() : 0; + } + + [[nodiscard]] int64_t + Count() const override { + return index_ ? index_->n_rows() : 0; + } + + [[nodiscard]] std::string + Type() const override { + return use_wand ? knowhere::IndexEnum::INDEX_SPARSE_WAND : knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX; + } + + private: + sparse::InvertedIndex* index_{}; + std::shared_ptr search_pool_; +}; // class SparseInvertedIndexNode + +KNOWHERE_SIMPLE_REGISTER_GLOBAL(SPARSE_INVERTED_INDEX, SparseInvertedIndexNode, fp32, /*use_wand=*/false); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(SPARSE_WAND, SparseInvertedIndexNode, fp32, /*use_wand=*/true); + +} // namespace knowhere diff --git a/src/index/sparse/sparse_inverted_index.h b/src/index/sparse/sparse_inverted_index.h new file mode 100644 index 00000000..18b940d5 --- /dev/null +++ b/src/index/sparse/sparse_inverted_index.h @@ -0,0 +1,464 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#ifndef SPARSE_INVERTED_INDEX_H +#define SPARSE_INVERTED_INDEX_H + +#include +#include +#include +#include +#include + +#include "io/memory_io.h" +#include "knowhere/bitsetview.h" +#include "knowhere/expected.h" +#include "knowhere/log.h" +#include "knowhere/sparse_utils.h" +#include "knowhere/utils.h" + +namespace knowhere::sparse { +template +class InvertedIndex { + public: + explicit InvertedIndex() { + } + + void + SetUseWand(bool use_wand) { + std::unique_lock lock(mu_); + use_wand_ = use_wand; + } + + Status + Save(MemoryIOWriter& writer) { + /** + * zero copy is not yet implemented, now serializing in a zero copy + * compatible way while still copying during deserialization. + * + * Layout: + * + * 1. int32_t rows, sign indicates whether to use wand + * 2. int32_t cols + * 3. for each row: + * 1. int32_t len + * 2. for each non-zero value: + * 1. table_t idx + * 2. T val + * With zero copy deserization, each SparseRow object should + * reference(not owning) the memory address of the first element. + * + * inverted_lut_ and max_in_dim_ not serialized, they will be + * constructed dynamically during deserialization. + * + * Data are densly packed in serialized bytes and no padding is added. + */ + std::shared_lock lock(mu_); + writeBinaryPOD(writer, n_rows_internal() * (use_wand_ ? 1 : -1)); + writeBinaryPOD(writer, n_cols_internal()); + writeBinaryPOD(writer, value_threshold_); + for (size_t i = 0; i < n_rows_internal(); ++i) { + auto& row = raw_data_[i]; + writeBinaryPOD(writer, row.size()); + if (row.size() == 0) { + continue; + } + writer.write(row.data(), row.size() * SparseRow::element_size()); + } + return Status::success; + } + + Status + Load(MemoryIOReader& reader) { + std::unique_lock lock(mu_); + int64_t rows; + readBinaryPOD(reader, rows); + use_wand_ = rows > 0; + rows = std::abs(rows); + size_t dim; + readBinaryPOD(reader, dim); + readBinaryPOD(reader, value_threshold_); + + raw_data_.reserve(rows); + inverted_lut_.resize(dim); + if (use_wand_) { + max_in_dim_.resize(dim); + } + + for (int64_t i = 0; i < rows; ++i) { + size_t count; + readBinaryPOD(reader, count); + raw_data_.emplace_back(count); + if (count == 0) { + continue; + } + reader.read(raw_data_[i].data(), count * SparseRow::element_size()); + add_row_to_index(raw_data_[i], i); + } + + return Status::success; + } + + // Non zero drop ratio is only supported for static index, i.e. data should + // include all rows that'll be added to the index. + Status + Train(const SparseRow* data, size_t rows, float drop_ratio_build) { + if (drop_ratio_build == 0.0f) { + return Status::success; + } + // TODO: maybe i += 10 to down sample to speed up. + size_t amount = 0; + for (size_t i = 0; i < rows; ++i) { + amount += data[i].size(); + } + std::vector vals(amount); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < data[i].size(); ++j) { + vals.push_back(fabs(data[i][j].val)); + } + } + auto pos = vals.begin() + static_cast(drop_ratio_build * vals.size()); + std::nth_element(vals.begin(), pos, vals.end()); + + std::unique_lock lock(mu_); + value_threshold_ = *pos; + drop_during_build_ = true; + return Status::success; + } + + Status + Add(const SparseRow* data, size_t rows, int64_t dim) { + std::unique_lock lock(mu_); + auto current_rows = n_rows_internal(); + if (current_rows > 0 && drop_during_build_) { + LOG_KNOWHERE_ERROR_ << "Not allowed to add data to a built index with drop_ratio_build > 0."; + return Status::invalid_args; + } + + if (inverted_lut_.size() < (size_t)dim) { + inverted_lut_.resize(dim); + if (use_wand_) { + max_in_dim_.resize(dim); + } + } + + raw_data_.insert(raw_data_.end(), data, data + rows); + for (size_t i = 0; i < rows; ++i) { + add_row_to_index(data[i], current_rows + i); + } + return Status::success; + } + + void + Search(const SparseRow& query, size_t k, float drop_ratio_search, float* distances, label_t* labels, + size_t refine_factor, const BitsetView& bitset) const { + // initially set result distances to NaN and labels to -1 + std::fill(distances, distances + k, std::numeric_limits::quiet_NaN()); + std::fill(labels, labels + k, -1); + if (query.size() == 0) { + return; + } + + std::vector values(query.size()); + for (size_t i = 0; i < query.size(); ++i) { + values[i] = std::abs(query[i].val); + } + auto pos = values.begin() + static_cast(drop_ratio_search * values.size()); + std::nth_element(values.begin(), pos, values.end()); + auto q_threshold = *pos; + + std::shared_lock lock(mu_); + // if no data was dropped during both build and search, no refinement is + // needed. + if (!drop_during_build_ && drop_ratio_search == 0) { + refine_factor = 1; + } + MaxMinHeap heap(k * refine_factor); + if (!use_wand_) { + search_brute_force(query, q_threshold, heap, bitset); + } else { + search_wand(query, q_threshold, heap, bitset); + } + + if (refine_factor == 1) { + collect_result(heap, distances, labels); + } else { + refine_and_collect(query, heap, k, distances, labels); + } + } + + void + GetVectorById(const label_t id, SparseRow& output) const { + output = raw_data_[id]; + } + + [[nodiscard]] size_t + size() const { + std::shared_lock lock(mu_); + size_t res = sizeof(*this); + res += sizeof(SparseRow) * n_rows_internal(); + for (auto& row : raw_data_) { + res += row.memory_usage(); + } + + res += sizeof(std::vector>) * inverted_lut_.capacity(); + for (auto& lut : inverted_lut_) { + res += sizeof(IdVal) * lut.capacity(); + } + if (use_wand_) { + res += sizeof(T) * max_in_dim_.capacity(); + } + return res; + } + + [[nodiscard]] size_t + n_rows() const { + std::shared_lock lock(mu_); + return n_rows_internal(); + } + + [[nodiscard]] size_t + n_cols() const { + std::shared_lock lock(mu_); + return n_cols_internal(); + } + + private: + size_t + n_rows_internal() const { + return raw_data_.size(); + } + + size_t + n_cols_internal() const { + return inverted_lut_.size(); + } + + // find the top-k candidates using brute force search, k as specified by the capacity of the heap. + // any value in q_vec that is smaller than q_threshold and any value with dimension >= n_cols() will be ignored. + void + search_brute_force(const SparseRow& q_vec, T q_threshold, MaxMinHeap& heap, const BitsetView& bitset) const { + std::vector scores(n_rows_internal(), 0.0f); + for (size_t idx = 0; idx < q_vec.size(); ++idx) { + auto [i, v] = q_vec[idx]; + if (v < q_threshold || i >= n_cols_internal()) { + continue; + } + // TODO: improve with SIMD + for (size_t j = 0; j < inverted_lut_[i].size(); j++) { + auto [idx, val] = inverted_lut_[i][j]; + scores[idx] += v * float(val); + } + } + for (size_t i = 0; i < n_rows_internal(); ++i) { + if ((bitset.empty() || !bitset.test(i)) && scores[i] != 0) { + heap.push(i, scores[i]); + } + } + } + + // LUT supports size() and operator[] which returns an IdVal. + template + class Cursor { + public: + Cursor(const LUT& lut, size_t num_vec, float max_score, float q_value, const BitsetView bitset) + : lut_(lut), num_vec_(num_vec), max_score_(max_score), q_value_(q_value), bitset_(bitset) { + while (loc_ < lut_.size() && !bitset_.empty() && bitset_.test(cur_vec_id())) { + loc_++; + } + } + Cursor(const Cursor& rhs) = delete; + + void + next() { + loc_++; + while (loc_ < lut_.size() && !bitset_.empty() && bitset_.test(cur_vec_id())) { + loc_++; + } + } + // advance loc until cur_vec_id() >= vec_id + void + seek(table_t vec_id) { + while (loc_ < lut_.size() && cur_vec_id() < vec_id) { + next(); + } + } + [[nodiscard]] table_t + cur_vec_id() const { + if (is_end()) { + return num_vec_; + } + return lut_[loc_].id; + } + T + cur_distance() const { + return lut_[loc_].val; + } + [[nodiscard]] bool + is_end() const { + return loc_ >= size(); + } + [[nodiscard]] float + q_value() const { + return q_value_; + } + [[nodiscard]] size_t + size() const { + return lut_.size(); + } + [[nodiscard]] float + max_score() const { + return max_score_; + } + + private: + const LUT& lut_; + size_t loc_ = 0; + size_t num_vec_ = 0; + float max_score_ = 0.0f; + float q_value_ = 0.0f; + const BitsetView bitset_; + }; // class Cursor + + // any value in q_vec that is smaller than q_threshold will be ignored. + void + search_wand(const SparseRow& q_vec, T q_threshold, MaxMinHeap& heap, const BitsetView& bitset) const { + auto q_dim = q_vec.size(); + std::vector>>>> cursors(q_dim); + auto valid_q_dim = 0; + for (size_t i = 0; i < q_dim; ++i) { + auto [idx, val] = q_vec[i]; + if (std::abs(val) < q_threshold || idx >= n_cols_internal()) { + continue; + } + cursors[valid_q_dim++] = std::make_shared>>>( + inverted_lut_[idx], n_rows_internal(), max_in_dim_[idx] * val, val, bitset); + } + if (valid_q_dim == 0) { + return; + } + cursors.resize(valid_q_dim); + auto sort_cursors = [&cursors] { + std::sort(cursors.begin(), cursors.end(), + [](auto& x, auto& y) { return x->cur_vec_id() < y->cur_vec_id(); }); + }; + sort_cursors(); + auto score_above_threshold = [&heap](float x) { return !heap.full() || x > heap.top().val; }; + while (true) { + float upper_bound = 0; + size_t pivot; + bool found_pivot = false; + for (pivot = 0; pivot < cursors.size(); ++pivot) { + if (cursors[pivot]->is_end()) { + break; + } + upper_bound += cursors[pivot]->max_score(); + if (score_above_threshold(upper_bound)) { + found_pivot = true; + break; + } + } + if (!found_pivot) { + break; + } + table_t pivot_id = cursors[pivot]->cur_vec_id(); + if (pivot_id == cursors[0]->cur_vec_id()) { + float score = 0; + for (auto& cursor : cursors) { + if (cursor->cur_vec_id() != pivot_id) { + break; + } + score += cursor->cur_distance() * cursor->q_value(); + cursor->next(); + } + heap.push(pivot_id, score); + sort_cursors(); + } else { + size_t next_list = pivot; + for (; cursors[next_list]->cur_vec_id() == pivot_id; --next_list) { + } + cursors[next_list]->seek(pivot_id); + for (size_t i = next_list + 1; i < cursors.size(); ++i) { + if (cursors[i]->cur_vec_id() >= cursors[i - 1]->cur_vec_id()) { + break; + } + std::swap(cursors[i], cursors[i - 1]); + } + } + } + } + + void + refine_and_collect(const SparseRow& q_vec, MaxMinHeap& inaccurate, size_t k, float* distances, + label_t* labels) const { + std::priority_queue, std::vector>, std::greater>> heap; + + while (!inaccurate.empty()) { + auto [u, d] = inaccurate.top(); + inaccurate.pop(); + + auto dist_acc = q_vec.dot(raw_data_[u]); + if (heap.size() < k) { + heap.emplace(u, dist_acc); + } else if (heap.top().val < dist_acc) { + heap.pop(); + heap.emplace(u, dist_acc); + } + } + collect_result(heap, distances, labels); + } + + template + void + collect_result(HeapType& heap, float* distances, label_t* labels) const { + int cnt = heap.size(); + for (auto i = cnt - 1; i >= 0; --i) { + labels[i] = heap.top().id; + distances[i] = heap.top().val; + heap.pop(); + } + } + + inline void + add_row_to_index(const SparseRow& row, table_t id) { + for (size_t j = 0; j < row.size(); ++j) { + auto [idx, val] = row[j]; + // Skip values close enough to zero(which contributes little to + // the total IP score). + if (drop_during_build_ && fabs(val) < value_threshold_) { + continue; + } + inverted_lut_[idx].emplace_back(id, val); + if (use_wand_) { + max_in_dim_[idx] = std::max(max_in_dim_[idx], val); + } + } + } + + std::vector> raw_data_; + mutable std::shared_mutex mu_; + + std::vector>> inverted_lut_; + bool use_wand_ = false; + // If we want to drop small values during build, we must first train the + // index with all the data to compute value_threshold_. + bool drop_during_build_ = false; + // when drop_during_build_ is true, any value smaller than value_threshold_ + // will not be added to inverted_lut_. value_threshold_ is set to the + // drop_ratio_build-th percentile of all absolute values in the index. + T value_threshold_ = 0.0f; + std::vector max_in_dim_; + +}; // class InvertedIndex + +} // namespace knowhere::sparse + +#endif // SPARSE_INVERTED_INDEX_H diff --git a/src/index/sparse/sparse_inverted_index_config.h b/src/index/sparse/sparse_inverted_index_config.h new file mode 100644 index 00000000..d142a949 --- /dev/null +++ b/src/index/sparse/sparse_inverted_index_config.h @@ -0,0 +1,47 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#ifndef SPARSE_INVERTED_INDEX_CONFIG_H +#define SPARSE_INVERTED_INDEX_CONFIG_H + +#include "knowhere/comp/index_param.h" +#include "knowhere/config.h" + +namespace knowhere { + +class SparseInvertedIndexConfig : public BaseConfig { + public: + CFG_FLOAT drop_ratio_build; + CFG_FLOAT drop_ratio_search; + CFG_INT refine_factor; + KNOHWERE_DECLARE_CONFIG(SparseInvertedIndexConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(drop_ratio_build) + .description("drop ratio for build") + .set_default(0.0f) + .set_range(0.0f, 1.0f) + .for_train(); + KNOWHERE_CONFIG_DECLARE_FIELD(drop_ratio_search) + .description("drop ratio for search") + .set_default(0.0f) + .set_range(0.0f, 1.0f) + .for_search() + .for_range_search(); + KNOWHERE_CONFIG_DECLARE_FIELD(refine_factor) + .description("refine factor") + .set_default(10) + .for_search() + .for_range_search(); + } +}; // class SparseInvertedIndexConfig + +} // namespace knowhere + +#endif // SPARSE_INVERTED_INDEX_CONFIG_H diff --git a/tests/ut/test_sparse.cc b/tests/ut/test_sparse.cc new file mode 100644 index 00000000..ee66f676 --- /dev/null +++ b/tests/ut/test_sparse.cc @@ -0,0 +1,260 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "catch2/catch_test_macros.hpp" +#include "catch2/generators/catch_generators.hpp" +#include "knowhere/bitsetview.h" +#include "knowhere/comp/brute_force.h" +#include "knowhere/comp/index_param.h" +#include "knowhere/comp/knowhere_config.h" +#include "knowhere/factory.h" +#include "utils.h" + +TEST_CASE("Test Mem Sparse Index With Float Vector", "[float metrics]") { + auto [nb, dim, doc_sparsity, query_sparsity] = GENERATE(table({ + // 300 dim, avg doc nnz 12, avg query nnz 9 + {2000, 300, 0.95, 0.97}, + // 300 dim, avg doc nnz 9, avg query nnz 3 + {2000, 300, 0.97, 0.99}, + // 3000 dim, avg doc nnz 90, avg query nnz 30 + {20000, 3000, 0.97, 0.99}, + })); + auto topk = 5; + int64_t nq = GENERATE(10, 100); + + auto [drop_ratio_build, drop_ratio_search] = GENERATE(table({ + {0.0, 0.0}, + {0.0, 0.15}, + {0.15, 0.3}, + })); + + auto metric = knowhere::metric::IP; + auto version = GenTestVersionList(); + + auto base_gen = [=]() { + knowhere::Json json; + json[knowhere::meta::DIM] = dim; + json[knowhere::meta::METRIC_TYPE] = metric; + json[knowhere::meta::TOPK] = topk; + return json; + }; + + auto sparse_inverted_index_gen = [base_gen, drop_ratio_build, drop_ratio_search]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::DROP_RATIO_BUILD] = drop_ratio_build; + json[knowhere::indexparam::DROP_RATIO_SEARCH] = drop_ratio_search; + return json; + }; + + const auto train_ds = GenSparseDataSet(nb, dim, doc_sparsity); + // it is possible the query has more dims than the train dataset. + const auto query_ds = GenSparseDataSet(nq, dim + 20, query_sparsity); + + const knowhere::Json conf = { + {knowhere::meta::METRIC_TYPE, metric}, + {knowhere::meta::TOPK, topk}, + }; + + auto check_distance_decreasing = [](const knowhere::DataSet& ds) { + auto nq = ds.GetRows(); + auto k = ds.GetDim(); + auto* distances = ds.GetDistance(); + auto* ids = ds.GetIds(); + for (auto i = 0; i < nq; ++i) { + for (auto j = 0; j < k - 1; ++j) { + if (ids[i * k + j] == -1 || ids[i * k + j + 1] == -1) { + break; + } + REQUIRE(distances[i * k + j] >= distances[i * k + j + 1]); + } + } + }; + + auto check_result_match_filter = [](const knowhere::DataSet& ds, const knowhere::BitsetView& bitset) { + auto nq = ds.GetRows(); + auto k = ds.GetDim(); + auto* ids = ds.GetIds(); + for (auto i = 0; i < nq; ++i) { + for (auto j = 0; j < k; ++j) { + if (ids[i * k + j] == -1) { + break; + } + REQUIRE(!bitset.test(ids[i * k + j])); + } + } + }; + + auto gt = knowhere::BruteForce::SearchSparse(train_ds, query_ds, conf, nullptr); + check_distance_decreasing(*gt.value()); + + SECTION("Test Search") { + using std::make_tuple; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX, sparse_inverted_index_gen), + make_tuple(knowhere::IndexEnum::INDEX_SPARSE_WAND, sparse_inverted_index_gen), + })); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success); + REQUIRE(idx.Size() > 0); + REQUIRE(idx.Count() == nb); + + knowhere::BinarySet bs; + REQUIRE(idx.Serialize(bs) == knowhere::Status::success); + REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success); + + auto results = idx.Search(*query_ds, json, nullptr); + REQUIRE(results.has_value()); + float recall = GetKNNRecall(*gt.value(), *results.value()); + check_distance_decreasing(*results.value()); + auto drop_ratio_build = json[knowhere::indexparam::DROP_RATIO_BUILD].get(); + auto drop_ratio_search = json[knowhere::indexparam::DROP_RATIO_SEARCH].get(); + if (drop_ratio_build == 0 && drop_ratio_search == 0) { + REQUIRE(recall == 1); + } else { + // most test cases are above 0.95, only a few between 0.9 and 0.95 + REQUIRE(recall >= 0.85); + } + } + + SECTION("Test Search with Bitset") { + using std::make_tuple; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX, sparse_inverted_index_gen), + make_tuple(knowhere::IndexEnum::INDEX_SPARSE_WAND, sparse_inverted_index_gen), + })); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success); + REQUIRE(idx.Size() > 0); + REQUIRE(idx.Count() == nb); + + auto gen_bitset_fn = GENERATE(GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet); + auto bitset_percentages = GENERATE(0.4f, 0.9f); + + auto bitset_data = gen_bitset_fn(nb, bitset_percentages * nb); + knowhere::BitsetView bitset(bitset_data.data(), nb); + auto filter_gt = knowhere::BruteForce::SearchSparse(train_ds, query_ds, conf, bitset); + check_result_match_filter(*filter_gt.value(), bitset); + + auto results = idx.Search(*query_ds, json, bitset); + check_result_match_filter(*results.value(), bitset); + + REQUIRE(results.has_value()); + float recall = GetKNNRecall(*filter_gt.value(), *results.value()); + check_distance_decreasing(*results.value()); + + auto drop_ratio_build = json[knowhere::indexparam::DROP_RATIO_BUILD].get(); + auto drop_ratio_search = json[knowhere::indexparam::DROP_RATIO_SEARCH].get(); + if (drop_ratio_build == 0 && drop_ratio_search == 0) { + REQUIRE(recall == 1); + } else { + REQUIRE(recall >= 0.8); + } + } +} + +TEST_CASE("Test Mem Sparse Index GetVectorByIds", "[float metrics]") { + auto [nb, dim, doc_sparsity, query_sparsity] = GENERATE(table({ + // 300 dim, avg doc nnz 12, avg query nnz 9 + {2000, 300, 0.95, 0.97}, + // 300 dim, avg doc nnz 9, avg query nnz 3 + {2000, 300, 0.97, 0.99}, + // 3000 dim, avg doc nnz 90, avg query nnz 30 + {20000, 3000, 0.97, 0.99}, + })); + int64_t nq = GENERATE(10, 100); + + auto [drop_ratio_build, drop_ratio_search] = GENERATE(table({ + {0.0, 0.0}, + {0.32, 0.0}, + })); + + auto metric = knowhere::metric::IP; + auto version = GenTestVersionList(); + + auto base_gen = [=]() { + knowhere::Json json; + json[knowhere::meta::DIM] = dim; + json[knowhere::meta::METRIC_TYPE] = metric; + json[knowhere::meta::TOPK] = 1; + return json; + }; + + auto sparse_inverted_index_gen = [base_gen, drop_ratio_build, drop_ratio_search]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::DROP_RATIO_BUILD] = drop_ratio_build; + json[knowhere::indexparam::DROP_RATIO_SEARCH] = drop_ratio_search; + return json; + }; + + const auto train_ds = GenSparseDataSet(nb, dim, doc_sparsity); + + const knowhere::Json conf = { + {knowhere::meta::METRIC_TYPE, metric}, + {knowhere::meta::TOPK, 1}, + }; + + SECTION("Test GetVectorByIds") { + using std::make_tuple; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX, sparse_inverted_index_gen), + make_tuple(knowhere::IndexEnum::INDEX_SPARSE_WAND, sparse_inverted_index_gen), + })); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + + auto ids_ds = GenIdsDataSet(nb, nq); + REQUIRE(idx.Type() == name); + auto res = idx.Build(*train_ds, json); + if (!idx.HasRawData(metric)) { + return; + } + REQUIRE(res == knowhere::Status::success); + knowhere::BinarySet bs; + idx.Serialize(bs); + + auto idx_new = knowhere::IndexFactory::Instance().Create(name, version); + idx_new.Deserialize(bs); + + auto retrieve_task = [&]() { + auto results = idx_new.GetVectorByIds(*ids_ds); + REQUIRE(results.has_value()); + auto xb = (knowhere::sparse::SparseRow*)train_ds->GetTensor(); + auto res_data = (knowhere::sparse::SparseRow*)results.value()->GetTensor(); + for (int i = 0; i < nq; ++i) { + const auto id = ids_ds->GetIds()[i]; + const auto& truth_row = xb[id]; + const auto& res_row = res_data[i]; + REQUIRE(truth_row.size() == res_row.size()); + for (size_t j = 0; j < truth_row.size(); ++j) { + REQUIRE(truth_row[j] == res_row[j]); + } + } + }; + + std::vector> retrieve_task_list; + for (int i = 0; i < 20; i++) { + retrieve_task_list.push_back(std::async(std::launch::async, [&] { return retrieve_task(); })); + } + for (auto& task : retrieve_task_list) { + task.wait(); + } + } +} diff --git a/tests/ut/utils.h b/tests/ut/utils.h index 61e6fdaa..82092e06 100644 --- a/tests/ut/utils.h +++ b/tests/ut/utils.h @@ -232,3 +232,48 @@ inline auto GenTestVersionList() { return GENERATE(as{}, knowhere::Version::GetCurrentVersion().VersionNumber()); } + +// Generate a sparse dataset with given sparsity. +inline knowhere::DataSetPtr +GenSparseDataSet(int32_t rows, int32_t cols, float sparsity, int seed = 42) { + int32_t num_elements = static_cast(rows * cols * (1.0f - sparsity)); + + std::mt19937 rng(seed); + auto real_distrib = std::uniform_real_distribution(0, 1); + auto row_distrib = std::uniform_int_distribution(0, rows - 1); + auto col_distrib = std::uniform_int_distribution(0, cols - 1); + + std::vector> data(rows); + + for (int32_t i = 0; i < num_elements; ++i) { + auto row = row_distrib(rng); + while (data[row].size() == (size_t)cols) { + row = row_distrib(rng); + } + auto col = col_distrib(rng); + while (data[row].find(col) != data[row].end()) { + col = col_distrib(rng); + } + auto val = real_distrib(rng); + data[row][col] = val; + } + + auto tensor = std::make_unique[]>(rows); + + for (int32_t i = 0; i < rows; ++i) { + if (data[i].size() == 0) { + continue; + } + knowhere::sparse::SparseRow row(data[i].size()); + size_t j = 0; + for (auto& [idx, val] : data[i]) { + row.set_at(j++, idx, val); + } + tensor[i] = std::move(row); + } + + auto ds = knowhere::GenDataSet(rows, cols, tensor.release()); + ds->SetIsOwner(true); + ds->SetIsSparse(true); + return ds; +}