Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Map block reduction #9029

Closed
ernimd opened this issue May 15, 2024 · 2 comments
Closed

Map block reduction #9029

ernimd opened this issue May 15, 2024 · 2 comments

Comments

@ernimd
Copy link

ernimd commented May 15, 2024

What is your issue?

Is there a way to do reductions over chunked data? I am working with large dataset and cannot afford to re-chunk along time dimension. Something similar like reduce function, but over a dataset applying a function along multiple data variables at the same time. There is a beautiful example where the current map_blocks falls short in this discussion #5774:

import xarray as xr
import numpy as np

nt = 4
da = xr.DataArray(np.arange(nt), coords={"t": np.arange(nt)}, dims=["t"])


def getsum_da(da, sumdims):
    sumda = da.sum(dim=sumdims, skipna=True)
    return sumda


da.sum(dim="t").compute()  # prints 6 = 0 + 1 + 2 + 3

result = xr.map_blocks(
    getsum_da, da.chunk(chunks={"t": -1}), args=["t"]
)  # no chunking along summation index
print(result.compute())  # prints 6

result = xr.map_blocks(
    getsum_da, da.chunk(chunks={"t": 1}), args=["t"]
)  # with chunking along summation index
print(result.compute())  # prints 3, the value for the last of 4 chunks

In my case I'm doing more intricate stuff:

import xarray as xr
import numpy as np

N_TS = 10000
N_LAT = 10
N_LON = 10
N_HEIGHT = 10
data = lambda: np.random.rand(N_TS, N_LAT, N_LON, N_HEIGHT)
coords = {
    "time": np.arange(N_TS),
    "latitude": np.arange(N_LAT),
    "longitude": np.arange(N_LON),
    "height": np.arange(N_HEIGHT),
}
dims = ["time", "latitude", "longitude", "height"]
ds = xr.Dataset(
    {
        "WS": (dims, 3 * np.random.randn(N_TS, N_LAT, N_LON, N_HEIGHT) + 6),
        "WD": (dims, np.random.rand(N_TS, N_LAT, N_LON, N_HEIGHT) * 360),
    },
    coords=coords,
)
ds = ds.chunk(
    {
        "time": 1000,
        "latitude": 2,
        "longitude": 3,
        "height": 4,
    }
)

N_WD_SECTORS = 12
N_WS_BINS = 100
WS_BINS = np.linspace(0, 25, N_WS_BINS + 1)
WD_BINS = np.linspace(0, 360, N_WD_SECTORS + 1)


def block_func(obj):
    def _hist(x, y):
        return np.histogram2d(x, y, bins=[WS_BINS, WD_BINS])[0]

    if obj.WS.size == 0:
        return xr.DataArray(
            np.zeros(
                (
                    ds.WS.data.chunksize[1],
                    ds.WS.data.chunksize[2],
                    ds.WS.data.chunksize[3],
                    N_WS_BINS,
                    N_WD_SECTORS,
                )
            ),
            coords={
                "latitude": np.arange(ds.WS.data.chunksize[1]),
                "longitude": np.arange(ds.WS.data.chunksize[2]),
                "height": np.arange(ds.WS.data.chunksize[3]),
                "ws_bins": WS_BINS[1:],
                "wd_bins": WD_BINS[1:],
            },
        )

    res = xr.apply_ufunc(
        _hist,
        obj["WS"],
        obj["WD"],
        input_core_dims=[["time"], ["time"]],
        output_core_dims=[["ws_bins", "wd_bins"]],
        output_dtypes=[np.int32],
        vectorize=True,
        dask="parallelized",
        dask_gufunc_kwargs={
            "output_sizes": {"ws_bins": N_WS_BINS, "wd_bins": N_WD_SECTORS}
        },
    )
    res["ws_bins"] = WS_BINS[1:]
    res["wd_bins"] = WD_BINS[1:]

    return res

res = xr.map_blocks(block_func, ds).compute()
print(res)

I was hopping this would sum the histograms over time. But in reality this simply return the last chunk's bins. Would be nice to have function like map_blocks_reduce where a global object would accumulate the result.

Thanks!

@ernimd ernimd added the needs triage Issue that has not been reviewed by xarray team member label May 15, 2024
Copy link

welcome bot commented May 15, 2024

Thanks for opening your first issue here at xarray! Be sure to follow the issue template!
If you have an idea for a solution, we would really welcome a Pull Request with proposed changes.
See the Contributing Guide for more.
It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better.
Thank you!

@dcherian dcherian added usage question and removed needs triage Issue that has not been reviewed by xarray team member labels May 22, 2024
@dcherian
Copy link
Contributor

You'll need to implement your algorithm with dask.array.reduction.

For your specific histogramming problem, see https://flox.readthedocs.io/en/latest/intro.html#histogramming-binning-by-multiple-variables or https://xhistogram.readthedocs.io/en/latest/index.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants