Skip to content

Commit

Permalink
hnsw support fp16/bf16
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
  • Loading branch information
cqy123456 committed Apr 19, 2024
1 parent ad6f71f commit 3ef5738
Show file tree
Hide file tree
Showing 9 changed files with 318 additions and 175 deletions.
22 changes: 18 additions & 4 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,21 @@ IsFlatIndex(const knowhere::IndexType& index_type) {
return std::find(flat_index_list.begin(), flat_index_list.end(), index_type) != flat_index_list.end();
}

template <typename DataType>
extern float
NormalizeVec(float* x, int32_t d);
NormalizeVec(DataType* x, int32_t d);

template <typename DataType>
extern std::vector<float>
NormalizeVecs(float* x, size_t rows, int32_t dim);
NormalizeVecs(DataType* x, size_t rows, int32_t dim);

template <typename DataType = knowhere::fp32>
extern void
Normalize(const DataSet& dataset);

extern std::unique_ptr<float[]>
CopyAndNormalizeVecs(const float* x, size_t rows, int32_t dim);
template <typename DataType>
extern std::unique_ptr<DataType[]>
CopyAndNormalizeVecs(const DataType* x, size_t rows, int32_t dim);

constexpr inline uint64_t seed = 0xc70f6907UL;

Expand Down Expand Up @@ -78,6 +82,16 @@ hash_binary_vec(const uint8_t* x, size_t d) {
return h;
}

inline uint64_t
hash_half_precision_float(const void* x, size_t d) {
uint64_t h = seed;
auto u16_x = (uint16_t*)(x);
for (size_t i = 0; i < d; ++i) {
h = h * 13331 + u16_x[i];
}
return h;
}

