Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimize duplication in map_blocks task graph #8412

Merged
merged 23 commits into from
Jan 3, 2024

Conversation

dcherian
Copy link
Contributor

@dcherian dcherian commented Nov 3, 2023

Builds on #8560

cc @max-sixty

print(len(cloudpickle.dumps(da.chunk(lat=1, lon=1).map_blocks(lambda x: x))))
# 779354739 -> 47699827
print(len(cloudpickle.dumps(da.chunk(lat=1, lon=1).drop_vars(da.indexes).map_blocks(lambda x: x))))
# 15981508

This is a quick attempt. I think we can generalize this to minimize duplication.

The downside is that the graphs are not totally embarrassingly parallel any more.
This PR:
image

vs main:
image

xarray/core/parallel.py Outdated Show resolved Hide resolved
@max-sixty
Copy link
Collaborator

Thanks a lot @dcherian !

(I don't have enough context to know how severe the change the parallelism is. I do really appreciate that .map_blocks is really simple in concept, and gets around Dask tripping over itself by just making an opaque function and running it lots of times. Possibly we could do the index filtering locally, which would need much more setup time, but retain .map_blocks simplicity...)

@dcherian
Copy link
Contributor Author

dcherian commented Nov 4, 2023

Possibly we could do the index filtering locally, which would need much more setup time,

We do filter the indexes. The problem is that the filtered index values are duplicated a very large number of times for the calculation. The duplication allows the graph to be embarassingly parallel.

And then we include them a second time to enable nice error messages.

@max-sixty
Copy link
Collaborator

We do filter the indexes. The problem is that the filtered index values are duplicated a very large number of times for the calculation. The duplication allows the graph to be embarassingly parallel.

Ah right, yes. I confirmed that — the size difference scales by n_blocks, not n_blocks*index_size, so it must be filtering:


da = xr.tutorial.load_dataset('air_temperature').isel(lat=slice(25 // 2), lon=slice(53 //2))

[ins] In [9]: len(cloudpickle.dumps(da.chunk(lat=1, lon=1).map_blocks(lambda x: x)))
Out[9]: 18688240

[ins] In [10]: len(cloudpickle.dumps(da.chunk(lat=1, lon=1).drop_vars(da.indexes).map_blocks(lambda x: x)))
Out[10]: 3766137

Defer to you on how this affects dask stability...

@dcherian
Copy link
Contributor Author

@fjetter do you think dask/distributed will handle the change in graph topology in the OP gracefully? map_blocks seems to have decent use in the wild to workaround dask scheduling issues, so it would be nice to not break that. Alternatively is there a better way to scatter out the duplicated data?

@dcherian
Copy link
Contributor Author

FWIW this graph seems to be what blockwise constructs for broadcasting:

dask.array.blockwise(
    lambda x,y: x+y,
    'ij',
    dask.array.ones((3,), chunks=(1,)),
    'i',
    dask.array.ones((5,), chunks=(1,)),
    'j',
).visualize()

image

xarray/core/parallel.py Outdated Show resolved Hide resolved
@fjetter
Copy link

fjetter commented Dec 18, 2023

map_blocks seems to have decent use in the wild to workaround dask scheduling issues, so it would be nice to not break that.

At least if the idea of "working around scheduling issues" is to forcefully flatten the graph to a purely embarrassingly parallel workload, this property is now gone but I believe you are still fine.

I am not super familiar with xarray datasets so I am doing a bit of guesswork here. IIUC this example dataset has three coordinates / indices lat, lon, time which are numpy arrays (always?) that are known to the client / known at graph construction time. IIUC the issue that is being fixed here is that these arrays are being duplicated?

Then there is also the air data variable that is the actual payload. In this situation tihs is also a numpy array but in a realistic one this is a remote data storage, e.g. a zarr file. We want to release these tasks asap.

If this is all correct, then yes, this is handled gracefully by dask (at least with the latest release, haven't checked older ones)

import xarray as xr
from dask.utils import key_split
from dask.order import diagnostics
from dask.base import collections_to_dsk
da = xr.tutorial.load_dataset('air_temperature')

dsk = collections_to_dsk([da.chunk(lat=1, lon=1).map_blocks(lambda x: x)])
diag, _ = diagnostics(dsk)
ages_data_tasks = [
    v.age == 1
    for k, v in diag.items()
    if key_split(k).startswith('xarray-air')
]
assert ages_data_tasks
assert all(ages_data_tasks)

Age refers to the number of "ticks / time steps" this task survives. age=1 means that once data is "produced", i.e. the task is scheduled, it's consumer is scheduled right afterwards such that after one time step this is being released.

Alternatively is there a better way to scatter out the duplicated data?

If those indices are truly always numpy arrays, I would probably suggest to just slice them to whatever size they need for the given task and embed them, keeping the embarrassingly parallel workload. I think I do not understand this problem sufficiently, It feels like I'm missing something.

@dcherian
Copy link
Contributor Author

I think I do not understand this problem sufficiently, It feels like I'm missing something.

Broadcasting means that the tiny shards get duplicated a very large number of times in the graph. The OP was prompted by a 1GB task graph.

@dcherian dcherian added the plan to merge Final call for comments label Dec 20, 2023
dcherian and others added 4 commits December 20, 2023 10:28
* main:
  Adapt map_blocks to use new Coordinates API (pydata#8560)
  add xeofs to ecosystem.rst (pydata#8561)
  Offer a fixture for unifying DataArray & Dataset tests (pydata#8533)
  Generalize cumulative reduction (scan) to non-dask types (pydata#8019)
@dcherian dcherian merged commit d87ba61 into pydata:main Jan 3, 2024
25 of 27 checks passed
@dcherian dcherian deleted the map-blocks-indexes-fix branch January 3, 2024 04:10
dcherian added a commit to dcherian/xarray that referenced this pull request Jan 4, 2024
* upstream/main:
  Faster encoding functions. (pydata#8565)
  ENH: vendor SerializableLock from dask and use as default backend lock, adapt tests (pydata#8571)
  Silence a bunch of CachingFileManager warnings (pydata#8584)
  Bump actions/download-artifact from 3 to 4 (pydata#8556)
  Minimize duplication in `map_blocks` task graph (pydata#8412)
  [pre-commit.ci] pre-commit autoupdate (pydata#8578)
  ignore a `DeprecationWarning` emitted by `seaborn` (pydata#8576)
  Fix mypy type ignore (pydata#8564)
  Support for the new compression arguments. (pydata#7551)
  FIX: reverse index output of bottleneck move_argmax/move_argmin functions (pydata#8552)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
plan to merge Final call for comments topic-dask
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Task graphs on .map_blocks with many chunks can be huge
3 participants