Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit error when coordinates are not numercial #190

Merged
merged 8 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions xrft/tests/test_xrft.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,3 +1344,21 @@ def test_nondim_coords():
xrft.power_spectrum(da)

xrft.power_spectrum(da, dim=["time", "y"])


def test_non_numerical_or_datetime_coords():
"""Error should be raised if there are non-numerical or non-datetime coordinate"""
da = xr.DataArray(
np.random.rand(2, 5, 3),
dims=["time", "x", "y"],
coords={
"time": np.array(["2019-04-18", "2019-04-19"], dtype="datetime64"),
"x": range(5),
"y": ["a", "b", "c"],
},
)

with pytest.raises(ValueError):
xrft.power_spectrum(da)

xrft.power_spectrum(da, dim=["time", "x"])
45 changes: 37 additions & 8 deletions xrft/xrft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import scipy.linalg as spl

from .detrend import detrend as _detrend

from pandas.api.types import is_numeric_dtype, is_datetime64_any_dtype

__all__ = [
"fft",
Expand Down Expand Up @@ -230,9 +230,9 @@ def _lag_coord(coord):
decoded_time = cftime.date2num(lag, ref_units, calendar)
return decoded_time
elif pd.api.types.is_datetime64_dtype(v0):
return lag.astype("timedelta64[s]").astype("f8").data
return lag.astype("timedelta64[s]").astype("f8")
else:
return lag.data
return lag


def dft(
Expand Down Expand Up @@ -330,7 +330,6 @@ def fft(
daft : `xarray.DataArray`
The output of the Fourier transformation, with appropriate dimensions.
"""

if dim is None:
dim = list(da.dims)
else:
Expand All @@ -352,6 +351,20 @@ def fft(
real_dim
] # real dim has to be moved or added at the end !

if not np.all(
[
(
is_numeric_dtype(da.coords[d])
or is_datetime64_any_dtype(da.coords[d])
or bool(getattr(da.coords[d][0].item(), "calendar", False))
Copy link
Member

@roxyboy roxyboy Feb 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lanougue Could you explain what this new condition is for with the calendar...?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xrft.fft can do Fourier Transforms on data with coordinates being of dtype "cftime" (some kind of temporal data with a defined calendar). In the test_xrft.py file, we have some checks with this kind of coordinates with "julian", "365_day", "360_day" type of calendar.
When coordinates are of cftime type, the returned dtype is "object" which is not considered as "numerical" or "datetime" by pandas API. It is thus needed to have some special test to accept this kind of data if we want to pass the checks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, sounds good. I'll go ahead and merge this :)

)
for d in dim
]
): # checking if coodinates are numerical or datetime
raise ValueError(
"All transformed dimensions coordinates must be numerical or datetime."
)

if chunks_to_segments:
da = _stack_chunks(da, dim)

Expand Down Expand Up @@ -452,7 +465,7 @@ def fft(
dims=up_dim,
coords={up_dim: newcoords[up_dim]},
) # taking advantage of xarray broadcasting and ordered coordinates
daft[up_dim].attrs.update({"direct_lag": lag.obj})
daft[up_dim].attrs.update({"direct_lag": lag})

if true_amplitude:
daft = daft * np.prod(delta_x)
Expand Down Expand Up @@ -520,7 +533,6 @@ def ifft(
da : `xarray.DataArray`
The output of the Inverse Fourier transformation, with appropriate dimensions.
"""

if dim is None:
dim = list(daft.dims)
else:
Expand All @@ -540,6 +552,21 @@ def ifft(
dim = [d for d in dim if d != real_dim] + [
real_dim
] # real dim has to be moved or added at the end !

if not np.all(
[
(
is_numeric_dtype(daft.coords[d])
or is_datetime64_any_dtype(daft.coords[d])
or bool(getattr(daft.coords[d][0].item(), "calendar", False))
)
for d in dim
]
): # checking if coodinates are numerical or datetime
raise ValueError(
"All transformed dimensions coordinates must be numerical or datetime."
)

if lag is None:
lag = [daft[d].attrs.get("direct_lag", 0.0) for d in dim]
msg = "Default ifft's behaviour (lag=None) changed! Default value of lag was zero (centered output coordinates) and is now set to transformed coordinate's attribute: 'direct_lag'."
Expand Down Expand Up @@ -898,8 +925,10 @@ def cross_phase(da1, da2, dim=None, true_phase=True, **kwargs):
kwargs : dict : see xrft.fft for argument list
"""

cp = xr.ufuncs.angle(
cross_spectrum(da1, da2, dim=dim, true_phase=true_phase, **kwargs)
cp = xr.apply_ufunc(
np.angle,
cross_spectrum(da1, da2, dim=dim, true_phase=true_phase, **kwargs),
dask="allowed",
)

if da1.name and da2.name:
Expand Down