template <typename DataType>
inline std::string
GetIndexKey(const std::string& name) {
Expand Down
59 changes: 54 additions & 5 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,27 @@ namespace knowhere {
const float FloatAccuracy = 0.00001;

// normalize one vector and return its norm
// todo(cqy123456): Template specialization for fp16/bf16;
// float16 uses the smallest representable positive float16 value(6.1 x 10^(-5)) as FloatAccuracy;
// bfloat16 uses the same FloatAccuracy as float32;
template <typename DataType>
float
NormalizeVec(DataType* x, int32_t d) {
float norm_l2_sqr = 0.0;
for (auto i = 0; i < d; i++) {
norm_l2_sqr += (float)x[i] * (float)x[i];
}
if (norm_l2_sqr > 0 && std::abs(1.0f - norm_l2_sqr) > FloatAccuracy) {
float norm_l2 = std::sqrt(norm_l2_sqr);
for (int32_t i = 0; i < d; i++) {
x[i] = (DataType)((float)x[i] / norm_l2);
}
return norm_l2;
}
return 1.0f;
}

template <>
float
NormalizeVec(float* x, int32_t d) {
float norm_l2_sqr = faiss::fvec_norm_L2sqr(x, d);
Expand All @@ -41,20 +62,22 @@ NormalizeVec(float* x, int32_t d) {
}

// normalize all vectors and return their norms
template <typename DataType>
std::vector<float>
NormalizeVecs(float* x, size_t rows, int32_t dim) {
NormalizeVecs(DataType* x, size_t rows, int32_t dim) {
std::vector<float> norms(rows);
for (size_t i = 0; i < rows; i++) {
norms[i] = NormalizeVec(x + i * dim, dim);
}
return norms;
}

template <typename DataType>
void
Normalize(const DataSet& dataset) {
auto rows = dataset.GetRows();
auto dim = dataset.GetDim();
float* data = (float*)dataset.GetTensor();
auto data = (DataType*)dataset.GetTensor();

LOG_KNOWHERE_DEBUG_ << "vector normalize, rows " << rows << ", dim " << dim;

Expand All @@ -64,9 +87,10 @@ Normalize(const DataSet& dataset) {
}

// copy and return normalized vectors
std::unique_ptr<float[]>
CopyAndNormalizeVecs(const float* x, size_t rows, int32_t dim) {
auto x_normalized = std::make_unique<float[]>(rows * dim);
template <typename DataType>
std::unique_ptr<DataType[]>
CopyAndNormalizeVecs(const DataType* x, size_t rows, int32_t dim) {
auto x_normalized = std::make_unique<DataType[]>(rows * dim);
std::copy_n(x, rows * dim, x_normalized.get());
NormalizeVecs(x_normalized.get(), rows, dim);
return x_normalized;
Expand Down Expand Up @@ -120,4 +144,29 @@ UseDiskLoad(const std::string& index_type, const int32_t& version) {
#endif
}

template float
NormalizeVec<fp16>(fp16* x, int32_t d);
template float
NormalizeVec<bf16>(bf16* x, int32_t d);

template std::vector<float>
NormalizeVecs<fp32>(fp32* x, size_t rows, int32_t dim);
template std::vector<float>
NormalizeVecs<fp16>(fp16* x, size_t rows, int32_t dim);
template std::vector<float>
NormalizeVecs<bf16>(bf16* x, size_t rows, int32_t dim);

template void
Normalize<fp32>(const DataSet& dataset);
template void
Normalize<fp16>(const DataSet& dataset);
template void
Normalize<bf16>(const DataSet& dataset);

template std::unique_ptr<fp32[]>
CopyAndNormalizeVecs(const fp32* x, size_t rows, int32_t dim);
template std::unique_ptr<fp16[]>
CopyAndNormalizeVecs(const fp16* x, size_t rows, int32_t dim);
template std::unique_ptr<bf16[]>
CopyAndNormalizeVecs(const bf16* x, size_t rows, int32_t dim);
} // namespace knowhere
73 changes: 41 additions & 32 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ using hnswlib::QuantType;

template <typename DataType, QuantType quant_type = QuantType::None>
class HnswIndexNode : public IndexNode {
static_assert(std::is_same_v<DataType, fp32> || std::is_same_v<DataType, bin1>,
"HnswIndexNode only support float/bianry");

public:
using DistType = float;
HnswIndexNode(const int32_t& /*version*/, const Object& object) : index_(nullptr) {
Expand All @@ -49,22 +46,33 @@ class HnswIndexNode : public IndexNode {
auto dim = dataset.GetDim();
auto hnsw_cfg = static_cast<const HnswConfig&>(cfg);
hnswlib::SpaceInterface<DistType>* space = nullptr;
if (IsMetricType(hnsw_cfg.metric_type.value(), metric::L2)) {
space = new (std::nothrow) hnswlib::L2Space(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::IP)) {
space = new (std::nothrow) hnswlib::InnerProductSpace(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE)) {
space = new (std::nothrow) hnswlib::CosineSpace(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::HAMMING)) {
space = new (std::nothrow) hnswlib::HammingSpace(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::JACCARD)) {
space = new (std::nothrow) hnswlib::JaccardSpace(dim);
if constexpr (KnowhereFloatTypeCheck<DataType>::value) {
if (IsMetricType(hnsw_cfg.metric_type.value(), metric::L2)) {
space = new (std::nothrow) hnswlib::L2Space<DataType, DistType>(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::IP)) {
space = new (std::nothrow) hnswlib::InnerProductSpace<DataType, DistType>(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE)) {
space = new (std::nothrow) hnswlib::CosineSpace<DataType, DistType>(dim);
} else {
LOG_KNOWHERE_WARNING_
<< "metric type and data type(float32, float16 and bfloat16) are not match in hnsw: "
<< hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
}
} else {
LOG_KNOWHERE_WARNING_ << "metric type not support in hnsw: " << hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
if (IsMetricType(hnsw_cfg.metric_type.value(), metric::HAMMING)) {
space = new (std::nothrow) hnswlib::HammingSpace(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::JACCARD)) {
space = new (std::nothrow) hnswlib::JaccardSpace(dim);
} else {
LOG_KNOWHERE_WARNING_ << "metric type and data type(binary) are not match in hnsw: "
<< hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
}
}
auto index = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space, rows, hnsw_cfg.M.value(),
hnsw_cfg.efConstruction.value());

auto index = new (std::nothrow) hnswlib::HierarchicalNSW<DataType, DistType, quant_type>(
space, rows, hnsw_cfg.M.value(), hnsw_cfg.efConstruction.value());
if (index == nullptr) {
LOG_KNOWHERE_WARNING_ << "memory malloc error.";
return Status::malloc_error;
Expand All @@ -75,7 +83,7 @@ class HnswIndexNode : public IndexNode {
}
this->index_ = index;
if constexpr (quant_type != QuantType::None) {
this->index_->trainSQuant((const float*)dataset.GetTensor(), rows);
this->index_->trainSQuant((const DataType*)dataset.GetTensor(), rows);
}
return Status::success;
}
Expand Down Expand Up @@ -225,8 +233,9 @@ class HnswIndexNode : public IndexNode {
private:
class iterator : public IndexNode::iterator {
public:
iterator(const hnswlib::HierarchicalNSW<DistType, quant_type>* index, const char* query, const bool transform,
const BitsetView& bitset, const bool for_tuning = false, const size_t seed_ef = kIteratorSeedEf)
iterator(const hnswlib::HierarchicalNSW<DataType, DistType, quant_type>* index, const char* query,
const bool transform, const BitsetView& bitset, const bool for_tuning = false,
const size_t seed_ef = kIteratorSeedEf)
: index_(index),
transform_(transform),
workspace_(index_->getIteratorWorkspace(query, seed_ef, for_tuning, bitset)) {
Expand Down Expand Up @@ -258,7 +267,7 @@ class HnswIndexNode : public IndexNode {
has_next_ = false;
}
}
const hnswlib::HierarchicalNSW<DistType, quant_type>* index_;
const hnswlib::HierarchicalNSW<DataType, DistType, quant_type>* index_;
const bool transform_;
std::unique_ptr<hnswlib::IteratorWorkspace> workspace_;
bool has_next_;
Expand Down Expand Up @@ -466,8 +475,8 @@ class HnswIndexNode : public IndexNode {

MemoryIOReader reader(binary->data.get(), binary->size);

hnswlib::SpaceInterface<float>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space);
hnswlib::SpaceInterface<DistType>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DataType, DistType, quant_type>(space);
index_->loadIndex(reader);
LOG_KNOWHERE_INFO_ << "Loaded HNSW index. #points num:" << index_->max_elements_ << " #M:" << index_->M_
<< " #max level:" << index_->maxlevel_
Expand All @@ -486,8 +495,8 @@ class HnswIndexNode : public IndexNode {
delete index_;
}
try {
hnswlib::SpaceInterface<float>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space);
hnswlib::SpaceInterface<DistType>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DataType, DistType, quant_type>(space);
index_->loadIndex(filename, config);
} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what();
Expand Down Expand Up @@ -581,7 +590,7 @@ class HnswIndexNode : public IndexNode {
}

private:
hnswlib::HierarchicalNSW<DistType, quant_type>* index_;
hnswlib::HierarchicalNSW<DataType, DistType, quant_type>* index_;
std::shared_ptr<ThreadPool> search_pool_;
};

Expand All @@ -592,14 +601,14 @@ KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_DEPRECATED, HnswIndexNode, fp16);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_DEPRECATED, HnswIndexNode, bf16);
#else
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp32);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp16);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, bf16);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp16);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, bf16);
#endif

KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp32, QuantType::SQ8);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp32, QuantType::SQ8Refine);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp16, QuantType::SQ8);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp16, QuantType::SQ8Refine);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, bf16, QuantType::SQ8);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, bf16, QuantType::SQ8Refine);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp16, QuantType::SQ8);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp16, QuantType::SQ8Refine);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, bf16, QuantType::SQ8);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, bf16, QuantType::SQ8Refine);
} // namespace knowhere
10 changes: 5 additions & 5 deletions tests/ut/test_iterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ TEST_CASE("Test Iterator IVFFlatCC With Newly Insert Vectors", "[float metrics]
}
}

TEST_CASE("Test Iterator Mem Index With Binary Metrics", "[float metrics]") {
TEST_CASE("Test Iterator Mem Index With Binary Metrics", "[binary metrics]") {
using Catch::Approx;

const int64_t nb = 1000, nq = 10;
Expand All @@ -342,21 +342,21 @@ TEST_CASE("Test Iterator Mem Index With Binary Metrics", "[float metrics]") {
json[knowhere::indexparam::SEED_EF] = 64;
return json;
};
const auto train_ds = GenDataSet(nb, dim);
const auto query_ds = GenDataSet(nq, dim);
const auto train_ds = GenBinDataSet(nb, dim);
const auto query_ds = GenBinDataSet(nq, dim);

const knowhere::Json conf = {
{knowhere::meta::METRIC_TYPE, metric},
{knowhere::meta::TOPK, topk},
};

auto gt = knowhere::BruteForce::Search<knowhere::fp32>(train_ds, query_ds, conf, nullptr);
auto gt = knowhere::BruteForce::Search<knowhere::bin1>(train_ds, query_ds, conf, nullptr);
SECTION("Test Search using iterator") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::bin1>(name, version).value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
Expand Down
Loading

0 comments on commit 3ef5738

Please sign in to comment.