Skip to content
Permalink
Browse files Browse the repository at this point in the history
Add missing valuidation to FusedBatchNorm.
PiperOrigin-RevId: 372460336
Change-Id: Ic8c4e4de67c58a741bd87f2e182bed07247d1126
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed May 7, 2021
1 parent 57d86e0 commit 6972f9d
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion tensorflow/core/kernels/fused_batch_norm_op.cc
Expand Up @@ -1282,6 +1282,32 @@ class FusedBatchNormOpBase : public OpKernel {
errors::InvalidArgument("Error during tensor copy."));
}

const auto num_channels = GetTensorDim(x, tensor_format_, 'C');
OP_REQUIRES(
context, scale.NumElements() == num_channels,
errors::InvalidArgument("scale must have the same number of elements "
"as the channels of x, got ",
scale.NumElements(), " and ", num_channels));
OP_REQUIRES(
context, offset.NumElements() == num_channels,
errors::InvalidArgument("offset must have the same number of elements "
"as the channels of x, got ",
offset.NumElements(), " and ", num_channels));
if (estimated_mean.NumElements() != 0) {
OP_REQUIRES(context, estimated_mean.NumElements() == num_channels,
errors::InvalidArgument(
"mean must be empty or have the same number of "
"elements as the channels of x, got ",
estimated_mean.NumElements(), " and ", num_channels));
}
if (estimated_variance.NumElements() != 0) {
OP_REQUIRES(context, estimated_variance.NumElements() == num_channels,
errors::InvalidArgument(
"variance must be empty or have the same number of "
"elements as the channels of x, got ",
estimated_variance.NumElements(), " and ", num_channels));
}

if (has_side_input_) {
OP_REQUIRES(context, side_input->shape() == x.shape(),
errors::InvalidArgument(
Expand All @@ -1294,7 +1320,7 @@ class FusedBatchNormOpBase : public OpKernel {
// NOTE(ezhulenev): This requirement is coming from implementation
// details of cudnnBatchNormalizationForwardTrainingEx.
OP_REQUIRES(
context, !is_training_ || x.dim_size(3) % 4 == 0,
context, !is_training_ || num_channels % 4 == 0,
errors::InvalidArgument("FusedBatchNorm with activation requires "
"channel dimension to be a multiple of 4."));
}
Expand Down

0 comments on commit 6972f9d

Please sign in to comment.