Skip to content

Commit 0a8a781

Browse files
Merge pull request #55274 from yongtang:55263-tf.signal.rfft2d
PiperOrigin-RevId: 436736690
2 parents cc33f17 + f4348ad commit 0a8a781

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

Diff for: tensorflow/core/kernels/fft_ops.cc

+4
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ class FFTBase : public OpKernel {
6666

6767
auto fft_length_as_vec = fft_length.vec<int32>();
6868
for (int i = 0; i < fft_rank; ++i) {
69+
OP_REQUIRES(ctx, fft_length_as_vec(i) >= 0,
70+
errors::InvalidArgument(
71+
"fft_length[", i,
72+
"] must >= 0, but got: ", fft_length_as_vec(i)));
6973
fft_shape[i] = fft_length_as_vec(i);
7074
// Each input dimension must have length of at least fft_shape[i]. For
7175
// IRFFTs, the inner-most input dimension must have length of at least

Diff for: tensorflow/python/kernel_tests/signal/fft_ops_test.py

+9
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,15 @@ def test_grad_random(self, rank, extra_dims, size, np_rtype):
609609
self._tf_ifft_for_rank(rank), re, im, result_is_complex=False,
610610
rtol=tol, atol=tol)
611611

612+
def test_invalid_args(self):
613+
# Test case for GitHub issue 55263
614+
a = np.empty([6, 0])
615+
b = np.array([1, -1])
616+
with self.assertRaisesRegex(errors.InvalidArgumentError, "must >= 0"):
617+
with self.session():
618+
v = fft_ops.rfft2d(input_tensor=a, fft_length=b)
619+
self.evaluate(v)
620+
612621

613622
@test_util.run_all_in_graph_and_eager_modes
614623
class FFTShiftTest(test.TestCase, parameterized.TestCase):

0 commit comments

Comments
 (0)