Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Intel MKL] Changes in common files to enable MKL Quantized ops with native format #45107

125 changes: 67 additions & 58 deletions tensorflow/core/common_runtime/mkl_layout_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
{csinfo_.depthwise_conv2d_grad_filter,
mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.dequantize,
mkl_op_registry::GetMklOpName(csinfo_.dequantize),
CopyAttrsAll, DequantizeRewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.dequantize, mkl_op_registry::GetMklOpName(csinfo_.dequantize),
CopyAttrsAll, DequantizeRewrite, kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.fused_batch_norm,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
Expand Down Expand Up @@ -553,114 +553,119 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
GetRewriteCause()});
rinfo_.push_back({csinfo_.quantized_avg_pool,
mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.quantized_concatv2,
mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()});
CopyAttrsAll, ConcatV2Rewrite, kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.quantized_conv2d,
mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
GetRewriteCause()});
kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.quantized_conv2d_per_channel,
mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_per_channel),
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.quantized_conv2d_with_requantize,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_conv2d_with_requantize),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
GetRewriteCause()});
kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.quantized_conv2d_with_bias,
mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_with_bias),
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_requantize,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_conv2d_with_bias_and_requantize),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
GetRewriteCause()});
kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.quantized_conv2d_and_relu,
mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_and_relu),
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.quantized_conv2d_and_relu_and_requantize,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_conv2d_and_relu_and_requantize),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
GetRewriteCause()});
kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_relu,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_conv2d_with_bias_and_relu),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
GetRewriteCause()});
kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize),
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.quantized_max_pool,
mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_conv2d_with_bias_sum_and_relu),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
GetRewriteCause()});
kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize),
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize,
mkl_op_registry::GetMklOpName(
csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize),
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.quantized_matmul_with_bias,
mkl_op_registry::GetMklOpName(csinfo_.quantized_matmul_with_bias),
CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite});
CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_relu,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_matmul_with_bias_and_relu),
CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite});
CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.quantized_matmul_with_bias_and_relu_and_requantize,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_matmul_with_bias_and_relu_and_requantize),
CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite});
CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_requantize,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_matmul_with_bias_and_requantize),
CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite});
CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_dequantize,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_matmul_with_bias_and_dequantize),
CopyAttrsQuantizedMatMulWithBiasAndDequantize,
AlwaysRewrite});
AlwaysRewrite, kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.quantized_depthwise_conv2d,
mkl_op_registry::GetMklOpName(csinfo_.quantized_depthwise_conv2d),
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.quantized_depthwise_conv2d_with_bias,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_depthwise_conv2d_with_bias),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
GetRewriteCause()});
kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.quantized_depthwise_conv2d_with_bias_and_relu,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_depthwise_conv2d_with_bias_and_relu),
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize,
mkl_op_registry::GetMklOpName(
csinfo_
.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize),
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.quantize_v2,
mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
CopyAttrsAll, QuantizeOpRewrite, GetRewriteCause()});
CopyAttrsAll, QuantizeOpRewrite,
kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.relu_grad,
Expand Down Expand Up @@ -2356,17 +2361,10 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
return nn_slot_idx;
}

