Skip to content
Permalink
Browse files Browse the repository at this point in the history
Prevent integer overflow in OpLevelCostEstimator::CalculateTensorSize.
In order to not change the API, we return a negative value in case of overflow. A better fix is to change the API to return a status instead.

PiperOrigin-RevId: 408713061
Change-Id: I3771475b0c72a2844a3854086966562fd33f2da5
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Nov 9, 2021
1 parent 29e8998 commit fcd18ce
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tensorflow/core/grappler/costs/op_level_cost_estimator.cc
Expand Up @@ -1555,7 +1555,13 @@ int64_t OpLevelCostEstimator::CalculateTensorSize(
int64_t count = CalculateTensorElementCount(tensor, found_unknown_shapes);
int size = DataTypeSize(BaseType(tensor.dtype()));
VLOG(2) << "Count: " << count << " DataTypeSize: " << size;
return count * size;
int64_t tensor_size = MultiplyWithoutOverflow(count, size);
if (tensor_size < 0) {
VLOG(1) << "Overflow encountered when computing tensor size, multiplying "
<< count << " with " << size;
return -1;
}
return tensor_size;
}

int64_t OpLevelCostEstimator::CalculateInputSize(const OpInfo& op_info,
Expand Down

0 comments on commit fcd18ce

Please sign in to comment.