From 08ab9d5e0a8bf0de3c79085f96854d4e6faee596 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 13 Oct 2025 13:02:18 -0700 Subject: [PATCH 1/5] Support DataTree in xarray.concat() --- doc/whats-new.rst | 4 +- xarray/structure/concat.py | 97 ++++++++++++++++++++-- xarray/tests/test_concat.py | 160 ++++++++++++++++++++++++++++++++++++ 3 files changed, 253 insertions(+), 8 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 22bdb458d6e..3f59657c0b6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -13,8 +13,8 @@ v2025.10.2 (unreleased) New Features ~~~~~~~~~~~~ -- :py:func:`merge` now supports merging :py:class:`DataTree` objects - (:issue:`9790`). +- :py:func:`merge` and :py:func:`concat` now support :py:class:`DataTree` + objects (:issue:`9790`, :issue:`9778`). By `Stephan Hoyer `_. Breaking Changes diff --git a/xarray/structure/concat.py b/xarray/structure/concat.py index fbc90ff6a50..2cc56740d7d 100644 --- a/xarray/structure/concat.py +++ b/xarray/structure/concat.py @@ -30,6 +30,7 @@ ) if TYPE_CHECKING: + from xarray.core.datatree import DataTree from xarray.core.types import ( CombineAttrsOptions, CompatOptions, @@ -40,6 +41,21 @@ T_DataVars = Union[ConcatOptions, Iterable[Hashable], None] +@overload +def concat( + objs: Iterable[DataTree], + dim: Hashable | T_Variable | T_DataArray | pd.Index | Any, + data_vars: T_DataVars | CombineKwargDefault = _DATA_VARS_DEFAULT, + coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault = _COORDS_DEFAULT, + compat: CompatOptions | CombineKwargDefault = _COMPAT_CONCAT_DEFAULT, + positions: Iterable[Iterable[int]] | None = None, + fill_value: object = dtypes.NA, + join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT, + combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, +) -> DataTree: ... + + # TODO: replace dim: Any by 1D array_likes @overload def concat( @@ -117,9 +133,7 @@ def concat( coords : {"minimal", "different", "all"} or list of Hashable, optional These coordinate variables will be concatenated together: * "minimal": Only coordinates in which the dimension already appears - are included. If concatenating over a dimension _not_ - present in any of the objects, then all data variables will - be concatenated along that new dimension. + are included. * "different": Coordinates which are not equal (ignoring attributes) across all datasets are also concatenated (as well as all for which dimension already appears). Beware: this option may load the data @@ -265,6 +279,7 @@ def concat( # dimension already exists from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset + from xarray.core.datatree import DataTree try: first_obj, objs = utils.peek_at(objs) @@ -278,7 +293,20 @@ def concat( f"compat={compat!r} invalid: must be 'broadcast_equals', 'equals', 'identical', 'no_conflicts' or 'override'" ) - if isinstance(first_obj, DataArray): + if isinstance(first_obj, DataTree): + return _datatree_concat( + objs, + dim=dim, + data_vars=data_vars, + coords=coords, + compat=compat, + positions=positions, + fill_value=fill_value, + join=join, + combine_attrs=combine_attrs, + create_index_for_new_dim=create_index_for_new_dim, + ) + elif isinstance(first_obj, DataArray): return _dataarray_concat( objs, dim=dim, @@ -342,7 +370,7 @@ def _calc_concat_over( datasets: list[T_Dataset], dim: Hashable, all_dims: set[Hashable], - data_vars: T_DataVars | CombineKwargDefault, + data_vars: T_DataVars | Iterable[Hashable] | CombineKwargDefault, coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault, compat: CompatOptions | CombineKwargDefault, ) -> tuple[set[Hashable], dict[Hashable, bool], list[int], set[Hashable]]: @@ -583,6 +611,8 @@ def _dataset_concat( join: JoinOptions | CombineKwargDefault, combine_attrs: CombineAttrsOptions, create_index_for_new_dim: bool, + *, + preexisting_dim: bool = False, ) -> T_Dataset: """ Concatenate a sequence of datasets along a new or existing dimension @@ -618,6 +648,8 @@ def _dataset_concat( all_dims, dim_coords, dims_sizes, coord_names, data_names, vars_order = ( _parse_datasets(datasets) ) + if preexisting_dim: + all_dims.add(dim_name) indexed_dim_names = set(dim_coords) both_data_and_coords = coord_names & data_names @@ -819,7 +851,7 @@ def get_indexes(name): def _dataarray_concat( arrays: Iterable[T_DataArray], dim: str | T_Variable | T_DataArray | pd.Index, - data_vars: T_DataVars | CombineKwargDefault, + data_vars: T_DataVars | Iterable[Hashable] | CombineKwargDefault, coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault, compat: CompatOptions | CombineKwargDefault, positions: Iterable[Iterable[int]] | None, @@ -877,3 +909,56 @@ def _dataarray_concat( result.attrs = merged_attrs return result + + +def _datatree_concat( + objs: Iterable[DataTree], + dim: Hashable | Variable | T_DataArray | pd.Index | Any, + data_vars: T_DataVars | Iterable[Hashable] | CombineKwargDefault, + coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault, + compat: CompatOptions | CombineKwargDefault, + positions: Iterable[Iterable[int]] | None, + fill_value: Any, + join: JoinOptions | CombineKwargDefault, + combine_attrs: CombineAttrsOptions, + create_index_for_new_dim: bool, +) -> DataTree: + """ + Concatenate a sequence of datatrees along a new or existing dimension + """ + from xarray.core.datatree import DataTree + from xarray.core.treenode import TreeIsomorphismError, group_subtrees + + dim_name, _ = _calc_concat_dim_index(dim) + + objs = list(objs) + if not all(isinstance(obj, DataTree) for obj in objs): + raise TypeError("All objects to concatenate must be DataTree objects") + + if compat == "identical": + if any(obj.name != objs[0].name for obj in objs[1:]): + raise ValueError("DataTree names not identical") + + dim_in_tree = any(dim_name in node.dims for node in objs[0].subtree) + + results = {} + try: + for path, nodes in group_subtrees(*objs): + datasets_to_concat = [node.to_dataset() for node in nodes] + results[path] = _dataset_concat( + datasets_to_concat, + dim=dim, + data_vars=data_vars, + coords=coords, + compat=compat, + positions=positions, + fill_value=fill_value, + join=join, + combine_attrs=combine_attrs, + create_index_for_new_dim=create_index_for_new_dim, + preexisting_dim=dim_in_tree, + ) + except TreeIsomorphismError as e: + raise ValueError("All trees must be isomorphic to be concatenated") from e + + return DataTree.from_dict(results, name=objs[0].name) diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 35dad48ea16..ac200fa790e 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -20,6 +20,7 @@ ) from xarray.core import dtypes, types from xarray.core.coordinates import Coordinates +from xarray.core.datatree import DataTree from xarray.core.indexes import PandasIndex from xarray.structure import merge from xarray.tests import ( @@ -1636,3 +1637,162 @@ def test_concat_multi_dim_index() -> None: joins_lr: list[types.JoinOptions] = ["left", "right"] for join in joins_lr: actual = concat([ds1, ds2], dim="x", join=join) + + +class TestConcatDataTree: + def test_concat_datatree_along_existing_dim(self): + dt1 = DataTree.from_dict(data={"/a": ("x", [1]), "/b": 3}, coords={"/x": [0]}) + dt2 = DataTree.from_dict(data={"/a": ("x", [2]), "/b": 3}, coords={"/x": [1]}) + expected = DataTree.from_dict( + data={"/a": ("x", [1, 2]), "/b": 3}, coords={"/x": [0, 1]} + ) + actual = concat([dt1, dt2], dim="x", data_vars="minimal", coords="minimal") + assert actual.identical(expected) + + def test_concat_datatree_along_existing_dim_defaults(self): + # scalar coordinate + dt1 = DataTree.from_dict(data={"/a": ("x", [1])}, coords={"/x": [0], "/b": 3}) + dt2 = DataTree.from_dict(data={"/a": ("x", [2])}, coords={"/x": [1], "/b": 3}) + expected = DataTree.from_dict( + data={"/a": ("x", [1, 2])}, coords={"/x": [0, 1], "b": 3} + ) + actual = concat([dt1, dt2], dim="x") + assert actual.identical(expected) + + # scalar data variable + dt1 = DataTree.from_dict(data={"/a": ("x", [1]), "/b": 3}, coords={"/x": [0]}) + dt2 = DataTree.from_dict(data={"/a": ("x", [2]), "/b": 3}, coords={"/x": [1]}) + expected = DataTree.from_dict( + data={"/a": ("x", [1, 2]), "/b": ("x", [3, 3])}, coords={"/x": [0, 1]} + ) + with pytest.warns( + FutureWarning, match="will change from data_vars='all' to data_vars=None" + ): + actual = concat([dt1, dt2], dim="x") + assert actual.identical(expected) + + def test_concat_datatree_isomorphic_error(self): + dt1 = DataTree.from_dict(data={"/data": ("x", [1]), "/a": None}) + dt2 = DataTree.from_dict(data={"/data": ("x", [2]), "/b": None}) + with pytest.raises( + ValueError, match="All trees must be isomorphic to be concatenated" + ): + concat([dt1, dt2], dim="x", data_vars="minimal", coords="minimal") + + def test_concat_datatree_datavars_all(self): + dt1 = DataTree.from_dict(data={"/a": 1, "/c/b": ("y", [10])}) + dt2 = DataTree.from_dict(data={"/a": 2, "/c/b": ("y", [20])}) + dim = pd.Index([100, 200], name="x") + actual = concat([dt1, dt2], dim=dim, data_vars="all", coords="minimal") + expected = DataTree.from_dict( + data={ + "/a": (("x",), [1, 2]), + "/c/b": (("x", "y"), [[10], [20]]), + }, + coords={"/x": dim}, + ) + assert actual.identical(expected) + + def test_concat_datatree_coords_all(self): + dt1 = DataTree.from_dict(data={"/child/d": ("y", [10])}, coords={"/c": 1}) + dt2 = DataTree.from_dict(data={"/child/d": ("y", [10])}, coords={"/c": 2}) + dim = pd.Index([0, 1], name="x") + actual = concat( + [dt1, dt2], dim=dim, data_vars="minimal", coords="all", compat="equals" + ) + expected = DataTree.from_dict( + data={"/child/d": ("y", [10])}, + coords={ + "/c": (("x",), [1, 2]), + "/x": dim, + "/child/x": dim, + }, + ) + assert actual.identical(expected) + + def test_concat_datatree_datavars_different(self): + dt1 = DataTree.from_dict(data={"/a": 0, "/b": 1}) + dt2 = DataTree.from_dict(data={"/a": 0, "/b": 2}) + dim = pd.Index([0, 1], name="x") + actual = concat( + [dt1, dt2], + dim=dim, + data_vars="different", + coords="minimal", + compat="equals", + ) + expected = DataTree.from_dict( + data={"/a": 0, "/b": (("x",), [1, 2])}, coords={"/x": dim} + ) + assert actual.identical(expected) + + def test_concat_datatree_nodes(self): + dt1 = DataTree.from_dict(data={"/a/d": ("x", [1])}, coords={"/x": [0]}) + dt2 = DataTree.from_dict(data={"/a/d": ("x", [2])}, coords={"/x": [1]}) + actual = concat([dt1, dt2], dim="x", data_vars="minimal", coords="minimal") + expected = DataTree.from_dict( + data={"/a/d": ("x", [1, 2])}, coords={"/x": [0, 1]} + ) + assert actual.identical(expected) + + def test_concat_datatree_names(self): + dt1 = DataTree(Dataset({"a": ("x", [1])}), name="a") + dt2 = DataTree(Dataset({"a": ("x", [2])}), name="b") + result = concat( + [dt1, dt2], dim="x", data_vars="minimal", coords="minimal", compat="equals" + ) + assert result.name == "a" + expected = DataTree(Dataset({"a": ("x", [1, 2])}), name="a") + assert result.identical(expected) + + with pytest.raises(ValueError, match="DataTree names not identical"): + concat( + [dt1, dt2], + dim="x", + data_vars="minimal", + coords="minimal", + compat="identical", + ) + + def test_concat_along_new_dim_raises_for_minimal(self): + dt1 = DataTree.from_dict({"/a/d": 1}) + dt2 = DataTree.from_dict({"/a/d": 2}) + with pytest.raises( + ValueError, match="data_vars='minimal' and coords='minimal'" + ): + concat([dt1, dt2], dim="y", data_vars="minimal", coords="minimal") + + def test_concat_data_in_child_only(self): + dt1 = DataTree.from_dict( + data={"/child/a": ("x", [1])}, coords={"/child/x": [0]} + ) + dt2 = DataTree.from_dict( + data={"/child/a": ("x", [2])}, coords={"/child/x": [1]} + ) + actual = concat([dt1, dt2], dim="x", data_vars="minimal", coords="minimal") + expected = DataTree.from_dict( + data={"/child/a": ("x", [1, 2])}, coords={"/child/x": [0, 1]} + ) + assert actual.identical(expected) + + def test_concat_data_in_child_only_defaults(self): + dt1 = DataTree.from_dict( + data={"/child/a": ("x", [1])}, coords={"/child/x": [0]} + ) + dt2 = DataTree.from_dict( + data={"/child/a": ("x", [2])}, coords={"/child/x": [1]} + ) + actual = concat([dt1, dt2], dim="x") + expected = DataTree.from_dict( + data={"/child/a": ("x", [1, 2])}, coords={"/child/x": [0, 1]} + ) + assert actual.identical(expected) + + def test_concat_data_in_child_new_dim(self): + dt1 = DataTree.from_dict(data={"/child/a": 1}, coords={"/child/x": 0}) + dt2 = DataTree.from_dict(data={"/child/a": 2}, coords={"/child/x": 1}) + actual = concat([dt1, dt2], dim="x") + expected = DataTree.from_dict( + data={"/child/a": ("x", [1, 2])}, coords={"/child/x": [0, 1]} + ) + assert actual.identical(expected) From 91110525005ae8567d181a456d1fd4385c572887 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 13 Oct 2025 14:20:04 -0700 Subject: [PATCH 2/5] fix docstring --- xarray/structure/concat.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/structure/concat.py b/xarray/structure/concat.py index 2cc56740d7d..1450c8f11e1 100644 --- a/xarray/structure/concat.py +++ b/xarray/structure/concat.py @@ -103,7 +103,7 @@ def concat( Parameters ---------- - objs : sequence of Dataset and DataArray + objs : sequence of DataArray, Dataset or DataTree xarray objects to concatenate together. Each object is expected to consist of variables and coordinates with matching shapes except for along the concatenated dimension. @@ -194,7 +194,8 @@ def concat( If a callable, it must expect a sequence of ``attrs`` dicts and a context object as its only parameters. create_index_for_new_dim : bool, default: True - Whether to create a new ``PandasIndex`` object when the objects being concatenated contain scalar variables named ``dim``. + Whether to create a new ``PandasIndex`` object when the objects being + concatenated contain scalar variables named ``dim``. Returns ------- From 74f371717e8b58ec45f1490e68509733a56df9cc Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 13 Oct 2025 14:50:13 -0700 Subject: [PATCH 3/5] fix mypy errors --- xarray/structure/concat.py | 4 ++-- xarray/tests/test_concat.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/structure/concat.py b/xarray/structure/concat.py index 1450c8f11e1..f127290e30d 100644 --- a/xarray/structure/concat.py +++ b/xarray/structure/concat.py @@ -603,7 +603,7 @@ def _parse_datasets( def _dataset_concat( datasets: Iterable[T_Dataset], - dim: str | T_Variable | T_DataArray | pd.Index, + dim: Hashable | T_Variable | T_DataArray | pd.Index, data_vars: T_DataVars | CombineKwargDefault, coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault, compat: CompatOptions | CombineKwargDefault, @@ -851,7 +851,7 @@ def get_indexes(name): def _dataarray_concat( arrays: Iterable[T_DataArray], - dim: str | T_Variable | T_DataArray | pd.Index, + dim: Hashable | T_Variable | T_DataArray | pd.Index, data_vars: T_DataVars | Iterable[Hashable] | CombineKwargDefault, coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault, compat: CompatOptions | CombineKwargDefault, diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index ac200fa790e..7bbcab247da 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -1460,12 +1460,12 @@ def test_concat_typing_check() -> None: TypeError, match="The elements in the input list need to be either all 'Dataset's or all 'DataArray's", ): - concat([ds, da], dim="foo") # type: ignore[type-var] + concat([ds, da], dim="foo") # type: ignore[list-item] with pytest.raises( TypeError, match="The elements in the input list need to be either all 'Dataset's or all 'DataArray's", ): - concat([da, ds], dim="foo") # type: ignore[type-var] + concat([da, ds], dim="foo") # type: ignore[list-item] def test_concat_not_all_indexes() -> None: From f934cf4c439bfec78a8e20ae1db406874c77bef8 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 13 Oct 2025 15:33:05 -0700 Subject: [PATCH 4/5] Add comment to explain preexisting_dim --- xarray/structure/concat.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/structure/concat.py b/xarray/structure/concat.py index f127290e30d..bb2d96c10be 100644 --- a/xarray/structure/concat.py +++ b/xarray/structure/concat.py @@ -650,6 +650,9 @@ def _dataset_concat( _parse_datasets(datasets) ) if preexisting_dim: + # When concatenating DataTree objects, a dimension may be pre-existing + # because it exists elsewhere on the trees, even if it does not exist + # on the dataset objects at this node. all_dims.add(dim_name) indexed_dim_names = set(dim_coords) From dc5d8273514d7b4135532ad4708946c585ceaf17 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 13 Oct 2025 15:36:25 -0700 Subject: [PATCH 5/5] Add another unit-test --- xarray/tests/test_concat.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 7bbcab247da..8c61d7bced2 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -1796,3 +1796,10 @@ def test_concat_data_in_child_new_dim(self): data={"/child/a": ("x", [1, 2])}, coords={"/child/x": [0, 1]} ) assert actual.identical(expected) + + def test_concat_different_dims_in_different_child(self): + dt1 = DataTree.from_dict(coords={"/first/x": [1], "/second/x": [2]}) + dt2 = DataTree.from_dict(coords={"/first/x": [3], "/second/x": [4]}) + actual = concat([dt1, dt2], dim="x") + expected = DataTree.from_dict(coords={"/first/x": [1, 3], "/second/x": [2, 4]}) + assert actual.identical(expected)