Status MklLayoutRewritePass::SetUpInputs(
std::unique_ptr<Graph>* g,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
NodeBuilder* nb, const Node* old_node) {
// Let's check if we need to add workspace tensors for this node.
// We add workspace edge only for MaxPool, LRN and BatchNorm.
std::vector<NodeBuilder::NodeOut> workspace_tensors;
bool are_workspace_tensors_available = false;

// Avoid workspace check for QuantizedConv2D and the fused
// Ops as they don't have attribute: "T".
// This method finds out if checking workspace is needed or not. Workspace is
// not used in quantized ops, so checking that would fail as quantized ops
// don't have attribute: "T".
bool IsWorkspaceCheckNeeded(const Node* node) {
std::vector<string> quant_ops{
"Dequantize",
"QuantizeV2",
Expand All @@ -2391,12 +2389,23 @@ Status MklLayoutRewritePass::SetUpInputs(
"QuantizedDepthwiseConv2DWithBias",
"QuantizedDepthwiseConv2DWithBiasAndRelu",
"QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize"};
bool should_check_workspace =
std::find(std::begin(quant_ops), std::end(quant_ops),
old_node->type_string()) == std::end(quant_ops);
if (should_check_workspace)
return std::find(std::begin(quant_ops), std::end(quant_ops),
node->type_string()) == std::end(quant_ops);
}

Status MklLayoutRewritePass::SetUpInputs(
std::unique_ptr<Graph>* g,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
NodeBuilder* nb, const Node* old_node) {
// Let's check if we need to add workspace tensors for this node.
// We add workspace edge only for MaxPool, LRN and BatchNorm.
std::vector<NodeBuilder::NodeOut> workspace_tensors;
bool are_workspace_tensors_available = false;

if (IsWorkspaceCheckNeeded(old_node)) {
AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
&are_workspace_tensors_available);
}

int new_node_input_slots = 0;
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
Expand Down Expand Up @@ -2767,7 +2776,6 @@ void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node,
nb->Attr("is_filter_const", filter_node->IsConstant());
nb->Attr("strides", strides);
nb->Attr("dilations", dilations);
nb->Attr("T", out_type); // added "T" for facilitating MklToTf conversion.
nb->Attr("data_format", data_format);
if (has_padding_list) {
nb->Attr("padding_list", padding_list);
Expand All @@ -2787,11 +2795,6 @@ void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBiasAndDequantize(
Node* filter_node = nullptr;
TF_CHECK_OK(orig_node->input_node(1, &filter_node));
nb->Attr("is_weight_const", filter_node->IsConstant());

// Add "T" for facilitating MklToTf conversion.
DataType T1;
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T1", &T1));
nb->Attr("T", T1);
}

void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBias(
Expand All @@ -2811,7 +2814,6 @@ void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBias(
nb->Attr("T2", T2);
nb->Attr("Toutput", Toutput);
nb->Attr("is_weight_const", weight_node->IsConstant());
nb->Attr("T", Toutput); // added "T" for facilitating MklToTf conversion.

// Requantization attr Tbias
DataType Tbias;
Expand Down Expand Up @@ -3583,11 +3585,13 @@ Status MklLayoutRewritePass::RewriteNodeForJustOpNameChange(

std::vector<NodeBuilder::NodeOut> workspace_tensors;
bool are_workspace_tensors_available = false;
AddWorkSpaceEdgeIfNeeded(g, orig_node, &nb, &workspace_tensors,
&are_workspace_tensors_available);
if (are_workspace_tensors_available) {
DCHECK_EQ(workspace_tensors.size(), 1);
nb.Input(workspace_tensors[0].node, workspace_tensors[0].index);
if (IsWorkspaceCheckNeeded(orig_node)) {
AddWorkSpaceEdgeIfNeeded(g, orig_node, &nb, &workspace_tensors,
&are_workspace_tensors_available);
if (are_workspace_tensors_available) {
DCHECK_EQ(workspace_tensors.size(), 1);
nb.Input(workspace_tensors[0].node, workspace_tensors[0].index);
}
}

if (!NativeFormatEnabled()) {
Expand All @@ -3596,7 +3600,12 @@ Status MklLayoutRewritePass::RewriteNodeForJustOpNameChange(
ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, false);
}

nb.Attr("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
if (DataTypeIsQuantized(orig_node->input_type(0)) ||
DataTypeIsQuantized(orig_node->output_type(0))) {
nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
} else {
nb.Attr("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
}

// Finalize graph and get new node.
s = nb.Finalize(&**g, new_node);
Expand Down Expand Up @@ -3681,12 +3690,12 @@ MklLayoutRewritePass::CheckForQuantizedNodeRewrite(const Node* n) const {

if (TryGetNodeAttr(n->def(), "Tinput", &Tinput) &&
TryGetNodeAttr(n->def(), "Tfilter", &Tfilter) &&
mkl_op_registry::IsMklLayoutDependentOp(
mkl_op_registry::IsMklQuantizedOp(
mkl_op_registry::GetMklOpName(n->type_string()), Tinput, Tfilter)) {
type_attrs_present = true;
} else if (TryGetNodeAttr(n->def(), "T1", &T1) &&
TryGetNodeAttr(n->def(), "T2", &T2) &&
mkl_op_registry::IsMklLayoutDependentOp(
mkl_op_registry::IsMklQuantizedOp(
mkl_op_registry::GetMklOpName(n->type_string()), T1, T2)) {
type_attrs_present = true;
}
Expand Down
45 changes: 45 additions & 0 deletions tensorflow/core/framework/common_shape_fns.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2452,6 +2452,51 @@ Status SparseReduceShapeFn(InferenceContext* c) {
return UnknownShape(c);
}

Status QuantizedConv2DShape(InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
c->set_output(1, c->Scalar());
c->set_output(2, c->Scalar());
return Status::OK();
}

Status QuantizedAvgPoolShape(InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::AvgPoolShape(c));
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
c->set_output(1, c->Scalar());
c->set_output(2, c->Scalar());
return Status::OK();
}

Status QuantizeV2Shape(InferenceContext* c) {
int axis = -1;
Status s = c->GetAttr("axis", &axis);
if (!s.ok() && s.code() != error::NOT_FOUND) {
return s;
}
const int minmax_rank = (axis == -1) ? 0 : 1;
TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
ShapeHandle minmax;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
if (axis != -1) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
DimensionHandle depth;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
}
c->set_output(1, minmax);
c->set_output(2, minmax);
return Status::OK();
}

} // namespace shape_inference

} // namespace tensorflow
9 changes: 9 additions & 0 deletions tensorflow/core/framework/common_shape_fns.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,15 @@ Status ExplicitShapes(InferenceContext* c);
// Shape function for SparseReduceMax and SparseReduceSum.
Status SparseReduceShapeFn(InferenceContext* c);

// Shape function for QuantizedConv2D op.
Status QuantizedConv2DShape(InferenceContext* c);

// Shape function for QuantizedAvgPool op
Status QuantizedAvgPoolShape(InferenceContext* c);

// Shape function for QuantizeV2 op
Status QuantizeV2Shape(InferenceContext* c);

} // namespace shape_inference

} // namespace tensorflow
Expand Down
17 changes: 10 additions & 7 deletions tensorflow/core/graph/mkl_graph_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ inline string GetMklNativeOpName(const string& name) {
(0 == name.compare("ConjugateTranspose") ||
0 == name.compare("BatchMatMul") || 0 == name.compare("BatchMatMulV2") ||
0 == name.compare("Einsum") || 0 == name.compare("MatMul") ||
0 == name.compare("Transpose"));
0 == name.compare("Transpose") || 0 == name.compare("QuantizeV2") ||
0 == name.compare("Dequantize") || 0 == name.rfind("Quantized", 0));

if (result) {
return string(kMklOpPrefix) + name;
} else {
Expand Down Expand Up @@ -203,10 +205,6 @@ static inline void BF16UnsupportedWarning() {
static inline bool IsMklLayoutDependentOp(const string& op_name, DataType T) {
string kernel = KernelsRegisteredForOp(op_name);

// Restrict quantized ops to QUINT8 and QINT8 for now
if (kernel.find(kMklQuantizedOpLabelPattern) != string::npos) {
return (T == DT_QUINT8 || T == DT_QINT8 || T == DT_QINT32);
}
// Restrict regular ops to FLOAT and BFLOAT16
if (kernel.find(kMklLayoutDependentOpLabelPattern) != string::npos) {
if (T == DT_FLOAT) return true;
Expand All @@ -228,8 +226,8 @@ static inline bool IsMklLayoutDependentOp(const string& op_name, DataType T) {
// TODO(mdfaijul): QuantizedConv2D is registered with input: QUINT8
// filter:QINT8 for mkldnn integration. First a dummy kernel is created
// and then it is replaced by an actual kernel.
static inline bool IsMklLayoutDependentOp(const string& op_name,
DataType Tinput, DataType Tfilter) {
static inline bool IsMklQuantizedOp(const string& op_name, DataType Tinput,
DataType Tfilter) {
string kernel = KernelsRegisteredForOp(op_name);

// Restrict quantized ops to QUINT8 and QINT8 for now
Expand All @@ -256,6 +254,11 @@ static inline bool IsMklNameChangeOp(const string& op_name, DataType T) {
// device='CPU'; label='MklNameChangeOp'; T in [DT_DOUBLE]
// device='CPU'; label='MklNameChangeOp'; T in [DT_FLOAT]

if (kernel.find(kMklQuantizedOpLabelPattern) != string::npos) {
// Restrict quantized ops to QUINT8, QINT8 and DT_QINT32
return (T == DT_QUINT8 || T == DT_QINT8 || T == DT_QINT32);
}

// Now we just construct a search string to match what we are looking for.
string search_string = kMklNameChangeOpLabelPattern;
search_string += string(";") + string(" T in [");
Expand Down