Skip to content

Commit

Permalink
Add np.fft.fftshift/ifftshift (google#1850)
Browse files Browse the repository at this point in the history
  • Loading branch information
adler-j committed Feb 4, 2020
1 parent ffc55ee commit 4080a1c
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
28 changes: 28 additions & 0 deletions jax/numpy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,34 @@ def rfftfreq(n, d=1.0):
return k / (d * n)


@_wraps(onp.fft.fftshift)
def fftshift(x, axes=None):
x = np.asarray(x)
if axes is None:
axes = tuple(range(x.ndim))
shift = [dim // 2 for dim in x.shape]
elif isinstance(axes, int):
shift = x.shape[axes] // 2
else:
shift = [x.shape[ax] // 2 for ax in axes]

return np.roll(x, shift, axes)


@_wraps(onp.fft.ifftshift)
def ifftshift(x, axes=None):
x = np.asarray(x)
if axes is None:
axes = tuple(range(x.ndim))
shift = [-(dim // 2) for dim in x.shape]
elif isinstance(axes, int):
shift = -(x.shape[axes] // 2)
else:
shift = [-(x.shape[ax] // 2) for ax in axes]

return np.roll(x, shift, axes)


for func in get_module_functions(onp.fft):
if func.__name__ not in globals():
globals()[func.__name__] = _not_implemented(func)
32 changes: 31 additions & 1 deletion tests/fft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def testFftn(self, inverse, real, shape, dtype, axes, rng_factory):
onp_op = _get_fftn_func(onp.fft, inverse, real)
np_fn = lambda a: np_op(a, axes=axes)
onp_fn = lambda a: onp_op(a, axes=axes) if axes is None or axes else a
# Numpy promotes to complex128 aggressively.
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)
Expand Down Expand Up @@ -350,5 +350,35 @@ def testRfftfreqErrors(self, n):
lambda: func(n=10, d=n)
)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "dtype={}_axes={}".format(
jtu.format_shape_dtype_string(shape, dtype), axes),
"dtype": dtype, "shape": shape, "rng_factory": rng_factory, "axes": axes}
for rng_factory in [jtu.rand_default]
for dtype in all_dtypes
for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]]
for axes in _get_fftn_test_axes(shape)))
def testFftshift(self, shape, dtype, rng_factory, axes):
rng = rng_factory()
args_maker = lambda: (rng(shape, dtype),)
np_fn = lambda arg: np.fft.fftshift(arg, axes=axes)
onp_fn = lambda arg: onp.fft.fftshift(arg, axes=axes)
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "dtype={}_axes={}".format(
jtu.format_shape_dtype_string(shape, dtype), axes),
"dtype": dtype, "shape": shape, "rng_factory": rng_factory, "axes": axes}
for rng_factory in [jtu.rand_default]
for dtype in all_dtypes
for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]]
for axes in _get_fftn_test_axes(shape)))
def testIfftshift(self, shape, dtype, rng_factory, axes):
rng = rng_factory()
args_maker = lambda: (rng(shape, dtype),)
np_fn = lambda arg: np.fft.ifftshift(arg, axes=axes)
onp_fn = lambda arg: onp.fft.ifftshift(arg, axes=axes)
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=True)

if __name__ == "__main__":
absltest.main()

0 comments on commit 4080a1c

Please sign in to comment.