Skip to content

Commit

Permalink
hnsw support fp16/bf16
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
  • Loading branch information
cqy123456 committed Apr 19, 2024
1 parent 818e12f commit 4f4908e
Show file tree
Hide file tree
Showing 11 changed files with 318 additions and 183 deletions.
22 changes: 18 additions & 4 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,21 @@ IsFlatIndex(const knowhere::IndexType& index_type) {
return std::find(flat_index_list.begin(), flat_index_list.end(), index_type) != flat_index_list.end();
}

template <typename DataType>
extern float
NormalizeVec(float* x, int32_t d);
NormalizeVec(DataType* x, int32_t d);

template <typename DataType>
extern std::vector<float>
NormalizeVecs(float* x, size_t rows, int32_t dim);
NormalizeVecs(DataType* x, size_t rows, int32_t dim);

template <typename DataType = knowhere::fp32>
extern void
Normalize(const DataSet& dataset);

extern std::unique_ptr<float[]>
CopyAndNormalizeVecs(const float* x, size_t rows, int32_t dim);
template <typename DataType>
extern std::unique_ptr<DataType[]>
CopyAndNormalizeVecs(const DataType* x, size_t rows, int32_t dim);

constexpr inline uint64_t seed = 0xc70f6907UL;

Expand Down Expand Up @@ -78,6 +82,16 @@ hash_binary_vec(const uint8_t* x, size_t d) {
return h;
}

inline uint64_t
hash_half_precision_float(const void* x, size_t d) {
uint64_t h = seed;
auto u16_x = (uint16_t*)(x);
for (size_t i = 0; i < d; ++i) {
h = h * 13331 + u16_x[i];
}
return h;
}

