Skip to content

Commit

Permalink
Merge pull request #61 from gjoseph92/core-broadcasting
Browse files Browse the repository at this point in the history
Fix regression: input arrays are not broadcast
  • Loading branch information
TomNicholas committed Jun 14, 2021
2 parents fe1b3fa + b20c3f9 commit 8a6765a
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 30 deletions.
2 changes: 1 addition & 1 deletion ci/environment-3.7.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dependencies:
- python=3.7
- xarray
- dask
- numpy=1.16
- numpy=1.17
- pytest
- pip
- pip:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"Topic :: Scientific/Engineering",
]

INSTALL_REQUIRES = ["xarray>=0.12.0", "dask", "numpy>=1.16"]
INSTALL_REQUIRES = ["xarray>=0.12.0", "dask", "numpy>=1.17"]
PYTHON_REQUIRES = ">=3.7"

DESCRIPTION = "Fast, flexible, label-aware histograms for numpy and xarray"
Expand Down
43 changes: 18 additions & 25 deletions xhistogram/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,7 @@ def _bincount_2d_vectorized(


def _bincount(*all_arrays, weights=False, axis=None, bins=None, density=None):

# is this necessary?
all_arrays_broadcast = broadcast_arrays(*all_arrays)

a0 = all_arrays_broadcast[0]
a0 = all_arrays[0]

do_full_array = (axis is None) or (set(axis) == set(_range(a0.ndim)))

Expand Down Expand Up @@ -226,7 +222,7 @@ def reshape_input(a):
d = reshape(c, (new_dim_0, new_dim_1))
return d

all_arrays_reshaped = [reshape_input(a) for a in all_arrays_broadcast]
all_arrays_reshaped = [reshape_input(a) for a in all_arrays]

if weights:
weights_array = all_arrays_reshaped.pop()
Expand Down Expand Up @@ -257,8 +253,8 @@ def histogram(
Parameters
----------
args : array_like
Input data. The number of input arguments determines the dimensonality
of the histogram. For example, two arguments prodocue a 2D histogram.
Input data. The number of input arguments determines the dimensionality
of the histogram. For example, two arguments produce a 2D histogram.
All args must have the same size.
bins : int, str or numpy array or a list of ints, strs and/or arrays, optional
If a list, there should be one entry for each item in ``args``.
Expand Down Expand Up @@ -358,21 +354,20 @@ def histogram(

dtype = "i8" if not has_weights else weights.dtype

# here I am assuming all the arrays have the same shape
# probably needs to be generalized
input_indexes = [tuple(_range(a.ndim)) for a in all_arrays]
input_index = input_indexes[0]
assert all([ii == input_index for ii in input_indexes])
# Broadcast input arrays. Note that this dispatches to `dsa.broadcast_arrays` as necessary.
all_arrays = broadcast_arrays(*all_arrays)
# Since all arrays now have the same shape, just get the axes of the first.
input_axes = tuple(_range(all_arrays[0].ndim))

# Some sanity checks and format bins and range correctly
bins = _ensure_correctly_formatted_bins(bins, n_inputs)
range = _ensure_correctly_formatted_range(range, n_inputs)

# histogram_bin_edges trigges computation on dask arrays. It would be possible
# histogram_bin_edges triggers computation on dask arrays. It would be possible
# to write a version of this that doesn't trigger when `range` is provided, but
# for now let's just use np.histogram_bin_edges
if is_dask_array:
if not all([isinstance(b, np.ndarray) for b in bins]):
if not all(isinstance(b, np.ndarray) for b in bins):
raise TypeError(
"When using dask arrays, bins must be provided as numpy array(s) of edges"
)
Expand All @@ -382,11 +377,11 @@ def histogram(
]
bincount_kwargs = dict(weights=has_weights, axis=axis, bins=bins, density=density)

# keep these axes in the inputs
# remove these axes from the inputs
if axis is not None:
drop_axes = tuple([ii for ii in input_index if ii in axis])
drop_axes = tuple(axis)
else:
drop_axes = input_index
drop_axes = input_axes

if _any_dask_array(weights, *all_arrays):
# We should be able to just apply the bin_count function to every
Expand All @@ -405,16 +400,14 @@ def histogram(

adjust_chunks = {i: (lambda x: 1) for i in drop_axes}

new_axes = {
max(input_index) + 1 + i: axis_len
for i, axis_len in enumerate([len(bin) - 1 for bin in bins])
}
out_index = input_index + tuple(new_axes)
new_axes_start = max(input_axes) + 1
new_axes = {new_axes_start + i: len(bin) - 1 for i, bin in enumerate(bins)}
out_index = input_axes + tuple(new_axes)

blockwise_args = []
for arg in all_arrays:
blockwise_args.append(arg)
blockwise_args.append(input_index)
blockwise_args.append(input_axes)

bin_counts = dsa.blockwise(
_bincount,
Expand All @@ -432,7 +425,7 @@ def histogram(
bin_counts = _bincount(*all_arrays, **bincount_kwargs).squeeze(drop_axes)

if density:
# Normalise by dividing by bin counts and areas such that all the
# Normalize by dividing by bin counts and areas such that all the
# histogram data integrated over all dimensions = 1
bin_widths = [np.diff(b) for b in bins]
if n_inputs == 1:
Expand Down
34 changes: 31 additions & 3 deletions xhistogram/test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,34 @@ def test_histogram_results_2d():
np.testing.assert_array_equal(hist, h)


@pytest.mark.parametrize("dask", [False, True])
def test_histogram_results_2d_broadcasting(dask):
nrows, ncols = 5, 20
data_a = np.random.randn(ncols)
data_b = np.random.randn(nrows, ncols)
nbins_a = 9
bins_a = np.linspace(-4, 4, nbins_a + 1)
nbins_b = 10
bins_b = np.linspace(-4, 4, nbins_b + 1)

if dask:
test_data_a = dsa.from_array(data_a, chunks=3)
test_data_b = dsa.from_array(data_b, chunks=(2, 7))
else:
test_data_a = data_a
test_data_b = data_b

h, _ = histogram(test_data_a, test_data_b, bins=[bins_a, bins_b])
assert h.shape == (nbins_a, nbins_b)

hist, _, _ = np.histogram2d(
np.broadcast_to(data_a, data_b.shape).ravel(),
data_b.ravel(),
bins=[bins_a, bins_b],
)
np.testing.assert_array_equal(hist, h)


def test_histogram_results_2d_density():
nrows, ncols = 5, 20
data_a = np.random.randn(nrows, ncols)
Expand Down Expand Up @@ -228,7 +256,7 @@ def test_histogram_shape(use_dask, block_size):


def test_histogram_dask():
""" Test that fails with dask arrays and inappropriate bins"""
"""Test that fails with dask arrays and inappropriate bins"""
shape = 10, 15, 12, 20
b = empty_dask_array(shape, chunks=(1,) + shape[1:])
histogram(b, bins=bins_arr) # Should work when bins is all numpy arrays
Expand All @@ -255,7 +283,7 @@ def test_histogram_dask():
],
)
def test_ensure_correctly_formatted_bins(in_out):
""" Test the helper function _ensure_correctly_formatted_bins"""
"""Test the helper function _ensure_correctly_formatted_bins"""
bins_in, n, bins_expected = in_out
if bins_expected is not None:
bins = _ensure_correctly_formatted_bins(bins_in, n)
Expand All @@ -277,7 +305,7 @@ def test_ensure_correctly_formatted_bins(in_out):
],
)
def test_ensure_correctly_formatted_range(in_out):
""" Test the helper function _ensure_correctly_formatted_range"""
"""Test the helper function _ensure_correctly_formatted_range"""
range_in, n, range_expected = in_out
if range_expected is not None:
range_ = _ensure_correctly_formatted_range(range_in, n)
Expand Down

0 comments on commit 8a6765a

Please sign in to comment.