Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix empty batch issue in svd.
If there are zero batches, the GPU kernel previously failed with a `work_element_count > 0`
check failure.  In the zero batch case, there is no output, so simply return.

PiperOrigin-RevId: 462210670
  • Loading branch information
cantonios authored and tensorflower-gardener committed Jul 20, 2022
1 parent 1744280 commit c55b476
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tensorflow/core/kernels/linalg/svd_op_gpu.cu.cc
Expand Up @@ -395,6 +395,12 @@ class SvdOpGpu : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC(context, context->allocate_output(2, shapeV, &outputV),
done);

// If there are zero batches, we are done.
if (shapeRaw.num_elements() == 0) {
done();
return;
}

if (n == 0 || m == 0) {
if (n == m || !compute_uv_ || !full_matrices_) {
// S, U, and V are all empty. Nothing to do.
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/python/kernel_tests/linalg/svd_op_test.py
Expand Up @@ -108,6 +108,14 @@ def testExecuteMultipleWithoutError(self):
for i in range(0, len(val), 2):
self.assertAllEqual(val[i], val[i + 1])

@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testEmptyBatches(self):
matrices = constant_op.constant(1.0, shape=[0, 2, 2])
s, u, v = self.evaluate(linalg_ops.svd(matrices))
self.assertAllEqual(s, np.zeros([0, 2]))
self.assertAllEqual(u, np.zeros([0, 2, 2]))
self.assertAllEqual(v, np.zeros([0, 2, 2]))


def _GetSvdOpTest(dtype_, shape_, use_static_shape_, compute_uv_,
full_matrices_):
Expand Down

0 comments on commit c55b476

Please sign in to comment.