Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix native resampler not working for some chunk sizes #2291

Merged
merged 7 commits into from
Nov 21, 2022
123 changes: 69 additions & 54 deletions satpy/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@
from pyresample.ewa import fornav, ll2cr
from pyresample.geometry import SwathDefinition

from satpy.utils import PerformanceWarning

try:
from math import lcm # type: ignore
except ImportError:
def lcm(a, b):
"""Get 'Least Common Multiple' with Python 3.8 compatibility."""
from math import gcd
return abs(a * b) // gcd(a, b)

try:
from pyresample.resampler import BaseResampler as PRBaseResampler
except ImportError:
Expand Down Expand Up @@ -983,58 +993,6 @@ def resample(self, data, cache_dir=None, mask_area=False, **kwargs):
mask_area=mask_area,
**kwargs)

@staticmethod
def _aggregate(d, y_size, x_size):
"""Average every 4 elements (2x2) in a 2D array."""
if d.ndim != 2:
# we can't guarantee what blocks we are getting and how
# it should be reshaped to do the averaging.
raise ValueError("Can't aggregrate (reduce) data arrays with "
"more than 2 dimensions.")
if not (x_size.is_integer() and y_size.is_integer()):
raise ValueError("Aggregation factors are not integers")
for agg_size, chunks in zip([y_size, x_size], d.chunks):
for chunk_size in chunks:
if chunk_size % agg_size != 0:
raise ValueError("Aggregation requires arrays with "
"shapes and chunks divisible by the "
"factor")

new_chunks = (tuple(int(x / y_size) for x in d.chunks[0]),
tuple(int(x / x_size) for x in d.chunks[1]))
return da.core.map_blocks(_mean, d, y_size, x_size,
meta=np.array((), dtype=d.dtype),
dtype=d.dtype, chunks=new_chunks)

@staticmethod
def _replicate(d_arr, repeats):
"""Repeat data pixels by the per-axis factors specified."""
# rechunk so new chunks are the same size as old chunks
c_size = max(x[0] for x in d_arr.chunks)

def _calc_chunks(c, c_size):
whole_chunks = [c_size] * int(sum(c) // c_size)
remaining = sum(c) - sum(whole_chunks)
if remaining:
whole_chunks += [remaining]
return tuple(whole_chunks)
new_chunks = [_calc_chunks(x, int(c_size // repeats[axis]))
for axis, x in enumerate(d_arr.chunks)]
d_arr = d_arr.rechunk(new_chunks)

repeated_chunks = []
for axis, axis_chunks in enumerate(d_arr.chunks):
factor = repeats[axis]
if not factor.is_integer():
raise ValueError("Expand factor must be a whole number")
repeated_chunks.append(tuple(x * int(factor) for x in axis_chunks))
repeated_chunks = tuple(repeated_chunks)
d_arr = d_arr.map_blocks(_repeat_by_factor,
meta=np.array((), dtype=d_arr.dtype),
dtype=d_arr.dtype,
chunks=repeated_chunks)
return d_arr

@classmethod
def _expand_reduce(cls, d_arr, repeats):
"""Expand reduce."""
Expand All @@ -1043,12 +1001,12 @@ def _expand_reduce(cls, d_arr, repeats):
if all(x == 1 for x in repeats.values()):
return d_arr
if all(x >= 1 for x in repeats.values()):
return cls._replicate(d_arr, repeats)
return _replicate(d_arr, repeats)
if all(x <= 1 for x in repeats.values()):
# reduce
y_size = 1. / repeats[0]
x_size = 1. / repeats[1]
return cls._aggregate(d_arr, y_size, x_size)
return _aggregate(d_arr, y_size, x_size)
raise ValueError("Must either expand or reduce in both "
"directions")

Expand Down Expand Up @@ -1086,6 +1044,63 @@ def compute(self, data, expand=True, **kwargs):
return update_resampled_coords(data, new_data, target_geo_def)


def _aggregate(d, y_size, x_size):
"""Average every 4 elements (2x2) in a 2D array."""
if d.ndim != 2:
# we can't guarantee what blocks we are getting and how
# it should be reshaped to do the averaging.
raise ValueError("Can't aggregrate (reduce) data arrays with "
"more than 2 dimensions.")
if not (x_size.is_integer() and y_size.is_integer()):
raise ValueError("Aggregation factors are not integers")
y_size = int(y_size)
x_size = int(x_size)
d = _rechunk_if_nonfactor_chunks(d, y_size, x_size)
new_chunks = (tuple(int(x / y_size) for x in d.chunks[0]),
tuple(int(x / x_size) for x in d.chunks[1]))
return da.core.map_blocks(_mean, d, y_size, x_size,
meta=np.array((), dtype=d.dtype),
dtype=d.dtype, chunks=new_chunks)


def _rechunk_if_nonfactor_chunks(dask_arr, y_size, x_size):
need_rechunk = False
new_chunks = list(dask_arr.chunks)
for dim_idx, agg_size in enumerate([y_size, x_size]):
if dask_arr.shape[dim_idx] % agg_size != 0:
raise ValueError("Aggregation requires arrays with shapes divisible by the factor.")
for chunk_size in dask_arr.chunks[dim_idx]:
if chunk_size % agg_size != 0:
need_rechunk = True
new_dim_chunk = lcm(chunk_size, agg_size)
new_chunks[dim_idx] = new_dim_chunk
if need_rechunk:
warnings.warn("Array chunk size is not divisible by aggregation factor. "
"Re-chunking to continue native resampling.", PerformanceWarning)
dask_arr = dask_arr.rechunk(tuple(new_chunks))
return dask_arr


def _replicate(d_arr, repeats):
"""Repeat data pixels by the per-axis factors specified."""
repeated_chunks = _get_replicated_chunk_sizes(d_arr, repeats)
d_arr = d_arr.map_blocks(_repeat_by_factor,
meta=np.array((), dtype=d_arr.dtype),
dtype=d_arr.dtype,
chunks=repeated_chunks)
return d_arr


def _get_replicated_chunk_sizes(d_arr, repeats):
repeated_chunks = []
for axis, axis_chunks in enumerate(d_arr.chunks):
factor = repeats[axis]
if not factor.is_integer():
raise ValueError("Expand factor must be a whole number")
repeated_chunks.append(tuple(x * int(factor) for x in axis_chunks))
return tuple(repeated_chunks)


def _get_arg_to_pass_for_skipna_handling(**kwargs):
"""Determine if skipna can be passed to the compute functions for the average and sum bucket resampler."""
# FIXME this can be removed once Pyresample 1.18.0 is a Satpy requirement
Expand Down