Skip to content

Update concat for multi-variable indexes. #10371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 5, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions xarray/structure/concat.py
Original file line number Diff line number Diff line change
@@ -324,9 +324,15 @@ def _calc_concat_over(datasets, dim, dim_names, data_vars: T_DataVars, coords, c
"""
Determine which dataset variables need to be concatenated in the result,
"""
# Return values
# variables to be concatenated
concat_over = set()
# variables checked for equality
equals = {}
# skip merging these variables.
# if concatenating over a dimension 'x' that is associated with an index over 2 variables,
# 'x' and 'y', then we assert join="equals" on `y` and don't need to merge it.
# that assertion happens in the align step prior to this function being called
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering how this would behave for the corner case where the "y" coordinate (and/or the "x" coordinate) does not have any index in one or more of the objects to concatenate. We could leave it for now, though. It is quite unlikely that it will occur in practice I'd say.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean, we defer to the Index, so it can do what's sensible?

skip_merge = set()

if dim in dim_names:
concat_over_existing_dim = True
@@ -339,6 +345,9 @@ def _calc_concat_over(datasets, dim, dim_names, data_vars: T_DataVars, coords, c
if concat_over_existing_dim and dim not in ds.dims and dim in ds:
ds = ds.set_coords(dim)
concat_over.update(k for k, v in ds.variables.items() if dim in v.dims)
for _, idx_vars in ds.xindexes.group_by_index():
if any(dim in v.dims for v in idx_vars.values()):
skip_merge.update(idx_vars.keys())
concat_dim_lengths.append(ds.sizes.get(dim, 1))

def process_subset_opt(opt, subset):
@@ -436,7 +445,7 @@ def process_subset_opt(opt, subset):

process_subset_opt(data_vars, "data_vars")
process_subset_opt(coords, "coords")
return concat_over, equals, concat_dim_lengths
return concat_over, equals, concat_dim_lengths, skip_merge


# determine dimensional coordinate names and a dict mapping name to DataArray
@@ -540,12 +549,12 @@ def _dataset_concat(
]

# determine which variables to concatenate
concat_over, equals, concat_dim_lengths = _calc_concat_over(
concat_over, equals, concat_dim_lengths, skip_merge = _calc_concat_over(
datasets, dim_name, dim_names, data_vars, coords, compat
)

# determine which variables to merge, and then merge them according to compat
variables_to_merge = (coord_names | data_names) - concat_over
variables_to_merge = (coord_names | data_names) - concat_over - skip_merge

result_vars = {}
result_indexes = {}
73 changes: 73 additions & 0 deletions xarray/tests/indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from collections.abc import Hashable, Iterable, Mapping, Sequence
from typing import Any

import numpy as np

from xarray import Variable
from xarray.core.indexes import Index, PandasIndex
from xarray.core.types import Self


class ScalarIndex(Index):
def __init__(self, value: int):
self.value = value

@classmethod
def from_variables(cls, variables, *, options) -> Self:
var = next(iter(variables.values()))
return cls(int(var.values))

def equals(self, other, *, exclude=None) -> bool:
return isinstance(other, ScalarIndex) and other.value == self.value


class XYIndex(Index):
def __init__(self, x: PandasIndex, y: PandasIndex):
self.x: PandasIndex = x
self.y: PandasIndex = y

@classmethod
def from_variables(cls, variables, *, options):
return cls(
x=PandasIndex.from_variables({"x": variables["x"]}, options=options),
y=PandasIndex.from_variables({"y": variables["y"]}, options=options),
)

def create_variables(
self, variables: Mapping[Any, Variable] | None = None
) -> dict[Any, Variable]:
return self.x.create_variables() | self.y.create_variables()

def equals(self, other, exclude=None):
if exclude is None:
exclude = frozenset()
x_eq = True if self.x.dim in exclude else self.x.equals(other.x)
y_eq = True if self.y.dim in exclude else self.y.equals(other.y)
return x_eq and y_eq

@classmethod
def concat(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this

cls,
indexes: Sequence[Self],
dim: Hashable,
positions: Iterable[Iterable[int]] | None = None,
) -> Self:
first = next(iter(indexes))
if dim == "x":
newx = PandasIndex.concat(
tuple(i.x for i in indexes), dim=dim, positions=positions
)
newy = first.y
elif dim == "y":
newx = first.x
newy = PandasIndex.concat(
tuple(i.y for i in indexes), dim=dim, positions=positions
)
return cls(x=newx, y=newy)

def isel(self, indexers: Mapping[Any, int | slice | np.ndarray | Variable]) -> Self:
newx = self.x.isel({"x": indexers.get("x", slice(None))})
newy = self.y.isel({"y": indexers.get("y", slice(None))})
assert newx is not None
assert newy is not None
return type(self)(newx, newy)
52 changes: 50 additions & 2 deletions xarray/tests/test_concat.py
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
import pandas as pd
import pytest

from xarray import DataArray, Dataset, Variable, concat
from xarray.core import dtypes
from xarray import AlignmentError, DataArray, Dataset, Variable, concat
from xarray.core import dtypes, types
from xarray.core.coordinates import Coordinates
from xarray.core.indexes import PandasIndex
from xarray.structure import merge
@@ -23,6 +23,7 @@
requires_dask,
requires_pyarrow,
)
from xarray.tests.indexes import XYIndex
from xarray.tests.test_dataset import create_test_data

if TYPE_CHECKING:
@@ -1381,3 +1382,50 @@ def test_concat_index_not_same_dim() -> None:
match=r"Cannot concatenate along dimension 'x' indexes with dimensions.*",
):
concat([ds1, ds2], dim="x")


def test_concat_multi_dim_index() -> None:
ds1 = (
Dataset(
{"foo": (("x", "y"), np.random.randn(2, 2))},
coords={"x": [1, 2], "y": [3, 4]},
)
.drop_indexes(["x", "y"])
.set_xindex(["x", "y"], XYIndex)
)
ds2 = (
Dataset(
{"foo": (("x", "y"), np.random.randn(2, 2))},
coords={"x": [1, 2], "y": [5, 6]},
)
.drop_indexes(["x", "y"])
.set_xindex(["x", "y"], XYIndex)
)

expected = (
Dataset(
{
"foo": (
("x", "y"),
np.concatenate([ds1.foo.data, ds2.foo.data], axis=-1),
)
},
coords={"x": [1, 2], "y": [3, 4, 5, 6]},
)
.drop_indexes(["x", "y"])
.set_xindex(["x", "y"], XYIndex)
)
# note: missing 'override'
joins: list[types.JoinOptions] = ["inner", "outer", "exact", "left", "right"]
for join in joins:
actual = concat([ds1, ds2], dim="y", join=join)
assert_identical(actual, expected, check_default_indexes=False)

with pytest.raises(AlignmentError):
actual = concat([ds1, ds2], dim="x", join="exact")

# TODO: fix these, or raise better error message
with pytest.raises(AssertionError):
joins_lr: list[types.JoinOptions] = ["left", "right"]
for join in joins_lr:
actual = concat([ds1, ds2], dim="x", join=join)
69 changes: 4 additions & 65 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,8 @@

import contextlib

from pandas.errors import UndefinedVariableError

import xarray as xr
from xarray import (
AlignmentError,
@@ -72,13 +74,7 @@
requires_sparse,
source_ndarray,
)

try:
from pandas.errors import UndefinedVariableError
except ImportError:
# TODO: remove once we stop supporting pandas<1.4.3
from pandas.core.computation.ops import UndefinedVariableError

from xarray.tests.indexes import ScalarIndex, XYIndex

with contextlib.suppress(ImportError):
import dask.array as da
@@ -1602,32 +1598,8 @@ def test_isel_multicoord_index(self) -> None:
# regression test https://github.com/pydata/xarray/issues/10063
# isel on a multi-coordinate index should return a unique index associated
# to each coordinate
class MultiCoordIndex(xr.Index):
def __init__(self, idx1, idx2):
self.idx1 = idx1
self.idx2 = idx2

@classmethod
def from_variables(cls, variables, *, options=None):
idx1 = PandasIndex.from_variables(
{"x": variables["x"]}, options=options
)
idx2 = PandasIndex.from_variables(
{"y": variables["y"]}, options=options
)

return cls(idx1, idx2)

def create_variables(self, variables=None):
return {**self.idx1.create_variables(), **self.idx2.create_variables()}

def isel(self, indexers):
idx1 = self.idx1.isel({"x": indexers.get("x", slice(None))})
idx2 = self.idx2.isel({"y": indexers.get("y", slice(None))})
return MultiCoordIndex(idx1, idx2)

coords = xr.Coordinates(coords={"x": [0, 1], "y": [1, 2]}, indexes={})
ds = xr.Dataset(coords=coords).set_xindex(["x", "y"], MultiCoordIndex)
ds = xr.Dataset(coords=coords).set_xindex(["x", "y"], XYIndex)

ds2 = ds.isel(x=slice(None), y=slice(None))
assert ds2.xindexes["x"] is ds2.xindexes["y"]
@@ -2642,18 +2614,6 @@ def test_align_index_var_attrs(self, join) -> None:
def test_align_scalar_index(self) -> None:
# ensure that indexes associated with scalar coordinates are not ignored
# during alignment
class ScalarIndex(Index):
def __init__(self, value: int):
self.value = value

@classmethod
def from_variables(cls, variables, *, options):
var = next(iter(variables.values()))
return cls(int(var.values))

def equals(self, other, *, exclude=None):
return isinstance(other, ScalarIndex) and other.value == self.value

ds1 = Dataset(coords={"x": 0}).set_xindex("x", ScalarIndex)
ds2 = Dataset(coords={"x": 0}).set_xindex("x", ScalarIndex)

@@ -2667,27 +2627,6 @@ def equals(self, other, *, exclude=None):
xr.align(ds1, ds3, join="exact")

def test_align_multi_dim_index_exclude_dims(self) -> None:
class XYIndex(Index):
def __init__(self, x: PandasIndex, y: PandasIndex):
self.x: PandasIndex = x
self.y: PandasIndex = y

@classmethod
def from_variables(cls, variables, *, options):
return cls(
x=PandasIndex.from_variables(
{"x": variables["x"]}, options=options
),
y=PandasIndex.from_variables(
{"y": variables["y"]}, options=options
),
)

def equals(self, other, exclude=None):
x_eq = True if self.x.dim in exclude else self.x.equals(other.x)
y_eq = True if self.y.dim in exclude else self.y.equals(other.y)
return x_eq and y_eq

ds1 = (
Dataset(coords={"x": [1, 2], "y": [3, 4]})
.drop_indexes(["x", "y"])
Loading
Oops, something went wrong.