Skip to content

Commit

Permalink
Cherry-pick: Fix usage of map_blocks in AreaWeighted and elsewhere (S…
Browse files Browse the repository at this point in the history
…ciTools#5767)

* fix usage of map_blocks

* fix map_blocks for non-lazy data

* add benchmark

* unskip benchmark

* add benchmark

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove benchmarks

* remove unnecessary import

* What's New entry.

* map_complete_blocks docstring.

* map_complete_blocks returns.

* Typo.

* Typo.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Martin Yeo <martin.yeo@metoffice.gov.uk>
  • Loading branch information
3 people committed Mar 4, 2024
1 parent b5a754e commit 5876627
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 29 deletions.
47 changes: 34 additions & 13 deletions lib/iris/_lazy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,11 @@ def lazy_elementwise(lazy_array, elementwise_op):
return da.map_blocks(elementwise_op, lazy_array, dtype=dtype)


def map_complete_blocks(src, func, dims, out_sizes):
def map_complete_blocks(src, func, dims, out_sizes, *args, **kwargs):
"""Apply a function to complete blocks.
Complete means that the data is not chunked along the chosen dimensions.
Uses :func:`dask.array.map_blocks` to implement the mapping.
Parameters
----------
Expand All @@ -465,27 +466,47 @@ def map_complete_blocks(src, func, dims, out_sizes):
Dimensions that cannot be chunked.
out_sizes : tuple of int
Output size of dimensions that cannot be chunked.
*args : tuple
Additional arguments to pass to `func`.
**kwargs : dict
Additional keyword arguments to pass to `func`.
Returns
-------
Array-like
See Also
--------
:func:`dask.array.map_blocks` : The function used for the mapping.
"""
data = None
result = None

if is_lazy_data(src):
data = src
elif not hasattr(src, "has_lazy_data"):
# Not a lazy array and not a cube. So treat as ordinary numpy array.
return func(src)
result = func(src, *args, **kwargs)
elif not src.has_lazy_data():
return func(src.data)
result = func(src.data, *args, **kwargs)
else:
data = src.lazy_data()

# Ensure dims are not chunked
in_chunks = list(data.chunks)
for dim in dims:
in_chunks[dim] = src.shape[dim]
data = data.rechunk(in_chunks)
if result is None and data is not None:
# Ensure dims are not chunked
in_chunks = list(data.chunks)
for dim in dims:
in_chunks[dim] = src.shape[dim]
data = data.rechunk(in_chunks)

# Determine output chunks
out_chunks = list(data.chunks)
for dim, size in zip(dims, out_sizes):
out_chunks[dim] = size
# Determine output chunks
out_chunks = list(data.chunks)
for dim, size in zip(dims, out_sizes):
out_chunks[dim] = size

return data.map_blocks(func, chunks=out_chunks, dtype=src.dtype)
result = data.map_blocks(
func, *args, chunks=out_chunks, dtype=src.dtype, **kwargs
)

return result
10 changes: 4 additions & 6 deletions lib/iris/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,18 +1378,16 @@ def _percentile(data, percent, fast_percentile_method=False, **kwargs):
percent = [percent]
percent = np.array(percent)

# Perform the percentile calculation.
_partial_percentile = functools.partial(
result = iris._lazy_data.map_complete_blocks(
data,
_calc_percentile,
(-1,),
percent.shape,
percent=percent,
fast_percentile_method=fast_percentile_method,
**kwargs,
)

result = iris._lazy_data.map_complete_blocks(
data, _partial_percentile, (-1,), percent.shape
)

# Check whether to reduce to a scalar result, as per the behaviour
# of other aggregators.
if result.shape == (1,):
Expand Down
10 changes: 4 additions & 6 deletions lib/iris/analysis/_area_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,20 +392,18 @@ def _regrid_area_weighted_rectilinear_src_and_grid__perform(

tgt_shape = (len(grid_y.points), len(grid_x.points))

# Calculate new data array for regridded cube.
regrid = functools.partial(
new_data = map_complete_blocks(
src_cube,
_regrid_along_dims,
(src_y_dim, src_x_dim),
meshgrid_x.shape,
x_dim=src_x_dim,
y_dim=src_y_dim,
weights=weights,
tgt_shape=tgt_shape,
mdtol=mdtol,
)

new_data = map_complete_blocks(
src_cube, regrid, (src_y_dim, src_x_dim), meshgrid_x.shape
)

# Wrap up the data as a Cube.

_regrid_callback = functools.partial(
Expand Down
8 changes: 4 additions & 4 deletions lib/iris/analysis/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,9 +932,11 @@ def __call__(self, src):
x_dim = src.coord_dims(src_x_coord)[0]
y_dim = src.coord_dims(src_y_coord)[0]

# Define regrid function
regrid = functools.partial(
data = map_complete_blocks(
src,
self._regrid,
(y_dim, x_dim),
sample_grid_x.shape,
x_dim=x_dim,
y_dim=y_dim,
src_x_coord=src_x_coord,
Expand All @@ -945,8 +947,6 @@ def __call__(self, src):
extrapolation_mode=self._extrapolation_mode,
)

data = map_complete_blocks(src, regrid, (y_dim, x_dim), sample_grid_x.shape)

# Wrap up the data as a Cube.
_regrid_callback = functools.partial(
self._regrid,
Expand Down

0 comments on commit 5876627

Please sign in to comment.