Skip to content

Commit

Permalink
Fix empty input crash for SparseFillEmptyRowsGrad.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 478085721
  • Loading branch information
cantonios authored and tensorflower-gardener committed Sep 30, 2022
1 parent 2c04e3f commit af4a6a3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
43 changes: 26 additions & 17 deletions tensorflow/core/kernels/sparse_fill_empty_rows_op_gpu.cu.cc
Expand Up @@ -297,9 +297,12 @@ struct SparseFillEmptyRows<GPUDevice, T, Tindex> {
empty_row_indicator = empty_row_indicator_t.vec<bool>().data();
}

TF_RETURN_IF_ERROR(wrap_kernel_call(ComputeEmptyRowIndicatorKernel<Tindex>,
/*device=*/device, /*size=*/dense_rows,
elements_per_row, empty_row_indicator));
if (dense_rows > 0) {
TF_RETURN_IF_ERROR(
wrap_kernel_call(ComputeEmptyRowIndicatorKernel<Tindex>,
/*device=*/device, /*size=*/dense_rows,
elements_per_row, empty_row_indicator));
}

// For each row, the number of empty rows up to and including that row.
Tensor num_empty_rows_through_t;
Expand Down Expand Up @@ -405,14 +408,16 @@ struct SparseFillEmptyRows<GPUDevice, T, Tindex> {
done);
}

OP_REQUIRES_OK_ASYNC(
context,
wrap_kernel_call(ScatterNewElementsKernel<T, Tindex>,
/*device=*/device, /*size=*/dense_rows, rank,
default_value, num_empty_rows_through,
input_row_ends, empty_row_indicator, output_indices,
output_values),
done);
if (dense_rows > 0) {
OP_REQUIRES_OK_ASYNC(
context,
wrap_kernel_call(ScatterNewElementsKernel<T, Tindex>,
/*device=*/device, /*size=*/dense_rows, rank,
default_value, num_empty_rows_through,
input_row_ends, empty_row_indicator,
output_indices, output_values),
done);
}

done();
};
Expand Down Expand Up @@ -461,9 +466,11 @@ struct SparseFillEmptyRows<GPUDevice, T, Tindex> {
TF_RETURN_IF_ERROR(
context->allocate_temp(index_type, TensorShape({N}), &row_indices_t));
auto row_indices = row_indices_t.flat<Tindex>();
TF_RETURN_IF_ERROR(wrap_kernel_call(CopyRowIndicesKernel<Tindex>,
/*device=*/device, /*size=*/N, rank,
indices, row_indices));
if (N > 0) {
TF_RETURN_IF_ERROR(wrap_kernel_call(CopyRowIndicesKernel<Tindex>,
/*device=*/device, /*size=*/N, rank,
indices, row_indices));
}
// Allocate input_index_map.
TF_RETURN_IF_ERROR(context->allocate_temp(index_type, TensorShape({N}),
input_index_map_t));
Expand Down Expand Up @@ -528,9 +535,11 @@ struct SparseFillEmptyRowsGrad<GPUDevice, T, Tindex> {
auto visited = visited_t.vec<bool>();
visited.device(device) = visited.constant(false);

TF_RETURN_IF_ERROR(wrap_kernel_call(
GatherOriginalGradValuesKernel<T, Tindex>, /*device=*/device,
/*size=*/N, reverse_index_map, grad_values, d_values, visited));
if (N > 0) {
TF_RETURN_IF_ERROR(wrap_kernel_call(
GatherOriginalGradValuesKernel<T, Tindex>, /*device=*/device,
/*size=*/N, reverse_index_map, grad_values, d_values, visited));
}

// Now we mask out the visited values and sum the remaining ones (which
// correspond to the empty rows in the forward input) to compute
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/python/kernel_tests/sparse_ops/sparse_ops_test.py
Expand Up @@ -514,6 +514,13 @@ def testFillNumber(self):
self.assertAllEqual(empty_row_indicator_out,
np.array([0, 0, 1, 0, 1]).astype(np.bool_))

def testSparseFillEmptyRowsGradEmpty(self):
with test_util.use_gpu():
grad, _ = self.evaluate(
sparse_ops.sparse_fill_empty_rows_grad(
reverse_index_map=[], grad_values=[]))
self.assertAllEqual(grad, [])

@test_util.run_deprecated_v1
def testFillFloat(self):
with self.session():
Expand Down

0 comments on commit af4a6a3

Please sign in to comment.