Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions ngraph_bridge/ngraph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3327,8 +3327,17 @@ static Status TranslateQuantizeAndDequantizeV2Op(
op->name(), ng_r_et, ng::Shape(), std::vector<float>({scale}));
auto ng_offset = ConstructNgNode<ng::op::Constant>(
op->name(), ng_q_et, ng::Shape(), std::vector<int>({0}));
ng::op::Quantize::RoundMode ng_round_mode =
ng::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY;
ng::op::Quantize::RoundMode ng_round_mode;
string round_mode_string;
TF_RETURN_IF_ERROR(
GetNodeAttr(op->attrs(), "round_mode", &round_mode_string));
if (round_mode_string == "HALF_UP") {
ng_round_mode = ng::op::Quantize::RoundMode::ROUND_NEAREST_UPWARD;
} else if (round_mode_string == "HALF_TO_EVEN") {
ng_round_mode = ng::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
} else {
return errors::Internal("Tensorflow Rounding Mode not supported by Ngraph");
}
auto ng_quant = ConstructNgNode<ng::op::Quantize>(
op->name(), ng_input, ng_scale, ng_offset, ng_q_et, ng::AxisSet(),
ng_round_mode);
Expand Down Expand Up @@ -3608,10 +3617,17 @@ static Status TranslateQuantizeV2Op(const Node* op,
ng::element::Type ng_et;
TF_RETURN_IF_ERROR(TFDataTypeToNGraphElementType(dtype, &ng_et));

// TODO: Only RoundMode = ROUND_NEAREST_TOWARD_EVEN is supported, for now.
// Support other modes later
ng::op::Quantize::RoundMode ng_round_mode =
ng::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
ng::op::Quantize::RoundMode ng_round_mode;
string round_mode_string;
TF_RETURN_IF_ERROR(
GetNodeAttr(op->attrs(), "round_mode", &round_mode_string));
if (round_mode_string == "HALF_UP") {
ng_round_mode = ng::op::Quantize::RoundMode::ROUND_NEAREST_UPWARD;
} else if (round_mode_string == "HALF_TO_EVEN") {
ng_round_mode = ng::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
} else {
return errors::Internal("Tensorflow Rounding Mode not supported by Ngraph");
}

auto ng_node = ng::builder::ScaledQuantize(ng_input, ng_min, ng_max, ng_et,
ng::AxisSet(), ng_round_mode);
Expand Down
56 changes: 55 additions & 1 deletion test/test_array_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,61 @@ TEST(ArrayOps, QuantizeAndDequantizeV2x8xtruexfalse) {
output_datatypes, sess_run_fetchoutputs);

opexecuter.RunTest();
} // end of test op QuantizeAndDequantizeV2x8xtruexfalse
}

TEST(ArrayOps, QuantizeAndDequantizeV2RoundingMode1) {
Scope root = Scope::NewRootScope();
int dim1 = 2;
int dim2 = 3;

Tensor A(DT_FLOAT, TensorShape({dim1, dim2}));
AssignInputValues<float>(A, {0.9, 3.4, 2.6, 5.4, 4.2, 4.5});

auto attrs = ops::QuantizeAndDequantizeV2::Attrs();
attrs.num_bits_ = 8;
attrs.range_given_ = true;
attrs.signed_input_ = true;
attrs.round_mode_ = "HALF_UP";

vector<int> static_input_indexes = {1, 2};
ops::QuantizeAndDequantizeV2 R =
ops::QuantizeAndDequantizeV2(root, A, 0.0f, 127.0f, attrs);

vector<DataType> output_datatypes = {DT_FLOAT};

std::vector<Output> sess_run_fetchoutputs = {R.output};
OpExecuter opexecuter(root, "QuantizeAndDequantizeV2", static_input_indexes,
output_datatypes, sess_run_fetchoutputs);

opexecuter.RunTest();
}

TEST(ArrayOps, QuantizeAndDequantizeV2RoundingMode2) {
Scope root = Scope::NewRootScope();
int dim1 = 2;
int dim2 = 3;

Tensor A(DT_FLOAT, TensorShape({dim1, dim2}));
AssignInputValues<float>(A, {0.9, 3.4, 2.6, 5.4, 4.2, 4.5});

auto attrs = ops::QuantizeAndDequantizeV2::Attrs();
attrs.num_bits_ = 8;
attrs.range_given_ = true;
attrs.signed_input_ = true;
attrs.round_mode_ = "HALF_TO_EVEN";

vector<int> static_input_indexes = {1, 2};
ops::QuantizeAndDequantizeV2 R =
ops::QuantizeAndDequantizeV2(root, A, 0.0f, 127.0f, attrs);

vector<DataType> output_datatypes = {DT_FLOAT};

std::vector<Output> sess_run_fetchoutputs = {R.output};
OpExecuter opexecuter(root, "QuantizeAndDequantizeV2", static_input_indexes,
output_datatypes, sess_run_fetchoutputs);

opexecuter.RunTest();
} // end of test op QuantizeAndDequantizeV2x8xtruextrue

// CPU only supports QuantizedConcat with DT_QINT32 and DT_QUINT8
TEST(ArrayOps, QuantizedConcat) {
Expand Down