From 415ab8576927063f16afa9c0da2464c7e786a3bf Mon Sep 17 00:00:00 2001 From: mdfaijul Date: Tue, 19 Mar 2019 16:43:26 -0700 Subject: [PATCH 01/12] Add new ops to support quantized Matmul with some fusions New ops include: - QuantizedMatmul + BiasAdd - QuantizedMatmul + BiasAdd + Relu - QuantizedMatmul + BiasAdd + Relu + Requantize --- tensorflow/core/BUILD | 6 +- .../api_def_QuantizedMatMulWithBias.pbtxt | 4 + ...i_def_QuantizedMatMulWithBiasAndRelu.pbtxt | 4 + ...edMatMulWithBiasAndReluAndRequantize.pbtxt | 4 + tensorflow/core/api_def/excluded_ops.cc | 30 +- tensorflow/core/kernels/BUILD | 23 + tensorflow/core/kernels/mkl_qmatmul_op.cc | 875 ++++++++++++++++++ tensorflow/core/ops/mkl_nn_ops.cc | 134 +++ tensorflow/core/ops/nn_ops.cc | 111 ++- tensorflow/core/util/mkl_util.h | 58 +- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 12 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 12 + 12 files changed, 1231 insertions(+), 42 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBias.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBiasAndRelu.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBiasAndReluAndRequantize.pbtxt create mode 100644 tensorflow/core/kernels/mkl_qmatmul_op.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index d2aeadb48a4a13..946043b0ccd0db 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1516,6 +1516,7 @@ cc_library( "//tensorflow/core/kernels:quantized_ops", "//tensorflow/core/kernels/neon:neon_depthwise_conv_op", ]) + if_mkl([ + "//tensorflow/core/kernels:mkl_aggregate_ops", "//tensorflow/core/kernels:mkl_concat_op", "//tensorflow/core/kernels:mkl_conv_op", "//tensorflow/core/kernels:mkl_cwise_ops_common", @@ -1523,15 +1524,15 @@ cc_library( "//tensorflow/core/kernels:mkl_identity_op", "//tensorflow/core/kernels:mkl_input_conversion_op", "//tensorflow/core/kernels:mkl_lrn_op", - "//tensorflow/core/kernels:mkl_requantize_ops", "//tensorflow/core/kernels:mkl_pooling_ops", + "//tensorflow/core/kernels:mkl_qmatmul_op", + "//tensorflow/core/kernels:mkl_requantize_ops", "//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_reshape_op", "//tensorflow/core/kernels:mkl_slice_op", "//tensorflow/core/kernels:mkl_softmax_op", "//tensorflow/core/kernels:mkl_transpose_op", "//tensorflow/core/kernels:mkl_tfconv_op", - "//tensorflow/core/kernels:mkl_aggregate_ops", ]) + if_cuda([ "//tensorflow/core/grappler/optimizers:gpu_swapping_kernels", "//tensorflow/core/grappler/optimizers:gpu_swapping_ops", @@ -4208,6 +4209,7 @@ tf_cc_test_mkl( "//tensorflow/core/kernels:mkl_input_conversion_op", "//tensorflow/core/kernels:mkl_lrn_op", "//tensorflow/core/kernels:mkl_pooling_ops", + "//tensorflow/core/kernels:mkl_qmatmul_op", "//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_reshape_op", "//tensorflow/core/kernels:mkl_slice_op", diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBias.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBias.pbtxt new file mode 100644 index 00000000000000..0e636b4fe340e8 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBias.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "QuantizedMatMulWithBias" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBiasAndRelu.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBiasAndRelu.pbtxt new file mode 100644 index 00000000000000..9cca7b13ba91fc --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBiasAndRelu.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "QuantizedMatMulWithBiasAndRelu" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBiasAndReluAndRequantize.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBiasAndReluAndRequantize.pbtxt new file mode 100644 index 00000000000000..691b43371ff5d9 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBiasAndReluAndRequantize.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "QuantizedMatMulWithBiasAndReluAndRequantize" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/excluded_ops.cc b/tensorflow/core/api_def/excluded_ops.cc index 65d2102ac80579..0b2e2e2b8e4c7d 100644 --- a/tensorflow/core/api_def/excluded_ops.cc +++ b/tensorflow/core/api_def/excluded_ops.cc @@ -19,21 +19,25 @@ namespace tensorflow { const std::unordered_set* GetExcludedOps() { static std::unordered_set* excluded_ops = - new std::unordered_set( - {"BigQueryReader", "GenerateBigQueryReaderPartitions", - "GcsConfigureBlockCache", "GcsConfigureCredentials", + new std::unordered_set({ + "BigQueryReader", "GenerateBigQueryReaderPartitions", + "GcsConfigureBlockCache", "GcsConfigureCredentials", #ifdef INTEL_MKL - // QuantizedFusedOps for Intel CPU - "QuantizedConcatV2", "QuantizedConv2DAndRequantize", - "QuantizedConv2DWithBias", "QuantizedConv2DWithBiasAndRequantize", - "QuantizedConv2DAndRelu", "QuantizedConv2DAndReluAndRequantize", - "QuantizedConv2DWithBiasAndRelu", - "QuantizedConv2DWithBiasAndReluAndRequantize", - "QuantizedConv2DWithBiasSumAndRelu", - "QuantizedConv2DWithBiasSumAndReluAndRequantize", - "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize" + // QuantizedFusedOps for Intel CPU + "QuantizedConcatV2", "QuantizedConv2DAndRequantize", + "QuantizedConv2DWithBias", "QuantizedConv2DWithBiasAndRequantize", + "QuantizedConv2DAndRelu", "QuantizedConv2DAndReluAndRequantize", + "QuantizedConv2DWithBiasAndRelu", + "QuantizedConv2DWithBiasAndReluAndRequantize", + "QuantizedConv2DWithBiasSumAndRelu", + "QuantizedConv2DWithBiasSumAndReluAndRequantize", + "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize", + "QuantizedMatMulWithBias" + "QuantizedMatMulWithBiasAndRelu" + "QuantizedMatMulWithBiasAndReluAndRequantize", + #endif // INTEL_MKL - }); + }); return excluded_ops; } } // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 5d107f43978dd1..ee62db490c58ba 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -6832,6 +6832,29 @@ tf_cc_test( ], ) +tf_mkl_kernel_library( + name = "mkl_qmatmul_op", + srcs = [ + "mkl_qmatmul_op.cc", + ], + hdrs = ["mkl_quantized_conv_ops.h", + "no_op.h", + ], + deps = [ + ":bounds_check", + ":matmul_op", + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:mkl_nn_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + ] + mkl_deps(), +) + + tf_mkl_kernel_library( name = "mkl_conv_op", hdrs = [ diff --git a/tensorflow/core/kernels/mkl_qmatmul_op.cc b/tensorflow/core/kernels/mkl_qmatmul_op.cc new file mode 100644 index 00000000000000..9391b084c4426f --- /dev/null +++ b/tensorflow/core/kernels/mkl_qmatmul_op.cc @@ -0,0 +1,875 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implement a quantized eight-bit version of the matmul operation +// with bias, relu and requantization fusion support utilizing +// mkldnn u8s8s32 inner product API. +// Right now, this version can support +// the input which is quantized as uint8 via either MIN_FIRST or SCALE mode, +// the weight is quantized to int8 via SCALE model and +// bias is always there. +// Other than that, this op does not support other input combination yet. +// When input is quantized to uint8 via MIN_FIRST, bias need compensation. +// The detailed algorithm is illustrated as below: +// +// A𝑓32 is original fp32 activation tensor +// Min(A𝑓32) is minimal scalar value of A𝑓32 +// Max(A𝑓32) is maximum scalar value of A𝑓32 +// MaxAbs(A𝑓32) is absolute maximum scalar value of A𝑓32 +// Qa is the quantizaiton scale of activation +// Au8 is the quantized unsigned int8 activation tensor +// With SCALE quantization, Qa and Au8 can be calculated as below +// Qa = 255/MaxAbs(A𝑓32) +// Au8 = || QaA𝑓32 || +// With MIN_FIRST quantization, Q'a and A'u8 can be calculated as below: +// Q'a = 255/(Max(A𝑓32)–Min(A𝑓32)) +// A'u8 = || Qa(A𝑓32–Min(A𝑓32)*1) || +// where 1 is a vector of all 1s, which is used to do broadcast operation +// || . || mean the round function to nearest integer +// +// W𝑓32 is original fp32 weight tensor +// MaxAbs(W𝑓32) is absolute maximum scalar value of W𝑓32 +// Qw is the quantizaiton scale of weight +// Ws8 is the quantized signed int8 weight tensor +// Qw and Ws8 can be calculated as below +// Qw = 127/MaxAbs(W𝑓32) +// Ws8 = || QwW𝑓32 || +// +// B𝑓32 is original fp32 bias tensor +// Bs32 is converted 32bit integer bias tensor +// With SCALE quantization of activation, +// Bs32 is calucated as below: +// Bs32 = QaQwB𝑓32 +// With MIN_FIRST quantization of activation +// B'𝑓32 is the fp32 bias tensor with compensation +// B's32 is the coverted 32bit integer bias tensor +// B'𝑓32 and B's32 can be calculated as below: +// B'𝑓32 = B𝑓32+Min(A𝑓32)W𝑓32*1 +// B's32=Q'aQwB𝑓32+Q'aMin(A𝑓32) Ws8*1 +// where Q'aQw is the multiply of Q'a and Qw, +// also is called output quantize scale +// +// With Au8, Ws8 and B's32 inputs, the QuantizedMatMulWithBias op +// calculate 32bit integer output as below: +// +// With MIN_FIRST activation quantization +// Xs32 = Ws8A'u8+B's32 +// = QaQwW𝑓32(A𝑓32–Min(A𝑓32)1)+QaMin(A𝑓32)Ws8*1+QaQwB𝑓32 +// = QaQw(W𝑓32A𝑓32+B𝑓32) = QaQwX𝑓32 +// With SCALE activation quantizaiton +// Xs32 = Ws8Au8+Bs32 +// = QaQwW𝑓32A𝑓32+QaQwB𝑓32 +// = QaQw(W𝑓32A𝑓32+B𝑓32) = QaQwX𝑓32 +// +// QuantizedMatMulWithBiasAndRelu op do the same calucation +// as above except adding relu function for the 32bit integer output +// +// QuantizedMatMulWithBiasAndReluAndRequantize op do one more requantize +// calculation based on above. The requantize scale Qr is calulated +// from offline calibration. +// Qr = 255/MaxAbs(X𝑓32) +// Xu8 = QrXs32 +// +// More information of this implmentation can be referred from +// https://software.intel.com/en-us/articles/lower-numerical-precision-deep-learning-inference-and-training +#ifdef INTEL_MKL + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/mkl_quantized_conv_ops.h" +#include "tensorflow/core/kernels/no_op.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/mkl_util.h" + +#include "mkldnn.h" + +using mkldnn::prop_kind; +using mkldnn::stream; +using mkldnn::inner_product_forward; + +namespace { +enum { + QUANTIZE_MODE_MIN_FIRST, + QUANTIZE_MODE_SCALED, +}; +} // namespace + +namespace tensorflow { + +// This structure aggregates multiple inputs to MklDnnMatMul* methods. +struct MklDnnMatMulFwdParams { + memory::dims src_dims; + memory::dims weight_dims; + memory::dims bias_dims; + memory::dims dst_dims; + string dtypes = string(""); + struct PostOpParam { + string name; + std::vector param; + }; + std::vector post_op_params; + + MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims, + memory::dims bias_dims, memory::dims dst_dims) + : src_dims(src_dims), + weight_dims(weight_dims), + bias_dims(bias_dims), + dst_dims(dst_dims) {} +}; +// With quantization, input, weight, and output can have different types +// so we use differnt template parameters for each type +template +class MklDnnMatMulFwdPrimitive : public MklPrimitive { + public: + explicit MklDnnMatMulFwdPrimitive( + const MklDnnMatMulFwdParams& matmulFwdParams) + : cpu_engine_(engine::cpu, 0) { + context_.fwd_stream.reset(new stream(stream::kind::eager)); + // create matmul primitive + if (context_.matmul_fwd == nullptr) { + Setup(matmulFwdParams); + } + } + + ~MklDnnMatMulFwdPrimitive() {} + + // inner-product forward execute with bias + // src_data: input data buffer of src + // weight_data: input data buffer of weight + // bias_data: input data buffer of bias + // dst_data: output data buffer of dst + void Execute(const Tinput* src_data, const Tweight* weight_data, + const Tbias* bias_data, Toutput* dst_data) { + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data))); + context_.weight_mem->set_data_handle( + static_cast(const_cast(weight_data))); + context_.bias_mem->set_data_handle( + static_cast(const_cast(bias_data))); + context_.dst_mem->set_data_handle(static_cast(dst_data)); + context_.fwd_stream->submit(context_.fwd_primitives); + + // after execution, set data handle back + context_.src_mem->set_data_handle(DummyData); + context_.weight_mem->set_data_handle(DummyData); + context_.bias_mem->set_data_handle(DummyData); + context_.dst_mem->set_data_handle(DummyData); + } + + memory::format GetSrcMemoryFormat() const { return context_.src_fmt; } + memory::format GetweightMemoryFormat() const { return context_.weight_fmt; } + std::shared_ptr + GetPrimitiveDesc() const { + return context_.fwd_pd; + } + + private: + // Primitive reuse context for inner-product Fwd op + struct MklDnnMatMulFwdContext { + // expected memory format for this primitive instance + memory::format src_fmt; + memory::format weight_fmt; + + // MKLDNN memory + std::shared_ptr src_mem; + std::shared_ptr weight_mem; + std::shared_ptr bias_mem; + std::shared_ptr dst_mem; + + // desc & primitive desc + std::shared_ptr fwd_desc; + + // memory desc + std::shared_ptr src_md; + std::shared_ptr weight_md; + std::shared_ptr bias_md; + std::shared_ptr dst_md; + + // inner-product primitive + std::shared_ptr fwd_pd; + std::shared_ptr matmul_fwd; + + std::shared_ptr fwd_stream; + std::vector fwd_primitives; + + MklDnnMatMulFwdContext() + : src_fmt(memory::format::any), + weight_fmt(memory::format::any), + src_mem(nullptr), + weight_mem(nullptr), + bias_mem(nullptr), + dst_mem(nullptr), + fwd_desc(nullptr), + src_md(nullptr), + weight_md(nullptr), + bias_md(nullptr), + fwd_pd(nullptr), + matmul_fwd(nullptr), + fwd_stream(nullptr) {} + }; + + void Setup(const MklDnnMatMulFwdParams& matmul_fwd_params) { + // create memory descriptors for inner-product data with no specified format + context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims}, + MklDnnType(), + memory::format::any)); + + context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims}, + MklDnnType(), + memory::format::any)); + + context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims}, + MklDnnType(), + memory::format::any)); + + context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims}, + MklDnnType(), + memory::format::any)); + // create an inner-product + context_.fwd_desc.reset(new inner_product_forward::desc( + prop_kind::forward_inference, *context_.src_md, *context_.weight_md, + *context_.bias_md, *context_.dst_md)); + + context_.fwd_pd.reset(new inner_product_forward::primitive_desc( + *context_.fwd_desc, cpu_engine_)); + + // Check if there is any fusion as post-ops + auto const& post_op_params = matmul_fwd_params.post_op_params; + mkldnn::primitive_attr post_ops_attr; + mkldnn::post_ops post_ops; + if (!post_op_params.empty()) { + for (auto const& post_op_param : post_op_params) { + if (post_op_param.name == "relu") { + DCHECK_EQ(post_op_param.param.size(), 3); + float op_scale = post_op_param.param[0]; + float op_alpha = post_op_param.param[1]; + float op_beta = post_op_param.param[2]; + post_ops.append_eltwise(op_scale, mkldnn::eltwise_relu, op_alpha, + op_beta); + } else if (post_op_param.name == "output_scale") { + DCHECK_EQ(post_op_param.param.size(), 1); + std::vector scales; + scales.push_back(post_op_param.param[0]); + post_ops_attr.set_output_scales(0, scales); + } else { + DCHECK((post_op_param.name == "relu") || + (post_op_param.name == "output_scale")); + } + } + post_ops_attr.set_post_ops(post_ops); + context_.fwd_pd.reset(new inner_product_forward::primitive_desc( + *context_.fwd_desc, post_ops_attr, cpu_engine_)); + } else { + context_.fwd_pd.reset(new inner_product_forward::primitive_desc( + *context_.fwd_desc, cpu_engine_)); + } + + // store the expected memory format + context_.src_fmt = static_cast( + context_.fwd_pd.get()->src_primitive_desc().desc().data.format); + + context_.weight_fmt = static_cast( + context_.fwd_pd.get()->weights_primitive_desc().desc().data.format); + + // create memory primitive based on dummy data + context_.src_mem.reset( + new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData)); + context_.weight_mem.reset( + new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData)); + context_.dst_mem.reset( + new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); + + context_.bias_mem.reset(new memory( + {{{matmul_fwd_params.bias_dims}, MklDnnType(), memory::format::x}, + cpu_engine_}, + DummyData)); + + // create inner-product primitive + context_.matmul_fwd.reset(new inner_product_forward( + *context_.fwd_pd, *context_.src_mem, *context_.weight_mem, + *context_.bias_mem, *context_.dst_mem)); + + context_.fwd_primitives.push_back(*context_.matmul_fwd); + return; + } + + struct MklDnnMatMulFwdContext context_; + engine cpu_engine_; +}; + +template +class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklDnnMatMulFwdPrimitive* Get( + const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, bool do_not_cache) { + MklDnnMatMulFwdPrimitive* matmul_fwd = + nullptr; + + if (do_not_cache) { + // Always create new primitive + matmul_fwd = + new MklDnnMatMulFwdPrimitive( + mkldnn_matmul_fwd_dims); + } else { + // try to find a suitable one in pool + matmul_fwd = dynamic_cast< + MklDnnMatMulFwdPrimitive*>( + MklDnnMatMulFwdPrimitiveFactory::GetInstance() + .GetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims)); + if (matmul_fwd == nullptr) { + matmul_fwd = + new MklDnnMatMulFwdPrimitive( + mkldnn_matmul_fwd_dims); + MklDnnMatMulFwdPrimitiveFactory::GetInstance() + .SetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims, matmul_fwd); + } + } + + return matmul_fwd; + } + + private: + MklDnnMatMulFwdPrimitiveFactory() {} + ~MklDnnMatMulFwdPrimitiveFactory() {} + + static MklDnnMatMulFwdPrimitiveFactory& GetInstance() { + static MklDnnMatMulFwdPrimitiveFactory instance_; + return instance_; + } + + static string CreateKey(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) { + string prefix = "matmul_fwd_"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.src_dims); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_dims); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes); + + // Generate keys for post-ops + for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) { + if (post_op_param.name == "relu") { + DCHECK_EQ(post_op_param.param.size(), 3); + key_creator.AddAsKey(post_op_param.name); + key_creator.AddAsKey(post_op_param.param[0]); + key_creator.AddAsKey(post_op_param.param[1]); + key_creator.AddAsKey(post_op_param.param[2]); + } else if (post_op_param.name == "output_scale") { + DCHECK_EQ(post_op_param.param.size(), 1); + key_creator.AddAsKey(post_op_param.name); + key_creator.AddAsKey(post_op_param.param[0]); + } else { + return string("not_a_key"); + } + } + + return key_creator.GetKey(); + } + + MklPrimitive* GetMklDnnMatMulFwd( + const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) { + string key = CreateKey(mkldnn_matmul_fwd_dims); + return this->GetOp(key); + } + + void SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, + MklPrimitive* op) { + string key = CreateKey(mkldnn_matmul_fwd_dims); + this->SetOp(key, op); + } +}; + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +class MklDnnQuantizedMatMulOp : public OpKernel { + public: + virtual ~MklDnnQuantizedMatMulOp() { + if (this->input_bias_ != nullptr) { + delete this->input_bias_; + input_bias_ = nullptr; + } + if (this->scaled_bias_ != nullptr) { + delete this->scaled_bias_; + scaled_bias_ = nullptr; + } + if (this->comp_bias_ != nullptr) { + delete this->comp_bias_; + comp_bias_ = nullptr; + } + } + + float* GetCompBiasBuffer(int size) { + if (!comp_bias_) { + comp_bias_ = new float[size]; + } + return comp_bias_; + } + + explicit MklDnnQuantizedMatMulOp(OpKernelConstruction* context) + : OpKernel(context) { + string mode_string; + OP_REQUIRES_OK(context, context->GetAttr("input_quant_mode", &mode_string)); + if (mode_string == "MIN_FIRST") { + mode_ = QUANTIZE_MODE_MIN_FIRST; + } else if (mode_string == "SCALED") { + mode_ = QUANTIZE_MODE_SCALED; + } + } + + void Compute(OpKernelContext* context) override { + try { + // Input tensors + const Tensor& src_tensor = MklGetInput(context, kInputIndexSrc); + const Tensor& weight_tensor = MklGetInput(context, kInputIndexWeight); + const Tensor& bias_tensor = MklGetInput(context, kInputIndexBias); + + MklDnnShape src_mkl_shape, weight_mkl_shape; + GetMklShape(context, kInputIndexSrc, &src_mkl_shape); + GetMklShape(context, kInputIndexWeight, &weight_mkl_shape); + OP_REQUIRES(context, weight_mkl_shape.IsMklTensor() == false, + errors::InvalidArgument("weight should not be in " + "Mkl Layout")); + + MklDnnData src(&cpu_engine_); + MklDnnData weight(&cpu_engine_); + + memory::dims src_dims, weight_dims; + memory::dims dst_dims_tf_order, dst_dims_mkl_order; + + // Get shapes of input tensors in MKL-DNN order + auto src_tf_shape = src_mkl_shape.IsMklTensor() + ? src_mkl_shape.GetTfShape() + : src_tensor.shape(); + auto weight_tf_shape = weight_mkl_shape.IsMklTensor() + ? weight_mkl_shape.GetTfShape() + : weight_tensor.shape(); + + src_dims = TFShapeToMklDnnDims(src_tf_shape); + weight_dims = TFShapeToMklDnnDims(weight_tf_shape); + dst_dims_mkl_order = {static_cast(src_tf_shape.dim_size(0)), + static_cast(weight_tf_shape.dim_size(1))}; + + // weight dims need to be reversed to create inner-product forward + // descriptor + weight_dims = {static_cast(weight_tf_shape.dim_size(1)), + static_cast(weight_tf_shape.dim_size(0))}; + + // Create memory for user data. + // Describe how the inputs and outputs of inner-product look like. Also + // specify buffers containing actual input and output data. + Tensor* dst_tensor = nullptr; + auto input_output_fmt = memory::format::nc; + + // If input is in MKL layout, then simply take input layout; otherwise, + // construct input Tf layout. For TF layout, although input shape + // (src_dims) required is in MKL-DNN order, the layout is Tensorflow's + // layout depending on data format. + auto src_md = + src_mkl_shape.IsMklTensor() + ? src_mkl_shape.GetMklLayout() + : memory::desc(src_dims, MklDnnType(), input_output_fmt); + src.SetUsrMem(src_md, &src_tensor); + + // Although weight shape (weight_dims) required is in MKL-DNN order, + // the layout is Tensorflow's layout. + auto weight_md = weight_mkl_shape.IsMklTensor() + ? weight_mkl_shape.GetMklLayout() + : memory::desc(weight_dims, MklDnnType(), + memory::format::io); + weight.SetUsrMem(weight_md, &weight_tensor); + + MklDnnMatMulFwdPrimitive* + matmul_fwd = nullptr; + memory::dims bias_dims = {}; + bias_dims = {static_cast(bias_tensor.dim_size(0))}; + + MklDnnMatMulFwdParams matmul_fwd_dims(src_dims, weight_dims, bias_dims, + dst_dims_mkl_order); + + // Extend the basic parameters for data types and fusions + this->ExtendMklDnnMatMulFwdParams(context, matmul_fwd_dims); + + // get a MatMul fwd from primitive pool + matmul_fwd = + MklDnnMatMulFwdPrimitiveFactory::Get(matmul_fwd_dims, 0); + + // Allocate output Tensor. + std::shared_ptr + matmul_fwd_pd = matmul_fwd->GetPrimitiveDesc(); + AllocateOutputTensor(context, *matmul_fwd_pd, dst_dims_mkl_order, + input_output_fmt, &dst_tensor); + + Toutput* dst_data = + reinterpret_cast(dst_tensor->flat().data()); + + // check if src and weight data need to be reordered. + Tinput* src_data = nullptr; + if (src_md.data.format != matmul_fwd->GetSrcMemoryFormat()) { + src.SetUsrMem(src_md, &src_tensor); + src.CheckReorderToOpMem(matmul_fwd_pd.get()->src_primitive_desc()); + src_data = static_cast(src.GetOpMem().get_data_handle()); + } else { + src_data = static_cast( + const_cast(src_tensor.flat().data())); + } + Tweight* weight_data = nullptr; + if (weight_md.data.format != matmul_fwd->GetweightMemoryFormat()) { + weight.SetUsrMem(weight_md, &weight_tensor); + weight.CheckReorderToOpMem( + matmul_fwd_pd.get()->weights_primitive_desc()); + weight_data = + static_cast(weight.GetOpMem().get_data_handle()); + } else { + weight_data = static_cast( + const_cast(weight_tensor.flat().data())); + } + + // execute inner-product + Tbias* bias_data = this->GetBiasHandle(context, matmul_fwd_pd, + bias_tensor, weight_tensor); + matmul_fwd->Execute(src_data, weight_data, bias_data, dst_data); + } catch (mkldnn::error& e) { + string error_msg = tensorflow::strings::StrCat( + "Status: ", e.status, ", message: ", string(e.message), ", in file ", + __FILE__, ":", __LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); + } + + // Compute additional outputs: min/max scalars. + const float min_input = context->input(3).flat()(0); + const float max_input = context->input(4).flat()(0); + const float min_weight = context->input(5).flat()(0); + const float max_weight = context->input(6).flat()(0); + + float min_output_value; + float max_output_value; + if (std::is_same::value || + std::is_same::value) { + // This is the case the inner-product and requantization are fused. + // min_freezed_output and max_freezed_output are the actual range + // for the output + min_output_value = context->input(7).flat()(0); + max_output_value = context->input(8).flat()(0); + } else { + MklQuantizationRangeForMultiplication( + min_input, max_input, min_weight, max_weight, &min_output_value, + &max_output_value); + } + + Tensor* output_min = nullptr; + Tensor* output_max = nullptr; + MklDnnShape output_min_mkl_shape, output_max_mkl_shape; + output_min_mkl_shape.SetMklTensor(false); + output_max_mkl_shape.SetMklTensor(false); + AllocateOutputSetMklShape(context, 1, &output_min, {}, + output_min_mkl_shape); + AllocateOutputSetMklShape(context, 2, &output_max, {}, + output_max_mkl_shape); + output_min->flat()(0) = min_output_value; + output_max->flat()(0) = max_output_value; + } + + protected: + virtual void ExtendMklDnnMatMulFwdParams(OpKernelContext* context, + MklDnnMatMulFwdParams& params) { + // Append data type names of input, weight, bias, and output. + params.dtypes.append(typeid(Tinput).name()); + params.dtypes.append(typeid(Tweight).name()); + params.dtypes.append(typeid(Tbias).name()); + params.dtypes.append(typeid(Toutput).name()); + + // When the output type is quint8, the output data is requantized + // into quint8. A post_op "output_scale" is added to do the conversion. + if (std::is_same::value || + std::is_same::value) { + const float min_input = context->input(3).flat()(0); + const float max_input = context->input(4).flat()(0); + const float min_weight = context->input(5).flat()(0); + const float max_weight = context->input(6).flat()(0); + const float min_freezed_output = context->input(7).flat()(0); + const float max_freezed_output = context->input(8).flat()(0); + + float min_output_value; + float max_output_value; + MklQuantizationRangeForMultiplication( + min_input, max_input, min_weight, max_weight, &min_output_value, + &max_output_value); + float scale_int32 = + std::max(std::abs(min_output_value), std::abs(max_output_value)); + float scale_eightbit = + std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); + float scale = 1.0; + if (std::is_same::value) + scale = scale_int32 / scale_eightbit / static_cast(1 << 23); + else + scale = scale_int32 / scale_eightbit / static_cast(1 << 24); + + std::vector output_scale; + output_scale.push_back(scale); + params.post_op_params.push_back({"output_scale", output_scale}); + } + } + + // This function handles bias conversion and compensation + // for MIN_FIRST and SCALE mode + // If input is quantized via MIN_FIRST + // Bs32=QaQwB𝑓32 + QaMin(A𝑓32) Ws8*1 + // If input is quantized via SCALE + // Bs32=QaQwB𝑓32 + // where QaQw is the multiply of Qa and Qw, + // also is called output quantize scale + Tbias* GetBiasHandle( + OpKernelContext* context, + std::shared_ptr& + mkldnn_matmul_fwd_pd, + const Tensor& bias_tensor, const Tensor& weight_tensor) { + // If the bias is int32, it means the bias is already be converted offline. + // and it can be added to matmul output directly. + if (std::is_same::value) { + return static_cast( + const_cast(bias_tensor.flat().data())); + } else { + // If the bias is fp32, then need to calculate the bias + const float min_input = context->input(3).flat()(0); + const float max_input = context->input(4).flat()(0); + const float min_weight = context->input(5).flat()(0); + const float max_weight = context->input(6).flat()(0); + + std::vector net; + float out_scale; + // If the bias is float and input quantize is MIN_FIRST + // bias has to be compensated with + // Bs32=QaQwB𝑓32 + QaMin(A𝑓32) Ws8*1 + if (mode_ == QUANTIZE_MODE_MIN_FIRST) { + int k = weight_tensor.dim_size(0); + int n = weight_tensor.dim_size(1); + float* comp_bias = GetCompBiasBuffer(n); + + qint8* wt_buf = static_cast( + const_cast(weight_tensor.flat().data())); + + const float* bias_buf = static_cast( + const_cast(bias_tensor.flat().data())); + + float qa_amin = 255 * min_input / (max_input - min_input); + + out_scale = (255.0 * 127.0) / + ((max_input - min_input) * + std::max(std::abs(max_weight), std::abs(min_weight))); + +#pragma omp parallel for schedule(static) + for (int j = 0; j < n; j++) { + int x = 0; + for (int i = 0; i < k; i++) { + x += wt_buf[i * n + j]; + } + comp_bias[j] = + ((bias_buf[j] * out_scale) + static_cast(x * qa_amin)); + } + + return reinterpret_cast(comp_bias_); + + } else { + // If the bias is float and input quantize is SCALE + // bias has to be compensated with + // Bs32=QaQwBf32 + out_scale = 255.0 * 127.0 / + (std::max(std::abs(max_input), std::abs(min_input)) * + std::max(std::abs(max_weight), std::abs(min_weight))); + + std::vector scales; + scales.push_back(out_scale); + mkldnn::primitive_attr bias_attr; + bias_attr.set_output_scales(0, scales); + + void* bias_buf = static_cast( + const_cast(bias_tensor.flat().data())); + input_bias_ = + new memory(mkldnn_matmul_fwd_pd->bias_primitive_desc(), bias_buf); + scaled_bias_ = new memory(mkldnn_matmul_fwd_pd->bias_primitive_desc()); + auto reorder_desc = mkldnn::reorder::primitive_desc( + input_bias_->get_primitive_desc(), + scaled_bias_->get_primitive_desc(), bias_attr); + net.push_back( + mkldnn::reorder(reorder_desc, *input_bias_, *scaled_bias_)); + stream(stream::kind::eager).submit(net).wait(); + return reinterpret_cast(scaled_bias_->get_data_handle()); + } + } + } + + // Allocate output tensor. + virtual void AllocateOutputTensor( + OpKernelContext* context, + const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc, + const memory::dims& output_dims_mkl_order, + memory::format output_tf_format, Tensor** output_tensor) { + CHECK_NOTNULL(output_tensor); + auto dst_pd = mkldnn_matmul_prim_desc.dst_primitive_desc(); + + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(true); + output_mkl_shape.SetMklLayout(&dst_pd); + output_mkl_shape.SetElemType(MklDnnType()); + output_mkl_shape.SetTfLayout2D(output_dims_mkl_order.size(), + output_dims_mkl_order, output_tf_format); + + TensorShape output_tf_shape; + output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput))); + + // Allocate Output Tensor + AllocateOutputSetMklShape(context, kOutputIndexDst, output_tensor, + output_tf_shape, output_mkl_shape); + } + + engine cpu_engine_ = engine(engine::cpu, 0); + + private: + memory* input_bias_ = nullptr; + memory* scaled_bias_ = nullptr; + + // buffer to save the compensated bias + float* comp_bias_ = nullptr; + + const int kInputIndexSrc = 0, kInputIndexWeight = 1, kInputIndexBias = 2; + const int kOutputIndexDst = 0; + + int mode_; +}; + +template +class MklDnnQuantizedMatMulReluOp + : public MklDnnQuantizedMatMulOp { + public: + virtual ~MklDnnQuantizedMatMulReluOp() {} + + explicit MklDnnQuantizedMatMulReluOp(OpKernelConstruction* context) + : MklDnnQuantizedMatMulOp( + context) {} + + protected: + void ExtendMklDnnMatMulFwdParams(OpKernelContext* context, + MklDnnMatMulFwdParams& params) override { + MklDnnQuantizedMatMulOp::ExtendMklDnnMatMulFwdParams(context, + params); + params.post_op_params.push_back({"relu", {1.0, 0.0, 0.0}}); + } +}; + +// kernel registration +// Register NoOp kernel for QuantizedMatMulWithBias to get a python interface. +// This kernel will be replaced by an MKL kernel during graph +// optimization pass. +REGISTER_KERNEL_BUILDER(Name("QuantizedMatMulWithBias") + .Device(DEVICE_CPU) + .TypeConstraint("T1") + .TypeConstraint("T2") + .TypeConstraint("Toutput"), + NoOp); + +REGISTER_KERNEL_BUILDER( + Name("_MklQuantizedMatMulWithBias") + .Device(DEVICE_CPU) + .TypeConstraint("T1") + .TypeConstraint("T2") + .TypeConstraint("Tbias") + .TypeConstraint("Toutput") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklDnnQuantizedMatMulOp); +REGISTER_KERNEL_BUILDER( + Name("_MklQuantizedMatMulWithBias") + .Device(DEVICE_CPU) + .TypeConstraint("T1") + .TypeConstraint("T2") + .TypeConstraint("Tbias") + .TypeConstraint("Toutput") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklDnnQuantizedMatMulOp); + +// Register NoOp kernel for QuantizedMatMulWithBiasAndRelu to get a python +// interface. +// This kernel will be replaced by an MKL kernel during graph-optimization pass. +REGISTER_KERNEL_BUILDER(Name("QuantizedMatMulWithBiasAndRelu") + .Device(DEVICE_CPU) + .TypeConstraint("T1") + .TypeConstraint("T2") + .TypeConstraint("Toutput"), + NoOp); + +// Register NoOp kernel for QuantizedIPWithBiasAndReluAndRequantize +// to get a python interface. +// This kernel will be replaced by an MKL kernel during graph-optimization pass. +REGISTER_KERNEL_BUILDER(Name("QuantizedMatMulWithBiasAndReluAndRequantize") + .Device(DEVICE_CPU) + .TypeConstraint("T1") + .TypeConstraint("T2") + .TypeConstraint("Tbias") + .TypeConstraint("Toutput"), + NoOp); +REGISTER_KERNEL_BUILDER(Name("QuantizedMatMulWithBiasAndReluAndRequantize") + .Device(DEVICE_CPU) + .TypeConstraint("T1") + .TypeConstraint("T2") + .TypeConstraint("Tbias") + .TypeConstraint("Toutput"), + NoOp); + +// Register a templatized implementation of _MklQuantizedMatMulWithBiasAndRelu. +REGISTER_KERNEL_BUILDER( + Name("_MklQuantizedMatMulWithBiasAndRelu") + .Device(DEVICE_CPU) + .TypeConstraint("T1") + .TypeConstraint("T2") + .TypeConstraint("Toutput") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklDnnQuantizedMatMulReluOp); + +// Register a templatized implementation of +// _MklQuantizedMatMulWithBiasAndReluAndRequantize. +REGISTER_KERNEL_BUILDER( + Name("_MklQuantizedMatMulWithBiasAndReluAndRequantize") + .Device(DEVICE_CPU) + .TypeConstraint("T1") + .TypeConstraint("T2") + .TypeConstraint("Tbias") + .TypeConstraint("Toutput") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklDnnQuantizedMatMulReluOp); +REGISTER_KERNEL_BUILDER( + Name("_MklQuantizedMatMulWithBiasAndReluAndRequantize") + .Device(DEVICE_CPU) + .TypeConstraint("T1") + .TypeConstraint("T2") + .TypeConstraint("Tbias") + .TypeConstraint("Toutput") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklDnnQuantizedMatMulReluOp); + +} // namespace tensorflow +#endif // INTEL_MKL diff --git a/tensorflow/core/ops/mkl_nn_ops.cc b/tensorflow/core/ops/mkl_nn_ops.cc index 0e6ad9162a54c4..70a36bb792ab63 100644 --- a/tensorflow/core/ops/mkl_nn_ops.cc +++ b/tensorflow/core/ops/mkl_nn_ops.cc @@ -758,6 +758,140 @@ REGISTER_OP("_MklDepthwiseConv2dNativeBackpropFilter") return Status::OK(); }); +REGISTER_OP("_MklQuantizedMatMulWithBias") + .Input("a: T1") + .Input("b: T2") + .Input("bias: Tbias") + .Input("min_a: float") + .Input("max_a: float") + .Input("min_b: float") + .Input("max_b: float") + .Input("mkl_a: uint8") // MKl second tensor + .Input("mkl_b: uint8") // MKl second tensor + .Input("mkl_bias: uint8") // MKl second tensor + .Input("mkl_min_a: uint8") // MKl second tensor + .Input("mkl_max_a: uint8") // MKl second tensor + .Input("mkl_min_b: uint8") // MKl second tensor + .Input("mkl_max_b: uint8") // MKl second tensor + .Output("out: Toutput") + .Output("min_out: float") + .Output("max_out: float") + .Output("mkl_out: uint8") // MKl second tensor + .Output("mkl_min_out: uint8") // MKl second tensor + .Output("mkl_max_out: uint8") // MKl second tensor + .Attr("T1: quantizedtype") + .Attr("T2: quantizedtype") + .Attr("Tbias: {float, qint32}") + .Attr("T: quantizedtype") // Additional attr "T" for MklToTf conversion + .Attr("Toutput: quantizedtype = DT_QINT32") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("_MklQuantizedMatMulWithBiasAndRelu") + .Input("a: T1") + .Input("b: T2") + .Input("bias: float") + .Input("min_a: float") + .Input("max_a: float") + .Input("min_b: float") + .Input("max_b: float") + .Input("mkl_a: uint8") // MKl second tensor + .Input("mkl_b: uint8") // MKl second tensor + .Input("mkl_bias: uint8") // MKl second tensor + .Input("mkl_min_a: uint8") // MKl second tensor + .Input("mkl_max_a: uint8") // MKl second tensor + .Input("mkl_min_b: uint8") // MKl second tensor + .Input("mkl_max_b: uint8") // MKl second tensor + .Output("out: Toutput") + .Output("min_out: float") + .Output("max_out: float") + .Output("mkl_out: uint8") // MKl second tensor + .Output("mkl_min_out: uint8") // MKl second tensor + .Output("mkl_max_out: uint8") // MKl second tensor + .Attr("T1: quantizedtype") + .Attr("T2: quantizedtype") + .Attr("T: quantizedtype") // Additional attr "T" for MklToTf conversion + .Attr("Toutput: quantizedtype = DT_QINT32") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("_MklQuantizedMatMulWithBiasAndReluAndRequantize") + .Input("a: T1") + .Input("b: T2") + .Input("bias: Tbias") + .Input("min_a: float") + .Input("max_a: float") + .Input("min_b: float") + .Input("max_b: float") + .Input("min_freezed_output: float") + .Input("max_freezed_output: float") + .Input("mkl_a: uint8") // MKl second tensor + .Input("mkl_b: uint8") // MKl second tensor + .Input("mkl_bias: uint8") // MKl second tensor + .Input("mkl_min_a: uint8") // MKl second tensor + .Input("mkl_max_a: uint8") // MKl second tensor + .Input("mkl_min_b: uint8") // MKl second tensor + .Input("mkl_max_b: uint8") // MKl second tensor + .Input("mkl_min_freezed_output: uint8") // MKl second tensor + .Input("mkl_max_freezed_output: uint8") // MKl second tensor + .Output("out: Toutput") + .Output("min_out: float") + .Output("max_out: float") + .Output("mkl_out: uint8") // MKl second tensor + .Output("mkl_min_out: uint8") // MKl second tensor + .Output("mkl_max_out: uint8") // MKl second tensor + .Attr("T1: quantizedtype") + .Attr("T2: quantizedtype") + .Attr("Tbias: {float, qint32}") + .Attr("T: quantizedtype") // Additional attr "T" for MklToTf conversion + .Attr("Toutput: quantizedtype = DT_QUINT8") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); + + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 2b1d031be86c9f..71be6567936081 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1216,9 +1216,9 @@ Status TopKShapeFn(InferenceContext* c) { DimensionHandle last_dim = c->Dim(input, -1); if (c->ValueKnown(last_dim) && c->ValueKnown(k_dim) && c->Value(last_dim) < c->Value(k_dim)) { - return errors::InvalidArgument( - "input must have last dimension >= k = ", c->Value(k_dim), " but is ", - c->Value(last_dim)); + return errors::InvalidArgument("input must have last dimension >= k = ", + c->Value(k_dim), " but is ", + c->Value(last_dim)); } // Replace last_dim with k_dim. @@ -1272,9 +1272,9 @@ REGISTER_OP("NthElement") DimensionHandle last_dim = c->Dim(input, -1); if (c->ValueKnown(last_dim) && c->ValueKnown(n_dim) && c->Value(last_dim) <= c->Value(n_dim)) { - return errors::InvalidArgument( - "Input must have last dimension > n = ", c->Value(n_dim), - " but is ", c->Value(last_dim)); + return errors::InvalidArgument("Input must have last dimension > n = ", + c->Value(n_dim), " but is ", + c->Value(last_dim)); } // Reduce last_dim for output tensor @@ -2891,4 +2891,103 @@ REGISTER_OP("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize") return Status::OK(); }); +// Fusion of Quantized MatMul and BiasAdd. +REGISTER_OP("QuantizedMatMulWithBias") + .Input("a: T1") + .Input("b: T2") + .Input("bias: Tbias") + .Input("min_a: float") + .Input("max_a: float") + .Input("min_b: float") + .Input("max_b: float") + .Output("out: Toutput") + .Output("min_out: float") + .Output("max_out: float") + .Attr("T1: quantizedtype") + .Attr("T2: quantizedtype") + .Attr("Tbias: {float, qint32}") + .Attr("Toutput: quantizedtype = DT_QINT32") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("QuantizedMatMulWithBiasAndRelu") + .Input("a: T1") + .Input("b: T2") + .Input("bias: float") + .Input("min_a: float") + .Input("max_a: float") + .Input("min_b: float") + .Input("max_b: float") + .Output("out: Toutput") + .Output("min_out: float") + .Output("max_out: float") + .Attr("T1: quantizedtype") + .Attr("T2: quantizedtype") + .Attr("Toutput: quantizedtype = DT_QINT32") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("QuantizedMatMulWithBiasAndReluAndRequantize") + .Input("a: T1") + .Input("b: T2") + .Input("bias: Tbias") + .Input("min_a: float") + .Input("max_a: float") + .Input("min_b: float") + .Input("max_b: float") + .Input("min_freezed_output: float") + .Input("max_freezed_output: float") + .Output("out: Toutput") + .Output("min_out: float") + .Output("max_out: float") + .Attr("T1: quantizedtype") + .Attr("T2: quantizedtype") + .Attr("Tbias: {float, qint32}") + .Attr("Toutput: quantizedtype = DT_QUINT8") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); + + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); } // namespace tensorflow diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index fcd2e18944a26e..c343ca7292bcb5 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -295,32 +295,32 @@ class MklShape { CHECK_EQ(dnnDelete_F32(convert), E_SUCCESS); } - // The following methods are used for serializing and de-serializing the - // contents of the mklshape object. - // The data is serialized in this order - // isMklTensor_ - // dimension_ - // sizes_ - // strides_ - // mklLayout_ - // tfLayout_ - // tf_to_mkl_dim_map_ +// The following methods are used for serializing and de-serializing the +// contents of the mklshape object. +// The data is serialized in this order +// isMklTensor_ +// dimension_ +// sizes_ +// strides_ +// mklLayout_ +// tfLayout_ +// tf_to_mkl_dim_map_ #define SIZE_OF_MKL_DNN_BUF \ (dnnLayoutSerializationBufferSize_F32()) // Size of buffer needed to // serialize dnn_layout pointer - // Size of buffer to hold the serialized object, the size is computed as - // follows sizeof(isMklTensor_) + sizeof(dimension_) + sizeof(sizes_) + - // sizeof(strides_) - // + sizeof(mklLayout_ buffer) + sizeof(tfLayout_ buffer) - // + sizeof(tf_to_mkl_dim_map_) +// Size of buffer to hold the serialized object, the size is computed as +// follows sizeof(isMklTensor_) + sizeof(dimension_) + sizeof(sizes_) + +// sizeof(strides_) +// + sizeof(mklLayout_ buffer) + sizeof(tfLayout_ buffer) +// + sizeof(tf_to_mkl_dim_map_) #define SIZE_OF_MKL_SERIAL_DATA(dims) \ (2 * sizeof(size_t) + 3 * dims * sizeof(size_t) + 2 * SIZE_OF_MKL_DNN_BUF) - // First we need to define some macro for offsets into the serial buffer where - // different elements of Mklshape is written/read from +// First we need to define some macro for offsets into the serial buffer where +// different elements of Mklshape is written/read from #define IS_MKL_TENSOR_OFFSET 0 // Location from start of buffer where isMklTensor_ is serialized @@ -633,7 +633,8 @@ class MklDnnShape { /// also be Blocked format. inline void SetTfLayout(size_t dims, const memory::dims& sizes, memory::format format) { - CHECK_EQ(dims, sizes.size()); + DCHECK(dims != sizes.size()) << "SetTfLayout: Number of dimensions is not" + "match with dimension array"; data_.dimension_ = dims; for (size_t ii = 0; ii < dims; ii++) { data_.sizes_[ii] = sizes[ii]; @@ -644,6 +645,21 @@ class MklDnnShape { } } + inline void SetTfLayout2D(size_t dims, const memory::dims& sizes, + memory::format format) { + DCHECK(dims != sizes.size()) << "SetTfLayout2D: Number of dimensions is" + "match with dimension array"; + data_.dimension_ = dims; + for (size_t ii = 0; ii < dims; ii++) { + data_.sizes_[ii] = sizes[ii]; + } + data_.tf_data_format_ = format; + if (format != memory::format::blocked) { + data_.map_[0] = MklDnnDims::Dim_N; + data_.map_[1] = MklDnnDims::Dim_C; + } + } + inline const memory::desc GetTfLayout() const { memory::dims dims; for (size_t ii = 0; ii < data_.dimension_; ii++) { @@ -862,9 +878,9 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor, CHECK(output_tensor.CopyFrom(mkl_tensor, output_shape)); } } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); LOG(FATAL) << "Operation received an exception: " << error_msg; } return output_tensor; diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 717bb8c72ef78c..303c6a26b49de0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -2480,6 +2480,18 @@ tf_module { name: "QuantizedMatMul" argspec: "args=[\'a\', \'b\', \'min_a\', \'max_a\', \'min_b\', \'max_b\', \'Toutput\', \'transpose_a\', \'transpose_b\', \'Tactivation\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'False\', \'False\', \"\", \'None\'], " } + member_method { + name: "QuantizedMatMulWithBias" + argspec: "args=[\'a\', \'b\', \'bias\', \'min_a\', \'max_a\', \'min_b\', \'max_b\', \'Toutput\', \'transpose_a\', \'transpose_b\', \'input_quant_mode\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "QuantizedMatMulWithBiasAndRelu" + argspec: "args=[\'a\', \'b\', \'bias\', \'min_a\', \'max_a\', \'min_b\', \'max_b\', \'Toutput\', \'transpose_a\', \'transpose_b\', \'input_quant_mode\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "QuantizedMatMulWithBiasAndReluAndRequantize" + argspec: "args=[\'a\', \'b\', \'bias\', \'min_a\', \'max_a\', \'min_b\', \'max_b\', \'min_freezed_output\', \'max_freezed_output\', \'Toutput\', \'transpose_a\', \'transpose_b\', \'input_quant_mode\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "QuantizedMaxPool" argspec: "args=[\'input\', \'min_input\', \'max_input\', \'ksize\', \'strides\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 717bb8c72ef78c..303c6a26b49de0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -2480,6 +2480,18 @@ tf_module { name: "QuantizedMatMul" argspec: "args=[\'a\', \'b\', \'min_a\', \'max_a\', \'min_b\', \'max_b\', \'Toutput\', \'transpose_a\', \'transpose_b\', \'Tactivation\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'False\', \'False\', \"\", \'None\'], " } + member_method { + name: "QuantizedMatMulWithBias" + argspec: "args=[\'a\', \'b\', \'bias\', \'min_a\', \'max_a\', \'min_b\', \'max_b\', \'Toutput\', \'transpose_a\', \'transpose_b\', \'input_quant_mode\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "QuantizedMatMulWithBiasAndRelu" + argspec: "args=[\'a\', \'b\', \'bias\', \'min_a\', \'max_a\', \'min_b\', \'max_b\', \'Toutput\', \'transpose_a\', \'transpose_b\', \'input_quant_mode\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "QuantizedMatMulWithBiasAndReluAndRequantize" + argspec: "args=[\'a\', \'b\', \'bias\', \'min_a\', \'max_a\', \'min_b\', \'max_b\', \'min_freezed_output\', \'max_freezed_output\', \'Toutput\', \'transpose_a\', \'transpose_b\', \'input_quant_mode\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "QuantizedMaxPool" argspec: "args=[\'input\', \'min_input\', \'max_input\', \'ksize\', \'strides\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " From f4b3f9d119fd1d2ab1469fd88cad718009a6590d Mon Sep 17 00:00:00 2001 From: mdfaijul Date: Thu, 9 May 2019 09:50:00 -0700 Subject: [PATCH 02/12] Addressed PR-review comments. --- .../api_def_QuantizedMatMulWithBias.pbtxt | 81 ++++++++++++++++ ...i_def_QuantizedMatMulWithBiasAndRelu.pbtxt | 82 ++++++++++++++++ ...edMatMulWithBiasAndReluAndRequantize.pbtxt | 96 +++++++++++++++++++ tensorflow/core/kernels/BUILD | 9 +- tensorflow/core/kernels/mkl_qmatmul_op.cc | 41 ++++---- tensorflow/core/ops/mkl_nn_ops.cc | 64 ++++++------- tensorflow/core/util/mkl_util.h | 10 +- 7 files changed, 320 insertions(+), 63 deletions(-) diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBias.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBias.pbtxt index 0e636b4fe340e8..4f78565e3fcc1c 100644 --- a/tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBias.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_QuantizedMatMulWithBias.pbtxt @@ -1,4 +1,85 @@ op { graph_op_name: "QuantizedMatMulWithBias" visibility: HIDDEN + in_arg { + name: "a" + description: <