In [14]:
# Copyright 2020 Google LLC.
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pandas as pd
import xarray
import numpy_groupies

def _binned_agg(
    array: np.ndarray,
    indices: np.ndarray,
    num_bins: int,
    *,
    func,
    fill_value,
    dtype,
) -> np.ndarray:
    """NumPy helper function for aggregating over bins."""
    mask = np.logical_not(np.isnan(indices))
    int_indices = indices[mask].astype(int)
    shape = array.shape[:-indices.ndim] + (num_bins,)
    result = numpy_groupies.aggregate(
        int_indices, array[..., mask],
        func=func,
        size=num_bins,
        fill_value=fill_value,
        dtype=dtype,
        axis=-1,
    )
    return result

def groupby_bins_agg(
    array: xarray.DataArray,
    group: xarray.DataArray,
    bins,
    func='sum',
    fill_value=0,
    dtype=None,
    **cut_kwargs,
) -> xarray.DataArray:
    """Faster equivalent of Xarray's groupby_bins(...).sum()."""
    # TODO: implement this upstream in xarray:
    # https://github.com/pydata/xarray/issues/4473
    binned = pd.cut(np.ravel(group), bins, **cut_kwargs)
    new_dim_name = group.name + "_bins"
    indices = group.copy(data=binned.codes.reshape(group.shape))

    result = xarray.apply_ufunc(
        _binned_agg, array, indices,
        input_core_dims=[indices.dims, indices.dims],
        output_core_dims=[[new_dim_name]],
        output_dtypes=[array.dtype],
        dask_gufunc_kwargs=dict(
            output_sizes={new_dim_name: binned.categories.size},
        ),
        kwargs={
            'num_bins': binned.categories.size,
            'func': func,
            'fill_value': fill_value,
            'dtype': dtype,
        },
        dask='parallelized',
    )
    result.coords[new_dim_name] = binned.categories
    return result

def make_test_data(t, x, y, seed=0):
    signal = xarray.DataArray(
        np.random.RandomState(seed).rand(t, x, y),
        dims=['time', 'y', 'x'],
        coords={
            'time': np.arange(t),
            'y': np.arange(x),
            'x': np.arange(y),
        },
        name='signal')
    distance = ((signal.x ** 2 + signal.y ** 2) ** 0.5).rename('distance')
    return signal, distance

# unit test

In [25]:
signal, distance = make_test_data(t=2, x=50, y=50)
bins = 10

In [29]:
actual = groupby_bins_agg(signal, distance, bins, func='mean')
expected = signal.groupby_bins(distance, bins=10).mean()
xarray.testing.assert_allclose(actual, expected)

In [30]:
actual

# Speed tests

In [57]:
signal, distance = make_test_data(t=20, x=1000, y=1000)
bins = 50

In [58]:
signal.nbytes / 1e6

160.0

## numpy speed test

In [60]:
%time _ = signal.groupby_bins(distance, bins).mean()

CPU times: user 8.52 s, sys: 674 ms, total: 9.19 s
Wall time: 10.3 s


In [61]:
%time _ = groupby_bins_agg(signal, distance, bins, func='mean')

CPU times: user 909 ms, sys: 290 ms, total: 1.2 s
Wall time: 1.3 s


## dask speed test

In [77]:
import dask
dask_signal = signal.chunk({'time': 1})
dask.config.set(num_workers=4)
dask_signal

Unnamed: 0,Array,Chunk
Bytes,160.00 MB,8.00 MB
Shape,"(20, 1000, 1000)","(1, 1000, 1000)"
Count,20 Tasks,20 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 160.00 MB 8.00 MB Shape (20, 1000, 1000) (1, 1000, 1000) Count 20 Tasks 20 Chunks Type float64 numpy.ndarray",1000  1000  20,

Unnamed: 0,Array,Chunk
Bytes,160.00 MB,8.00 MB
Shape,"(20, 1000, 1000)","(1, 1000, 1000)"
Count,20 Tasks,20 Chunks
Type,float64,numpy.ndarray


In [85]:
%time result = dask_signal.groupby_bins(distance, bins).mean()
%time result.compute()
result

CPU times: user 8.13 s, sys: 365 ms, total: 8.49 s
Wall time: 8.87 s
CPU times: user 1.12 s, sys: 332 ms, total: 1.45 s
Wall time: 967 ms


Unnamed: 0,Array,Chunk
Bytes,8.00 kB,8 B
Shape,"(20, 50)","(1, 1)"
Count,5060 Tasks,1000 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 8.00 kB 8 B Shape (20, 50) (1, 1) Count 5060 Tasks 1000 Chunks Type float64 numpy.ndarray",50  20,

Unnamed: 0,Array,Chunk
Bytes,8.00 kB,8 B
Shape,"(20, 50)","(1, 1)"
Count,5060 Tasks,1000 Chunks
Type,float64,numpy.ndarray


In [84]:
%time result = groupby_bins_agg(dask_signal, distance, bins, func='mean')
%time result.compute()
result

CPU times: user 54.8 ms, sys: 7.46 ms, total: 62.2 ms
Wall time: 61.3 ms
CPU times: user 884 ms, sys: 191 ms, total: 1.08 s
Wall time: 484 ms


Unnamed: 0,Array,Chunk
Bytes,8.00 kB,400 B
Shape,"(20, 50)","(1, 50)"
Count,101 Tasks,20 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 8.00 kB 400 B Shape (20, 50) (1, 50) Count 101 Tasks 20 Chunks Type float64 numpy.ndarray",50  20,

Unnamed: 0,Array,Chunk
Bytes,8.00 kB,400 B
Shape,"(20, 50)","(1, 50)"
Count,101 Tasks,20 Chunks
Type,float64,numpy.ndarray
