Skip to content
Permalink
Browse files Browse the repository at this point in the history
Reorganize and add more validation to MKL requantization
PiperOrigin-RevId: 387901341
Change-Id: I2515b9034c64e113db0bcec8337d30643ab0a0f1
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Jul 30, 2021
1 parent aff0d5b commit 2032145
Showing 1 changed file with 25 additions and 15 deletions.
40 changes: 25 additions & 15 deletions tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc
Expand Up @@ -49,35 +49,45 @@ class MklRequantizePerChannelOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
try {
const Tensor& input = ctx->input(kInputTensorIndex);
OP_REQUIRES(
ctx, input.dims() == 4,
errors::InvalidArgument("Current RequantizePerChannel operator"
"supports 4D tensors only."));

const Tensor& input_min_vec = ctx->input(kInputMinVecIndex);
size_t depth = input_min_vec.NumElements();
float* input_min_vec_data = (float*)const_cast<void*>(
static_cast<const void*>(input_min_vec.flat<float>().data()));

const Tensor& input_max_vec = ctx->input(kInputMaxVecIndex);
OP_REQUIRES(
ctx, input_max_vec.NumElements() == depth,
errors::InvalidArgument("input_max has incorrect size, expected ",
depth, " was ", input_max_vec.NumElements()));
float* input_max_vec_data = (float*)const_cast<void*>(
static_cast<const void*>(input_max_vec.flat<float>().data()));

const Tensor& input_requested_min = ctx->input(this->kRequestMinIndex);
OP_REQUIRES(
ctx, input_requested_min.NumElements() == 1,
errors::InvalidArgument("requested_output_min must be a scalar"));
const float input_requested_min_float =
input_requested_min.flat<float>()(0);

const Tensor& input_requested_max = ctx->input(this->kRequestMaxIndex);
OP_REQUIRES(
ctx, input_requested_min.NumElements() == 1,
errors::InvalidArgument("requested_output_max must be a scalar"));
const float input_requested_max_float =
input_requested_max.flat<float>()(0);

size_t depth = input_min_vec.NumElements();
OP_REQUIRES(
ctx, input.dims() == 4,
errors::InvalidArgument("Current RequantizePerChannel operator"
"supports 4D tensors only."));
OP_REQUIRES(
ctx, input_min_vec.dim_size(0) == depth,
errors::InvalidArgument("input_min has incorrect size, expected ",
depth, " was ", input_min_vec.dim_size(0)));
OP_REQUIRES(
ctx, input_max_vec.dim_size(0) == depth,
errors::InvalidArgument("input_max has incorrect size, expected ",
depth, " was ", input_max_vec.dim_size(0)));

if (out_type_ == DT_QINT8) DCHECK(input_requested_min_float < 0.0f);
if (out_type_ == DT_QINT8) {
OP_REQUIRES(ctx, input_requested_min_float < 0.0f,
errors::InvalidArgument(
"If out_type is QINT8, requested_output_max must be "
"non negative, got ",
input_requested_min_float));
}

const float factor = (out_type_ == DT_QINT8) ? 127.0f : 255.0f;
const float requested_min_max =
Expand Down

0 comments on commit 2032145

Please sign in to comment.