Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non lazy behavior for weighted average when using resampled data #4625

Closed
jbusecke opened this issue Nov 30, 2020 · 13 comments · Fixed by #4668
Closed

Non lazy behavior for weighted average when using resampled data #4625

jbusecke opened this issue Nov 30, 2020 · 13 comments · Fixed by #4668

Comments

@jbusecke
Copy link
Contributor

I am trying to apply an averaging function to multi year chunks of monthly model data. At the core the function performs a weighted average (and then some coordinate manipulations). I am using resample(time='1AS') and then try to map my custom function onto the data (see example below). Without actually loading the data, this step is prohibitively long in my workflow (20-30min depending on the model).
Is there a way to apply this step completely lazily, like in the case where a simple non-weighted .mean() is used?

from dask.diagnostics import ProgressBar
import xarray as xr
import numpy as np

# simple customized weighted mean function
def mean_func(ds):
    return ds.weighted(ds.weights).mean('time')

# example dataset
t = xr.cftime_range(start='2000', periods=1000, freq='1AS')
weights = xr.DataArray(np.random.rand(len(t)),dims=['time'], coords={'time':t})
data = xr.DataArray(np.random.rand(len(t)),dims=['time'], coords={'time':t, 'weights':weights})
ds = xr.Dataset({'data':data}).chunk({'time':1})
ds

image

Using resample with a simple mean works without any computation being triggered:

with ProgressBar():
    ds.resample(time='3AS').mean('time')

But when I do the same step with my custom function, there are some computations showing up

with ProgressBar():
    ds.resample(time='3AS').map(mean_func)
