Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding idft #129

Merged
merged 35 commits into from
Jan 28, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
256824f
Merge pull request #1 from xgcm/master
lanougue Nov 10, 2020
e7d8f2b
Merge pull request #2 from xgcm/master
lanougue Nov 17, 2020
7849305
Merge pull request #3 from xgcm/master
lanougue Dec 4, 2020
bb59a8a
updated idft
Dec 4, 2020
148fb26
update idft + fft + ifft
Dec 10, 2020
0ee8ad5
debug spacing_tol
Dec 11, 2020
0515c57
adding tests and warning
Dec 11, 2020
5bf333a
adding dft parseval tests + some typos
Dec 11, 2020
eeec439
moving test to parseval function
Dec 11, 2020
8b97038
correction idft with real
Dec 13, 2020
0502b21
simplification power spectrum
Dec 13, 2020
792cc2b
simplification spectrum and adding true_flags
Dec 14, 2020
6333630
simplification cross phase
Dec 14, 2020
d9f3d87
debug spectrum with False density
Dec 14, 2020
5ff68a7
debug cross-spectrum test
Dec 14, 2020
1e33e06
debug test cross_phase
Dec 14, 2020
f8a04f0
debug test cross_phase 2
Dec 14, 2020
b49f132
adding cross phase tests
Dec 14, 2020
febecb5
adding cross phase tests 2
Dec 14, 2020
b7684d1
adding test_real_dft_true_phase + scaling parameter/density deprecation
Dec 15, 2020
ac511fc
rm old code
Dec 15, 2020
886850e
typo + unknown sclaing
Dec 15, 2020
c4bc754
debug test_isotropize
Dec 15, 2020
c325c54
flag warning + adding test
Dec 15, 2020
7af46ce
update test density error
Dec 15, 2020
a8275f9
test_coordinates + idft default behaviour
Dec 15, 2020
644c731
restoring default dft true_flags
Dec 15, 2020
186a4aa
correct indentation
Dec 15, 2020
5ca6ae7
adding warning
Dec 15, 2020
5beb1aa
user warning
Dec 15, 2020
fbb9d06
density and scaling flag handling
Dec 15, 2020
9a3caaa
restoring phase warning in cross_spectrum
Dec 15, 2020
dcb40ee
function signature
Dec 15, 2020
41ec321
function signature2 + rm comments
Dec 16, 2020
5cb5cd2
dask debug in test
Dec 16, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions xrft/tests/test_xrft.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,3 +1033,53 @@ def test_idft_dft():
FTs, shift=True, true_phase=True, true_amplitude=True, lag=mean_lag
)
xrt.assert_allclose(s, IFTs)

roxyboy marked this conversation as resolved.
Show resolved Hide resolved

def test_parseval_dft1d():
"""Testing parseval identity in 1D"""
Nx = 40
dx = np.random.rand()

