Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement dask-specific methods #196

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
220 changes: 220 additions & 0 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,3 +1345,223 @@ def to_zarr(

def plot(self):
raise NotImplementedError

def load(self: DataTree, **kwargs) -> DataTree:
"""Manually trigger loading of the data referenced by this collection.


End-users generally shouldn't need to call this method directly, since
most operations should dispatch to the underlying xarray objects which
this collection contains. There may be use cases where a user wants to
eagerly load data from disk into memory.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.compute``.

See Also
--------
dask.compute
"""
new_datatree_dict = {node.path: node.ds.load(**kwargs) for node in self.subtree}
return DataTree.from_dict(new_datatree_dict)
Comment on lines +1367 to +1368
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this will have the behavior you intend: DataTree.from_dict will construct a completely new tree object, and then you are inserting whatever you get when you call DatasetView.load(). It's not altering self in-place.

Also I think it would be worth double-checking that DatasetView.load() does what you expect too with regard to new objects / copying - I never really thought about that case when I wrote DatasetView.

If you want to return the same tree but with all the data loaded I think you need to alter the current tree in-place instead of creating a new one, i.e. something like

for node in self.subtree:
    self[node.path] = node.ds.load

though this might not fail gracefully...


def compute(self: DataTree, **kwargs) -> DataTree:
"""Manually trigger loading of the data referenced by this collection
and return a new DataTree. The original is left unaltered.


End-users generally shouldn't need to call this method directly, since
most operations should dispatch to the underlying xarray objects which
this collection contains. There may be use cases where a user needs to
work with many file objects on disk.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.compute``.

See Also
--------
dask.compute
"""
new = self.copy(deep=False)
return new.load(**kwargs)

def persist(self: DataTree, **kwargs) -> DataTree:
"""Trigger computation in constituent dask arrays.


Force any data contained in dask arrays to be loaded into memory, where
possible, but keep the data as dask arrays. This is useful when
operating on data with a distributed cluster; if you're using a single
machine with a single pool of memory, consider using ``.compute()``
instead.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed to ``dask.persist``.

See Also
--------
dask.persist
"""
new_datatree_dict = {
node.path: node.ds.persist(**kwargs) for node in self.subtree
}
return DataTree.from_dict(new_datatree_dict)
Comment on lines +1411 to +1414
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment on .load() applies to this too I think.


def __dask_tokenize__(self):
from dask.base import normalize_token

# This method should return a value fully representative of the object
# here. ``xarray.Dataset`` implements a method that accomplishes this,
# and ``DataTree`` is just fundamentally defining relationships between
# these ``Dataset``s. So here we re-use the ``Dataset`` tokenization and
# incorporate the ancestry as an additional component (encoded in the
# incorporate the ancestry as an additional component (encoded in the
# names of the datasets).

ds_tokens = {node.path: node.ds.__dask_tokenize__() for node in self.subtree}
ds_tokens = {node.path: node.ds.__dask_tokenize__() for node in self.subtree}

return normalize_token((type(self), ds_tokens))

return normalize_token((type(self), ds_tokens))
Comment on lines +1427 to +1432
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unintentional repetition of lines? The double return shouldn't even be valid syntax should it??


def __dask_graph__(self):
graphs = {node.path: node.ds.__dask_graph__() for node in self.subtree}
graphs = {node.path: node.ds.__dask_graph__() for node in self.subtree}
Comment on lines +1435 to +1436
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More unintentional repetition?

graphs = {k: v for k, v in graphs.items() if v is not None}

if not graphs:
return None
else:
try:
from dask.highlevelgraph import HighLevelGraph

return HighLevelGraph.merge(*graphs.values())
except ImportError:
from dask import sharedict

return sharedict.merge(*graphs.values())

def __dask_keys__(self):
return [node.ds.__dask_keys__() for node in self.subtree]
return [node.ds.__dask_keys__() for node in self.subtree]
Comment on lines +1452 to +1453
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here


def __dask_layers__(self):
all_keys = self.__dask_keys__()
return sum((all_keys), ())

@property
def __dask_optimize__(self):
import dask.array as da

return da.Array.__dask_optimize__

@property
def __dask_scheduler__(self):
import dask.array as da

return da.Array.__dask_scheduler__

def __dask_postcompute__(self):
return self._dask_postcompute, ()

def _dask_postcompute(self: DataTree, results: Iterable[DatasetView]) -> DataTree:
from dask import is_dask_collection

datatree_nodes = {}
results_iter = iter(results)

for node in self.subtree:
if is_dask_collection(node.ds):
finalize, args = node.ds.__dask_postcompute__()
# darothen: Are we sure that results_iter is ordered the same as
# self.subtree?
# self.subtree?
Comment on lines +1483 to +1485
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does the iterable of DatasetView come from? Presumably you are looping over the nodes, but where are you doing that? Like I don't understand where results is passed in from.

ds = finalize(next(results_iter), *args)
else:
ds = node.ds
datatree_nodes[node.path] = ds

# We use this to avoid validation at time of object creation
new_root = datatree_nodes[self.path]
return type(self)._construct_direct(
new_root.ds._variables,
new_root.ds._coord_names,
new_root.ds._dims,
new_root.ds._attrs,
new_root.ds._indexes,
new_root.ds._encoding,
new_root._name,
new_root._parent,
new_root._children,
new_root._close,
)

def __dask_postpersist__(self):
return self._dask_postpersist, ()

def _dask_postpersist(
self: DataTree, dsk: Mapping, *, rename: Mapping[str, str] | None = None
) -> DataTree:
from dask import is_dask_collection
from dask.highlevelgraph import HighLevelGraph
from dask.optimization import cull

datatree_nodes = {}

for node in self.subtree:
if not is_dask_collection(node):
datatree_nodes[node.path] = node.ds
continue

if isinstance(dsk, HighLevelGraph):
# NOTE(darothen): Implementation based on xarray.Dataset._dask_postpersist(),
# so we preserve the implementation note for future refinement
# TODO: Pin minimum dask version and ensure that can remove this
# note.
# dask >= 2021.3
# __dask_postpersist__() was called by dask.highlevelgraph.
# Don't use dsk.cull(), as we need to prevent partial layers:
# https://github.com/dask/dask/issues/7137
layers = node.__dask_layers__()
if rename:
layers = [rename.get(k, k) for k in layers]
dsk2 = dsk.cull_layers(layers)
elif rename: # pragma: nocover
# NOTE(darothen): Similar to above we preserve the implementation
# note.
# replace_name_in_key requires dask >= 2021.3.
from dask.base import flatten, replace_name_in_key

keys = [
replace_name_in_key(k, rename)
for k in flatten(node.__dask_keys__())
]
dsk2, _ = cull(dsk, keys)
else:
# __dask_postpersist__() was called by dask.{optimize,persist}
dsk2, _ = cull(dsk, node.__dask_keys__())

finalize, args = node.__dask_postpersist__()
kwargs = {"rename": rename} if rename else {}
datatree_nodes[node.path] = finalize(dsk2, *args, **kwargs)

new_root = datatree_nodes[self.path]
return type(self)._construct_direct(
new_root.ds._variables,
new_root.ds._coord_names,
new_root.ds._dims,
new_root.ds._attrs,
new_root.ds._indexes,
new_root.ds._encoding,
new_root._name,
new_root._parent,
new_root._children,
new_root._close,
)