[########################################] | 100% Completed |  0.1s
[########################################] | 100% Completed |  0.1s
[########################################] | 100% Completed |  0.1s
[########################################] | 100% Completed |  0.1s

I am quite sure these are the same kind of computations that make my real-world workflow so slow.

I also confirmed that this not happening when I do not use resample first

with ProgressBar():
    mean_func(ds)

this does not trigger a computation either. So this must be somehow related to resample? I would be happy to dig deeper into this, if somebody with more knowledge could point me to the right place.

Environment:

Output of xr.show_versions()

INSTALLED VERSIONS

commit: None
python: 3.8.6 | packaged by conda-forge | (default, Oct 7 2020, 19:08:05)
[GCC 7.5.0]
python-bits: 64
OS: Linux
OS-release: 3.10.0-1160.2.2.el7.x86_64
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_US.UTF-8
LOCALE: en_US.UTF-8
libhdf5: 1.10.6
libnetcdf: 4.7.4

xarray: 0.16.2.dev77+g1a4f7bd
pandas: 1.1.3
numpy: 1.19.2
scipy: 1.5.2
netCDF4: 1.5.4
pydap: None
h5netcdf: 0.8.1
h5py: 2.10.0
Nio: None
zarr: 2.4.0
cftime: 1.2.1
nc_time_axis: 1.2.0
PseudoNetCDF: None
rasterio: 1.1.3
cfgrib: None
iris: None
bottleneck: None
dask: 2.30.0
distributed: 2.30.0
matplotlib: 3.3.2
cartopy: 0.18.0
seaborn: None
numbagg: None
pint: 0.16.1
setuptools: 49.6.0.post20201009
pip: 20.2.4
conda: None
pytest: 6.1.2
IPython: 7.18.1
sphinx: None

@mathause
Copy link
Collaborator

I fear it's the weight check 🤦, try commenting lines 105 to 121:

def _weight_check(w):

@jbusecke
Copy link
Contributor Author

Oh nooo. So would you suggest that in addition to #4559, we should have a kwarg to completely skip this?

@keewis
Copy link
Collaborator

keewis commented Nov 30, 2020

the issue seems to be just this:

weights = weights.copy(
data=weights.data.map_blocks(_weight_check, dtype=weights.dtype)
)

Also, the computation is still triggered, even if we remove the map_blocks call:

    weights = weights.copy(data=weights.data)

not sure why, though.

@dcherian
Copy link
Contributor

dcherian commented Nov 30, 2020

The weighted fix in #4559 is correct, that's why

with ProgressBar():
    mean_func(ds)

does not compute.

This is more instructive:

from xarray.tests import raise_if_dask_computes

with raise_if_dask_computes():
    ds.resample(time='3AS').map(mean_func)
....
    150 
    151     def _sum_of_weights(

~/work/python/xarray/xarray/core/computation.py in dot(dims, *arrays, **kwargs)
   1483         output_core_dims=output_core_dims,
   1484         join=join,
-> 1485         dask="allowed",
   1486     )
   1487     return result.transpose(*[d for d in all_dims if d in result.dims])

~/work/python/xarray/xarray/core/computation.py in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
   1132             join=join,
   1133             exclude_dims=exclude_dims,
-> 1134             keep_attrs=keep_attrs,
   1135         )
   1136     # feed Variables directly through apply_variable_ufunc

~/work/python/xarray/xarray/core/computation.py in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
    266     else:
    267         name = result_name(args)
--> 268     result_coords = build_output_coords(args, signature, exclude_dims)
    269 
    270     data_vars = [getattr(a, "variable", a) for a in args]

~/work/python/xarray/xarray/core/computation.py in build_output_coords(args, signature, exclude_dims)
    231         # TODO: save these merged indexes, instead of re-computing them later
    232         merged_vars, unused_indexes = merge_coordinates_without_align(
--> 233             coords_list, exclude_dims=exclude_dims
    234         )
    235 

~/work/python/xarray/xarray/core/merge.py in merge_coordinates_without_align(objects, prioritized, exclude_dims)
    327         filtered = collected
    328 
--> 329     return merge_collected(filtered, prioritized)
    330 
    331 

~/work/python/xarray/xarray/core/merge.py in merge_collected(grouped, prioritized, compat)
    227                 variables = [variable for variable, _ in elements_list]
    228                 try:
--> 229                     merged_vars[name] = unique_variable(name, variables, compat)
    230                 except MergeError:
    231                     if compat != "minimal":

~/work/python/xarray/xarray/core/merge.py in unique_variable(name, variables, compat, equals)
    132         if equals is None:
    133             # now compare values with minimum number of computes
--> 134             out = out.compute()
    135             for var in variables[1:]:
    136                 equals = getattr(out, compat)(var)

~/work/python/xarray/xarray/core/variable.py in compute(self, **kwargs)
    459         """
    460         new = self.copy(deep=False)
--> 461         return new.load(**kwargs)
    462 
    463     def __dask_tokenize__(self):

~/work/python/xarray/xarray/core/variable.py in load(self, **kwargs)
    435         """
    436         if is_duck_dask_array(self._data):
--> 437             self._data = as_compatible_data(self._data.compute(**kwargs))
    438         elif not is_duck_array(self._data):
    439             self._data = np.asarray(self._data)

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
    165         dask.base.compute
    166         """
--> 167         (result,) = compute(self, traverse=False, **kwargs)
    168         return result
    169 

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    450         postcomputes.append(x.__dask_postcompute__())
    451 
--> 452     results = schedule(dsk, keys, **kwargs)
    453     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    454 

~/work/python/xarray/xarray/tests/__init__.py in __call__(self, dsk, keys, **kwargs)
    112             raise RuntimeError(
    113                 "Too many computes. Total: %d > max: %d."
--> 114                 % (self.total_computes, self.max_computes)
    115             )
    116         return dask.get(dsk, keys, **kwargs)

RuntimeError: Too many computes. Total: 1 > max: 0.

It looks like we're repeatedly checking weights for equality (if you navigate to merge_collected in the stack, name = "weights". The lazy_array_equal check is failing, because a copy is made somewhere.

ipdb>  up

> /home/deepak/work/python/xarray/xarray/core/merge.py(229)merge_collected()
    227                 variables = [variable for variable, _ in elements_list]
    228                 try:
--> 229                     merged_vars[name] = unique_variable(name, variables, compat)
    230                 except MergeError:
    231                     if compat != "minimal":

ipdb>  name

'weights'

ipdb>  variables

[<xarray.Variable (time: 1)>
dask.array<getitem, shape=(1,), dtype=float64, chunksize=(1,), chunktype=numpy.ndarray>, <xarray.Variable (time: 1)>
dask.array<copy, shape=(1,), dtype=float64, chunksize=(1,), chunktype=numpy.ndarray>]

ipdb>  variables[0].data.name

'getitem-2a74b8ca20ae20100597e397404ba17b'

ipdb>  variables[1].data.name

'copy-fff901a87f4a2293c750766c554aa68d'

@dcherian
Copy link
Contributor

dcherian commented Dec 1, 2020

Ah this works (but we lose weights as a coord var).

# simple customized weighted mean function
def mean_func(ds):
    return ds.weighted(ds.weights.reset_coords(drop=True)).mean('time')

Adding reset_coords fixes this because it gets rid of the non-dim coord weights.
image

return dot(da, weights, dims=dim)

dot compares the weights coord var on ds and weights to decide if it should keep it.

The new call to .copy ends up making a copy of weights coord on the weights dataarray, so the lazy equality check fails. One solution is to avoid the call to copy and create the DataArray directly

enc = weights.encoding
weights = DataArray(
   weights.data.map_blocks(_weight_check, dtype=weights.dtype),
   dims=weights.dims,
   coords=weights.coords,
   attrs=weights.attrs
)
weights.encoding = enc

This works locally.

@jbusecke
Copy link
Contributor Author

jbusecke commented Dec 1, 2020

Sweet. Ill try to apply this fix for my workflow now. Happy to submit a PR with the suggested changes to weighted.py too.

@dcherian
Copy link
Contributor

dcherian commented Dec 1, 2020

PRs are always welcome!

@dcherian
Copy link
Contributor

dcherian commented Dec 1, 2020

Untested but specifying deep=False in the call to copy should fix it

@jbusecke
Copy link
Contributor Author

jbusecke commented Dec 1, 2020

Do you have a suggestion how to test this? Should I write a test involving resample + weighted?

@dcherian
Copy link
Contributor

dcherian commented Dec 1, 2020

Yes something like what you have with

with raise_if_dask_computes():
    ds.resample(time='3AS').map(mean_func)

BUT something is wrong with my explanation above. The error is only triggered when the number of timesteps is not divisble by the resampling frequency. If you set periods=3 when creating t, the old version works fine, if you change it to 4 it computes. But setting deep=False fixes it in all cases. I am v. confused!

@jbusecke
Copy link
Contributor Author

jbusecke commented Dec 1, 2020

Oh I remember that too, and I didn't understand it at all...

@jbusecke
Copy link
Contributor Author

jbusecke commented Dec 9, 2020

So I have added a test in #4668 and it confirms that this behavior is only occurring if the resample interval is smaller or equal than the chunks.
If the resample interval is larger than the chunks it stays completely lazy...not sure if this is a general limitation? Does anyone have more insight into how resample handles this kind of workflow?

@jbusecke
Copy link
Contributor Author

jbusecke commented Dec 9, 2020

As @dcherian pointed out above copy(..., deep=False) does fix this for all cases I am testing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants