@@ -15,6 +15,8 @@ limitations under the License.
1515
1616// See docs in ../ops/nn_ops.cc.
1717
18+ #include " tensorflow/core/framework/op_requires.h"
19+ #include " tensorflow/core/platform/errors.h"
1820#define EIGEN_USE_THREADS
1921
2022#include " third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -117,6 +119,18 @@ class QuantizedMaxPoolingOp : public MaxPoolingOp<Device, T> {
117119 : MaxPoolingOp<Device, T>(context) {}
118120
119121 void Compute (OpKernelContext* context) override {
122+ auto min_input_tensor = context->input (1 );
123+ auto max_input_tensor = context->input (2 );
124+ OP_REQUIRES (
125+ context, min_input_tensor.NumElements () == 1 ,
126+ errors::InvalidArgument (
127+ " min_input must be a scalar float value, got tensor with shape " ,
128+ min_input_tensor.shape ()));
129+ OP_REQUIRES (
130+ context, max_input_tensor.NumElements () == 1 ,
131+ errors::InvalidArgument (
132+ " max_input must be a scalar float value, got tensor with shape " ,
133+ max_input_tensor.shape ()));
120134 const float min_input = context->input (1 ).flat <float >()(0 );
121135 const float max_input = context->input (2 ).flat <float >()(0 );
122136 MaxPoolingOp<Device, T>::Compute (context);
0 commit comments