-
Notifications
You must be signed in to change notification settings - Fork 43
/
mapping.py
310 lines (250 loc) · 12 KB
/
mapping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
from __future__ import annotations
import functools
from itertools import repeat
from textwrap import dedent
from typing import TYPE_CHECKING, Callable, Tuple
from xarray import DataArray, Dataset
from .iterators import LevelOrderIter
from .treenode import NodePath, TreeNode
if TYPE_CHECKING:
from .datatree import DataTree
class TreeIsomorphismError(ValueError):
"""Error raised if two tree objects do not share the same node structure."""
pass
def check_isomorphic(
a: DataTree,
b: DataTree,
require_names_equal: bool = False,
check_from_root: bool = True,
):
"""
Check that two trees have the same structure, raising an error if not.
Does not compare the actual data in the nodes.
By default this function only checks that subtrees are isomorphic, not the entire tree above (if it exists).
Can instead optionally check the entire trees starting from the root, which will ensure all
Can optionally check if corresponding nodes should have the same name.
Parameters
----------
a : DataTree
b : DataTree
require_names_equal : Bool
Whether or not to also check that each node has the same name as its counterpart.
check_from_root : Bool
Whether or not to first traverse to the root of the trees before checking for isomorphism.
If a & b have no parents then this has no effect.
Raises
------
TypeError
If either a or b are not tree objects.
TreeIsomorphismError
If a and b are tree objects, but are not isomorphic to one another.
Also optionally raised if their structure is isomorphic, but the names of any two
respective nodes are not equal.
"""
if not isinstance(a, TreeNode):
raise TypeError(f"Argument `a` is not a tree, it is of type {type(a)}")
if not isinstance(b, TreeNode):
raise TypeError(f"Argument `b` is not a tree, it is of type {type(b)}")
if check_from_root:
a = a.root
b = b.root
diff = diff_treestructure(a, b, require_names_equal=require_names_equal)
if diff:
raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff)
def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str:
"""
Return a summary of why two trees are not isomorphic.
If they are isomorphic return an empty string.
"""
# Walking nodes in "level-order" fashion means walking down from the root breadth-first.
# Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree
# (which it is so long as children are stored in a tuple or list rather than in a set).
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
path_a, path_b = node_a.path, node_b.path
if require_names_equal:
if node_a.name != node_b.name:
diff = dedent(
f"""\
Node '{path_a}' in the left object has name '{node_a.name}'
Node '{path_b}' in the right object has name '{node_b.name}'"""
)
return diff
if len(node_a.children) != len(node_b.children):
diff = dedent(
f"""\
Number of children on node '{path_a}' of the left object: {len(node_a.children)}
Number of children on node '{path_b}' of the right object: {len(node_b.children)}"""
)
return diff
return ""
def map_over_subtree(func: Callable) -> Callable:
"""
Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees.
Applies a function to every dataset in one or more subtrees, returning new trees which store the results.
The function will be applied to any non-empty dataset stored in any of the nodes in the trees. The returned trees
will have the same structure as the supplied trees.
`func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after
mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any
returned value that is one of these types will be stacked into a separate tree before returning all of them.
The trees passed to the resulting function must all be isomorphic to one another. Their nodes need not be named
similarly, but all the output trees will have nodes named in the same way as the first tree passed.
Parameters
----------
func : callable
Function to apply to datasets with signature:
`func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.
(i.e. func must accept at least one Dataset and return at least one Dataset.)
Function will not be applied to any nodes without datasets.
*args : tuple, optional
Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
via .ds .
**kwargs : Any
Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
via .ds .
Returns
-------
mapped : callable
Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at
each node.
See also
--------
DataTree.map_over_subtree
DataTree.map_over_subtree_inplace
DataTree.subtree
"""
# TODO examples in the docstring
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?
@functools.wraps(func)
def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
"""Internal function which maps func over every node in tree, returning a tree of the results."""
from .datatree import DataTree
all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [
a for a in kwargs.values() if isinstance(a, DataTree)
]
if len(all_tree_inputs) > 0:
first_tree, *other_trees = all_tree_inputs
else:
raise TypeError("Must pass at least one tree object")
for other_tree in other_trees:
# isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic
check_isomorphic(
first_tree, other_tree, require_names_equal=False, check_from_root=False
)
# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
out_data_objects = {}
args_as_tree_length_iterables = [
a.subtree if isinstance(a, DataTree) else repeat(a) for a in args
]
n_args = len(args_as_tree_length_iterables)
kwargs_as_tree_length_iterables = {
k: v.subtree if isinstance(v, DataTree) else repeat(v)
for k, v in kwargs.items()
}
for node_of_first_tree, *all_node_args in zip(
first_tree.subtree,
*args_as_tree_length_iterables,
*list(kwargs_as_tree_length_iterables.values()),
):
node_args_as_datasets = [
a.to_dataset() if isinstance(a, DataTree) else a
for a in all_node_args[:n_args]
]
node_kwargs_as_datasets = dict(
zip(
[k for k in kwargs_as_tree_length_iterables.keys()],
[
v.to_dataset() if isinstance(v, DataTree) else v
for v in all_node_args[n_args:]
],
)
)
# Now we can call func on the data in this particular set of corresponding nodes
results = (
func(*node_args_as_datasets, **node_kwargs_as_datasets)
if not node_of_first_tree.is_empty
else None
)
# TODO implement mapping over multiple trees in-place using if conditions from here on?
out_data_objects[node_of_first_tree.path] = results
# Find out how many return values we received
num_return_values = _check_all_return_values(out_data_objects)
# Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees
original_root_path = first_tree.path
result_trees = []
for i in range(num_return_values):
out_tree_contents = {}
for n in first_tree.subtree:
p = n.path
if p in out_data_objects.keys():
if isinstance(out_data_objects[p], tuple):
output_node_data = out_data_objects[p][i]
else:
output_node_data = out_data_objects[p]
else:
output_node_data = None
# Discard parentage so that new trees don't include parents of input nodes
relative_path = str(NodePath(p).relative_to(original_root_path))
relative_path = "/" if relative_path == "." else relative_path
out_tree_contents[relative_path] = output_node_data
new_tree = DataTree.from_dict(
out_tree_contents,
name=first_tree.name,
)
result_trees.append(new_tree)
# If only one result then don't wrap it in a tuple
if len(result_trees) == 1:
return result_trees[0]
else:
return tuple(result_trees)
return _map_over_subtree
def _check_single_set_return_values(path_to_node, obj):
"""Check types returned from single evaluation of func, and return number of return values received from func."""
if isinstance(obj, (Dataset, DataArray)):
return 1
elif isinstance(obj, tuple):
for r in obj:
if not isinstance(r, (Dataset, DataArray)):
raise TypeError(
f"One of the results of calling func on datasets on the nodes at position {path_to_node} is "
f"of type {type(r)}, not Dataset or DataArray."
)
return len(obj)
else:
raise TypeError(
f"The result of calling func on the node at position {path_to_node} is of type {type(obj)}, not "
f"Dataset or DataArray, nor a tuple of such types."
)
def _check_all_return_values(returned_objects):
"""Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types."""
if all(r is None for r in returned_objects.values()):
raise TypeError(
"Called supplied function on all nodes but found a return value of None for"
"all of them."
)
result_data_objects = [
(path_to_node, r)
for path_to_node, r in returned_objects.items()
if r is not None
]
if len(result_data_objects) == 1:
# Only one node in the tree: no need to check consistency of results between nodes
path_to_node, result = result_data_objects[0]
num_return_values = _check_single_set_return_values(path_to_node, result)
else:
prev_path, _ = result_data_objects[0]
prev_num_return_values, num_return_values = None, None
for path_to_node, obj in result_data_objects[1:]:
num_return_values = _check_single_set_return_values(path_to_node, obj)
if (
num_return_values != prev_num_return_values
and prev_num_return_values is not None
):
raise TypeError(
f"Calling func on the nodes at position {path_to_node} returns {num_return_values} separate return "
f"values, whereas calling func on the nodes at position {prev_path} instead returns "
f"{prev_num_return_values} separate return values."
)
prev_path, prev_num_return_values = path_to_node, num_return_values
return num_return_values