Skip to content

Commit

Permalink
Make: Upgrade to the newest SimSIMD
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Jul 30, 2023
1 parent 11d7844 commit 368d853
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
19 changes: 13 additions & 6 deletions include/usearch/index_plugins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1173,27 +1173,34 @@ class metric_punned_t {
};
}

template <typename scalar_at>
static stl_function_t pun_stl_(std::function<result_t(scalar_at const*, scalar_at const*)> typed) {
return [=](byte_t const* a, byte_t const* b) -> result_t {
return typed((scalar_at const*)a, (scalar_at const*)b);
};
}

// clang-format off
static metric_punned_t ip_metric_f32_(std::size_t bytes_per_vector) {
#if USEARCH_USE_SIMSIMD
if (hardware_supports(isa_kind_t::sve_k)) return {[=](f32_t const* a, f32_t const* b) { return simsimd_dot_f32sve(a, b, bytes_per_vector / 4); }, bytes_per_vector, metric_kind_t::ip_k, scalar_kind_t::f32_k, isa_kind_t::sve_k};
if (hardware_supports(isa_kind_t::neon_k) && bytes_per_vector % 16 == 0) return {[=](f32_t const* a, f32_t const* b) { return simsimd_dot_f32x4neon(a, b, bytes_per_vector / 4); }, bytes_per_vector, metric_kind_t::ip_k, scalar_kind_t::f32_k, isa_kind_t::neon_k};
if (hardware_supports(isa_kind_t::avx2_k) && bytes_per_vector % 16 == 0) return {[=](f32_t const* a, f32_t const* b) { return simsimd_dot_f32x4avx2(a, b, bytes_per_vector / 4); }, bytes_per_vector, metric_kind_t::ip_k, scalar_kind_t::f32_k, isa_kind_t::avx2_k};
if (hardware_supports(isa_kind_t::sve_k)) return {pun_stl_<f32_t>([=](f32_t const* a, f32_t const* b) { return simsimd_dot_f32_sve(a, b, bytes_per_vector / 4); }), bytes_per_vector, metric_kind_t::ip_k, scalar_kind_t::f32_k, isa_kind_t::sve_k};
if (hardware_supports(isa_kind_t::neon_k) && bytes_per_vector % 16 == 0) return {pun_stl_<f32_t>([=](f32_t const* a, f32_t const* b) { return simsimd_dot_f32x4_neon(a, b, bytes_per_vector / 4); }), bytes_per_vector, metric_kind_t::ip_k, scalar_kind_t::f32_k, isa_kind_t::neon_k};
if (hardware_supports(isa_kind_t::avx2_k) && bytes_per_vector % 16 == 0) return {pun_stl_<f32_t>([=](f32_t const* a, f32_t const* b) { return simsimd_dot_f32x4_avx2(a, b, bytes_per_vector / 4); }), bytes_per_vector, metric_kind_t::ip_k, scalar_kind_t::f32_k, isa_kind_t::avx2_k};
#endif
return {to_stl_<metric_ip_gt<f32_t>>(bytes_per_vector), bytes_per_vector, metric_kind_t::ip_k, scalar_kind_t::f32_k, isa_kind_t::auto_k};
}

static metric_punned_t cos_metric_f16_(std::size_t bytes_per_vector) {
#if USEARCH_USE_SIMSIMD
if (hardware_supports(isa_kind_t::avx512_k) && bytes_per_vector % 32 == 0) return {[=](simsimd_f16_t const* a, simsimd_f16_t const* b) { return simsimd_cos_f16x16avx512(a, b, bytes_per_vector / 2); }, bytes_per_vector, metric_kind_t::cos_k, scalar_kind_t::f16_k, isa_kind_t::avx512_k};
if (hardware_supports(isa_kind_t::neon_k) && bytes_per_vector % 8 == 0) return {[=](simsimd_f16_t const* a, simsimd_f16_t const* b) { return simsimd_cos_f16x4neon(a, b, bytes_per_vector / 2); }, bytes_per_vector, metric_kind_t::cos_k, scalar_kind_t::f16_k, isa_kind_t::neon_k};
if (hardware_supports(isa_kind_t::avx512_k) && bytes_per_vector % 32 == 0) return {pun_stl_<simsimd_f16_t>([=](simsimd_f16_t const* a, simsimd_f16_t const* b) { return simsimd_cos_f16x16_avx512(a, b, bytes_per_vector / 2); }), bytes_per_vector, metric_kind_t::cos_k, scalar_kind_t::f16_k, isa_kind_t::avx512_k};
if (hardware_supports(isa_kind_t::neon_k) && bytes_per_vector % 8 == 0) return {pun_stl_<simsimd_f16_t>([=](simsimd_f16_t const* a, simsimd_f16_t const* b) { return simsimd_cos_f16x4_neon(a, b, bytes_per_vector / 2); }), bytes_per_vector, metric_kind_t::cos_k, scalar_kind_t::f16_k, isa_kind_t::neon_k};
#endif
return {to_stl_<metric_cos_gt<f16_t, f32_t>>(bytes_per_vector), bytes_per_vector, metric_kind_t::cos_k, scalar_kind_t::f16_k, isa_kind_t::auto_k};
}

static metric_punned_t cos_metric_f8_(std::size_t bytes_per_vector) {
#if USEARCH_USE_SIMSIMD
if (hardware_supports(isa_kind_t::neon_k) && bytes_per_vector % 16 == 0) return {[=](int8_t const* a, int8_t const* b) { return simsimd_cos_i8x16neon(a, b, bytes_per_vector); }, bytes_per_vector, metric_kind_t::cos_k, scalar_kind_t::f8_k, isa_kind_t::neon_k};
if (hardware_supports(isa_kind_t::neon_k) && bytes_per_vector % 16 == 0) return {pun_stl_<int8_t>([=](int8_t const* a, int8_t const* b) { return simsimd_cos_i8x16_neon(a, b, bytes_per_vector); }), bytes_per_vector, metric_kind_t::cos_k, scalar_kind_t::f8_k, isa_kind_t::neon_k};
#endif
return {to_stl_<cos_f8_t>(bytes_per_vector), bytes_per_vector, metric_kind_t::cos_k, scalar_kind_t::f8_k, isa_kind_t::auto_k};
}
Expand Down

0 comments on commit 368d853

Please sign in to comment.