Skip to content
Permalink
Browse files Browse the repository at this point in the history
Disallow division by zero FPE in tf.raw_ops.ResourceScatterDiv
Had to update a test that was broken.

PiperOrigin-RevId: 388516976
Change-Id: Ic358e6bf0559e011539974d453fc7aa18b427e9c
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Aug 3, 2021
1 parent 0a237f7 commit 4aacb30
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
35 changes: 35 additions & 0 deletions tensorflow/core/kernels/resource_variable_ops.cc
Expand Up @@ -873,6 +873,35 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
#undef REGISTER_GATHER_ND_ALL_INDICES
#undef REGISTER_GATHER_ND_FULL

namespace {

template <typename Device>
bool isCPUDevice() {
return false;
}

template <>
bool isCPUDevice<CPUDevice>() {
return true;
}

template <typename T>
bool ValidateInput(const Tensor& updates) {
const auto updates_flat = updates.flat<T>();
const T zero(0);
for (int i = 0; i < updates.NumElements(); i++) {
if (updates_flat(i) == zero) return false;
}
return true;
}

template <>
bool ValidateInput<Variant>(const Tensor& updates) {
return true;
}

} // namespace

template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
class ResourceScatterUpdateOp : public OpKernel {
public:
Expand Down Expand Up @@ -939,6 +968,12 @@ class ResourceScatterUpdateOp : public OpKernel {
" indexing: ", params->dim_size(0), " > ",
std::numeric_limits<Index>::max()));

// Prevent division by 0
if (isCPUDevice<Device>() && op == tensorflow::scatter_op::UpdateOp::DIV) {
OP_REQUIRES(c, ValidateInput<T>(updates),
errors::InvalidArgument("updates must not contain 0"));
}

if (N > 0) {
auto indices_flat = indices.flat<Index>();
auto params_flat = params->flat_outer_dims<T>();
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/python/distribute/sharded_variable_test.py
Expand Up @@ -175,8 +175,9 @@ def func():
'scatter_update')
def test_scatter_ops_even_partition(self, op):
v = variables_lib.Variable(array_ops.zeros((30, 1)))
# Make sure values does not contain 0 due to testing `scatter_div`!
sparse_delta = ops.IndexedSlices(
values=constant_op.constant([[0.], [1.], [2.], [3.], [4.]]),
values=constant_op.constant([[1.], [2.], [3.], [4.], [5.]]),
indices=constant_op.constant([0, 10, 12, 21, 22]))

v0 = variables_lib.Variable(array_ops.zeros((10, 1)))
Expand Down

0 comments on commit 4aacb30

Please sign in to comment.