Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Top-level functions
:toctree: generated/

align
broadcast
concat
set_options

Expand Down
14 changes: 13 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Bug fixes
- Fixes for several issues found on ``DataArray`` objects with the same name
as one of their coordinates (see :ref:`v0.7.0.breaking` for more details).

- ``DataArray.to_masked_array`` always returns masked array with mask being an array
- ``DataArray.to_masked_array`` always returns masked array with mask being an array
(not a scalar value) (:issue:`684`)

v0.6.2 (unreleased)
Expand Down Expand Up @@ -96,6 +96,18 @@ Enhancements
moves both data and coordinates.
- Assigning a ``pandas`` object to a ``Dataset`` directly is now permitted. Its
index names correspond to the `dims`` of the ``Dataset``, and its data is aligned
- New function :py:func:`~xray.broadcast` for explicitly broadcasting
``DataArray`` and ``Dataset`` objects against each other. For example:

.. ipython:: python

a = xray.DataArray([1, 2, 3], dims='x')
b = xray.DataArray([5, 6], dims='y')
a
b
a2, b2 = xray.broadcast(a, b)
a2
b2

Bug fixes
~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion xray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .core.alignment import align, broadcast_arrays
from .core.alignment import align, broadcast, broadcast_arrays
from .core.combine import concat, auto_combine
from .core.variable import Variable, Coordinate
from .core.dataset import Dataset
Expand Down
114 changes: 97 additions & 17 deletions xray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,46 +253,126 @@ def var_indexers(var, indexers):
return reindexed


def broadcast_arrays(*args):
"""Explicitly broadcast any number of DataArrays against one another.
def broadcast(*args):
"""Explicitly broadcast any number of DataArray or Dataset objects against
one another.

xray objects automatically broadcast against each other in arithmetic
operations, so this function should not be necessary for normal use.

Parameters
----------
*args: DataArray
*args: DataArray or Dataset objects
Arrays to broadcast against each other.

Returns
-------
broadcast: tuple of DataArray
broadcast: tuple of xray objects
The same data as the input arrays, but with additional dimensions
inserted so that all arrays have the same dimensions and shape.
inserted so that all data arrays have the same dimensions and shape.

Raises
------
ValueError
If indexes on the different arrays are not aligned.
If indexes on the different objects are not aligned.

Examples
--------

Broadcast two data arrays against one another to fill out their dimensions:

>>> a = xray.DataArray([1, 2, 3], dims='x')
>>> b = xray.DataArray([5, 6], dims='y')
>>> a
<xray.DataArray (x: 3)>
array([1, 2, 3])
Coordinates:
* x (x) int64 0 1 2
>>> b
<xray.DataArray (y: 2)>
array([5, 6])
Coordinates:
* y (y) int64 0 1
>>> a2, b2 = xray.broadcast(a, b)
>>> a2
<xray.DataArray (x: 3, y: 2)>
array([[1, 1],
[2, 2],
[3, 3]])
Coordinates:
* x (x) int64 0 1 2
* y (y) int64 0 1
>>> b2
<xray.DataArray (x: 3, y: 2)>
array([[5, 6],
[5, 6],
[5, 6]])
Coordinates:
* y (y) int64 0 1
* x (x) int64 0 1 2

Fill out the dimensions of all data variables in a dataset:

>>> ds = xray.Dataset({'a': a, 'b': b})
>>> ds2, = xray.broadcast(ds) # use tuple unpacking to extract one dataset
>>> ds2
<xray.Dataset>
Dimensions: (x: 3, y: 2)
Coordinates:
* x (x) int64 0 1 2
* y (y) int64 0 1
Data variables:
a (x, y) int64 1 1 2 2 3 3
b (x, y) int64 5 6 5 6 5 6
"""
# TODO: fixme for coordinate arrays

from .dataarray import DataArray
from .dataset import Dataset

all_indexes = _get_all_indexes(args)
for k, v in all_indexes.items():
if not all(v[0].equals(vi) for vi in v[1:]):
raise ValueError('cannot broadcast arrays: the %s index is not '
'aligned (use xray.align first)' % k)

vars = broadcast_variables(*[a.variable for a in args])
indexes = dict((k, all_indexes[k][0]) for k in vars[0].dims)
common_coords = OrderedDict()
dims_map = OrderedDict()
for arg in args:
for dim in arg.dims:
if dim not in common_coords:
common_coords[dim] = arg.coords[dim].variable
dims_map[dim] = common_coords[dim].size

def _broadcast_array(array):
data = array.variable.expand_dims(dims_map)
coords = OrderedDict(array.coords)
coords.update(common_coords)
dims = tuple(common_coords)
return DataArray(data, coords, dims, name=array.name,
attrs=array.attrs, encoding=array.encoding)

def _broadcast_dataset(ds):
data_vars = OrderedDict()
for k in ds.data_vars:
data_vars[k] = ds.variables[k].expand_dims(dims_map)

coords = OrderedDict(ds.coords)
coords.update(common_coords)

return Dataset(data_vars, coords, ds.attrs)

result = []
for arg in args:
if isinstance(arg, DataArray):
result.append(_broadcast_array(arg))
elif isinstance(arg, Dataset):
result.append(_broadcast_dataset(arg))
else:
raise ValueError('all input must be Dataset or DataArray objects')

arrays = []
for a, v in zip(args, vars):
arr = DataArray(v.values, indexes, v.dims, a.name, a.attrs, a.encoding)
for k, v in a.coords.items():
arr.coords[k] = v
arrays.append(arr)
return tuple(result)

return tuple(arrays)

def broadcast_arrays(*args):
warnings.warn('xray.broadcast_arrays is deprecated: use xray.broadcast '
'instead', DeprecationWarning, stacklevel=2)
return broadcast(*args)
22 changes: 17 additions & 5 deletions xray/test/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import deepcopy
from textwrap import dedent

from xray import (align, broadcast_arrays, Dataset, DataArray,
from xray import (align, broadcast, Dataset, DataArray,
Coordinate, Variable)
from xray.core.pycompat import iteritems, OrderedDict
from . import TestCase, ReturnItem, source_ndarray, unittest, requires_dask
Expand Down Expand Up @@ -1267,7 +1267,7 @@ def test_align_dtype(self):
def test_broadcast_arrays(self):
x = DataArray([1, 2], coords=[('a', [-1, -2])], name='x')
y = DataArray([1, 2], coords=[('b', [3, 4])], name='y')
x2, y2 = broadcast_arrays(x, y)
x2, y2 = broadcast(x, y)
expected_coords = [('a', [-1, -2]), ('b', [3, 4])]
expected_x2 = DataArray([[1, 1], [2, 2]], expected_coords, name='x')
expected_y2 = DataArray([[1, 2], [1, 2]], expected_coords, name='y')
Expand All @@ -1276,15 +1276,27 @@ def test_broadcast_arrays(self):

x = DataArray(np.random.randn(2, 3), dims=['a', 'b'])
y = DataArray(np.random.randn(3, 2), dims=['b', 'a'])
x2, y2 = broadcast_arrays(x, y)
x2, y2 = broadcast(x, y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can there be a test for broadcasting coordinates, not just data variables? That is something I use frequently, and I currently have to hack it (as in #649).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add some more explicit tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

expected_x2 = x
expected_y2 = y.T
self.assertDataArrayIdentical(expected_x2, x2)
self.assertDataArrayIdentical(expected_y2, y2)

z = DataArray([1, 2], coords=[('a', [-10, 20])])
with self.assertRaisesRegexp(ValueError, 'cannot broadcast'):
z = DataArray([1, 2], coords=[('a', [-10, 20])])
broadcast_arrays(x, z)
broadcast(x, z)

def test_broadcast_coordinates(self):
# regression test for GH649
ds = Dataset({'a': (['x', 'y'], np.ones((5, 6)))})
x_bc, y_bc, a_bc = broadcast(ds.x, ds.y, ds.a)
self.assertDataArrayIdentical(ds.a, a_bc)

X, Y = np.meshgrid(np.arange(5), np.arange(6), indexing='ij')
exp_x = DataArray(X, dims=['x', 'y'], name='x')
exp_y = DataArray(Y, dims=['x', 'y'], name='y')
self.assertDataArrayIdentical(exp_x, x_bc)
self.assertDataArrayIdentical(exp_y, y_bc)

def test_to_pandas(self):
# 0d
Expand Down
28 changes: 26 additions & 2 deletions xray/test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import numpy as np
import pandas as pd

from xray import (align, concat, conventions, backends, Dataset, DataArray,
Variable, Coordinate, auto_combine, open_dataset,
from xray import (align, broadcast, concat, conventions, backends, Dataset,
DataArray, Variable, Coordinate, auto_combine, open_dataset,
set_options)
from xray.core import indexing, utils
from xray.core.pycompat import iteritems, OrderedDict
Expand Down Expand Up @@ -953,6 +953,30 @@ def test_align(self):
with self.assertRaises(TypeError):
align(left, right, foo='bar')

def test_broadcast(self):
ds = Dataset({'foo': 0, 'bar': ('x', [1]), 'baz': ('y', [2, 3])},
{'c': ('x', [4])})
expected = Dataset({'foo': (('x', 'y'), [[0, 0]]),
'bar': (('x', 'y'), [[1, 1]]),
'baz': (('x', 'y'), [[2, 3]])},
{'c': ('x', [4])})
actual, = broadcast(ds)
self.assertDatasetIdentical(expected, actual)

ds_x = Dataset({'foo': ('x', [1])})
ds_y = Dataset({'bar': ('y', [2, 3])})
expected_x = Dataset({'foo': (('x', 'y'), [[1, 1]])})
expected_y = Dataset({'bar': (('x', 'y'), [[2, 3]])})
actual_x, actual_y = broadcast(ds_x, ds_y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind the previous comment. It looks like you test coordinate broadcasting here.

self.assertDatasetIdentical(expected_x, actual_x)
self.assertDatasetIdentical(expected_y, actual_y)

array_y = ds_y['bar']
expected_y = expected_y['bar']
actual_x, actual_y = broadcast(ds_x, array_y)
self.assertDatasetIdentical(expected_x, actual_x)
self.assertDataArrayIdentical(expected_y, actual_y)

def test_variable_indexing(self):
data = create_test_data()
v = data['var1']
Expand Down