template <typename DataType>
inline std::string
GetIndexKey(const std::string& name) {
Expand Down
59 changes: 54 additions & 5 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,27 @@ namespace knowhere {
const float FloatAccuracy = 0.00001;

// normalize one vector and return its norm
// todo(cqy123456): Template specialization for fp16/bf16;
// float16 uses the smallest representable positive float16 value(6.1 x 10^(-5)) as FloatAccuracy;
// bfloat16 uses the same FloatAccuracy as float32;
template <typename DataType>
float
NormalizeVec(DataType* x, int32_t d) {
float norm_l2_sqr = 0.0;
for (auto i = 0; i < d; i++) {
norm_l2_sqr += (float)x[i] * (float)x[i];
}
if (norm_l2_sqr > 0 && std::abs(1.0f - norm_l2_sqr) > FloatAccuracy) {
float norm_l2 = std::sqrt(norm_l2_sqr);
for (int32_t i = 0; i < d; i++) {
x[i] = (DataType)((float)x[i] / norm_l2);
}
return norm_l2;
}
return 1.0f;
}

template <>
float
NormalizeVec(float* x, int32_t d) {
float norm_l2_sqr = faiss::fvec_norm_L2sqr(x, d);
Expand All @@ -41,20 +62,22 @@ NormalizeVec(float* x, int32_t d) {
}

// normalize all vectors and return their norms
template <typename DataType>
std::vector<float>
NormalizeVecs(float* x, size_t rows, int32_t dim) {
NormalizeVecs(DataType* x, size_t rows, int32_t dim) {
std::vector<float> norms(rows);
for (size_t i = 0; i < rows; i++) {
norms[i] = NormalizeVec(x + i * dim, dim);
}
return norms;
}

template <typename DataType>
void
Normalize(const DataSet& dataset) {
auto rows = dataset.GetRows();
auto dim = dataset.GetDim();
float* data = (float*)dataset.GetTensor();
auto data = (DataType*)dataset.GetTensor();

LOG_KNOWHERE_DEBUG_ << "vector normalize, rows " << rows << ", dim " << dim;

Expand All @@ -64,9 +87,10 @@ Normalize(const DataSet& dataset) {
}

// copy and return normalized vectors
std::unique_ptr<float[]>
CopyAndNormalizeVecs(const float* x, size_t rows, int32_t dim) {
auto x_normalized = std::make_unique<float[]>(rows * dim);
template <typename DataType>
std::unique_ptr<DataType[]>
CopyAndNormalizeVecs(const DataType* x, size_t rows, int32_t dim) {
auto x_normalized = std::make_unique<DataType[]>(rows * dim);
std::copy_n(x, rows * dim, x_normalized.get());
NormalizeVecs(x_normalized.get(), rows, dim);
return x_normalized;
Expand Down Expand Up @@ -120,4 +144,29 @@ UseDiskLoad(const std::string& index_type, const int32_t& version) {
#endif
}

template float
NormalizeVec<fp16>(fp16* x, int32_t d);
template float
NormalizeVec<bf16>(bf16* x, int32_t d);

template std::vector<float>
NormalizeVecs<fp32>(fp32* x, size_t rows, int32_t dim);
template std::vector<float>
NormalizeVecs<fp16>(fp16* x, size_t rows, int32_t dim);
template std::vector<float>
NormalizeVecs<bf16>(bf16* x, size_t rows, int32_t dim);

template void
Normalize<fp32>(const DataSet& dataset);
template void
Normalize<fp16>(const DataSet& dataset);
template void
Normalize<bf16>(const DataSet& dataset);

template std::unique_ptr<fp32[]>
CopyAndNormalizeVecs(const fp32* x, size_t rows, int32_t dim);
template std::unique_ptr<fp16[]>
CopyAndNormalizeVecs(const fp16* x, size_t rows, int32_t dim);
template std::unique_ptr<bf16[]>
CopyAndNormalizeVecs(const bf16* x, size_t rows, int32_t dim);
} // namespace knowhere
82 changes: 45 additions & 37 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ using hnswlib::QuantType;

template <typename DataType, QuantType quant_type = QuantType::None>
class HnswIndexNode : public IndexNode {
static_assert(std::is_same_v<DataType, fp32> || std::is_same_v<DataType, bin1>,
"HnswIndexNode only support float/bianry");

public:
using DistType = float;
HnswIndexNode(const int32_t& /*version*/, const Object& object) : index_(nullptr) {
Expand All @@ -49,22 +46,33 @@ class HnswIndexNode : public IndexNode {
auto dim = dataset.GetDim();
auto hnsw_cfg = static_cast<const HnswConfig&>(cfg);
hnswlib::SpaceInterface<DistType>* space = nullptr;
if (IsMetricType(hnsw_cfg.metric_type.value(), metric::L2)) {
space = new (std::nothrow) hnswlib::L2Space(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::IP)) {
space = new (std::nothrow) hnswlib::InnerProductSpace(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE)) {
space = new (std::nothrow) hnswlib::CosineSpace(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::HAMMING)) {
space = new (std::nothrow) hnswlib::HammingSpace(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::JACCARD)) {
space = new (std::nothrow) hnswlib::JaccardSpace(dim);
if constexpr (KnowhereFloatTypeCheck<DataType>::value) {
if (IsMetricType(hnsw_cfg.metric_type.value(), metric::L2)) {
space = new (std::nothrow) hnswlib::L2Space<DataType, DistType>(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::IP)) {
space = new (std::nothrow) hnswlib::InnerProductSpace<DataType, DistType>(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE)) {
space = new (std::nothrow) hnswlib::CosineSpace<DataType, DistType>(dim);
} else {
LOG_KNOWHERE_WARNING_
<< "metric type and data type(float32, float16 and bfloat16) are not match in hnsw: "
<< hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
}
} else {
LOG_KNOWHERE_WARNING_ << "metric type not support in hnsw: " << hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
if (IsMetricType(hnsw_cfg.metric_type.value(), metric::HAMMING)) {
space = new (std::nothrow) hnswlib::HammingSpace(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::JACCARD)) {
space = new (std::nothrow) hnswlib::JaccardSpace(dim);
} else {
LOG_KNOWHERE_WARNING_ << "metric type and data type(binary) are not match in hnsw: "
<< hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
}
}
auto index = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space, rows, hnsw_cfg.M.value(),
hnsw_cfg.efConstruction.value());

auto index = new (std::nothrow) hnswlib::HierarchicalNSW<DataType, DistType, quant_type>(
space, rows, hnsw_cfg.M.value(), hnsw_cfg.efConstruction.value());
if (index == nullptr) {
LOG_KNOWHERE_WARNING_ << "memory malloc error.";
return Status::malloc_error;
Expand All @@ -75,7 +83,7 @@ class HnswIndexNode : public IndexNode {
}
this->index_ = index;
if constexpr (quant_type != QuantType::None) {
this->index_->trainSQuant((const float*)dataset.GetTensor(), rows);
this->index_->trainSQuant((const DataType*)dataset.GetTensor(), rows);
}
return Status::success;
}
Expand Down Expand Up @@ -225,11 +233,11 @@ class HnswIndexNode : public IndexNode {
private:
class iterator : public IndexIterator {
public:
iterator(const hnswlib::HierarchicalNSW<DistType, quant_type>* index, const char* query, const bool transform,
const BitsetView& bitset, const bool for_tuning = false, const size_t seed_ef = kIteratorSeedEf,
const float refine_ratio = 0.5f)
: IndexIterator(transform, (hnswlib::HierarchicalNSW<DistType, quant_type>::sq_enabled &&
hnswlib::HierarchicalNSW<DistType, quant_type>::has_raw_data)
iterator(const hnswlib::HierarchicalNSW<DataType, DistType, quant_type>* index, const char* query,
const bool transform, const BitsetView& bitset, const bool for_tuning = false,
const size_t seed_ef = kIteratorSeedEf, const float refine_ratio = 0.5f)
: IndexIterator(transform, (hnswlib::HierarchicalNSW<DataType, DistType, quant_type>::sq_enabled &&
hnswlib::HierarchicalNSW<DataType, DistType, quant_type>::has_raw_data)
? refine_ratio
: 0.0f),
index_(index),
Expand All @@ -251,15 +259,15 @@ class HnswIndexNode : public IndexNode {
}
float
raw_distance(int64_t id) override {
if constexpr (hnswlib::HierarchicalNSW<DistType, quant_type>::sq_enabled &&
hnswlib::HierarchicalNSW<DistType, quant_type>::has_raw_data) {
if constexpr (hnswlib::HierarchicalNSW<DataType, DistType, quant_type>::sq_enabled &&
hnswlib::HierarchicalNSW<DataType, DistType, quant_type>::has_raw_data) {
return (transform_ ? -1 : 1) * index_->calcRefineDistance(workspace_->raw_query_data.get(), id);
}
throw std::runtime_error("raw_distance not supported: index does not have raw data or sq is not enabled");
}

private:
const hnswlib::HierarchicalNSW<DistType, quant_type>* index_;
const hnswlib::HierarchicalNSW<DataType, DistType, quant_type>* index_;
const bool transform_;
std::unique_ptr<hnswlib::IteratorWorkspace> workspace_;
};
Expand Down Expand Up @@ -466,8 +474,8 @@ class HnswIndexNode : public IndexNode {

MemoryIOReader reader(binary->data.get(), binary->size);

hnswlib::SpaceInterface<float>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space);
hnswlib::SpaceInterface<DistType>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DataType, DistType, quant_type>(space);
index_->loadIndex(reader);
LOG_KNOWHERE_INFO_ << "Loaded HNSW index. #points num:" << index_->max_elements_ << " #M:" << index_->M_
<< " #max level:" << index_->maxlevel_
Expand All @@ -486,8 +494,8 @@ class HnswIndexNode : public IndexNode {
delete index_;
}
try {
hnswlib::SpaceInterface<float>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space);
hnswlib::SpaceInterface<DistType>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DataType, DistType, quant_type>(space);
index_->loadIndex(filename, config);
} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what();
Expand Down Expand Up @@ -581,7 +589,7 @@ class HnswIndexNode : public IndexNode {
}

private:
hnswlib::HierarchicalNSW<DistType, quant_type>* index_;
hnswlib::HierarchicalNSW<DataType, DistType, quant_type>* index_;
std::shared_ptr<ThreadPool> search_pool_;
};

Expand All @@ -592,14 +600,14 @@ KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_DEPRECATED, HnswIndexNode, fp16);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_DEPRECATED, HnswIndexNode, bf16);
#else
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp32);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp16);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, bf16);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp16);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, bf16);
#endif

KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp32, QuantType::SQ8);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp32, QuantType::SQ8Refine);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp16, QuantType::SQ8);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp16, QuantType::SQ8Refine);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, bf16, QuantType::SQ8);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, bf16, QuantType::SQ8Refine);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp16, QuantType::SQ8);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp16, QuantType::SQ8Refine);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, bf16, QuantType::SQ8);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, bf16, QuantType::SQ8Refine);
} // namespace knowhere
10 changes: 5 additions & 5 deletions tests/ut/test_iterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ TEST_CASE("Test Iterator IVFFlatCC With Newly Insert Vectors", "[float metrics]
}
}

TEST_CASE("Test Iterator Mem Index With Binary Metrics", "[float metrics]") {
TEST_CASE("Test Iterator Mem Index With Binary Metrics", "[binary metrics]") {
using Catch::Approx;

const int64_t nb = 1000, nq = 10;
Expand All @@ -348,21 +348,21 @@ TEST_CASE("Test Iterator Mem Index With Binary Metrics", "[float metrics]") {
json[knowhere::indexparam::SEED_EF] = 64;
return json;
};
const auto train_ds = GenDataSet(nb, dim);
const auto query_ds = GenDataSet(nq, dim);
const auto train_ds = GenBinDataSet(nb, dim);
const auto query_ds = GenBinDataSet(nq, dim);

const knowhere::Json conf = {
{knowhere::meta::METRIC_TYPE, metric},
{knowhere::meta::TOPK, topk},
};

auto gt = knowhere::BruteForce::Search<knowhere::fp32>(train_ds, query_ds, conf, nullptr);
auto gt = knowhere::BruteForce::Search<knowhere::bin1>(train_ds, query_ds, conf, nullptr);
SECTION("Test Search using iterator") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::bin1>(name, version).value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
Expand Down
Loading

0 comments on commit 4f4908e

Please sign in to comment.