From 52ee95f8ae6b9631ac381b5b889de47e41f2440e Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 1 Jan 2016 13:38:11 -0800 Subject: [PATCH] Add broadcast function to the API This is a renaming and update of the existing `xray.broadcast_arrays` function, which now works properly in the light of GH648. Examples -------- Broadcast two data arrays against one another to fill out their dimensions: >>> a = xray.DataArray([1, 2, 3], dims='x') >>> b = xray.DataArray([5, 6], dims='y') >>> a array([1, 2, 3]) Coordinates: * x (x) int64 0 1 2 >>> b array([5, 6]) Coordinates: * y (y) int64 0 1 >>> a2, b2 = xray.broadcast(a, b) >>> a2 array([[1, 1], [2, 2], [3, 3]]) Coordinates: * x (x) int64 0 1 2 * y (y) int64 0 1 >>> b2 array([[5, 6], [5, 6], [5, 6]]) Coordinates: * y (y) int64 0 1 * x (x) int64 0 1 2 Fill out the dimensions of all data variables in a dataset: >>> ds = xray.Dataset({'a': a, 'b': b}) >>> ds2, = xray.broadcast(ds) # use tuple unpacking to extract one dataset >>> ds2 Dimensions: (x: 3, y: 2) Coordinates: * x (x) int64 0 1 2 * y (y) int64 0 1 Data variables: a (x, y) int64 1 1 2 2 3 3 b (x, y) int64 5 6 5 6 5 6 --- doc/api.rst | 1 + doc/whats-new.rst | 14 ++++- xray/__init__.py | 2 +- xray/core/alignment.py | 114 ++++++++++++++++++++++++++++++------ xray/test/test_dataarray.py | 22 +++++-- xray/test/test_dataset.py | 28 ++++++++- 6 files changed, 155 insertions(+), 26 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 7619010a865..f0125792a2c 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -15,6 +15,7 @@ Top-level functions :toctree: generated/ align + broadcast concat set_options diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 497edcc9b65..951d9c6798f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -67,7 +67,7 @@ Bug fixes - Fixes for several issues found on ``DataArray`` objects with the same name as one of their coordinates (see :ref:`v0.7.0.breaking` for more details). -- ``DataArray.to_masked_array`` always returns masked array with mask being an array +- ``DataArray.to_masked_array`` always returns masked array with mask being an array (not a scalar value) (:issue:`684`) v0.6.2 (unreleased) @@ -96,6 +96,18 @@ Enhancements moves both data and coordinates. - Assigning a ``pandas`` object to a ``Dataset`` directly is now permitted. Its index names correspond to the `dims`` of the ``Dataset``, and its data is aligned +- New function :py:func:`~xray.broadcast` for explicitly broadcasting + ``DataArray`` and ``Dataset`` objects against each other. For example: + + .. ipython:: python + + a = xray.DataArray([1, 2, 3], dims='x') + b = xray.DataArray([5, 6], dims='y') + a + b + a2, b2 = xray.broadcast(a, b) + a2 + b2 Bug fixes ~~~~~~~~~ diff --git a/xray/__init__.py b/xray/__init__.py index e48935015d7..5f7e69be8dc 100644 --- a/xray/__init__.py +++ b/xray/__init__.py @@ -1,4 +1,4 @@ -from .core.alignment import align, broadcast_arrays +from .core.alignment import align, broadcast, broadcast_arrays from .core.combine import concat, auto_combine from .core.variable import Variable, Coordinate from .core.dataset import Dataset diff --git a/xray/core/alignment.py b/xray/core/alignment.py index 32d093747cb..bc4010c2d87 100644 --- a/xray/core/alignment.py +++ b/xray/core/alignment.py @@ -253,31 +253,80 @@ def var_indexers(var, indexers): return reindexed -def broadcast_arrays(*args): - """Explicitly broadcast any number of DataArrays against one another. +def broadcast(*args): + """Explicitly broadcast any number of DataArray or Dataset objects against + one another. xray objects automatically broadcast against each other in arithmetic operations, so this function should not be necessary for normal use. Parameters ---------- - *args: DataArray + *args: DataArray or Dataset objects Arrays to broadcast against each other. Returns ------- - broadcast: tuple of DataArray + broadcast: tuple of xray objects The same data as the input arrays, but with additional dimensions - inserted so that all arrays have the same dimensions and shape. + inserted so that all data arrays have the same dimensions and shape. Raises ------ ValueError - If indexes on the different arrays are not aligned. + If indexes on the different objects are not aligned. + + Examples + -------- + + Broadcast two data arrays against one another to fill out their dimensions: + + >>> a = xray.DataArray([1, 2, 3], dims='x') + >>> b = xray.DataArray([5, 6], dims='y') + >>> a + + array([1, 2, 3]) + Coordinates: + * x (x) int64 0 1 2 + >>> b + + array([5, 6]) + Coordinates: + * y (y) int64 0 1 + >>> a2, b2 = xray.broadcast(a, b) + >>> a2 + + array([[1, 1], + [2, 2], + [3, 3]]) + Coordinates: + * x (x) int64 0 1 2 + * y (y) int64 0 1 + >>> b2 + + array([[5, 6], + [5, 6], + [5, 6]]) + Coordinates: + * y (y) int64 0 1 + * x (x) int64 0 1 2 + + Fill out the dimensions of all data variables in a dataset: + + >>> ds = xray.Dataset({'a': a, 'b': b}) + >>> ds2, = xray.broadcast(ds) # use tuple unpacking to extract one dataset + >>> ds2 + + Dimensions: (x: 3, y: 2) + Coordinates: + * x (x) int64 0 1 2 + * y (y) int64 0 1 + Data variables: + a (x, y) int64 1 1 2 2 3 3 + b (x, y) int64 5 6 5 6 5 6 """ - # TODO: fixme for coordinate arrays - from .dataarray import DataArray + from .dataset import Dataset all_indexes = _get_all_indexes(args) for k, v in all_indexes.items(): @@ -285,14 +334,45 @@ def broadcast_arrays(*args): raise ValueError('cannot broadcast arrays: the %s index is not ' 'aligned (use xray.align first)' % k) - vars = broadcast_variables(*[a.variable for a in args]) - indexes = dict((k, all_indexes[k][0]) for k in vars[0].dims) + common_coords = OrderedDict() + dims_map = OrderedDict() + for arg in args: + for dim in arg.dims: + if dim not in common_coords: + common_coords[dim] = arg.coords[dim].variable + dims_map[dim] = common_coords[dim].size + + def _broadcast_array(array): + data = array.variable.expand_dims(dims_map) + coords = OrderedDict(array.coords) + coords.update(common_coords) + dims = tuple(common_coords) + return DataArray(data, coords, dims, name=array.name, + attrs=array.attrs, encoding=array.encoding) + + def _broadcast_dataset(ds): + data_vars = OrderedDict() + for k in ds.data_vars: + data_vars[k] = ds.variables[k].expand_dims(dims_map) + + coords = OrderedDict(ds.coords) + coords.update(common_coords) + + return Dataset(data_vars, coords, ds.attrs) + + result = [] + for arg in args: + if isinstance(arg, DataArray): + result.append(_broadcast_array(arg)) + elif isinstance(arg, Dataset): + result.append(_broadcast_dataset(arg)) + else: + raise ValueError('all input must be Dataset or DataArray objects') - arrays = [] - for a, v in zip(args, vars): - arr = DataArray(v.values, indexes, v.dims, a.name, a.attrs, a.encoding) - for k, v in a.coords.items(): - arr.coords[k] = v - arrays.append(arr) + return tuple(result) - return tuple(arrays) + +def broadcast_arrays(*args): + warnings.warn('xray.broadcast_arrays is deprecated: use xray.broadcast ' + 'instead', DeprecationWarning, stacklevel=2) + return broadcast(*args) diff --git a/xray/test/test_dataarray.py b/xray/test/test_dataarray.py index 1a1b5e883e0..a244dd48dbd 100644 --- a/xray/test/test_dataarray.py +++ b/xray/test/test_dataarray.py @@ -4,7 +4,7 @@ from copy import deepcopy from textwrap import dedent -from xray import (align, broadcast_arrays, Dataset, DataArray, +from xray import (align, broadcast, Dataset, DataArray, Coordinate, Variable) from xray.core.pycompat import iteritems, OrderedDict from . import TestCase, ReturnItem, source_ndarray, unittest, requires_dask @@ -1267,7 +1267,7 @@ def test_align_dtype(self): def test_broadcast_arrays(self): x = DataArray([1, 2], coords=[('a', [-1, -2])], name='x') y = DataArray([1, 2], coords=[('b', [3, 4])], name='y') - x2, y2 = broadcast_arrays(x, y) + x2, y2 = broadcast(x, y) expected_coords = [('a', [-1, -2]), ('b', [3, 4])] expected_x2 = DataArray([[1, 1], [2, 2]], expected_coords, name='x') expected_y2 = DataArray([[1, 2], [1, 2]], expected_coords, name='y') @@ -1276,15 +1276,27 @@ def test_broadcast_arrays(self): x = DataArray(np.random.randn(2, 3), dims=['a', 'b']) y = DataArray(np.random.randn(3, 2), dims=['b', 'a']) - x2, y2 = broadcast_arrays(x, y) + x2, y2 = broadcast(x, y) expected_x2 = x expected_y2 = y.T self.assertDataArrayIdentical(expected_x2, x2) self.assertDataArrayIdentical(expected_y2, y2) + z = DataArray([1, 2], coords=[('a', [-10, 20])]) with self.assertRaisesRegexp(ValueError, 'cannot broadcast'): - z = DataArray([1, 2], coords=[('a', [-10, 20])]) - broadcast_arrays(x, z) + broadcast(x, z) + + def test_broadcast_coordinates(self): + # regression test for GH649 + ds = Dataset({'a': (['x', 'y'], np.ones((5, 6)))}) + x_bc, y_bc, a_bc = broadcast(ds.x, ds.y, ds.a) + self.assertDataArrayIdentical(ds.a, a_bc) + + X, Y = np.meshgrid(np.arange(5), np.arange(6), indexing='ij') + exp_x = DataArray(X, dims=['x', 'y'], name='x') + exp_y = DataArray(Y, dims=['x', 'y'], name='y') + self.assertDataArrayIdentical(exp_x, x_bc) + self.assertDataArrayIdentical(exp_y, y_bc) def test_to_pandas(self): # 0d diff --git a/xray/test/test_dataset.py b/xray/test/test_dataset.py index 0f7abb21579..7213d3293d9 100644 --- a/xray/test/test_dataset.py +++ b/xray/test/test_dataset.py @@ -12,8 +12,8 @@ import numpy as np import pandas as pd -from xray import (align, concat, conventions, backends, Dataset, DataArray, - Variable, Coordinate, auto_combine, open_dataset, +from xray import (align, broadcast, concat, conventions, backends, Dataset, + DataArray, Variable, Coordinate, auto_combine, open_dataset, set_options) from xray.core import indexing, utils from xray.core.pycompat import iteritems, OrderedDict @@ -953,6 +953,30 @@ def test_align(self): with self.assertRaises(TypeError): align(left, right, foo='bar') + def test_broadcast(self): + ds = Dataset({'foo': 0, 'bar': ('x', [1]), 'baz': ('y', [2, 3])}, + {'c': ('x', [4])}) + expected = Dataset({'foo': (('x', 'y'), [[0, 0]]), + 'bar': (('x', 'y'), [[1, 1]]), + 'baz': (('x', 'y'), [[2, 3]])}, + {'c': ('x', [4])}) + actual, = broadcast(ds) + self.assertDatasetIdentical(expected, actual) + + ds_x = Dataset({'foo': ('x', [1])}) + ds_y = Dataset({'bar': ('y', [2, 3])}) + expected_x = Dataset({'foo': (('x', 'y'), [[1, 1]])}) + expected_y = Dataset({'bar': (('x', 'y'), [[2, 3]])}) + actual_x, actual_y = broadcast(ds_x, ds_y) + self.assertDatasetIdentical(expected_x, actual_x) + self.assertDatasetIdentical(expected_y, actual_y) + + array_y = ds_y['bar'] + expected_y = expected_y['bar'] + actual_x, actual_y = broadcast(ds_x, array_y) + self.assertDatasetIdentical(expected_x, actual_x) + self.assertDataArrayIdentical(expected_y, actual_y) + def test_variable_indexing(self): data = create_test_data() v = data['var1']