Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dft reverse axis #149

Merged
merged 8 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions xrft/tests/test_xrft.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,3 +1252,17 @@ def test_constant_coordinates():

with pytest.raises(ValueError):
xrft.idft(s)


def test_reversed_coordinates():
"""Reversed coordinates should not impact dft with true_phase = True"""
N = 20
s = xr.DataArray(
np.random.rand(N) + 1j * np.random.rand(N),
dims="x",
coords={"x": np.arange(N // 2, -N // 2, -1) + 2},
)
s2 = s.sortby("x")
xrt.assert_allclose(
xrft.dft(s, dim="x", true_phase=True), xrft.dft(s2, dim="x", true_phase=True)
)
15 changes: 12 additions & 3 deletions xrft/xrft.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,11 @@ def _lag_coord(coord):

v0 = coord.values[0]
calendar = getattr(v0, "calendar", None)
lag = coord[(len(coord.data)) // 2]
if coord[-1] > coord[0]:
coord_data = coord.data
else:
coord_data = np.flip(coord.data, axis=-1)
lag = coord_data[len(coord.data) // 2]
if calendar:
import cftime

Expand Down Expand Up @@ -408,7 +412,12 @@ def dft(
_, da = _apply_window(da, dim, window_type=window)

if true_phase:
f = fft_fn(fft.ifftshift(da.data, axes=axis_num), axes=axis_num)
reversed_axis = [
da.get_axis_num(d) for d in dim if da[d][-1] < da[d][0]
] # handling decreasing coordinates
f = fft_fn(
fft.ifftshift(np.flip(da, axis=reversed_axis), axes=axis_num), axes=axis_num
)
else:
f = fft_fn(da.data, axes=axis_num)

Expand Down Expand Up @@ -566,7 +575,7 @@ def idft(
delta = np.abs(diff[0])
l = _lag_coord(daft[d]) if d is not real_dim else daft[d][0].data
if not np.allclose(
diff, diff[0], rtol=spacing_tol
diff, delta, rtol=spacing_tol
): # means that input is not on regular increasing grid
reordered_coord = daft[d].copy()
reordered_coord = reordered_coord.sortby(d)
Expand Down