diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index 9d9a0994d..3c79b4cfa 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -3327,8 +3327,17 @@ static Status TranslateQuantizeAndDequantizeV2Op( op->name(), ng_r_et, ng::Shape(), std::vector({scale})); auto ng_offset = ConstructNgNode( op->name(), ng_q_et, ng::Shape(), std::vector({0})); - ng::op::Quantize::RoundMode ng_round_mode = - ng::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY; + ng::op::Quantize::RoundMode ng_round_mode; + string round_mode_string; + TF_RETURN_IF_ERROR( + GetNodeAttr(op->attrs(), "round_mode", &round_mode_string)); + if (round_mode_string == "HALF_UP") { + ng_round_mode = ng::op::Quantize::RoundMode::ROUND_NEAREST_UPWARD; + } else if (round_mode_string == "HALF_TO_EVEN") { + ng_round_mode = ng::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN; + } else { + return errors::Internal("Tensorflow Rounding Mode not supported by Ngraph"); + } auto ng_quant = ConstructNgNode( op->name(), ng_input, ng_scale, ng_offset, ng_q_et, ng::AxisSet(), ng_round_mode); @@ -3608,10 +3617,17 @@ static Status TranslateQuantizeV2Op(const Node* op, ng::element::Type ng_et; TF_RETURN_IF_ERROR(TFDataTypeToNGraphElementType(dtype, &ng_et)); - // TODO: Only RoundMode = ROUND_NEAREST_TOWARD_EVEN is supported, for now. - // Support other modes later - ng::op::Quantize::RoundMode ng_round_mode = - ng::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN; + ng::op::Quantize::RoundMode ng_round_mode; + string round_mode_string; + TF_RETURN_IF_ERROR( + GetNodeAttr(op->attrs(), "round_mode", &round_mode_string)); + if (round_mode_string == "HALF_UP") { + ng_round_mode = ng::op::Quantize::RoundMode::ROUND_NEAREST_UPWARD; + } else if (round_mode_string == "HALF_TO_EVEN") { + ng_round_mode = ng::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN; + } else { + return errors::Internal("Tensorflow Rounding Mode not supported by Ngraph"); + } auto ng_node = ng::builder::ScaledQuantize(ng_input, ng_min, ng_max, ng_et, ng::AxisSet(), ng_round_mode); diff --git a/test/test_array_ops.cpp b/test/test_array_ops.cpp index 7bbdbbaa8..a07e26ec9 100644 --- a/test/test_array_ops.cpp +++ b/test/test_array_ops.cpp @@ -797,7 +797,61 @@ TEST(ArrayOps, QuantizeAndDequantizeV2x8xtruexfalse) { output_datatypes, sess_run_fetchoutputs); opexecuter.RunTest(); -} // end of test op QuantizeAndDequantizeV2x8xtruexfalse +} + +TEST(ArrayOps, QuantizeAndDequantizeV2RoundingMode1) { + Scope root = Scope::NewRootScope(); + int dim1 = 2; + int dim2 = 3; + + Tensor A(DT_FLOAT, TensorShape({dim1, dim2})); + AssignInputValues(A, {0.9, 3.4, 2.6, 5.4, 4.2, 4.5}); + + auto attrs = ops::QuantizeAndDequantizeV2::Attrs(); + attrs.num_bits_ = 8; + attrs.range_given_ = true; + attrs.signed_input_ = true; + attrs.round_mode_ = "HALF_UP"; + + vector static_input_indexes = {1, 2}; + ops::QuantizeAndDequantizeV2 R = + ops::QuantizeAndDequantizeV2(root, A, 0.0f, 127.0f, attrs); + + vector output_datatypes = {DT_FLOAT}; + + std::vector sess_run_fetchoutputs = {R.output}; + OpExecuter opexecuter(root, "QuantizeAndDequantizeV2", static_input_indexes, + output_datatypes, sess_run_fetchoutputs); + + opexecuter.RunTest(); +} + +TEST(ArrayOps, QuantizeAndDequantizeV2RoundingMode2) { + Scope root = Scope::NewRootScope(); + int dim1 = 2; + int dim2 = 3; + + Tensor A(DT_FLOAT, TensorShape({dim1, dim2})); + AssignInputValues(A, {0.9, 3.4, 2.6, 5.4, 4.2, 4.5}); + + auto attrs = ops::QuantizeAndDequantizeV2::Attrs(); + attrs.num_bits_ = 8; + attrs.range_given_ = true; + attrs.signed_input_ = true; + attrs.round_mode_ = "HALF_TO_EVEN"; + + vector static_input_indexes = {1, 2}; + ops::QuantizeAndDequantizeV2 R = + ops::QuantizeAndDequantizeV2(root, A, 0.0f, 127.0f, attrs); + + vector output_datatypes = {DT_FLOAT}; + + std::vector sess_run_fetchoutputs = {R.output}; + OpExecuter opexecuter(root, "QuantizeAndDequantizeV2", static_input_indexes, + output_datatypes, sess_run_fetchoutputs); + + opexecuter.RunTest(); +} // end of test op QuantizeAndDequantizeV2x8xtruextrue // CPU only supports QuantizedConcat with DT_QINT32 and DT_QUINT8 TEST(ArrayOps, QuantizedConcat) {