diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index 9a50431a2fa9c9..7f7508def6c5a3 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -176,7 +176,7 @@ struct LaunchXsmmBackwardInputConvolution { desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; // LIBXSMM_DNN_TENSOR_FORMAT_RSCK; desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; - desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE; + desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE; desc.datatype = LIBXSMM_DNN_DATATYPE_F32; auto input_ptr = input_backward.data(); diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index b3803778c86765..24aa3594904af0 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -228,7 +228,7 @@ class LaunchXsmmConvOp { desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC; desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; - desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE; + desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE; desc.datatype = LIBXSMM_DNN_DATATYPE_F32; if (!CanUseXsmmConv2D(desc, data_format)) { diff --git a/tensorflow/core/kernels/sparse_matmul_op.h b/tensorflow/core/kernels/sparse_matmul_op.h index 61bd6593c37f8b..098b2d650007b2 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.h +++ b/tensorflow/core/kernels/sparse_matmul_op.h @@ -31,11 +31,11 @@ namespace internal { // in the lower 16-bits of input template EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_l(const Packet& from) { - tensorflow::uint32 tmp; + tensorflow::uint32 tmp; #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - tmp = (reinterpret_cast(from) ) & 0xffff0000; -#else - tmp = (reinterpret_cast(from) << 16) & 0xffff0000; + tmp = (reinterpret_cast(from)) & 0xffff0000; +#else + tmp = (reinterpret_cast(from) << 16) & 0xffff0000; #endif return reinterpret_cast(tmp); } @@ -44,12 +44,12 @@ EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_l(const Packet& from) { // in the upper 16-bits of input template EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_u(const Packet& from) { - tensorflow::uint32 tmp; + tensorflow::uint32 tmp; #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - tmp = (reinterpret_cast(from) << 16 ) & 0xffff0000; + tmp = (reinterpret_cast(from) << 16) & 0xffff0000; #else - tmp = (reinterpret_cast(from)) & 0xffff0000; -#endif + tmp = (reinterpret_cast(from)) & 0xffff0000; +#endif return reinterpret_cast(tmp); } @@ -61,12 +61,12 @@ EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) { float r[4]; tensorflow::uint32 p[4]; pstoreu(r, from); - tensorflow::uint32 * ir = reinterpret_cast(r); + tensorflow::uint32* ir = reinterpret_cast(r); p[0] = (ir[0] << 16) & 0xffff0000; - p[1] = ir[0]& 0xffff0000; + p[1] = ir[0] & 0xffff0000; p[2] = (ir[1] << 16) & 0xffff0000; p[3] = ir[1] & 0xffff0000; - return ploadu(reinterpret_cast(p)); + return ploadu(reinterpret_cast(p)); } template @@ -74,12 +74,12 @@ EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) { float r[4]; tensorflow::uint32 p[4]; pstoreu(r, from); - tensorflow::uint32 * ir = reinterpret_cast(r); + tensorflow::uint32* ir = reinterpret_cast(r); p[0] = (ir[2] << 16) & 0xffff0000; p[1] = ir[2] & 0xffff0000; p[2] = (ir[3] << 16) & 0xffff0000; p[3] = ir[3] & 0xffff0000; - return ploadu(reinterpret_cast(p)); + return ploadu(reinterpret_cast(p)); } #endif @@ -131,23 +131,25 @@ EIGEN_DEVICE_FUNC inline Packet pload2bf16( template <> EIGEN_STRONG_INLINE Packet4f pload4bf16(const float* from) { tensorflow::uint32 p[4]; - const tensorflow::uint32* ir = reinterpret_cast(from); + const tensorflow::uint32* ir = + reinterpret_cast(from); p[0] = (ir[0] << 16) & 0xffff0000; - p[1] = ir[0]& 0xffff0000; + p[1] = ir[0] & 0xffff0000; p[2] = (ir[1] << 16) & 0xffff0000; p[3] = ir[1] & 0xffff0000; - return ploadu(reinterpret_cast(p)); + return ploadu(reinterpret_cast(p)); } template <> EIGEN_STRONG_INLINE Packet4f pload2bf16(const float* from) { tensorflow::uint32 p[4]; - const tensorflow::uint32* ir = reinterpret_cast(from); + const tensorflow::uint32* ir = + reinterpret_cast(from); p[0] = (ir[0] << 16) & 0xffff0000; - p[1] = ir[0]& 0xffff0000; + p[1] = ir[0] & 0xffff0000; p[2] = (ir[0] << 16) & 0xffff0000; p[3] = ir[0] & 0xffff0000; - return ploadu(reinterpret_cast(p)); + return ploadu(reinterpret_cast(p)); } #endif @@ -255,12 +257,13 @@ EIGEN_STRONG_INLINE Packet8d pbroadcast_second(const Packet8d& a_in) { } template <> EIGEN_STRONG_INLINE Packet8d pbroadcast_third(const Packet8d& a_in) { - Packet2d a = _mm512_extractf32x4_ps(a_in, 1); + Packet2d a = _mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1); return _mm512_broadcastsd_pd(a); } template <> EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth(const Packet8d& a_in) { - Packet2d a = _mm_permute_pd(_mm512_extractf32x4_ps(a_in, 1), 3); + Packet2d a = + _mm_permute_pd(_mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1), 3); return _mm512_broadcastsd_pd(a); } template <> @@ -417,14 +420,17 @@ EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth(const Packet8f& a) { template EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) { - return _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_castsi512_si256(from)), - 16); + return _mm512_castsi512_ps(_mm512_slli_epi32( + _mm512_cvtepu16_epi32(_mm512_castsi512_si256(_mm512_castps_si512(from))), + 16)); } template EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) { - return _mm512_slli_epi32( - _mm512_cvtepu16_epi32(_mm512_extractf64x4_pd(from, 1)), 16); + Packet16i tmp = _mm512_castps_si512(from); + Packet16i tmp2 = _mm512_alignr_epi32(tmp, tmp, 8); + return _mm512_castsi512_ps(_mm512_slli_epi32( + _mm512_cvtepu16_epi32(_mm512_castsi512_si256(tmp2)), 16)); } #endif diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc index 7936cbcd46f071..6d403b4dfb5c95 100644 --- a/tensorflow/core/kernels/xsmm_conv2d.cc +++ b/tensorflow/core/kernels/xsmm_conv2d.cc @@ -131,32 +131,7 @@ class libxsmm_dnn_conv_desc_wrap { struct HashFunction { std::size_t operator()(const libxsmm_dnn_conv_desc_wrap& w) const { - // unsigned char ptr[sizeof(&w.d)]; - - // memcpy(ptr, (unsigned char *)&w.d, sizeof(&w.d)) - - // - /* - std::ostringstream N,C,H,W,K,R,S,u,v,padh,padw; - - N << w.d.N; C << w.d.C; - H << w.d.H; W << w.d.W; - K << w.d.K; R << w.d.R; - S << w.d.S; u << w.d.u; - v << w.d.v; padh << w.d.pad_h_in; - padw << w.d.pad_w_in; - - - std::string out_ = N.str() + C.str()\ - + H.str() + W.str()\ - + K.str() + R.str()\ - + S.str() + u.str()\ - + v.str() + padh.str()\ - + padw.str(); - // - // - */ - return (std::hash()((unsigned long long)&(w.d))); + return libxsmm_hash(&w.d, sizeof(w.d), 25071975); } }; @@ -221,8 +196,6 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind); if (status == LIBXSMM_DNN_WARN_FALLBACK) { - chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle), - "Destroy handle"); return false; // Use non-libxsmm code } chk_libxsmm_err(status, "Check codegen status"); @@ -324,8 +297,6 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, chk_libxsmm_err(status, "Link filter"); } if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) { - chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_output), "Zero output"); - chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT), "Bind input forward"); diff --git a/third_party/libxsmm.BUILD b/third_party/libxsmm.BUILD index f9f1ea1085712d..4124f2db637689 100644 --- a/third_party/libxsmm.BUILD +++ b/third_party/libxsmm.BUILD @@ -11,19 +11,8 @@ exports_files(["LICENSE"]) libxsmm_interface_arguments = "0 1" # Arguments to ./scripts/libxsmm_config.py, see that file for detailed description. -# ilp64: no -# big: no -# offload: no -# alignment [b] -# prefetch: 1 (auto) -# threshold: fallback to BLAS if n*m*k above this -# synchronize: yes -# jit: yes -# flags -# alpha = 1 -# beta = 1 -# gemm = 2 -libxsmm_config_arguments = "0 0 0 64 1 0 1 1 0 1 1 2" +# rely on default arguments +libxsmm_config_arguments = "" # Arguments to ./scripts/libxsmm_dispatch.py, see that file for detailed description. # (dummy argument)