Skip to content

Commit

Permalink
Add typing to dask.order (dask#10553)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Oct 12, 2023
1 parent 1a9ba01 commit 9f7d557
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 87 deletions.
45 changes: 36 additions & 9 deletions dask/core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Collection, Iterable
from typing import Any, cast
from collections.abc import Collection, Iterable, Mapping
from typing import Any, Literal, TypeVar, cast, overload

from dask.typing import Key, no_default
from dask.typing import Graph, Key, NoDefault, no_default


def ishashable(x):
Expand Down Expand Up @@ -223,7 +223,32 @@ def validate_key(key: object) -> None:
raise TypeError(f"Unexpected key type {type(key)} (value: {key!r})")


def get_dependencies(dsk, key=None, task=no_default, as_list=False):
@overload
def get_dependencies(
dsk: Graph,
key: Key | None = ...,
task: Key | NoDefault = ...,
as_list: Literal[False] = ...,
) -> set[Key]:
...


@overload
def get_dependencies(
dsk: Graph,
key: Key | None,
task: Key | NoDefault,
as_list: Literal[True],
) -> list[Key]:
...


def get_dependencies(
dsk: Graph,
key: Key | None = None,
task: Key | NoDefault = no_default,
as_list: bool = False,
) -> set[Key] | list[Key]:
"""Get the immediate tasks on which this task depends
Examples
Expand Down Expand Up @@ -264,7 +289,7 @@ def get_dependencies(dsk, key=None, task=no_default, as_list=False):
return keys_in_tasks(dsk, [arg], as_list=as_list)


def get_deps(dsk):
def get_deps(dsk: Graph) -> tuple[dict[Key, set[Key]], dict[Key, set[Key]]]:
"""Get dependencies and dependents from dask dask graph
>>> inc = lambda x: x + 1
Expand Down Expand Up @@ -308,22 +333,24 @@ def flatten(seq, container=list):
yield item


def reverse_dict(d):
T_ = TypeVar("T_")


def reverse_dict(d: Mapping[T_, Iterable[T_]]) -> dict[T_, set[T_]]:
"""
>>> a, b, c = 'abc'
>>> d = {a: [b, c], b: [c]}
>>> reverse_dict(d) # doctest: +SKIP
{'a': set([]), 'b': set(['a']}, 'c': set(['a', 'b'])}
"""
result = defaultdict(set)
result: defaultdict[T_, set[T_]] = defaultdict(set)
_add = set.add
for k, vals in d.items():
result[k]
for val in vals:
_add(result[val], k)
result.default_factory = None
return result
return dict(result)


def subs(task, key, val):
Expand Down
6 changes: 3 additions & 3 deletions dask/graph_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from __future__ import annotations

import uuid
from collections.abc import Callable, Hashable, Set
from typing import Any, Literal, TypeVar
from collections.abc import Callable, Hashable
from typing import Literal, TypeVar

from dask.base import (
clone_key,
Expand Down Expand Up @@ -321,7 +321,7 @@ def _bind_one(

dsk = child.__dask_graph__() # type: ignore
new_layers: dict[str, Layer] = {}
new_deps: dict[str, Set[Any]] = {}
new_deps: dict[str, set[str]] = {}

if isinstance(dsk, HighLevelGraph):
try:
Expand Down
16 changes: 8 additions & 8 deletions dask/highlevelgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def is_materialized(self) -> bool:
return True

@abc.abstractmethod
def get_output_keys(self) -> Set:
def get_output_keys(self) -> Set[Key]:
"""Return a set of all output keys
Output keys are all keys in the layer that might be referenced by
Expand Down Expand Up @@ -405,16 +405,16 @@ class HighLevelGraph(Graph):
"""

layers: Mapping[str, Layer]
dependencies: Mapping[str, Set[str]]
key_dependencies: dict[Key, Set[Key]]
dependencies: Mapping[str, set[str]]
key_dependencies: dict[Key, set[Key]]
_to_dict: dict
_all_external_keys: set

def __init__(
self,
layers: Mapping[str, Graph],
dependencies: Mapping[str, Set[str]],
key_dependencies: dict[Key, Set[Key]] | None = None,
dependencies: Mapping[str, set[str]],
key_dependencies: dict[Key, set[Key]] | None = None,
):
self.dependencies = dependencies
self.key_dependencies = key_dependencies or {}
Expand Down Expand Up @@ -487,7 +487,7 @@ def from_collections(
return cls._from_collection(name, layer, dependencies[0])
layers = {name: layer}
name_dep: set[str] = set()
deps: dict[str, Set[str]] = {name: name_dep}
deps: dict[str, set[str]] = {name: name_dep}
for collection in toolz.unique(dependencies, key=id):
if is_dask_collection(collection):
graph = collection.__dask_graph__()
Expand Down Expand Up @@ -583,7 +583,7 @@ def items(self) -> ItemsView[Key, Any]:
def values(self) -> ValuesView[Any]:
return self.to_dict().values()

def get_all_dependencies(self) -> dict[Key, Set[Key]]:
def get_all_dependencies(self) -> dict[Key, set[Key]]:
"""Get dependencies of all keys
This will in most cases materialize all layers, which makes
Expand Down Expand Up @@ -616,7 +616,7 @@ def copy(self) -> HighLevelGraph:
@classmethod
def merge(cls, *graphs: Graph) -> HighLevelGraph:
layers: dict[str, Graph] = {}
dependencies: dict[str, Set[str]] = {}
dependencies: dict[str, set[str]] = {}
for g in graphs:
if isinstance(g, HighLevelGraph):
layers.update(g.layers)
Expand Down
Loading

0 comments on commit 9f7d557

Please sign in to comment.