22#include " dequantize.hpp"
33#include " presets.hpp"
44
5+ #if defined(__INTEL_LLVM_COMPILER)
6+ #if __has_include(<sycl/ext/oneapi/bfloat16.hpp>)
7+ #include < sycl/ext/oneapi/bfloat16.hpp>
8+ #define GGML_SYCL_HAS_BF16
9+ #endif
10+ #endif
11+
512template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
613static void dequantize_block (const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
714 const sycl::nd_item<3 > &item_ct1) {
@@ -566,6 +573,10 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
566573 return dequantize_row_iq4_nl_sycl;
567574 case GGML_TYPE_F32:
568575 return convert_unary_sycl<float >;
576+ #ifdef GGML_SYCL_HAS_BF16
577+ case GGML_TYPE_BF16:
578+ return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
579+ #endif
569580 default :
570581 return nullptr ;
571582 }
@@ -627,6 +638,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
627638 return dequantize_row_iq4_nl_sycl;
628639 case GGML_TYPE_F16:
629640 return convert_unary_sycl<sycl::half>;
641+ #ifdef GGML_SYCL_HAS_BF16
642+ case GGML_TYPE_BF16:
643+ return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
644+ #endif
645+
630646 default :
631647 return nullptr ;
632648 }
@@ -636,7 +652,11 @@ to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
636652 switch (type) {
637653 case GGML_TYPE_F32:
638654 return convert_unary_nc_sycl<float >;
655+ #ifdef GGML_SYCL_HAS_BF16
656+ case GGML_TYPE_BF16:
657+ return convert_unary_nc_sycl<sycl::ext::oneapi::bfloat16>;
658+ #endif
639659 default :
640660 return nullptr ;
641661 }
642- }
662+ }
0 commit comments