Skip to content
Permalink
Browse files Browse the repository at this point in the history
Merge pull request #55274 from yongtang:55263-tf.signal.rfft2d
PiperOrigin-RevId: 436736690
  • Loading branch information
tensorflower-gardener committed Mar 23, 2022
2 parents cc33f17 + f4348ad commit 0a8a781
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tensorflow/core/kernels/fft_ops.cc
Expand Up @@ -66,6 +66,10 @@ class FFTBase : public OpKernel {

auto fft_length_as_vec = fft_length.vec<int32>();
for (int i = 0; i < fft_rank; ++i) {
OP_REQUIRES(ctx, fft_length_as_vec(i) >= 0,
errors::InvalidArgument(
"fft_length[", i,
"] must >= 0, but got: ", fft_length_as_vec(i)));
fft_shape[i] = fft_length_as_vec(i);
// Each input dimension must have length of at least fft_shape[i]. For
// IRFFTs, the inner-most input dimension must have length of at least
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/python/kernel_tests/signal/fft_ops_test.py
Expand Up @@ -609,6 +609,15 @@ def test_grad_random(self, rank, extra_dims, size, np_rtype):
self._tf_ifft_for_rank(rank), re, im, result_is_complex=False,
rtol=tol, atol=tol)

def test_invalid_args(self):
# Test case for GitHub issue 55263
a = np.empty([6, 0])
b = np.array([1, -1])
with self.assertRaisesRegex(errors.InvalidArgumentError, "must >= 0"):
with self.session():
v = fft_ops.rfft2d(input_tensor=a, fft_length=b)
self.evaluate(v)


@test_util.run_all_in_graph_and_eager_modes
class FFTShiftTest(test.TestCase, parameterized.TestCase):
Expand Down

0 comments on commit 0a8a781

Please sign in to comment.