Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ v2025.10.2 (unreleased)
New Features
~~~~~~~~~~~~

- :py:func:`merge` now supports merging :py:class:`DataTree` objects
(:issue:`9790`).
- :py:func:`merge` and :py:func:`concat` now support :py:class:`DataTree`
objects (:issue:`9790`, :issue:`9778`).
By `Stephan Hoyer <https://github.com/shoyer>`_.

Breaking Changes
Expand Down
109 changes: 99 additions & 10 deletions xarray/structure/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)

if TYPE_CHECKING:
from xarray.core.datatree import DataTree
from xarray.core.types import (
CombineAttrsOptions,
CompatOptions,
Expand All @@ -40,6 +41,21 @@
T_DataVars = Union[ConcatOptions, Iterable[Hashable], None]


@overload
def concat(
objs: Iterable[DataTree],
dim: Hashable | T_Variable | T_DataArray | pd.Index | Any,
data_vars: T_DataVars | CombineKwargDefault = _DATA_VARS_DEFAULT,
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault = _COORDS_DEFAULT,
compat: CompatOptions | CombineKwargDefault = _COMPAT_CONCAT_DEFAULT,
positions: Iterable[Iterable[int]] | None = None,
fill_value: object = dtypes.NA,
join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT,
combine_attrs: CombineAttrsOptions = "override",
create_index_for_new_dim: bool = True,
) -> DataTree: ...


# TODO: replace dim: Any by 1D array_likes
@overload
def concat(
Expand Down Expand Up @@ -87,7 +103,7 @@ def concat(

Parameters
----------
objs : sequence of Dataset and DataArray
objs : sequence of DataArray, Dataset or DataTree
xarray objects to concatenate together. Each object is expected to
consist of variables and coordinates with matching shapes except for
along the concatenated dimension.
Expand Down Expand Up @@ -117,9 +133,7 @@ def concat(
coords : {"minimal", "different", "all"} or list of Hashable, optional
These coordinate variables will be concatenated together:
* "minimal": Only coordinates in which the dimension already appears
are included. If concatenating over a dimension _not_
present in any of the objects, then all data variables will
be concatenated along that new dimension.
are included.
* "different": Coordinates which are not equal (ignoring attributes)
across all datasets are also concatenated (as well as all for which
dimension already appears). Beware: this option may load the data
Expand Down Expand Up @@ -180,7 +194,8 @@ def concat(
If a callable, it must expect a sequence of ``attrs`` dicts and a context object
as its only parameters.
create_index_for_new_dim : bool, default: True
Whether to create a new ``PandasIndex`` object when the objects being concatenated contain scalar variables named ``dim``.
Whether to create a new ``PandasIndex`` object when the objects being
concatenated contain scalar variables named ``dim``.

Returns
-------
Expand Down Expand Up @@ -265,6 +280,7 @@ def concat(
# dimension already exists
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree

try:
first_obj, objs = utils.peek_at(objs)
Expand All @@ -278,7 +294,20 @@ def concat(
f"compat={compat!r} invalid: must be 'broadcast_equals', 'equals', 'identical', 'no_conflicts' or 'override'"
)

if isinstance(first_obj, DataArray):
if isinstance(first_obj, DataTree):
return _datatree_concat(
objs,
dim=dim,
data_vars=data_vars,
coords=coords,
compat=compat,
positions=positions,
fill_value=fill_value,
join=join,
combine_attrs=combine_attrs,
create_index_for_new_dim=create_index_for_new_dim,
)
elif isinstance(first_obj, DataArray):
return _dataarray_concat(
objs,
dim=dim,
Expand Down Expand Up @@ -342,7 +371,7 @@ def _calc_concat_over(
datasets: list[T_Dataset],
dim: Hashable,
all_dims: set[Hashable],
data_vars: T_DataVars | CombineKwargDefault,
data_vars: T_DataVars | Iterable[Hashable] | CombineKwargDefault,
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault,
compat: CompatOptions | CombineKwargDefault,
) -> tuple[set[Hashable], dict[Hashable, bool], list[int], set[Hashable]]:
Expand Down Expand Up @@ -574,7 +603,7 @@ def _parse_datasets(

def _dataset_concat(
datasets: Iterable[T_Dataset],
dim: str | T_Variable | T_DataArray | pd.Index,
dim: Hashable | T_Variable | T_DataArray | pd.Index,
data_vars: T_DataVars | CombineKwargDefault,
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault,
compat: CompatOptions | CombineKwargDefault,
Expand All @@ -583,6 +612,8 @@ def _dataset_concat(
join: JoinOptions | CombineKwargDefault,
combine_attrs: CombineAttrsOptions,
create_index_for_new_dim: bool,
*,
preexisting_dim: bool = False,
) -> T_Dataset:
"""
Concatenate a sequence of datasets along a new or existing dimension
Expand Down Expand Up @@ -618,6 +649,11 @@ def _dataset_concat(
all_dims, dim_coords, dims_sizes, coord_names, data_names, vars_order = (
_parse_datasets(datasets)
)
if preexisting_dim:
# When concatenating DataTree objects, a dimension may be pre-existing
# because it exists elsewhere on the trees, even if it does not exist
# on the dataset objects at this node.
all_dims.add(dim_name)
indexed_dim_names = set(dim_coords)

both_data_and_coords = coord_names & data_names
Expand Down Expand Up @@ -818,8 +854,8 @@ def get_indexes(name):

def _dataarray_concat(
arrays: Iterable[T_DataArray],
dim: str | T_Variable | T_DataArray | pd.Index,
data_vars: T_DataVars | CombineKwargDefault,
dim: Hashable | T_Variable | T_DataArray | pd.Index,
data_vars: T_DataVars | Iterable[Hashable] | CombineKwargDefault,
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault,
compat: CompatOptions | CombineKwargDefault,
positions: Iterable[Iterable[int]] | None,
Expand Down Expand Up @@ -877,3 +913,56 @@ def _dataarray_concat(
result.attrs = merged_attrs

return result


def _datatree_concat(
objs: Iterable[DataTree],
dim: Hashable | Variable | T_DataArray | pd.Index | Any,
data_vars: T_DataVars | Iterable[Hashable] | CombineKwargDefault,
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault,
compat: CompatOptions | CombineKwargDefault,
positions: Iterable[Iterable[int]] | None,
fill_value: Any,
join: JoinOptions | CombineKwargDefault,
combine_attrs: CombineAttrsOptions,
create_index_for_new_dim: bool,
) -> DataTree:
"""
Concatenate a sequence of datatrees along a new or existing dimension
"""
from xarray.core.datatree import DataTree
from xarray.core.treenode import TreeIsomorphismError, group_subtrees

dim_name, _ = _calc_concat_dim_index(dim)

objs = list(objs)
if not all(isinstance(obj, DataTree) for obj in objs):
raise TypeError("All objects to concatenate must be DataTree objects")

if compat == "identical":
if any(obj.name != objs[0].name for obj in objs[1:]):
raise ValueError("DataTree names not identical")

dim_in_tree = any(dim_name in node.dims for node in objs[0].subtree)

results = {}
try:
for path, nodes in group_subtrees(*objs):
datasets_to_concat = [node.to_dataset() for node in nodes]
results[path] = _dataset_concat(
datasets_to_concat,
dim=dim,
data_vars=data_vars,
coords=coords,
compat=compat,
positions=positions,
fill_value=fill_value,
join=join,
combine_attrs=combine_attrs,
create_index_for_new_dim=create_index_for_new_dim,
preexisting_dim=dim_in_tree,
)
except TreeIsomorphismError as e:
raise ValueError("All trees must be isomorphic to be concatenated") from e

return DataTree.from_dict(results, name=objs[0].name)
Loading
Loading