Skip to content
Permalink
Browse files Browse the repository at this point in the history
Internal change
PiperOrigin-RevId: 411896058
Change-Id: Ia031058247e3cf382957a6662d3f9e1cbb481ca2
  • Loading branch information
ishark authored and tensorflower-gardener committed Nov 23, 2021
1 parent 05e7d51 commit 3218043
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 17 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/grappler/costs/BUILD
Expand Up @@ -355,6 +355,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:status_matchers",
],
)

Expand Down
37 changes: 24 additions & 13 deletions tensorflow/core/grappler/costs/op_level_cost_estimator.cc
Expand Up @@ -2153,7 +2153,7 @@ OpInfo::TensorProperties OpLevelCostEstimator::DescribeTensor(
}

/* static */
OpLevelCostEstimator::ConvolutionDimensions
StatusOr<OpLevelCostEstimator::ConvolutionDimensions>
OpLevelCostEstimator::OpDimensionsFromInputs(
const TensorShapeProto& original_image_shape, const OpInfo& op_info,
bool* found_unknown_shapes) {
Expand Down Expand Up @@ -2190,6 +2190,11 @@ OpLevelCostEstimator::OpDimensionsFromInputs(
std::vector<int64_t> strides = GetStrides(op_info);
int64_t sx = strides[x_index];
int64_t sy = strides[y_index];
if (sx == 0 || sy == 0) {
return errors::InvalidArgument(
"Stride must be > 0 for Height and Width, but got (", sy, ", ", sx,
")");
}
const auto padding = GetPadding(op_info);

int64_t ox = GetOutputSize(ix, kx, sx, padding);
Expand All @@ -2206,8 +2211,9 @@ Status OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context,
bool found_unknown_shapes = false;
const auto& op_info = op_context.op_info;
// x: op_info.inputs(0)
ConvolutionDimensions dims = OpDimensionsFromInputs(
op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
&found_unknown_shapes));
// kx * ky - 1 comparisons per output (kx * xy > 1)
// or 1 copy per output (kx * k1 = 1).
int per_output_ops = dims.kx * dims.ky == 1 ? 1 : dims.kx * dims.ky - 1;
Expand Down Expand Up @@ -2248,8 +2254,9 @@ Status OpLevelCostEstimator::PredictMaxPoolGrad(const OpContext& op_context,
op_info.ShortDebugString());
}

ConvolutionDimensions dims = OpDimensionsFromInputs(
op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
&found_unknown_shapes));

int64_t ops = 0;
if (dims.kx == 1 && dims.ky == 1) {
Expand Down Expand Up @@ -2324,8 +2331,9 @@ Status OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context,
bool found_unknown_shapes = false;
const auto& op_info = op_context.op_info;
// x: op_info.inputs(0)
ConvolutionDimensions dims = OpDimensionsFromInputs(
op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
&found_unknown_shapes));

// kx * ky - 1 additions and 1 multiplication per output.
int64_t ops = dims.batch * dims.ox * dims.oy * dims.oz * dims.kx * dims.ky;
Expand Down Expand Up @@ -2382,8 +2390,9 @@ Status OpLevelCostEstimator::PredictAvgPoolGrad(const OpContext& op_context,
found_unknown_shapes = true;
}

ConvolutionDimensions dims =
OpDimensionsFromInputs(x_shape, op_info, &found_unknown_shapes);
TF_ASSIGN_OR_RETURN(
ConvolutionDimensions dims,
OpDimensionsFromInputs(x_shape, op_info, &found_unknown_shapes));

int64_t ops = 0;
if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
Expand All @@ -2409,8 +2418,9 @@ Status OpLevelCostEstimator::PredictFusedBatchNorm(
// offset: op_info.inputs(2)
// mean: op_info.inputs(3) --> only for inference
// variance: op_info.inputs(4) --> only for inference
ConvolutionDimensions dims = OpDimensionsFromInputs(
op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
&found_unknown_shapes));
const bool is_training = IsTraining(op_info);

int64_t ops = 0;
Expand Down Expand Up @@ -2459,8 +2469,9 @@ Status OpLevelCostEstimator::PredictFusedBatchNormGrad(
// scale: op_info.inputs(2)
// mean: op_info.inputs(3)
// variance or inverse of variance: op_info.inputs(4)
ConvolutionDimensions dims = OpDimensionsFromInputs(
op_info.inputs(1).shape(), op_info, &found_unknown_shapes);
TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
OpDimensionsFromInputs(op_info.inputs(1).shape(), op_info,
&found_unknown_shapes));

