Skip to content
Permalink
Browse files Browse the repository at this point in the history
Prevent integer overflow in OpLevelCostEstimator::CalculateOutputSize.
In order to not change the API, we return a negative value in case of overflow. A better fix is to change the API to return a status instead.

PiperOrigin-RevId: 408701427
Change-Id: Idf31e7f0bf18ca824d084fdd355e1f653f145c20
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Nov 9, 2021
1 parent 9fb7e81 commit b9bd6cf
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
1 change: 1 addition & 0 deletions tensorflow/core/grappler/costs/BUILD
Expand Up @@ -340,6 +340,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/clusters:utils",
"//tensorflow/core/util:overflow",
] + tf_protos_grappler(),
)

Expand Down
10 changes: 9 additions & 1 deletion tensorflow/core/grappler/costs/op_level_cost_estimator.cc
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/op_context.h"
#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/overflow.h"

namespace tensorflow {
namespace grappler {
Expand Down Expand Up @@ -1607,7 +1608,14 @@ int64_t OpLevelCostEstimator::CalculateOutputSize(const OpInfo& op_info,
auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
found_unknown_shapes);
for (const auto& dim : output_shape.dim()) {
output_size *= dim.size();
int64_t new_output_size =
MultiplyWithoutOverflow(output_size, dim.size());
if (new_output_size < 0) {
VLOG(1) << "Overflow encountered when estimating cost, multiplying "
<< output_size << " with " << dim.size();
return -1;
}
output_size = new_output_size;
}
total_output_size += output_size;
VLOG(1) << "Output Size: " << output_size
Expand Down

0 comments on commit b9bd6cf

Please sign in to comment.