Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement BF16 scalar quantizer for faiss #570

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions thirdparty/faiss/benchs/bench_fw/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def optimize_codec(
[
(None, "Flat"),
(None, "SQfp16"),
(None, "SQbf16"),
(None, "SQ8"),
] + [
(f"OPQ{M}_{M * dim}", f"PQ{M}x{b}")
Expand Down
1 change: 1 addition & 0 deletions thirdparty/faiss/c_api/IndexScalarQuantizer_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ typedef enum FaissQuantizerType {
QT_fp16,
QT_8bit_direct, ///< fast indexing of uint8s
QT_6bit, ///< 6 bits per component
QT_bf16,
} FaissQuantizerType;

// forward declaration
Expand Down
3 changes: 3 additions & 0 deletions thirdparty/faiss/contrib/factory_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def get_code_size(d, indexkey):
return (d * 6 + 7) // 8
elif indexkey == 'SQfp16':
return d * 2
elif indexkey == 'SQbf16':
return d * 2

mo = re.match('PCAR?(\\d+),(.*)$', indexkey)
if mo:
Expand Down Expand Up @@ -123,6 +125,7 @@ def reverse_index_factory(index):
faiss.ScalarQuantizer.QT_4bit: "4",
faiss.ScalarQuantizer.QT_6bit: "6",
faiss.ScalarQuantizer.QT_fp16: "fp16",
faiss.ScalarQuantizer.QT_bf16: "bf16",
}
return f"SQ{sqtypes[index.sq.qtype]}"

Expand Down
3 changes: 2 additions & 1 deletion thirdparty/faiss/faiss/IndexScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ IndexScalarQuantizer::IndexScalarQuantizer(
MetricType metric)
: IndexFlatCodes(0, d, metric), sq(d, qtype) {
is_trained = qtype == ScalarQuantizer::QT_fp16 ||
qtype == ScalarQuantizer::QT_8bit_direct;
qtype == ScalarQuantizer::QT_8bit_direct ||
qtype == ScalarQuantizer::QT_bf16;
code_size = sq.code_size;
}

Expand Down
5 changes: 5 additions & 0 deletions thirdparty/faiss/faiss/impl/ScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ void ScalarQuantizer::set_derived_sizes() {
code_size = d * 2;
bits = 16;
break;
case QT_bf16:
code_size = d * 2;
bits = 16;
break;
}
}

Expand Down Expand Up @@ -127,6 +131,7 @@ void ScalarQuantizer::train(size_t n, const float* x) {
break;
case QT_fp16:
case QT_8bit_direct:
case QT_bf16:
// no training necessary
break;
}
Expand Down
3 changes: 2 additions & 1 deletion thirdparty/faiss/faiss/impl/ScalarQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ struct ScalarQuantizer : Quantizer {
QT_4bit_uniform,
QT_fp16,
QT_8bit_direct, ///< fast indexing of uint8s
QT_6bit, ///< 6 bits per component
QT_6bit, ///< 6 bits per component,
QT_bf16,
};

QuantizerType qtype = QT_8bit;
Expand Down
43 changes: 43 additions & 0 deletions thirdparty/faiss/faiss/impl/ScalarQuantizerCodec.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/ScalarQuantizer.h>
#include <faiss/impl/ScalarQuantizerOp.h>
#include <faiss/utils/bf16.h>
#include <faiss/utils/fp16.h>
#include <faiss/utils/utils.h>

Expand Down Expand Up @@ -227,6 +228,37 @@ struct QuantizerFP16<1> : SQuantizer {
}
};

/*******************************************************************
* BF16 quantizer
*******************************************************************/

template <int SIMDWIDTH>
struct QuantizerBF16 {};

template <>
struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer {
const size_t d;

QuantizerBF16(size_t d, const std::vector<float>& /* unused */) : d(d) {}

void encode_vector(const float* x, uint8_t* code) const final {
for (size_t i = 0; i < d; i++) {
((uint16_t*)code)[i] = encode_bf16(x[i]);
}
}

void decode_vector(const uint8_t* code, float* x) const final {
for (size_t i = 0; i < d; i++) {
x[i] = decode_bf16(((uint16_t*)code)[i]);
}
}

FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
const {
return decode_bf16(((uint16_t*)code)[i]);
}
};

/*******************************************************************
* 8bit_direct quantizer
*******************************************************************/
Expand Down Expand Up @@ -282,6 +314,8 @@ SQuantizer* select_quantizer_1(
d, trained);
case ScalarQuantizer::QT_fp16:
return new QuantizerFP16<SIMDWIDTH>(d, trained);
case ScalarQuantizer::QT_bf16:
return new QuantizerBF16<SIMDWIDTH>(d, trained);
case ScalarQuantizer::QT_8bit_direct:
return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
}
Expand Down Expand Up @@ -511,6 +545,10 @@ SQDistanceComputer* select_distance_computer(
return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
d, trained);