int64_t ops = 0;
const auto rsqrt_cost = Eigen::internal::functor_traits<
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/grappler/costs/op_level_cost_estimator.h
Expand Up @@ -290,7 +290,7 @@ class OpLevelCostEstimator {
bool* found_unknown_shapes);

// For Pooling, FusedBatchNorm, and their grad ops.
static ConvolutionDimensions OpDimensionsFromInputs(
static StatusOr<ConvolutionDimensions> OpDimensionsFromInputs(
const TensorShapeProto& original_image_shape, const OpInfo& op_info,
bool* found_unknown_shapes);

Expand Down
60 changes: 57 additions & 3 deletions tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/status_matchers.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/device_properties.pb.h"

Expand Down Expand Up @@ -558,9 +559,10 @@ class OpLevelCostEstimatorTest : public ::testing::Test {
}

bool found_unknown_shapes;
auto dims = OpLevelCostEstimator::OpDimensionsFromInputs(
op_context.op_info.inputs(0).shape(), op_context.op_info,
&found_unknown_shapes);
TF_ASSERT_OK_AND_ASSIGN(
auto dims, OpLevelCostEstimator::OpDimensionsFromInputs(
op_context.op_info.inputs(0).shape(), op_context.op_info,
&found_unknown_shapes));
Padding padding_enum;
if (padding == "VALID") {
padding_enum = Padding::VALID;
Expand All @@ -581,6 +583,38 @@ class OpLevelCostEstimatorTest : public ::testing::Test {
EXPECT_EQ(padding_enum, dims.padding);
}

StatusOr<OpLevelCostEstimator::ConvolutionDimensions>
CallOpDimensionsFromInputs(const int n, const int h, const int w, const int c,
const int kx, const int ky, const int sx,
const int sy, const string& data_format,
const string& padding) {
OpContext op_context;

const std::vector<int> x = {n, h, w, c};
const std::vector<int> ksize = {1, kx, ky, 1};
std::vector<int> strides;
if (data_format == "NHWC") {
strides = {1, sy, sx, 1};
} else {
strides = {1, 1, sy, sx};
}

auto& op_info = op_context.op_info;
SetCpuDevice(&op_info);
op_info.set_op("MaxPool");

DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
auto* attr = op_info.mutable_attr();
SetAttrValue(data_format, &(*attr)["data_format"]);
SetAttrValue(padding, &(*attr)["padding"]);
SetAttrValue(strides, &(*attr)["strides"]);
SetAttrValue(ksize, &(*attr)["ksize"]);
bool found_unknown_shapes;
return OpLevelCostEstimator::OpDimensionsFromInputs(
op_context.op_info.inputs(0).shape(), op_context.op_info,
&found_unknown_shapes);
}

OpLevelCostEstimator estimator_;
};

Expand Down Expand Up @@ -1383,6 +1417,26 @@ TEST_F(OpLevelCostEstimatorTest, OpDimensionsFromInputs) {
}
}

TEST_F(OpLevelCostEstimatorTest, OpDimensionsFromInputsError) {
std::vector<string> paddings = {"VALID", "SAME"};
std::vector<string> formats = {"NHWC", "NCHW"};
for (const auto& p : paddings) {
for (const auto& f : formats) {
// n, h, w, c, kx, ky, sx, sy, data_format, padding.
ASSERT_THAT(
CallOpDimensionsFromInputs(10, 14, 14, 3840, 3, 3, 0, 2, f, p),
testing::StatusIs(
error::INVALID_ARGUMENT,
"Stride must be > 0 for Height and Width, but got (2, 0)"));
ASSERT_THAT(
CallOpDimensionsFromInputs(10, 14, 14, 3840, 3, 3, 2, 0, f, p),
testing::StatusIs(
error::INVALID_ARGUMENT,
"Stride must be > 0 for Height and Width, but got (0, 2)"));
}
}
}

TEST_F(OpLevelCostEstimatorTest, PredictMaxPool) {
auto predict_max_pool = [this](const int n, const int in, const int c,
const int k, const int s,
Expand Down

0 comments on commit 3218043

Please sign in to comment.