Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize map_over_subtree #252

Open
Illviljan opened this issue Aug 1, 2023 · 5 comments · May be fixed by #253
Open

Parallelize map_over_subtree #252

Illviljan opened this issue Aug 1, 2023 · 5 comments · May be fixed by #253

Comments

@Illviljan
Copy link

I think there's some good opportunities to run map_over_subtree in parallel using dask.delayed.

Consider this example data:

import numpy as np
import xarray as xr
from datatree import DataTree


number_of_files = 25
number_of_groups = 20
number_of_variables = 2000

datasets = {}
for f in range(number_of_files):
    for g in range(number_of_groups):
        # Create random data:
        time = np.linspace(0, 50 + f, 100 + g)
        y = f * time + g

        # Create dataset:
        ds = xr.Dataset(
            data_vars={
                f"temperature_{g}{i}": ("time", y)
                for i in range(number_of_variables // number_of_groups)
            },
            coords={"time": ("time", time)},
        )  # .chunk()

        # Prepare for Datatree:
        name = f"file_{f}/group_{g}"
        datasets[name] = ds

dt = DataTree.from_dict(datasets)

# %% Interpolate to same time coordinate
new_time = np.linspace(0, 150, 50)
dt_interp = dt.interp(time=new_time)  
# Original 10s, with dask.delayed 6s
# If datasets were chunked: Original 34s, with dask.delayed 10s

Here's my modded map_over_subtree:

def map_over_subtree(func: Callable) -> Callable:
    """
    Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees.

    Applies a function to every dataset in one or more subtrees, returning new trees which store the results.

    The function will be applied to any non-empty dataset stored in any of the nodes in the trees. The returned trees
    will have the same structure as the supplied trees.

    `func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after
    mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any
    returned value that is one of these types will be stacked into a separate tree before returning all of them.

    The trees passed to the resulting function must all be isomorphic to one another. Their nodes need not be named
    similarly, but all the output trees will have nodes named in the same way as the first tree passed.

    Parameters
    ----------
    func : callable
        Function to apply to datasets with signature:

        `func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.

        (i.e. func must accept at least one Dataset and return at least one Dataset.)
        Function will not be applied to any nodes without datasets.
    *args : tuple, optional
        Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
        via .ds .
    **kwargs : Any
        Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
        via .ds .

    Returns
    -------
    mapped : callable
        Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at
        each node.

    See also
    --------
    DataTree.map_over_subtree
    DataTree.map_over_subtree_inplace
    DataTree.subtree
    """

    # TODO examples in the docstring

    # TODO inspect function to work out immediately if the wrong number of arguments were passed for it?

    @functools.wraps(func)
    def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
        """Internal function which maps func over every node in tree, returning a tree of the results."""
        from .datatree import DataTree

        parallel = True
        if parallel:
            import dask

            func_ = dask.delayed(func)
        else:
            func_ = func

        all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [
            a for a in kwargs.values() if isinstance(a, DataTree)
        ]

        if len(all_tree_inputs) > 0:
            first_tree, *other_trees = all_tree_inputs
        else:
            raise TypeError("Must pass at least one tree object")

        for other_tree in other_trees:
            # isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic
            check_isomorphic(
                first_tree,
                other_tree,
                require_names_equal=False,
                check_from_root=False,
            )

        # Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
        # We don't know which arguments are DataTrees so we zip all arguments together as iterables
        # Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
        out_data_objects = {}
        args_as_tree_length_iterables = [
            a.subtree if isinstance(a, DataTree) else repeat(a) for a in args
        ]
        n_args = len(args_as_tree_length_iterables)
        kwargs_as_tree_length_iterables = {
            k: v.subtree if isinstance(v, DataTree) else repeat(v)
            for k, v in kwargs.items()
        }
        for node_of_first_tree, *all_node_args in zip(
            first_tree.subtree,
            *args_as_tree_length_iterables,
            *list(kwargs_as_tree_length_iterables.values()),
        ):
            node_args_as_datasets = [
                a.to_dataset() if isinstance(a, DataTree) else a
                for a in all_node_args[:n_args]
            ]
            node_kwargs_as_datasets = dict(
                zip(
                    [k for k in kwargs_as_tree_length_iterables.keys()],
                    [
                        v.to_dataset() if isinstance(v, DataTree) else v
                        for v in all_node_args[n_args:]
                    ],
                )
            )

            # Now we can call func on the data in this particular set of corresponding nodes
            results = (
                func_(*node_args_as_datasets, **node_kwargs_as_datasets)
                if not node_of_first_tree.is_empty
                else None
            )

            # TODO implement mapping over multiple trees in-place using if conditions from here on?
            out_data_objects[node_of_first_tree.path] = results

        if parallel:
            keys, values = dask.compute(
                [k for k in out_data_objects.keys()],
                [v for v in out_data_objects.values()],
            )
            out_data_objects = {k: v for k, v in zip(keys, values)}

        # Find out how many return values we received
        num_return_values = _check_all_return_values(out_data_objects)

        # Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees
        original_root_path = first_tree.path
        result_trees = []
        for i in range(num_return_values):
            out_tree_contents = {}
            for n in first_tree.subtree:
                p = n.path
                if p in out_data_objects.keys():
                    if isinstance(out_data_objects[p], tuple):
                        output_node_data = out_data_objects[p][i]
                    else:
                        output_node_data = out_data_objects[p]
                else:
                    output_node_data = None

                # Discard parentage so that new trees don't include parents of input nodes
                relative_path = str(
                    NodePath(p).relative_to(original_root_path)
                )
                relative_path = "/" if relative_path == "." else relative_path
                out_tree_contents[relative_path] = output_node_data

            new_tree = DataTree.from_dict(
                out_tree_contents,
                name=first_tree.name,
            )
            result_trees.append(new_tree)

        # If only one result then don't wrap it in a tuple
        if len(result_trees) == 1:
            return result_trees[0]
        else:
            return tuple(result_trees)

    return _map_over_subtree

I'm a little unsure how to get the parallel-argument down to map_over_subtree though?

@TomNicholas
Copy link
Collaborator

Good idea @Illviljan !

I'm a little unsure how to get the parallel-argument down to map_over_subtree though?

Do you actually need to pass it through at all? Couldn't you just do this:

def map_over_subtree(func: Callable, parallel=False) -> Callable:
    @functools.wraps(func)
    def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
        from .datatree import DataTree

        if parallel:
            import dask

or ideally just do this optimization automatically (if dask is installed I guess)?


I'm wondering how xarray normally does this optimization when you apply an operation to every data variable in a Dataset, for instance. Is it related to #196?

@Illviljan
Copy link
Author

I tried a version with parallel as an argument but it isn't passed correctly via the normal methods: dt.interp(time=new_time, parallel=True) errors because it thinks parallel is a coordinate.

Maybe we could always use this optimization. Dask usually adds some overhead though, and I just haven't played around enough to know where that threshold is or if it is significant.

I'm wondering how xarray normally does this optimization when you apply an operation to every data variable in a Dataset, for instance. Is it related to #196?

I think the only place this trick is used is xr.open_mfdataset. Not sure why though, maybe most xarray methods predates dask.delayed?
I also have a feeling my datasets with 2000+ variables is not the normal setup for most xarray users, so there's probably not been a need to optimize in the variable direction.

I don't fully understand all the changes in #196, I see that one as being able to trigger computation of all the dask arrays inside the DataArrays. My suggestion is earlier in that chain; setting up those chunked DataArrays in parallel.

@TomNicholas
Copy link
Collaborator

You have real datasets with 2000+ variables?!?

Now that I understand that this is not about triggering computation of dask arrays but about building the dask arrays in parallel, I'm less sure that this is a good idea.

I guess one way to look at it is through consistency: DataTree.map_over_subtree is very much a generalization of xarray's Dataset.map, just mapping over nested dictionaries of data variables instead of a single-level dict of data variables. As such I think that we should be consistent in how we treat these two implementations - either it makes sense to apply this optimization in both Dataset.map and DataTree.map_over_subtree, or to neither of them, because it's out-of-scope/too much overhead in both cases.

@Illviljan
Copy link
Author

Yes, the example code is quite realistic. That's my type of datasets, and there's still always something missing...

Dataset.map looks very lightweight compared to Dataset.interp and DataTree.map_over_subtree handles both. Some functions are heavier and needs to be treated differently and therefore it's good to have the option of parallelization.

@TomNicholas
Copy link
Collaborator

Dataset.map looks very lightweight compared to Dataset.interp and DataTree.map_over_subtree handles both.

Are you saying that we already do some parallelization like this within Dataset.interp?

We discussed this in the xarray dev call today briefl. Stephan had a few comments, chiefly that he would be surprised if this gave significant speedup in most cases because of restrictions imposed by the GIL. Possibly once python removes the GIL we might want to revisit this question for all of xarray.

@Illviljan Illviljan linked a pull request Aug 7, 2023 that will close this issue
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants