Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

polyval: Use Horner's algorithm + support chunked inputs #6548

Merged
merged 38 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
e309682
new polyval algo
headtr1ck Apr 30, 2022
88e476a
polyval improved typing with datasets
headtr1ck Apr 30, 2022
261aadc
more polyval unit tests
headtr1ck Apr 30, 2022
2a6a633
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2022
a2701b8
support for Dataset coord in polyval
headtr1ck Apr 30, 2022
553de10
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2022
4fbca23
fix bug in polyval broadcasting
headtr1ck Apr 30, 2022
b335679
Merge branch 'main' into main
max-sixty Apr 30, 2022
b54a4a3
Merge branch 'main' into main
max-sixty Apr 30, 2022
7001e4e
Merge branch 'main' into main
max-sixty Apr 30, 2022
212505f
support for datetime values in polyval
headtr1ck May 1, 2022
595c83c
Merge branch 'main' of github.com:headtr1ck/xarray into main
headtr1ck May 1, 2022
945cbfb
polyval update in whats-new
headtr1ck May 1, 2022
5945537
fix dask polyval unit tests
headtr1ck May 1, 2022
a0964ed
fix bug in polyval unit tests
headtr1ck May 1, 2022
82db394
add polyval benchmark
headtr1ck May 1, 2022
ca5a7f7
add breaking change of polyval tp whats-new
headtr1ck May 1, 2022
7a70831
move _ensure_numeric to its own function
headtr1ck May 1, 2022
401e126
fix import error in _ensure_numeric
headtr1ck May 1, 2022
3c21a64
add raise_if_dask_computes to polyval unit tests
headtr1ck May 1, 2022
ff37fe2
chunk coord arg as well for polyval unit tests
headtr1ck May 1, 2022
63ed137
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 1, 2022
8343b29
simplify polyval algo with sortby
headtr1ck May 1, 2022
7ad9168
Merge branch 'main' of github.com:headtr1ck/xarray into main
headtr1ck May 1, 2022
0ef3462
comment some code until PR#6556
headtr1ck May 1, 2022
e6f4675
Revert "comment some code until PR#6556"
headtr1ck May 1, 2022
3b88386
fix dask issues with polyval unit tests
headtr1ck May 1, 2022
ce392a7
remove unused import
headtr1ck May 1, 2022
4e8b72e
make polyval benchmark backwards compatible
headtr1ck May 1, 2022
70c4419
another bugfix for polyval benchmark
headtr1ck May 1, 2022
c4ced87
Update asv_bench/benchmarks/polyfit.py
dcherian May 2, 2022
7a73a42
Actually compute dask arrays in benchmark.
dcherian May 2, 2022
a824ad2
Minor cleanup
dcherian May 2, 2022
047cd04
simplify polyval algo using reindex
headtr1ck May 3, 2022
0d0bb8e
don't copy coeffs if not necessary
headtr1ck May 3, 2022
ef49710
Merge branch 'pydata:main' into main
headtr1ck May 3, 2022
05e0266
Make sure degree_dim is an indexed coordinate of int dtype
dcherian May 4, 2022
bd3dd81
Fix benchmark
dcherian May 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions asv_bench/benchmarks/polyfit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import xarray as xr

from . import parameterized, randn, requires_dask

ndegs = (2, 5, 20)
nxs = (10**2, 10**6)

xs = {nx: xr.DataArray(randn((nx,)), dims="x", name="x") for nx in nxs}
coeffs = {ndeg: xr.DataArray(randn((ndeg,)), dims="degree") for ndeg in ndegs}


class Polyval:
def setup(self, *args, **kwargs):
self.coeffs = coeffs
self.xs = xs

@parameterized(["nx", "ndeg"], [nxs, ndegs])
def time_polyval(self, nx, ndeg):
x = self.xs[nx]
c = self.coeffs[ndeg]
xr.polyval(x, c).compute()

@parameterized(["nx", "ndeg"], [nxs, ndegs])
def peakmem_polyval(self, nx, ndeg):
x = self.xs[nx]
c = self.coeffs[ndeg]
xr.polyval(x, c).compute()


class PolyvalDask(Polyval):
def setup(self, *args, **kwargs):
requires_dask()
super().setup(*args, **kwargs)
self.xs = {nx: self.xs[nx].chunk({"x": 10000}) for nx in nxs}
7 changes: 7 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ New Features
- Allow passing chunks in ``**kwargs`` form to :py:meth:`Dataset.chunk`, :py:meth:`DataArray.chunk`, and
:py:meth:`Variable.chunk`. (:pull:`6471`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- :py:meth:`xr.polyval` now supports :py:class:`Dataset` and :py:class:`DataArray` args of any shape,
is faster and requires less memory. (:pull:`6548`)
By `Michael Niklas <https://github.com/headtr1ck>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand All @@ -54,6 +57,10 @@ Breaking changes
- Xarray's ufuncs have been removed, now that they can be replaced by numpy's ufuncs in all
supported versions of numpy.
By `Maximilian Roos <https://github.com/max-sixty>`_.
- :py:meth:`xr.polyval` now uses the ``coord`` argument directly instead of its index coordinate.
(:pull:`6548`)
By `Michael Niklas <https://github.com/headtr1ck>`_.


Deprecations
~~~~~~~~~~~~
Expand Down
97 changes: 80 additions & 17 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
Iterable,
Mapping,
Sequence,
overload,
)

import numpy as np

from . import dtypes, duck_array_ops, utils
from .alignment import align, deep_align
from .common import zeros_like
from .duck_array_ops import datetime_to_numeric
from .indexes import Index, filter_indexes_from_coords
from .merge import merge_attrs, merge_coordinates_without_align
from .options import OPTIONS, _get_keep_attrs
Expand Down Expand Up @@ -1843,36 +1846,96 @@ def where(cond, x, y, keep_attrs=None):
)


def polyval(coord, coeffs, degree_dim="degree"):
@overload
def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray:
...


@overload
def polyval(coord: T_Xarray, coeffs: Dataset, degree_dim: Hashable) -> Dataset:
...


@overload
def polyval(coord: Dataset, coeffs: T_Xarray, degree_dim: Hashable) -> Dataset:
...


def polyval(
coord: T_Xarray, coeffs: T_Xarray, degree_dim: Hashable = "degree"
) -> T_Xarray:
"""Evaluate a polynomial at specific values

Parameters
----------
coord : DataArray
The 1D coordinate along which to evaluate the polynomial.
coeffs : DataArray
Coefficients of the polynomials.
degree_dim : str, default: "degree"
coord : DataArray or Dataset
Values at which to evaluate the polynomial.
coeffs : DataArray or Dataset
Coefficients of the polynomial.
degree_dim : Hashable, default: "degree"
Name of the polynomial degree dimension in `coeffs`.

Returns
-------
DataArray or Dataset
Evaluated polynomial.

See Also
--------
xarray.DataArray.polyfit
numpy.polyval
numpy.polynomial.polynomial.polyval
"""
from .dataarray import DataArray
from .missing import get_clean_interp_index

x = get_clean_interp_index(coord, coord.name, strict=False)

coeffs = coeffs.sortby(degree_dim)
dcherian marked this conversation as resolved.
Show resolved Hide resolved
deg_coord = coeffs[degree_dim]
headtr1ck marked this conversation as resolved.
Show resolved Hide resolved
max_deg = int(deg_coord[-1])

lhs = DataArray(
np.vander(x, int(deg_coord.max()) + 1),
dims=(coord.name, degree_dim),
coords={coord.name: coord, degree_dim: np.arange(deg_coord.max() + 1)[::-1]},
)
return (lhs * coeffs).sum(degree_dim)
x = _ensure_numeric(coord)

# using Horner's method
# https://en.wikipedia.org/wiki/Horner%27s_method
res = coeffs.isel({degree_dim: -1}, drop=True) + zeros_like(x)
deg_idx = len(deg_coord) - 2 # -2nd index
for deg in range(max_deg - 1, -1, -1):
res *= x
if deg_idx >= 0 and deg == int(deg_coord[deg_idx]):
# this degrees coefficient is provided, if not assume 0
res += coeffs.isel({degree_dim: deg_idx}, drop=True)
deg_idx -= 1

return res


def _ensure_numeric(data: T_Xarray) -> T_Xarray:
"""Converts all datetime64 variables to float64

Parameters
----------
data : DataArray or Dataset
Variables with possible datetime dtypes.

Returns
-------
DataArray or Dataset
Variables with datetime64 dtypes converted to float64.
"""
from .dataset import Dataset

def to_floatable(x: DataArray) -> DataArray:
if x.dtype.kind in "mM":
return x.copy(
data=datetime_to_numeric(
x.data,
offset=np.datetime64("1970-01-01"),
datetime_unit="ns",
),
)
return x

if isinstance(data, Dataset):
return data.map(to_floatable)
else:
return to_floatable(data)


def _calc_idxminmax(
Expand Down
102 changes: 72 additions & 30 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1933,37 +1933,79 @@ def test_where_attrs() -> None:
assert actual.attrs == {}


@pytest.mark.parametrize("use_dask", [True, False])
@pytest.mark.parametrize("use_datetime", [True, False])
def test_polyval(use_dask, use_datetime) -> None:
if use_dask and not has_dask:
pytest.skip("requires dask")

if use_datetime:
xcoord = xr.DataArray(
pd.date_range("2000-01-01", freq="D", periods=10), dims=("x",), name="x"
)
x = xr.core.missing.get_clean_interp_index(xcoord, "x")
else:
x = np.arange(10)
xcoord = xr.DataArray(x, dims=("x",), name="x")

da = xr.DataArray(
np.stack((1.0 + x + 2.0 * x**2, 1.0 + 2.0 * x + 3.0 * x**2)),
dims=("d", "x"),
coords={"x": xcoord, "d": [0, 1]},
)
coeffs = xr.DataArray(
[[2, 1, 1], [3, 2, 1]],
dims=("d", "degree"),
coords={"d": [0, 1], "degree": [2, 1, 0]},
)
@pytest.mark.parametrize("use_dask", [False, True])
@pytest.mark.parametrize(
["x", "coeffs", "expected"],
headtr1ck marked this conversation as resolved.
Show resolved Hide resolved
[
pytest.param(
xr.DataArray([1, 2, 3], dims="x"),
xr.DataArray([2, 3, 4], dims="degree"),
xr.DataArray([9, 2 + 6 + 16, 2 + 9 + 36], dims="x"),
id="simple",
),
pytest.param(
xr.DataArray([1, 2, 3], dims="x"),
xr.DataArray([[0, 1], [0, 1]], dims=("y", "degree")),
xr.DataArray([[1, 2, 3], [1, 2, 3]], dims=("y", "x")),
id="broadcast-x",
),
pytest.param(
xr.DataArray([1, 2, 3], dims="x"),
xr.DataArray([[0, 1], [1, 0], [1, 1]], dims=("x", "degree")),
xr.DataArray([1, 1, 1 + 3], dims="x"),
id="shared-dim",
),
pytest.param(
xr.DataArray([1, 2, 3], dims="x"),
xr.DataArray([1, 0, 0], dims="degree", coords={"degree": [2, 1, 0]}),
xr.DataArray([1, 2**2, 3**2], dims="x"),
id="reordered-index",
),
pytest.param(
xr.DataArray([1, 2, 3], dims="x"),
xr.DataArray([5], dims="degree", coords={"degree": [3]}),
xr.DataArray([5, 5 * 2**3, 5 * 3**3], dims="x"),
id="sparse-index",
),
pytest.param(
xr.DataArray([1, 2, 3], dims="x"),
xr.Dataset({"a": ("degree", [0, 1]), "b": ("degree", [1, 0])}),
xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [1, 1, 1])}),
id="array-dataset",
),
pytest.param(
xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [2, 3, 4])}),
xr.DataArray([1, 1], dims="degree"),
xr.Dataset({"a": ("x", [2, 3, 4]), "b": ("x", [3, 4, 5])}),
id="dataset-array",
),
pytest.param(
xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("y", [2, 3, 4])}),
xr.Dataset({"a": ("degree", [0, 1]), "b": ("degree", [1, 1])}),
xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("y", [3, 4, 5])}),
id="dataset-dataset",
),
pytest.param(
xr.DataArray(pd.date_range("1970-01-01", freq="s", periods=3), dims="x"),
xr.DataArray([0, 1], dims="degree"),
xr.DataArray(
[0, 1e9, 2e9],
dims="x",
coords={"x": pd.date_range("1970-01-01", freq="s", periods=3)},
),
id="datetime",
),
],
)
def test_polyval(use_dask, x, coeffs, expected) -> None:
if use_dask:
coeffs = coeffs.chunk({"d": 2})

da_pv = xr.polyval(da.x, coeffs)

xr.testing.assert_allclose(da, da_pv.T)
if not has_dask:
dcherian marked this conversation as resolved.
Show resolved Hide resolved
pytest.skip("requires dask")
coeffs = coeffs.chunk({"degree": 2})
x = x.chunk({"x": 2})
with raise_if_dask_computes():
actual = xr.polyval(x, coeffs)
xr.testing.assert_allclose(actual, expected)
headtr1ck marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("use_dask", [False, True])
Expand Down