Skip to content

Commit

Permalink
Merge branch 'main' into extension_arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Apr 18, 2024
2 parents c906c81 + 60f3e74 commit e6db83b
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 39 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ Bug fixes

Internal Changes
~~~~~~~~~~~~~~~~
- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`)
By `Matt Savoie <https://github.com/flamingbear>`_ `Owen Littlejohns
<https://github.com/owenlittlejohns>` and `Tom Nicholas <https://github.com/TomNicholas>`_.


.. _whats-new.2024.03.0:
Expand Down
13 changes: 7 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"]
complete = ["xarray[accel,io,parallel,viz,dev]"]
dev = [
"hypothesis",
"mypy",
"pre-commit",
"pytest",
"pytest-cov",
Expand Down Expand Up @@ -86,8 +87,8 @@ exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"]
[tool.mypy]
enable_error_code = "redundant-self"
exclude = [
'xarray/util/generate_.*\.py',
'xarray/datatree_/.*\.py',
'xarray/util/generate_.*\.py',
'xarray/datatree_/.*\.py',
]
files = "xarray"
show_error_codes = true
Expand All @@ -98,8 +99,8 @@ warn_unused_ignores = true

# Ignore mypy errors for modules imported from datatree_.
[[tool.mypy.overrides]]
module = "xarray.datatree_.*"
ignore_errors = true
module = "xarray.datatree_.*"

# Much of the numerical computing stack doesn't have type annotations yet.
[[tool.mypy.overrides]]
Expand Down Expand Up @@ -256,6 +257,9 @@ target-version = "py39"
# E402: module level import not at top of file
# E501: line too long - let black worry about that
# E731: do not assign a lambda expression, use a def
extend-safe-fixes = [
"TID252", # absolute imports
]
ignore = [
"E402",
"E501",
Expand All @@ -269,9 +273,6 @@ select = [
"I", # isort
"UP", # Pyupgrade
]
extend-safe-fixes = [
"TID252", # absolute imports
]

[tool.ruff.lint.per-file-ignores]
# don't enforce absolute imports
Expand Down
10 changes: 5 additions & 5 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from xarray.core.coordinates import DatasetCoordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, DataVariables
from xarray.core.datatree_mapping import (
TreeIsomorphismError,
check_isomorphic,
map_over_subtree,
)
from xarray.core.indexes import Index, Indexes
from xarray.core.merge import dataset_update_method
from xarray.core.options import OPTIONS as XR_OPTS
Expand All @@ -36,11 +41,6 @@
from xarray.datatree_.datatree.formatting_html import (
datatree_repr as datatree_repr_html,
)
from xarray.datatree_.datatree.mapping import (
TreeIsomorphismError,
check_isomorphic,
map_over_subtree,
)
from xarray.datatree_.datatree.ops import (
DataTreeArithmeticMixin,
MappedDatasetMethodsMixin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import sys
from itertools import repeat
from textwrap import dedent
from typing import TYPE_CHECKING, Callable, Tuple
from typing import TYPE_CHECKING, Callable

from xarray import DataArray, Dataset

from xarray.core.iterators import LevelOrderIter
from xarray.core.treenode import NodePath, TreeNode

Expand Down Expand Up @@ -84,14 +83,13 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
path_a, path_b = node_a.path, node_b.path

if require_names_equal:
if node_a.name != node_b.name:
diff = dedent(
f"""\
if require_names_equal and node_a.name != node_b.name:
diff = dedent(
f"""\
Node '{path_a}' in the left object has name '{node_a.name}'
Node '{path_b}' in the right object has name '{node_b.name}'"""
)
return diff
)
return diff

if len(node_a.children) != len(node_b.children):
diff = dedent(
Expand Down Expand Up @@ -125,7 +123,7 @@ def map_over_subtree(func: Callable) -> Callable:
func : callable
Function to apply to datasets with signature:
`func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.
`func(*args, **kwargs) -> Union[DataTree, Iterable[DataTree]]`.
(i.e. func must accept at least one Dataset and return at least one Dataset.)
Function will not be applied to any nodes without datasets.
Expand Down Expand Up @@ -154,7 +152,7 @@ def map_over_subtree(func: Callable) -> Callable:
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?

@functools.wraps(func)
def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
"""Internal function which maps func over every node in tree, returning a tree of the results."""
from xarray.core.datatree import DataTree

Expand Down Expand Up @@ -259,19 +257,18 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
return _map_over_subtree


def _handle_errors_with_path_context(path):
def _handle_errors_with_path_context(path: str):
"""Wraps given function so that if it fails it also raises path to node on which it failed."""

def decorator(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
if sys.version_info >= (3, 11):
# Add the context information to the error message
e.add_note(
f"Raised whilst mapping function over node with path {path}"
)
# Add the context information to the error message
add_note(
e, f"Raised whilst mapping function over node with path {path}"
)
raise

return wrapper
Expand All @@ -287,7 +284,9 @@ def add_note(err: BaseException, msg: str) -> None:
err.add_note(msg)


def _check_single_set_return_values(path_to_node, obj):
def _check_single_set_return_values(
path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray]
):
"""Check types returned from single evaluation of func, and return number of return values received from func."""
if isinstance(obj, (Dataset, DataArray)):
return 1
Expand Down
3 changes: 1 addition & 2 deletions xarray/core/iterators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from collections import abc
from collections.abc import Iterator
from typing import Callable

Expand All @@ -9,7 +8,7 @@
"""These iterators are copied from anytree.iterators, with minor modifications."""


class LevelOrderIter(abc.Iterator):
class LevelOrderIter(Iterator):
"""Iterate over tree applying level-order strategy starting at `node`.
This is the iterator used by `DataTree` to traverse nodes.
Expand Down
3 changes: 0 additions & 3 deletions xarray/datatree_/datatree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
# import public API
from .mapping import TreeIsomorphismError, map_over_subtree
from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError


__all__ = (
"TreeIsomorphismError",
"InvalidTreeError",
"NotFoundInTreeError",
"map_over_subtree",
)
2 changes: 1 addition & 1 deletion xarray/datatree_/datatree/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from xarray.core.formatting import _compat_to_str, diff_dataset_repr

from xarray.datatree_.datatree.mapping import diff_treestructure
from xarray.core.datatree_mapping import diff_treestructure
from xarray.datatree_.datatree.render import RenderTree

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion xarray/datatree_/datatree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from xarray import Dataset

from .mapping import map_over_subtree
from xarray.core.datatree_mapping import map_over_subtree

"""
Module which specifies the subset of xarray.Dataset's API which we wish to copy onto DataTree.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import numpy as np
import pytest
import xarray as xr

import xarray as xr
from xarray.core.datatree import DataTree
from xarray.datatree_.datatree.mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree
from xarray.core.datatree_mapping import (
TreeIsomorphismError,
check_isomorphic,
map_over_subtree,
)
from xarray.datatree_.datatree.testing import assert_equal

empty = xr.Dataset()
Expand All @@ -12,7 +16,7 @@
class TestCheckTreesIsomorphic:
def test_not_a_tree(self):
with pytest.raises(TypeError, match="not a tree"):
check_isomorphic("s", 1)
check_isomorphic("s", 1) # type: ignore[arg-type]

def test_different_widths(self):
dt1 = DataTree.from_dict(d={"a": empty})
Expand Down Expand Up @@ -69,7 +73,7 @@ def test_not_isomorphic_complex_tree(self, create_test_datatree):
def test_checking_from_root(self, create_test_datatree):
dt1 = create_test_datatree()
dt2 = create_test_datatree()
real_root = DataTree(name="real root")
real_root: DataTree = DataTree(name="real root")
dt2.name = "not_real_root"
dt2.parent = real_root
with pytest.raises(TreeIsomorphismError):
Expand Down

0 comments on commit e6db83b

Please sign in to comment.