Skip to content

Commit fa01fad

Browse files
dcherianbenbovy
andauthored
Update concat for multi-variable indexes. (#10371)
Co-authored-by: Benoit Bovy <benbovy@gmail.com>
1 parent 68e6d35 commit fa01fad

File tree

4 files changed

+140
-71
lines changed

4 files changed

+140
-71
lines changed

xarray/structure/concat.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,15 @@ def _calc_concat_over(datasets, dim, dim_names, data_vars: T_DataVars, coords, c
324324
"""
325325
Determine which dataset variables need to be concatenated in the result,
326326
"""
327-
# Return values
327+
# variables to be concatenated
328328
concat_over = set()
329+
# variables checked for equality
329330
equals = {}
331+
# skip merging these variables.
332+
# if concatenating over a dimension 'x' that is associated with an index over 2 variables,
333+
# 'x' and 'y', then we assert join="equals" on `y` and don't need to merge it.
334+
# that assertion happens in the align step prior to this function being called
335+
skip_merge = set()
330336

331337
if dim in dim_names:
332338
concat_over_existing_dim = True
@@ -339,6 +345,9 @@ def _calc_concat_over(datasets, dim, dim_names, data_vars: T_DataVars, coords, c
339345
if concat_over_existing_dim and dim not in ds.dims and dim in ds:
340346
ds = ds.set_coords(dim)
341347
concat_over.update(k for k, v in ds.variables.items() if dim in v.dims)
348+
for _, idx_vars in ds.xindexes.group_by_index():
349+
if any(dim in v.dims for v in idx_vars.values()):
350+
skip_merge.update(idx_vars.keys())
342351
concat_dim_lengths.append(ds.sizes.get(dim, 1))
343352

344353
def process_subset_opt(opt, subset):
@@ -436,7 +445,7 @@ def process_subset_opt(opt, subset):
436445

437446
process_subset_opt(data_vars, "data_vars")
438447
process_subset_opt(coords, "coords")
439-
return concat_over, equals, concat_dim_lengths
448+
return concat_over, equals, concat_dim_lengths, skip_merge
440449

441450

442451
# determine dimensional coordinate names and a dict mapping name to DataArray
@@ -540,12 +549,12 @@ def _dataset_concat(
540549
]
541550

542551
# determine which variables to concatenate
543-
concat_over, equals, concat_dim_lengths = _calc_concat_over(
552+
concat_over, equals, concat_dim_lengths, skip_merge = _calc_concat_over(
544553
datasets, dim_name, dim_names, data_vars, coords, compat
545554
)
546555

547556
# determine which variables to merge, and then merge them according to compat
548-
variables_to_merge = (coord_names | data_names) - concat_over
557+
variables_to_merge = (coord_names | data_names) - concat_over - skip_merge
549558

550559
result_vars = {}
551560
result_indexes = {}

xarray/tests/indexes.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from collections.abc import Hashable, Iterable, Mapping, Sequence
2+
from typing import Any
3+
4+
import numpy as np
5+
6+
from xarray import Variable
7+
from xarray.core.indexes import Index, PandasIndex
8+
from xarray.core.types import Self
9+
10+
11+
class ScalarIndex(Index):
12+
def __init__(self, value: int):
13+
self.value = value
14+
15+
@classmethod
16+
def from_variables(cls, variables, *, options) -> Self:
17+
var = next(iter(variables.values()))
18+
return cls(int(var.values))
19+
20+
def equals(self, other, *, exclude=None) -> bool:
21+
return isinstance(other, ScalarIndex) and other.value == self.value
22+
23+
24+
class XYIndex(Index):
25+
def __init__(self, x: PandasIndex, y: PandasIndex):
26+
self.x: PandasIndex = x
27+
self.y: PandasIndex = y
28+
29+
@classmethod
30+
def from_variables(cls, variables, *, options):
31+
return cls(
32+
x=PandasIndex.from_variables({"x": variables["x"]}, options=options),
33+
y=PandasIndex.from_variables({"y": variables["y"]}, options=options),
34+
)
35+
36+
def create_variables(
37+
self, variables: Mapping[Any, Variable] | None = None
38+
) -> dict[Any, Variable]:
39+
return self.x.create_variables() | self.y.create_variables()
40+
41+
def equals(self, other, exclude=None):
42+
if exclude is None:
43+
exclude = frozenset()
44+
x_eq = True if self.x.dim in exclude else self.x.equals(other.x)
45+
y_eq = True if self.y.dim in exclude else self.y.equals(other.y)
46+
return x_eq and y_eq
47+
48+
@classmethod
49+
def concat(
50+
cls,
51+
indexes: Sequence[Self],
52+
dim: Hashable,
53+
positions: Iterable[Iterable[int]] | None = None,
54+
) -> Self:
55+
first = next(iter(indexes))
56+
if dim == "x":
57+
newx = PandasIndex.concat(
58+
tuple(i.x for i in indexes), dim=dim, positions=positions
59+
)
60+
newy = first.y
61+
elif dim == "y":
62+
newx = first.x
63+
newy = PandasIndex.concat(
64+
tuple(i.y for i in indexes), dim=dim, positions=positions
65+
)
66+
return cls(x=newx, y=newy)
67+
68+
def isel(self, indexers: Mapping[Any, int | slice | np.ndarray | Variable]) -> Self:
69+
newx = self.x.isel({"x": indexers.get("x", slice(None))})
70+
newy = self.y.isel({"y": indexers.get("y", slice(None))})
71+
assert newx is not None
72+
assert newy is not None
73+
return type(self)(newx, newy)

xarray/tests/test_concat.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import pandas as pd
99
import pytest
1010

11-
from xarray import DataArray, Dataset, Variable, concat
12-
from xarray.core import dtypes
11+
from xarray import AlignmentError, DataArray, Dataset, Variable, concat
12+
from xarray.core import dtypes, types
1313
from xarray.core.coordinates import Coordinates
1414
from xarray.core.indexes import PandasIndex
1515
from xarray.structure import merge
@@ -23,6 +23,7 @@
2323
requires_dask,
2424
requires_pyarrow,
2525
)
26+
from xarray.tests.indexes import XYIndex
2627
from xarray.tests.test_dataset import create_test_data
2728

2829
if TYPE_CHECKING:
@@ -1381,3 +1382,50 @@ def test_concat_index_not_same_dim() -> None:
13811382
match=r"Cannot concatenate along dimension 'x' indexes with dimensions.*",
13821383
):
13831384
concat([ds1, ds2], dim="x")
1385+
1386+
1387+
def test_concat_multi_dim_index() -> None:
1388+
ds1 = (
1389+
Dataset(
1390+
{"foo": (("x", "y"), np.random.randn(2, 2))},
1391+
coords={"x": [1, 2], "y": [3, 4]},
1392+
)
1393+
.drop_indexes(["x", "y"])
1394+
.set_xindex(["x", "y"], XYIndex)
1395+
)
1396+
ds2 = (
1397+
Dataset(
1398+
{"foo": (("x", "y"), np.random.randn(2, 2))},
1399+
coords={"x": [1, 2], "y": [5, 6]},
1400+
)
1401+
.drop_indexes(["x", "y"])
1402+
.set_xindex(["x", "y"], XYIndex)
1403+
)
1404+
1405+
expected = (
1406+
Dataset(
1407+
{
1408+
"foo": (
1409+
("x", "y"),
1410+
np.concatenate([ds1.foo.data, ds2.foo.data], axis=-1),
1411+
)
1412+
},
1413+
coords={"x": [1, 2], "y": [3, 4, 5, 6]},
1414+
)
1415+
.drop_indexes(["x", "y"])
1416+
.set_xindex(["x", "y"], XYIndex)
1417+
)
1418+
# note: missing 'override'
1419+
joins: list[types.JoinOptions] = ["inner", "outer", "exact", "left", "right"]
1420+
for join in joins:
1421+
actual = concat([ds1, ds2], dim="y", join=join)
1422+
assert_identical(actual, expected, check_default_indexes=False)
1423+
1424+
with pytest.raises(AlignmentError):
1425+
actual = concat([ds1, ds2], dim="x", join="exact")
1426+
1427+
# TODO: fix these, or raise better error message
1428+
with pytest.raises(AssertionError):
1429+
joins_lr: list[types.JoinOptions] = ["left", "right"]
1430+
for join in joins_lr:
1431+
actual = concat([ds1, ds2], dim="x", join=join)

xarray/tests/test_dataset.py

Lines changed: 4 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
import contextlib
2525

26+
from pandas.errors import UndefinedVariableError
27+
2628
import xarray as xr
2729
from xarray import (
2830
AlignmentError,
@@ -72,13 +74,7 @@
7274
requires_sparse,
7375
source_ndarray,
7476
)
75-
76-
try:
77-
from pandas.errors import UndefinedVariableError
78-
except ImportError:
79-
# TODO: remove once we stop supporting pandas<1.4.3
80-
from pandas.core.computation.ops import UndefinedVariableError
81-
77+
from xarray.tests.indexes import ScalarIndex, XYIndex
8278

8379
with contextlib.suppress(ImportError):
8480
import dask.array as da
@@ -1602,32 +1598,8 @@ def test_isel_multicoord_index(self) -> None:
16021598
# regression test https://github.com/pydata/xarray/issues/10063
16031599
# isel on a multi-coordinate index should return a unique index associated
16041600
# to each coordinate
1605-
class MultiCoordIndex(xr.Index):
1606-
def __init__(self, idx1, idx2):
1607-
self.idx1 = idx1
1608-
self.idx2 = idx2
1609-
1610-
@classmethod
1611-
def from_variables(cls, variables, *, options=None):
1612-
idx1 = PandasIndex.from_variables(
1613-
{"x": variables["x"]}, options=options
1614-
)
1615-
idx2 = PandasIndex.from_variables(
1616-
{"y": variables["y"]}, options=options
1617-
)
1618-
1619-
return cls(idx1, idx2)
1620-
1621-
def create_variables(self, variables=None):
1622-
return {**self.idx1.create_variables(), **self.idx2.create_variables()}
1623-
1624-
def isel(self, indexers):
1625-
idx1 = self.idx1.isel({"x": indexers.get("x", slice(None))})
1626-
idx2 = self.idx2.isel({"y": indexers.get("y", slice(None))})
1627-
return MultiCoordIndex(idx1, idx2)
1628-
16291601
coords = xr.Coordinates(coords={"x": [0, 1], "y": [1, 2]}, indexes={})
1630-
ds = xr.Dataset(coords=coords).set_xindex(["x", "y"], MultiCoordIndex)
1602+
ds = xr.Dataset(coords=coords).set_xindex(["x", "y"], XYIndex)
16311603

16321604
ds2 = ds.isel(x=slice(None), y=slice(None))
16331605
assert ds2.xindexes["x"] is ds2.xindexes["y"]
@@ -2642,18 +2614,6 @@ def test_align_index_var_attrs(self, join) -> None:
26422614
def test_align_scalar_index(self) -> None:
26432615
# ensure that indexes associated with scalar coordinates are not ignored
26442616
# during alignment
2645-
class ScalarIndex(Index):
2646-
def __init__(self, value: int):
2647-
self.value = value
2648-
2649-
@classmethod
2650-
def from_variables(cls, variables, *, options):
2651-
var = next(iter(variables.values()))
2652-
return cls(int(var.values))
2653-
2654-
def equals(self, other, *, exclude=None):
2655-
return isinstance(other, ScalarIndex) and other.value == self.value
2656-
26572617
ds1 = Dataset(coords={"x": 0}).set_xindex("x", ScalarIndex)
26582618
ds2 = Dataset(coords={"x": 0}).set_xindex("x", ScalarIndex)
26592619

@@ -2667,27 +2627,6 @@ def equals(self, other, *, exclude=None):
26672627
xr.align(ds1, ds3, join="exact")
26682628

26692629
def test_align_multi_dim_index_exclude_dims(self) -> None:
2670-
class XYIndex(Index):
2671-
def __init__(self, x: PandasIndex, y: PandasIndex):
2672-
self.x: PandasIndex = x
2673-
self.y: PandasIndex = y
2674-
2675-
@classmethod
2676-
def from_variables(cls, variables, *, options):
2677-
return cls(
2678-
x=PandasIndex.from_variables(
2679-
{"x": variables["x"]}, options=options
2680-
),
2681-
y=PandasIndex.from_variables(
2682-
{"y": variables["y"]}, options=options
2683-
),
2684-
)
2685-
2686-
def equals(self, other, exclude=None):
2687-
x_eq = True if self.x.dim in exclude else self.x.equals(other.x)
2688-
y_eq = True if self.y.dim in exclude else self.y.equals(other.y)
2689-
return x_eq and y_eq
2690-
26912630
ds1 = (
26922631
Dataset(coords={"x": [1, 2], "y": [3, 4]})
26932632
.drop_indexes(["x", "y"])

0 commit comments

Comments
 (0)