Skip to content
Permalink
Browse files Browse the repository at this point in the history
Add more validation to RequantizationRangePerChannel.
PiperOrigin-RevId: 387693946
Change-Id: Ife8dcbdb021bec4787eef6a4361dd08f17c14bd6
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Jul 29, 2021
1 parent e2c9d55 commit 9e62869
Showing 1 changed file with 14 additions and 0 deletions.
Expand Up @@ -57,6 +57,20 @@ class MklRequantizationRangePerChannelOp : public OpKernel {
ctx, input_max.dim_size(0) == depth,
errors::InvalidArgument("input_max has incorrect size, expected ",
depth, " was ", input_max.dim_size(0)));
OP_REQUIRES(
ctx, input_min.NumElements() == depth,
errors::InvalidArgument("input_min must have the same number of "
"elements as input_max, got ",
input_min.NumElements(), " and ", depth));
OP_REQUIRES(ctx, input.NumElements() > 0,
errors::InvalidArgument("input must not be empty"));
OP_REQUIRES(ctx, input.dims() == 4,
errors::InvalidArgument("input must be in NHWC format"));
OP_REQUIRES(
ctx, input.dim_size(3) == depth,
errors::InvalidArgument(
"input must have same number of channels as length of input_min: ",
input.dim_size(3), " vs ", depth));

const float* input_min_data = input_min.flat<float>().data();
const float* input_max_data = input_max.flat<float>().data();
Expand Down

0 comments on commit 9e62869

Please sign in to comment.