case ScalarQuantizer::QT_bf16:
return new DCTemplate<QuantizerBF16<SIMDWIDTH>, Sim, SIMDWIDTH>(
d, trained);

case ScalarQuantizer::QT_8bit_direct:
if (d % 16 == 0) {
return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
Expand Down Expand Up @@ -613,6 +651,11 @@ InvertedListScanner* sel1_InvertedListScanner(
QuantizerFP16<SIMDWIDTH>,
Similarity,
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
case ScalarQuantizer::QT_bf16:
return sel2_InvertedListScanner<DCTemplate<
QuantizerBF16<SIMDWIDTH>,
Similarity,
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
case ScalarQuantizer::QT_8bit_direct:
if (sq->d % 16 == 0) {
return sel2_InvertedListScanner<
Expand Down
40 changes: 40 additions & 0 deletions thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,33 @@ struct QuantizerFP16_avx<8> : public QuantizerFP16<1> {
}
};

/*******************************************************************
* BF16 quantizer
*******************************************************************/

template <int SIMDWIDTH>
struct QuantizerBF16_avx {};

template <>
struct QuantizerBF16_avx<1> : public QuantizerBF16<1> {
QuantizerBF16_avx(size_t d, const std::vector<float>& unused)
: QuantizerBF16<1>(d, unused) {}
};

template <>
struct QuantizerBF16_avx<8> : public QuantizerBF16<1> {
QuantizerBF16_avx(size_t d, const std::vector<float>& trained)
: QuantizerBF16<1>(d, trained) {}

FAISS_ALWAYS_INLINE __m256
reconstruct_8_components(const uint8_t* code, int i) const {
__m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i));
__m256i code_256i = _mm256_cvtepu16_epi32(code_128i);
code_256i = _mm256_slli_epi32(code_256i, 16);
return _mm256_castsi256_ps(code_256i);
}
};

