Skip to content
Permalink
Browse files Browse the repository at this point in the history
Prevent division by 0 in resource_variable_ops.cc
PiperOrigin-RevId: 387939939
Change-Id: Ib04902d63756633999959a70613f2eaa30c2c151
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Jul 31, 2021
1 parent 3a73627 commit ac117ee
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions tensorflow/core/kernels/resource_variable_ops.cc
Expand Up @@ -710,7 +710,8 @@ class ResourceGatherOp : public OpKernel {
copy_functor(c->eigen_device<Device>(), tmp_indices.flat<Index>(),
indices.flat<Index>());

AddBatchOffsets(&tmp_indices, params);
AddBatchOffsets(c, &tmp_indices, params);
if (!c->status().ok()) return;
op_indices = &tmp_indices;
}

Expand Down Expand Up @@ -742,11 +743,17 @@ class ResourceGatherOp : public OpKernel {
// Example: batch_dims = 1, indices = [[0, 1, 2], [0, 1, 2]]
// If indexing into a params dimension of size 4, then the indices will become
// [0, 1, 2, 4, 5, 6]
void AddBatchOffsets(Tensor* indices, const Tensor& params) {
void AddBatchOffsets(OpKernelContext* ctx, Tensor* indices,
const Tensor& params) {
int64_t batch_size = 1; // The size of all batch dimensions.
for (int idx = 0; idx < batch_dims_; ++idx) {
batch_size *= params.dim_size(idx);
}
OP_REQUIRES(
ctx, batch_size != 0,
errors::InvalidArgument(
"Inner size of indices would result in batch_size of 0 and a ",
"division by 0 in the implementation. This is illegal"));

auto indices_flat = indices->flat<Index>();
int64_t const index_inner_size = indices->NumElements() / batch_size;
Expand Down

0 comments on commit ac117ee

Please sign in to comment.