s = xr.DataArray(
np.random.rand(Nx) + 1j * np.random.rand(Nx),
dims="x",
coords={
"x": dx
* (
np.arange(-Nx // 2, -Nx // 2 + Nx)
+ np.random.randint(-Nx // 2, Nx // 2)
)
},
)
FTs = xrft.dft(s, dim="x", true_phase=True, true_amplitude=True)
npt.assert_almost_equal(
(np.abs(s) ** 2).sum() * dx, (np.abs(FTs) ** 2).sum() * FTs["freq_x"].spacing
)


def test_parseval_dft2d():
"""Testing parseval identity in 2D"""
Nx, Ny = 40, 60
dx, dy = np.random.rand(), np.random.rand()

s = xr.DataArray(
np.random.rand(Nx, Ny) + 1j * np.random.rand(Nx, Ny),
dims=("x", "y"),
coords={
"x": dx
* (
np.arange(-Nx // 2, -Nx // 2 + Nx)
+ np.random.randint(-Nx // 2, Nx // 2)
),
"y": dy
* (
np.arange(-Ny // 2, -Ny // 2 + Ny)
+ np.random.randint(-Ny // 2, Ny // 2)
),
},
)
FTs = xrft.dft(s, dim=("x", "y"), true_phase=True, true_amplitude=True)
npt.assert_almost_equal(
(np.abs(s) ** 2).sum() * dx * dy,
(np.abs(FTs) ** 2).sum() * FTs["freq_x"].spacing * FTs["freq_y"].spacing,
)
84 changes: 42 additions & 42 deletions xrft/xrft.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,28 +186,28 @@ def fft(da, **kwargs):
See xrft.dft for argument list
"""
if kwargs.pop("true_phase", False):
print("true_phase argument is ignored in xrft.fft")
warnings.warn("true_phase argument is ignored in xrft.fft")
if kwargs.pop("true_amplitude", False):
print("true_amplitude argument is ignored in xrft.fft")
warnings.warn("true_amplitude argument is ignored in xrft.fft")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return dft(da, true_phase=False, true_amplitude=False, **kwargs)


def ifft(da, **kwargs):
def ifft(daft, **kwargs):
"""
See xrft.idft for argument list
"""
if kwargs.pop("true_phase", False):
print("true_phase argument is ignored in xrft.ifft")
warnings.warn("true_phase argument is ignored in xrft.ifft")
if kwargs.pop("true_amplitude", False):
print("true_amplitude argument is ignored in xrft.ifft")
warnings.warn("true_amplitude argument is ignored in xrft.ifft")
if kwargs.pop("lag", False):
print("lag argument is ignored in xrft.ifft")
msg = "xrft.ifft do not guaranty output coordinate phasing. Prefer xrft.dft and xrft.idft as forward and backward Fourier Transforms with true_phase flag set to True for accurate coordinates handling."
warnings.warn("lag argument is ignored in xrft.ifft")
msg = "xrft.ifft does not guarantee correct coordinate phasing for its output. We recommend xrft.dft and xrft.idft as forward and backward Fourier Transforms with true_phase flags set to True for accurate coordinate handling."
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return idft(da, true_phase=False, true_amplitude=False, **kwargs)
return idft(daft, true_phase=False, true_amplitude=False, **kwargs)


def dft(
Expand Down Expand Up @@ -274,7 +274,7 @@ def dft(
"""

if not true_phase and not true_amplitude:
msg = "xrft.dft default behaviour will be modified in future versions of xrft. Prefer xrft.fft to ensure future compatibility and deactivate this warning. Consider using xrft.dft for accurate coordinates phasing and FT amplitude handling."
msg = "Flags true_phase and true_amplitude will be set to True in future versions of xrft to preserve the theoretical phasing and amplitude of FT. Consider using xrft.fft to ensure future compatibility with numpy.fft like behavior and to deactivate this warning."
warnings.warn(msg, FutureWarning)

if dim is None:
Expand Down Expand Up @@ -377,7 +377,7 @@ def dft(


def idft(
da,
daft,
spacing_tol=1e-3,
dim=None,
real=None,
Expand All @@ -391,15 +391,15 @@ def idft(
lag=None,
):
"""
Perform inverse discrete Fourier transform of xarray data-array `da` along the
Perform inverse discrete Fourier transform of xarray data-array `daft` along the
specified dimensions.

.. math::
daft = \mathbb{F}(da - \overline{da})
da = \mathbb{F}(daft - \overline{daft})

Parameters
----------
da : `xarray.DataArray`
daft : `xarray.DataArray`
The data to be transformed
spacing_tol: float, optional
Spacing tolerance. Fourier transform should not be applied to uneven grid but
Expand Down Expand Up @@ -440,7 +440,7 @@ def idft(

Returns
-------
daft : `xarray.DataArray`
da : `xarray.DataArray`
The output of the Inverse Fourier transformation, with appropriate dimensions.
"""

Expand All @@ -449,13 +449,13 @@ def idft(
warnings.warn(msg, FutureWarning)

if dim is None:
dim = list(da.dims)
dim = list(daft.dims)
else:
if isinstance(dim, str):
dim = [dim]

if real is not None:
if real not in da.dims:
if real not in daft.dims:
raise ValueError(
"The dimension along which real FT is taken must be one of the existing dimensions."
)
Expand All @@ -474,19 +474,19 @@ def idft(
warnings.warn(msg, Warning)

for d, l in zip(dim, lag):
da = da * np.exp(1j * 2.0 * np.pi * da[d] * l)
daft = daft * np.exp(1j * 2.0 * np.pi * daft[d] * l)

if chunks_to_segments:
da = _stack_chunks(da, dim)
daft = _stack_chunks(daft, dim)

rawdims = da.dims # take care of segmented dimesions, if any
rawdims = daft.dims # take care of segmented dimesions, if any

if real is not None:
da = da.transpose(
*[d for d in da.dims if d not in [real]] + [real]
daft = daft.transpose(
*[d for d in daft.dims if d not in [real]] + [real]
) # dimension for real transformed is moved at the end

fftm = _fft_module(da)
fftm = _fft_module(daft)

if real is None:
fft_fn = fftm.ifftn
Expand All @@ -495,27 +495,27 @@ def idft(
fft_fn = fftm.irfftn
lanougue marked this conversation as resolved.
Show resolved Hide resolved

# the axes along which to take ffts
axis_num = [da.get_axis_num(d) for d in dim]
axis_num = [daft.get_axis_num(d) for d in dim]

N = [da.shape[n] for n in axis_num]
N = [daft.shape[n] for n in axis_num]

# verify even spacing of input coordinates (It handle fftshifted grids)
delta_x = []
for d in dim:
diff = _diff_coord(da[d])
diff = _diff_coord(daft[d])
delta = np.abs(diff[0])
l = _lag_coord(da[d])
l = _lag_coord(daft[d])
if not np.allclose(
diff, diff[0], rtol=spacing_tol
): # means that input is not on regular increasing grid
reordered_coord = da[d].copy()
reordered_coord = daft[d].copy()
reordered_coord = reordered_coord.sortby(d)
diff = _diff_coord(reordered_coord)
l = _lag_coord(reordered_coord)
if np.allclose(
diff, diff[0], rtol=spacing_tol
): # means that input is on fftshifted grid
da = da.sortby(d) # reordering the input
daft = daft.sortby(d) # reordering the input
else:
raise ValueError(
"Can't take Fourier transform because "
Expand All @@ -529,13 +529,13 @@ def idft(
delta_x.append(delta)

if detrend:
da = _apply_detrend(da, dim, axis_num, detrend)
daft = _apply_detrend(daft, dim, axis_num, detrend)

if window:
da = _apply_window(da, dim)
daft = _apply_window(daft, dim)

f = fftm.ifftshift(
da.data, axes=axis_num
daft.data, axes=axis_num
) # Force to be on fftshift grid before Fourier Transform
f = fft_fn(f, axes=axis_num)

Expand All @@ -547,29 +547,29 @@ def idft(

k = _freq(N, delta_x, real, shift)

newcoords, swap_dims = _new_dims_and_coords(da, dim, k, prefix)
daft = xr.DataArray(
f, dims=da.dims, coords=dict([c for c in da.coords.items() if c[0] not in dim])
newcoords, swap_dims = _new_dims_and_coords(daft, dim, k, prefix)
da = xr.DataArray(
f,
dims=daft.dims,
coords=dict([c for c in daft.coords.items() if c[0] not in dim]),
)
daft = daft.swap_dims(swap_dims).assign_coords(newcoords)
daft = daft.drop([d for d in dim if d in daft.coords])
da = da.swap_dims(swap_dims).assign_coords(newcoords)
da = da.drop([d for d in dim if d in da.coords])

if lag is not None:
with xr.set_options(
keep_attrs=True
): # This line ensures keeping spacing attribute in output coordinates
for d, l in zip(dim, lag):
tfd = swap_dims[d]
daft = daft.assign_coords({tfd: daft[tfd] + l})
da = da.assign_coords({tfd: da[tfd] + l})

if true_amplitude:
daft = daft / np.prod(
[float(daft[up_dim].spacing) for up_dim in swap_dims.values()]
)
da = da / np.prod([float(da[up_dim].spacing) for up_dim in swap_dims.values()])

return daft.transpose(
return da.transpose(
*[swap_dims.get(d, d) for d in rawdims]
) # Do nothing if da was not transposed
) # Do nothing if daft was not transposed


def power_spectrum(
Expand Down