/*******************************************************************
* 8bit_direct quantizer
*******************************************************************/
Expand Down Expand Up @@ -239,6 +266,8 @@ SQuantizer* select_quantizer_1_avx(
d, trained);
case QuantizerType::QT_fp16:
return new QuantizerFP16_avx<SIMDWIDTH>(d, trained);
case QuantizerType::QT_bf16:
return new QuantizerBF16_avx<SIMDWIDTH>(d, trained);
case QuantizerType::QT_8bit_direct:
return new Quantizer8bitDirect_avx<SIMDWIDTH>(d, trained);
}
Expand Down Expand Up @@ -581,6 +610,12 @@ SQDistanceComputer* select_distance_computer_avx(
Sim,
SIMDWIDTH>(d, trained);

case QuantizerType::QT_bf16:
return new DCTemplate_avx<
QuantizerBF16_avx<SIMDWIDTH>,
Sim,
SIMDWIDTH>(d, trained);

case QuantizerType::QT_8bit_direct:
if (d % 16 == 0) {
return new DistanceComputerByte_avx<Sim, SIMDWIDTH>(d, trained);
Expand Down Expand Up @@ -659,6 +694,11 @@ InvertedListScanner* sel1_InvertedListScanner_avx(
QuantizerFP16_avx<SIMDWIDTH>,
Similarity,
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
case QuantizerType::QT_bf16:
return sel2_InvertedListScanner_avx<DCTemplate_avx<
QuantizerBF16_avx<SIMDWIDTH>,
Similarity,
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
case QuantizerType::QT_8bit_direct:
if (sq->d % 16 == 0) {
return sel2_InvertedListScanner_avx<
Expand Down
46 changes: 46 additions & 0 deletions thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,39 @@ struct QuantizerFP16_avx512<16> : public QuantizerFP16_avx<8> {
}
};

/*******************************************************************
* BF16 quantizer
*******************************************************************/

template <int SIMDWIDTH>
struct QuantizerBF16_avx512 {};

template <>
struct QuantizerBF16_avx512<1> : public QuantizerBF16_avx<1> {
QuantizerBF16_avx512(size_t d, const std::vector<float>& unused)
: QuantizerBF16_avx<1>(d, unused) {}
};

template <>
struct QuantizerBF16_avx512<8> : public QuantizerBF16_avx<8> {
QuantizerBF16_avx512(size_t d, const std::vector<float>& trained)
: QuantizerBF16_avx<8>(d, trained) {}
};

template <>
struct QuantizerBF16_avx512<16> : public QuantizerBF16_avx<8> {
QuantizerBF16_avx512(size_t d, const std::vector<float>& trained)
: QuantizerBF16_avx<8>(d, trained) {}

FAISS_ALWAYS_INLINE __m512
reconstruct_16_components(const uint8_t* code, int i) const {
__m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i));
__m512i code_512i = _mm512_cvtepu16_epi32(code_256i);
code_512i = _mm512_slli_epi32(code_512i, 16);
return _mm512_castsi512_ps(code_512i);
}
};

/*******************************************************************
* 8bit_direct quantizer
*******************************************************************/
Expand Down Expand Up @@ -269,6 +302,8 @@ SQuantizer* select_quantizer_1_avx512(
SIMDWIDTH>(d, trained);
case QuantizerType::QT_fp16:
return new QuantizerFP16_avx512<SIMDWIDTH>(d, trained);
case QuantizerType::QT_bf16:
return new QuantizerBF16_avx512<SIMDWIDTH>(d, trained);
case QuantizerType::QT_8bit_direct:
return new Quantizer8bitDirect_avx512<SIMDWIDTH>(d, trained);
}
Expand Down Expand Up @@ -653,6 +688,12 @@ SQDistanceComputer* select_distance_computer_avx512(
Sim,
SIMDWIDTH>(d, trained);

case QuantizerType::QT_bf16:
return new DCTemplate_avx512<
QuantizerBF16_avx512<SIMDWIDTH>,
Sim,
SIMDWIDTH>(d, trained);

case QuantizerType::QT_8bit_direct:
if (d % 16 == 0) {
return new DistanceComputerByte_avx512<Sim, SIMDWIDTH>(
Expand Down Expand Up @@ -732,6 +773,11 @@ InvertedListScanner* sel1_InvertedListScanner_avx512(
QuantizerFP16_avx512<SIMDWIDTH>,
Similarity,
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
case QuantizerType::QT_bf16:
return sel2_InvertedListScanner_avx512<DCTemplate_avx512<
QuantizerBF16_avx512<SIMDWIDTH>,
Similarity,
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
case QuantizerType::QT_8bit_direct:
if (sq->d % 16 == 0) {
return sel2_InvertedListScanner_avx512<
Expand Down
40 changes: 40 additions & 0 deletions thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,33 @@ struct QuantizerFP16_neon<8> : public QuantizerFP16<1> {
}
};

/*******************************************************************
* BF16 quantizer
*******************************************************************/

template <int SIMDWIDTH>
struct QuantizerBF16_neon {};

template <>
struct QuantizerBF16_neon<1> : public QuantizerBF16<1> {
QuantizerBF16_neon(size_t d, const std::vector<float>& unused)
: QuantizerBF16<1>(d, unused) {}
};

template <>
struct QuantizerBF16_neon<8> : public QuantizerBF16<1> {
QuantizerBF16_neon(size_t d, const std::vector<float>& trained)
: QuantizerBF16<1>(d, trained) {}

FAISS_ALWAYS_INLINE float32x4x2_t
reconstruct_8_components(const uint8_t* code, int i) const {
uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
return {vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(codei.val[0]), 16)),
vreinterpretq_f32_u32(
vshlq_n_u32(vmovl_u16(codei.val[1]), 16))};
}
};

/*******************************************************************
* 8bit_direct quantizer
*******************************************************************/
Expand Down Expand Up @@ -212,6 +239,8 @@ SQuantizer* select_quantizer_1_neon(
d, trained);
case QuantizerType::QT_fp16:
return new QuantizerFP16_neon<SIMDWIDTH>(d, trained);
case QuantizerType::QT_bf16:
return new QuantizerBF16_neon<SIMDWIDTH>(d, trained);
case QuantizerType::QT_8bit_direct:
return new Quantizer8bitDirect_neon<SIMDWIDTH>(d, trained);
}
Expand Down Expand Up @@ -556,6 +585,12 @@ SQDistanceComputer* select_distance_computer_neon(
Sim,
SIMDWIDTH>(d, trained);

case QuantizerType::QT_bf16:
return new DCTemplate_neon<
QuantizerBF16_neon<SIMDWIDTH>,
Sim,
SIMDWIDTH>(d, trained);

case QuantizerType::QT_8bit_direct:
if (d % 16 == 0) {
return new DistanceComputerByte_neon<Sim, SIMDWIDTH>(d, trained);
Expand Down Expand Up @@ -634,6 +669,11 @@ InvertedListScanner* sel1_InvertedListScanner_neon(
QuantizerFP16_neon<SIMDWIDTH>,
Similarity,
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
case QuantizerType::QT_bf16:
return sel2_InvertedListScanner_neon<DCTemplate_neon<
QuantizerBF16_neon<SIMDWIDTH>,
Similarity,
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
case QuantizerType::QT_8bit_direct:
if (sq->d % 16 == 0) {
return sel2_InvertedListScanner_neon<
Expand Down
3 changes: 2 additions & 1 deletion thirdparty/faiss/faiss/index_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,9 @@ std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
{"SQ4", ScalarQuantizer::QT_4bit},
{"SQ6", ScalarQuantizer::QT_6bit},
{"SQfp16", ScalarQuantizer::QT_fp16},
{"SQbf16", ScalarQuantizer::QT_bf16},
};
const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16)";
const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16|SQbf16)";

std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
{"_Nfloat", AdditiveQuantizer::ST_norm_float},
Expand Down
Loading
Loading