From c5146e8f82ae8387a24b99b27fb8b7e623b38778 Mon Sep 17 00:00:00 2001 From: chunweiyuan Date: Mon, 23 Jan 2017 14:39:57 -0800 Subject: [PATCH] combine_first by using apply_ufunc in ops.fillna (#1204) * First commit of ops.fillna() using apply_ufunc(). * Saving preliminary changes to fillna using apply_ufunc.aset attr test. * _yield_applied for GroupedBy.fillna(). Inspired by _binary_op(). * Finished introducing combine_first to DataArray/Dataset objects. Remove _fillna from injection. * Minor spacing changes to doc. * Use np.where() in ops.fillna * Separating data_vars_join from join in apply_ufunc, and accessing it through ops.fillna. Also added docstring. * Rewrote Groupby's fillna, among other review comments.git status * Initial attempt to introduce keep_attrs to apply_ufunc. * Adding test for data_vars_join, also removing join kwarg in fillna of dataarray.py and dataset.py. * Brining keep_attrs from apply_ufunc to apply_dataset_ufunc and apply_dataarray_ufunc. * Forcing explicit specification for dataset_fill_value. * Added test for keep_attrs. Changed kwarg name frmo data_vars_join to dataset_join. * Updated docstrings. Moving changes to v0.9.0 * Adjust dataset_fill_value default in apply_ufunc signature --- doc/combining.rst | 35 +++++++++++++ doc/whats-new.rst | 7 +++ xarray/core/computation.py | 87 ++++++++++++++++++++++++++------- xarray/core/dataarray.py | 21 +++++++- xarray/core/dataset.py | 48 ++++++++++++------ xarray/core/groupby.py | 3 +- xarray/core/ops.py | 33 ++++++++++--- xarray/core/variable.py | 2 +- xarray/test/test_computation.py | 66 ++++++++++++++++++++++++- xarray/test/test_dataarray.py | 20 ++++++++ xarray/test/test_dataset.py | 28 +++++++++-- 11 files changed, 298 insertions(+), 52 deletions(-) diff --git a/doc/combining.rst b/doc/combining.rst index e16ff045b4c..f08cdee2527 100644 --- a/doc/combining.rst +++ b/doc/combining.rst @@ -13,6 +13,7 @@ Combining data * For combining datasets or data arrays along a dimension, see concatenate_. * For combining datasets with different variables, see merge_. +* For combining datasets or data arrays with different indexes or missing values, see combine_. .. _concatenate: @@ -116,6 +117,40 @@ used in the :py:class:`~xarray.Dataset` constructor: xr.Dataset({'a': arr[:-1], 'b': arr[1:]}) +.. _combine: + +Combine +~~~~~~~ + +The instance method ``combine_first`` combines two datasets/data arrays and +defaults to non-null values in the calling object, using values from the called +object to fill holes. The resulting coordinates are the union of coordinate labels. +Vacant cells as a result of the outer-join are filled with nan. + +Mimics the behavior of ``pandas.Dataframe.combine_first`` + +For data array, + +.. ipython:: python + + ar0 = DataArray([[0, 0], [0, 0]], [('x', ['a', 'b']), ('y', [-1, 0])]) + ar1 = DataArray([[1, 1], [1, 1]], [('x', ['b', 'c']), ('y', [0, 1])]) + ar2 = DataArray([2], [('x', ['d'])]) + ar0.combine_first(ar1) + ar1.combine_first(ar0) + ar0.combine_first(ar2) + +For datasets, ``ds0.combine_first(ds1)`` works similarly to ``xr.merge([ds0, ds1])``, +except that ``xr.merge`` raises a ``MergeError`` when there are conflicting values +in merging data variables, whereas ``.combine_first`` defaults to the calling object's values. + +.. ipython:: python + + ds0 = Dataset({'a': ('x', [1, 2]), 'x': [0, 1]}) + ds1 = Dataset({'a': ('x', [99, 3]), 'x': [1, 2]}) + ds0.combine_first(ds1) + xr.merge([ds0, ds1]) + .. _update: Update diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e672e8369a9..175af5cf4ab 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -128,6 +128,13 @@ Deprecations Enhancements ~~~~~~~~~~~~ + +- Added the xarray equivalent of `pandas.Dataframe.combine_first` as an instance + method to DataArray/Dataset objects, facilitated by the new `ops.fillna` with + `join` and `data_vars_join` options. + (see :ref:`combine`) + By `Chun-Wei Yuan `_. + - Added the ability to change default automatic alignment (arithmetic_join="inner") for binary operations via :py:func:`~xarray.set_options()` (see :ref:`automatic alignment`). diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 675c5cbe9f6..69c69654cca 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -8,6 +8,8 @@ import operator import re +import numpy as np + from . import ops from .alignment import deep_align from .merge import expand_and_merge_variables @@ -16,6 +18,8 @@ _DEFAULT_FROZEN_SET = frozenset() +_DEFAULT_FILL_VALUE = object() + # see http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html DIMENSION_NAME = r'\w+' @@ -202,6 +206,7 @@ def apply_dataarray_ufunc(func, *args, **kwargs): signature = kwargs.pop('signature') join = kwargs.pop('join', 'inner') exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET) + keep_attrs = kwargs.pop('keep_attrs', False) if kwargs: raise TypeError('apply_dataarray_ufunc() got unexpected keyword ' 'arguments: %s' % list(kwargs)) @@ -217,12 +222,18 @@ def apply_dataarray_ufunc(func, *args, **kwargs): result_var = func(*data_vars) if signature.n_outputs > 1: - return tuple(DataArray(variable, coords, name=name, fastpath=True) - for variable, coords in zip(result_var, result_coords)) + out = tuple(DataArray(variable, coords, name=name, fastpath=True) + for variable, coords in zip(result_var, result_coords)) else: coords, = result_coords - return DataArray(result_var, coords, name=name, fastpath=True) + out = DataArray(result_var, coords, name=name, fastpath=True) + if keep_attrs and isinstance(args[0], DataArray): + if isinstance(out, tuple): + out = tuple(ds._copy_attrs_from(args[0]) for ds in out) + else: + out._copy_attrs_from(args[0]) + return out def ordered_set_union(all_keys): # type: List[Iterable] -> Iterable @@ -326,32 +337,53 @@ def _fast_dataset(variables, coord_variables): def apply_dataset_ufunc(func, *args, **kwargs): """apply_dataset_ufunc(func, *args, signature, join='inner', - fill_value=None, exclude_dims=frozenset()): + dataset_join='inner', fill_value=None, + exclude_dims=frozenset(), keep_attrs=False): + + If dataset_join != 'inner', a non-default fill_value must be supplied + by the user. Otherwise a TypeError is raised. """ + from .dataset import Dataset signature = kwargs.pop('signature') join = kwargs.pop('join', 'inner') + dataset_join = kwargs.pop('dataset_join', 'inner') fill_value = kwargs.pop('fill_value', None) exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET) + keep_attrs = kwargs.pop('keep_attrs', False) + first_obj = args[0] # we'll copy attrs from this in case keep_attrs=True + + if dataset_join != 'inner' and fill_value is _DEFAULT_FILL_VALUE: + raise TypeError('To apply an operation to datasets with different ', + 'data variables, you must supply the ', + 'dataset_fill_value argument.') + if kwargs: raise TypeError('apply_dataset_ufunc() got unexpected keyword ' 'arguments: %s' % list(kwargs)) - if len(args) > 1: args = deep_align(args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False) list_of_coords = build_output_coords(args, signature, exclude_dims) - args = [getattr(arg, 'data_vars', arg) for arg in args] + result_vars = apply_dict_of_variables_ufunc( - func, *args, signature=signature, join=join, fill_value=fill_value) + func, *args, signature=signature, join=dataset_join, + fill_value=fill_value) if signature.n_outputs > 1: - return tuple(_fast_dataset(*args) - for args in zip(result_vars, list_of_coords)) + out = tuple(_fast_dataset(*args) + for args in zip(result_vars, list_of_coords)) else: coord_vars, = list_of_coords - return _fast_dataset(result_vars, coord_vars) + out = _fast_dataset(result_vars, coord_vars) + + if keep_attrs and isinstance(first_obj, Dataset): + if isinstance(out, tuple): + out = tuple(ds._copy_attrs_from(first_obj) for ds in out) + else: + out._copy_attrs_from(first_obj) + return out def _iter_over_selections(obj, dim, values): @@ -530,7 +562,8 @@ def apply_array_ufunc(func, *args, **kwargs): def apply_ufunc(func, *args, **kwargs): """apply_ufunc(func, *args, signature=None, join='inner', - exclude_dims=frozenset(), dataset_fill_value=None, + exclude_dims=frozenset(), dataset_join='inner', + dataset_fill_value=_DEFAULT_FILL_VALUE, keep_attrs=False, kwargs=None, dask_array='forbidden') Apply a vectorized function for unlabeled arrays to xarray objects. @@ -581,14 +614,23 @@ def apply_ufunc(func, *args, **kwargs): - 'inner': use the intersection of object indexes - 'left': use indexes from the first object with each dimension - 'right': use indexes from the last object with each dimension + dataset_join : {'outer', 'inner', 'left', 'right'}, optional + Method for joining variables of Dataset objects with mismatched + data variables. + - 'outer': take variables from both Dataset objects + - 'inner': take only overlapped variables + - 'left': take only variables from the first object + - 'right': take only variables from the last object + dataset_fill_value : optional + Value used in place of missing variables on Dataset inputs when the + datasets do not share the exact same ``data_vars``. Required if + ``dataset_join != 'inner'``, otherwise ignored. + keep_attrs: boolean, Optional + Whether to copy attributes from the first argument to the output. exclude_dims : set, optional Dimensions to exclude from alignment and broadcasting. Any inputs coordinates along these dimensions will be dropped. Each excluded dimension must be a core dimension in the function signature. - dataset_fill_value : optional - Value used in place of missing variables on Dataset inputs when the - datasets do not share the exact same ``data_vars``. Only relevant if - ``join != 'inner'``. kwargs: dict, optional Optional keyword arguments passed directly on to call ``func``. dask_array: 'forbidden' or 'allowed', optional @@ -664,8 +706,10 @@ def stack(objects, dim, new_coord): signature = kwargs.pop('signature', None) join = kwargs.pop('join', 'inner') + dataset_join = kwargs.pop('dataset_join', 'inner') + keep_attrs = kwargs.pop('keep_attrs', False) exclude_dims = kwargs.pop('exclude_dims', frozenset()) - dataset_fill_value = kwargs.pop('dataset_fill_value', None) + dataset_fill_value = kwargs.pop('dataset_fill_value', _DEFAULT_FILL_VALUE) kwargs_ = kwargs.pop('kwargs', None) dask_array = kwargs.pop('dask_array', 'forbidden') if kwargs: @@ -697,16 +741,21 @@ def stack(objects, dim, new_coord): this_apply = functools.partial( apply_ufunc, func, signature=signature, join=join, dask_array=dask_array, exclude_dims=exclude_dims, - dataset_fill_value=dataset_fill_value) + dataset_fill_value=dataset_fill_value, + dataset_join=dataset_join, + keep_attrs=keep_attrs) return apply_groupby_ufunc(this_apply, *args) elif any(is_dict_like(a) for a in args): return apply_dataset_ufunc(variables_ufunc, *args, signature=signature, join=join, exclude_dims=exclude_dims, - fill_value=dataset_fill_value) + fill_value=dataset_fill_value, + dataset_join=dataset_join, + keep_attrs=keep_attrs) elif any(isinstance(a, DataArray) for a in args): return apply_dataarray_ufunc(variables_ufunc, *args, signature=signature, - join=join, exclude_dims=exclude_dims) + join=join, exclude_dims=exclude_dims, + keep_attrs=keep_attrs) elif any(isinstance(a, Variable) for a in args): return variables_ufunc(*args) else: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ebc1c149423..3f9313354a5 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1097,10 +1097,27 @@ def fillna(self, value): if utils.is_dict_like(value): raise TypeError('cannot provide fill value as a dictionary with ' 'fillna on a DataArray') - out = self._fillna(value) - out.attrs = self.attrs + out = ops.fillna(self, value) return out + def combine_first(self, other): + """Combine two DataArray objects, with union of coordinates. + + This operation follows the normal broadcasting and alignment rules of + ``join='outer'``. Default to non-null values of array calling the + method. Use np.nan to fill in vacant cells after alignment. + + Parameters + ---------- + other : DataArray + Used to fill all matching missing values in this array. + + Returns + ------- + DataArray + """ + return ops.fillna(self, other, join="outer") + def reduce(self, func, dim=None, axis=None, keep_attrs=False, **kwargs): """Reduce this array by applying `func` along some dimension(s). diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 53b0c053c41..0081fb1c8fd 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2001,8 +2001,31 @@ def fillna(self, value): ------- Dataset """ - out = self._fillna(value) - out._copy_attrs_from(self) + if utils.is_dict_like(value): + value_keys = getattr(value, 'data_vars', value).keys() + if not set(value_keys) <= set(self.data_vars.keys()): + raise ValueError('all variables in the argument to `fillna` ' + 'must be contained in the original dataset') + out = ops.fillna(self, value) + return out + + def combine_first(self, other): + """Combine two Datasets, default to data_vars of self. + + The new coordinates follow the normal broadcasting and alignment rules + of ``join='outer'``. Vacant cells in the expanded coordinates are + filled with np.nan. + + Parameters + ---------- + other : DataArray + Used to fill all matching missing values in this array. + + Returns + ------- + DataArray + """ + out = ops.fillna(self, other, join="outer", dataset_join="outer") return out def reduce(self, func, dim=None, keep_attrs=False, numeric_only=False, @@ -2335,7 +2358,7 @@ def func(self, *args, **kwargs): return func @staticmethod - def _binary_op(f, reflexive=False, join=None, fillna=False): + def _binary_op(f, reflexive=False, join=None): @functools.wraps(f) def func(self, other): if isinstance(other, groupby.GroupBy): @@ -2344,8 +2367,7 @@ def func(self, other): if hasattr(other, 'indexes'): self, other = align(self, other, join=align_type, copy=False) g = f if not reflexive else lambda x, y: f(y, x) - ds = self._calculate_binary_op(g, other, join=align_type, - fillna=fillna) + ds = self._calculate_binary_op(g, other, join=align_type) return ds return func @@ -2370,14 +2392,9 @@ def func(self, other): return func def _calculate_binary_op(self, f, other, join='inner', - inplace=False, fillna=False): + inplace=False): def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars): - if fillna and join != 'left': - raise ValueError('`fillna` must be accompanied by left join') - if fillna and not set(rhs_data_vars) <= set(lhs_data_vars): - raise ValueError('all variables in the argument to `fillna` ' - 'must be contained in the original dataset') if inplace and set(lhs_data_vars) != set(rhs_data_vars): raise ValueError('datasets must have the same data variables ' 'for in-place arithmetic operations: %s, %s' @@ -2389,12 +2406,10 @@ def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars): if k in rhs_data_vars: dest_vars[k] = f(lhs_vars[k], rhs_vars[k]) elif join in ["left", "outer"]: - dest_vars[k] = (lhs_vars[k] if fillna else - f(lhs_vars[k], np.nan)) + dest_vars[k] = f(lhs_vars[k], np.nan) for k in rhs_data_vars: if k not in dest_vars and join in ["right", "outer"]: - dest_vars[k] = (rhs_vars[k] if fillna else - f(rhs_vars[k], np.nan)) + dest_vars[k] = f(rhs_vars[k], np.nan) return dest_vars if utils.is_dict_like(other) and not isinstance(other, Dataset): @@ -2421,7 +2436,8 @@ def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars): def _copy_attrs_from(self, other): self.attrs = other.attrs for v in other.variables: - self.variables[v].attrs = other.variables[v].attrs + if v in self.variables: + self.variables[v].attrs = other.variables[v].attrs def diff(self, dim, n=1, label='upper'): """Calculate the n-th order discrete difference along given axis. diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 3fe91c27108..24e963862f3 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -382,7 +382,8 @@ def fillna(self, value): Dataset.fillna DataArray.fillna """ - return self._fillna(value) + out = ops.fillna(self, value) + return out def where(self, cond): """Return an object of the same shape with all entries where cond is diff --git a/xarray/core/ops.py b/xarray/core/ops.py index e9da926b709..3516efe7321 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -279,11 +279,35 @@ def count(data, axis=None): return sum(~isnull(data), axis=axis) -def fillna(data, other): +def fillna(data, other, join="left", dataset_join="left"): """Fill missing values in this object with data from the other object. Follows normal broadcasting and alignment rules. + + Parameters + ---------- + join : {'outer', 'inner', 'left', 'right'}, optional + Method for joining the indexes of the passed objects along each + dimension + - 'outer': use the union of object indexes + - 'inner': use the intersection of object indexes + - 'left': use indexes from the first object with each dimension + - 'right': use indexes from the last object with each dimension + dataset_join : {'outer', 'inner', 'left', 'right'}, optional + Method for joining variables of Dataset objects with mismatched + data variables. + - 'outer': take variables from both Dataset objects + - 'inner': take only overlapped variables + - 'left': take only variables from the first object + - 'right': take only variables from the last object """ - return where(isnull(data), other, data) + from .computation import apply_ufunc + + def _fillna(data, other): + return where(isnull(data), other, data) + return apply_ufunc(_fillna, data, other, join=join, dask_array="allowed", + dataset_join=dataset_join, + dataset_fill_value=np.nan, + keep_attrs=True) def where_method(data, cond, other=np.nan): @@ -446,11 +470,6 @@ def inject_binary_ops(cls, inplace=False): for name, f in [('eq', array_eq), ('ne', array_ne)]: setattr(cls, op_str(name), cls._binary_op(f)) - # patch in fillna - f = _func_slash_method_wrapper(fillna) - method = cls._binary_op(f, join='left', fillna=True) - setattr(cls, '_fillna', method) - # patch in where f = _func_slash_method_wrapper(where_method, 'where') setattr(cls, '_where', cls._binary_op(f)) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 80effeb64e2..bd08acf4094 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -856,7 +856,7 @@ def unstack(self, **dimensions): return result def fillna(self, value): - return self._fillna(value) + return ops.fillna(self, value) def where(self, cond): return self._where(cond) diff --git a/xarray/test/test_computation.py b/xarray/test/test_computation.py index db9e51caac8..5ba8abb0259 100644 --- a/xarray/test/test_computation.py +++ b/xarray/test/test_computation.py @@ -335,7 +335,8 @@ def stack_invalid(obj): func = lambda x: xr.core.npcompat.stack([x, -x], axis=-1) sig = ([()], [('sign',)]) # no new_coords - return apply_ufunc(func, obj, signature=sig) + return apply_ufunc(func, obj, signature=sig, + dataset_fill_value=np.nan) def test_apply_exclude(): @@ -473,6 +474,69 @@ def test_broadcast_compat_data_2d(): broadcast_compat_data(var, ('w', 'y', 'x', 'z'), ())) +def test_keep_attrs(): + + def add(a, b, keep_attrs): + if keep_attrs: + return apply_ufunc(operator.add, a, b, keep_attrs=keep_attrs) + else: + return apply_ufunc(operator.add, a, b) + + a = xr.DataArray([0, 1], [('x', [0, 1])]) + a.attrs['attr'] = 'da' + b = xr.DataArray([1, 2], [('x', [0, 1])]) + + actual = add(a, b, keep_attrs=False) + assert not actual.attrs + actual = add(a, b, keep_attrs=True) + assert_identical(actual.attrs, a.attrs) + + a = xr.Dataset({'x': ('x', [1, 2]), 'x': [0, 1]}) + a.attrs['attr'] = 'ds' + a.x.attrs['attr'] = 'da' + b = xr.Dataset({'x': ('x', [1, 1]), 'x': [0, 1]}) + + actual = add(a, b, keep_attrs=False) + assert not actual.attrs + actual = add(a, b, keep_attrs=True) + assert_identical(actual.attrs, a.attrs) + assert_identical(actual.x.attrs, a.x.attrs) + + +def test_dataset_join(): + import numpy as np + ds0 = xr.Dataset({'a': ('x', [1, 2]), 'x': [0, 1]}) + ds1 = xr.Dataset({'a': ('x', [99, 3]), 'x': [1, 2]}) + + with pytest.raises(TypeError): + apply_ufunc(operator.add, ds0, ds1, dataset_join='outer') + + def add(a, b, join, dataset_join): + return apply_ufunc(operator.add, a, b, join=join, + dataset_join=dataset_join, + dataset_fill_value=np.nan) + + actual = add(ds0, ds1, 'outer', 'inner') + expected = xr.Dataset({'a': ('x', [np.nan, 101, np.nan]), + 'x': [0, 1, 2]}) + assert_identical(actual, expected) + + actual = add(ds0, ds1, 'outer', 'outer') + assert_identical(actual, expected) + + # if variables don't match, join will perform add with np.nan + ds2 = xr.Dataset({'b': ('x', [99, 3]), 'x': [1, 2]}) + actual = add(ds0, ds2, 'outer', 'inner') + expected = xr.Dataset({'x': [0, 1, 2]}) + assert_identical(actual, expected) + + actual = add(ds0, ds2, 'outer', 'outer') + expected = xr.Dataset({'a': ('x', [np.nan, np.nan, np.nan]), + 'b': ('x', [np.nan, np.nan, np.nan]), + 'x': [0, 1, 2]}) + assert_identical(actual, expected) + + class _NoCacheVariable(xr.Variable): """Subclass of Variable for testing that does not cache values.""" # TODO: remove this class when we change the default behavior for caching diff --git a/xarray/test/test_dataarray.py b/xarray/test/test_dataarray.py index bb2ad833100..7b5a7b8d481 100644 --- a/xarray/test/test_dataarray.py +++ b/xarray/test/test_dataarray.py @@ -2384,6 +2384,26 @@ def test_binary_op_join_setting(self): expected = xr.DataArray([np.nan, 2, 4, np.nan], [(dim, [0, 1, 2, 3])]) self.assertDataArrayEqual(actual, expected) + def test_combine_first(self): + ar0 = DataArray([[0, 0], [0, 0]], [('x', ['a', 'b']), ('y', [-1, 0])]) + ar1 = DataArray([[1, 1], [1, 1]], [('x', ['b', 'c']), ('y', [0, 1])]) + ar2 = DataArray([2], [('x', ['d'])]) + + actual = ar0.combine_first(ar1) + expected = DataArray([[0, 0, np.nan], [0, 0, 1], [np.nan, 1, 1]], + [('x', ['a', 'b', 'c']), ('y', [-1, 0, 1])]) + self.assertDataArrayEqual(actual, expected) + + actual = ar1.combine_first(ar0) + expected = DataArray([[0, 0, np.nan], [0, 1, 1], [np.nan, 1, 1]], + [('x', ['a', 'b', 'c']), ('y', [-1, 0, 1])]) + self.assertDataArrayEqual(actual, expected) + + actual = ar0.combine_first(ar2) + expected = DataArray([[0, 0], [0, 0], [2, 2]], + [('x', ['a', 'b', 'd']), ('y', [-1, 0])]) + self.assertDataArrayEqual(actual, expected) + @pytest.fixture(params=[1]) def da(request): diff --git a/xarray/test/test_dataset.py b/xarray/test/test_dataset.py index b7050ff0acb..916f13d72b2 100644 --- a/xarray/test/test_dataset.py +++ b/xarray/test/test_dataset.py @@ -3169,17 +3169,17 @@ def test_filter_by_attrs(self): def test_binary_op_join_setting(self): # arithmetic_join applies to data array coordinates - missing_2 = xr.Dataset({'x':[0, 1]}) - missing_0 = xr.Dataset({'x':[1, 2]}) + missing_2 = xr.Dataset({'x': [0, 1]}) + missing_0 = xr.Dataset({'x': [1, 2]}) with xr.set_options(arithmetic_join='outer'): actual = missing_2 + missing_0 - expected = xr.Dataset({'x':[0, 1, 2]}) + expected = xr.Dataset({'x': [0, 1, 2]}) self.assertDatasetEqual(actual, expected) # arithmetic join also applies to data_vars ds1 = xr.Dataset({'foo': 1, 'bar': 2}) ds2 = xr.Dataset({'bar': 2, 'baz': 3}) - expected = xr.Dataset({'bar': 4}) # default is inner joining + expected = xr.Dataset({'bar': 4}) # default is inner joining actual = ds1 + ds2 self.assertDatasetEqual(actual, expected) @@ -3202,7 +3202,7 @@ def test_full_like(self): # For more thorough tests, see test_variable.py # Note: testing data_vars with mismatched dtypes ds = Dataset({ - 'd1': DataArray([1,2,3], dims=['x'], coords={'x': [10,20,30]}), + 'd1': DataArray([1,2,3], dims=['x'], coords={'x': [10, 20, 30]}), 'd2': DataArray([1.1, 2.2, 3.3], dims=['y']) }, attrs={'foo': 'bar'}) actual = full_like(ds, 2) @@ -3223,6 +3223,24 @@ def test_full_like(self): self.assertEqual(expect['d2'].dtype, bool) self.assertDatasetIdentical(expect, actual) + def test_combine_first(self): + dsx0 = DataArray([0, 0], [('x', ['a', 'b'])]).to_dataset(name='dsx0') + dsx1 = DataArray([1, 1], [('x', ['b', 'c'])]).to_dataset(name='dsx1') + + actual = dsx0.combine_first(dsx1) + expected = Dataset({'dsx0': ('x', [0, 0, np.nan]), + 'dsx1': ('x', [np.nan, 1, 1])}, + coords={'x': ['a', 'b', 'c']}) + self.assertDatasetEqual(actual, expected) + self.assertDatasetEqual(actual, xr.merge([dsx0, dsx1])) + + # works just like xr.merge([self, other]) + dsy2 = DataArray([2, 2, 2], + [('x', ['b', 'c', 'd'])]).to_dataset(name='dsy2') + actual = dsx0.combine_first(dsy2) + expected = xr.merge([dsy2, dsx0]) + self.assertDatasetEqual(actual, expected) + ### Py.test tests