Skip to content

Commit 020de38

Browse files
authored
Add type hints (#68)
1 parent 39c883b commit 020de38

File tree

7 files changed

+43
-41
lines changed

7 files changed

+43
-41
lines changed

python_workflow_definition/src/python_workflow_definition/aiida.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919

2020

21-
def load_workflow_json(file_name):
21+
def load_workflow_json(file_name: str) -> WorkGraph:
2222
with open(file_name) as f:
2323
data = json.load(f)
2424

@@ -77,7 +77,7 @@ def load_workflow_json(file_name):
7777
return wg
7878

7979

80-
def write_workflow_json(wg, file_name):
80+
def write_workflow_json(wg: WorkGraph, file_name: str) -> dict:
8181
data = {NODES_LABEL: [], EDGES_LABEL: []}
8282
node_name_mapping = {}
8383
data_node_name_mapping = {}

python_workflow_definition/src/python_workflow_definition/executorlib.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import json
1+
from concurrent.futures import Executor
22
from importlib import import_module
33
from inspect import isfunction
4+
import json
45

56

67
from python_workflow_definition.shared import (
@@ -21,7 +22,7 @@ def get_item(obj, key):
2122
return obj[key]
2223

2324

24-
def _get_value(result_dict, nodes_new_dict, link_dict, exe):
25+
def _get_value(result_dict: dict, nodes_new_dict: dict, link_dict: dict, exe: Executor):
2526
source, source_handle = link_dict[SOURCE_LABEL], link_dict[SOURCE_PORT_LABEL]
2627
if source in result_dict.keys():
2728
result = result_dict[source]
@@ -35,7 +36,7 @@ def _get_value(result_dict, nodes_new_dict, link_dict, exe):
3536
return exe.submit(fn=get_item, obj=result, key=source_handle)
3637

3738

38-
def load_workflow_json(file_name, exe):
39+
def load_workflow_json(file_name: str, exe: Executor):
3940
with open(file_name, "r") as f:
4041
content = json.load(f)
4142

python_workflow_definition/src/python_workflow_definition/jobflow.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
)
2121

2222

23-
def _get_function_dict(flow):
23+
def _get_function_dict(flow: Flow):
2424
return {job.uuid: job.function for job in flow.jobs}
2525

2626

27-
def _get_nodes_dict(function_dict):
27+
def _get_nodes_dict(function_dict: dict):
2828
nodes_dict, nodes_mapping_dict = {}, {}
2929
for i, [k, v] in enumerate(function_dict.items()):
3030
nodes_dict[i] = v
@@ -33,7 +33,7 @@ def _get_nodes_dict(function_dict):
3333
return nodes_dict, nodes_mapping_dict
3434

3535

36-
def _get_edge_from_dict(target, key, value_dict, nodes_mapping_dict):
36+
def _get_edge_from_dict(target: str, key: str, value_dict: dict, nodes_mapping_dict: dict) -> dict:
3737
if len(value_dict["attributes"]) == 1:
3838
return {
3939
TARGET_LABEL: target,
@@ -50,7 +50,7 @@ def _get_edge_from_dict(target, key, value_dict, nodes_mapping_dict):
5050
}
5151

5252

53-
def _get_edges_and_extend_nodes(flow_dict, nodes_mapping_dict, nodes_dict):
53+
def _get_edges_and_extend_nodes(flow_dict: dict, nodes_mapping_dict: dict, nodes_dict: dict):
5454
edges_lst = []
5555
for job in flow_dict["jobs"]:
5656
for k, v in job["function_kwargs"].items():
@@ -185,7 +185,7 @@ def _get_edges_and_extend_nodes(flow_dict, nodes_mapping_dict, nodes_dict):
185185
return edges_lst, nodes_dict
186186

187187

188-
def _resort_total_lst(total_dict, nodes_dict):
188+
def _resort_total_lst(total_dict: dict, nodes_dict: dict) -> dict:
189189
nodes_with_dep_lst = list(sorted(total_dict.keys()))
190190
nodes_without_dep_lst = [
191191
k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
@@ -205,7 +205,7 @@ def _resort_total_lst(total_dict, nodes_dict):
205205
return total_new_dict
206206

207207

208-
def _group_edges(edges_lst):
208+
def _group_edges(edges_lst: list) -> dict:
209209
total_dict = {}
210210
for ed_major in edges_lst:
211211
target_id = ed_major[TARGET_LABEL]
@@ -218,11 +218,11 @@ def _group_edges(edges_lst):
218218
return total_dict
219219

220220

221-
def _get_input_dict(nodes_dict):
221+
def _get_input_dict(nodes_dict: dict) -> dict:
222222
return {k: v for k, v in nodes_dict.items() if not isfunction(v)}
223223

224224

225-
def _get_workflow(nodes_dict, input_dict, total_dict, source_handles_dict):
225+
def _get_workflow(nodes_dict: dict, input_dict: dict, total_dict: dict, source_handles_dict: dict) -> list:
226226
def get_attr_helper(obj, source_handle):
227227
if source_handle is None:
228228
return getattr(obj, "output")
@@ -262,7 +262,7 @@ def _get_item_from_tuple(input_obj, index, index_lst):
262262
return list(input_obj)[index_lst.index(index)]
263263

264264

265-
def load_workflow_json(file_name):
265+
def load_workflow_json(file_name: str) -> Flow:
266266
with open(file_name, "r") as f:
267267
content = json.load(f)
268268

@@ -302,7 +302,7 @@ def load_workflow_json(file_name):
302302
return Flow(task_lst)
303303

304304

305-
def write_workflow_json(flow, file_name="workflow.json"):
305+
def write_workflow_json(flow: Flow, file_name: str = "workflow.json"):
306306
flow_dict = flow.as_dict()
307307
function_dict = _get_function_dict(flow=flow)
308308
nodes_dict, nodes_mapping_dict = _get_nodes_dict(function_dict=function_dict)

python_workflow_definition/src/python_workflow_definition/plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616

1717

18-
def plot(file_name):
18+
def plot(file_name: str):
1919
with open(file_name, "r") as f:
2020
content = json.load(f)
2121

python_workflow_definition/src/python_workflow_definition/purepython.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919

2020

21-
def resort_total_lst(total_lst, nodes_dict):
21+
def resort_total_lst(total_lst: list, nodes_dict: dict) -> list:
2222
nodes_with_dep_lst = list(sorted([v[0] for v in total_lst]))
2323
nodes_without_dep_lst = [
2424
k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
@@ -36,7 +36,7 @@ def resort_total_lst(total_lst, nodes_dict):
3636
return total_new_lst
3737

3838

39-
def group_edges(edges_lst):
39+
def group_edges(edges_lst: list) -> list:
4040
edges_sorted_lst = sorted(edges_lst, key=lambda x: x[TARGET_LABEL], reverse=True)
4141
total_lst, tmp_lst = [], []
4242
target_id = edges_sorted_lst[0][TARGET_LABEL]
@@ -51,7 +51,7 @@ def group_edges(edges_lst):
5151
return total_lst
5252

5353

54-
def _get_value(result_dict, nodes_new_dict, link_dict):
54+
def _get_value(result_dict: dict, nodes_new_dict: dict, link_dict: dict):
5555
source, source_handle = link_dict[SOURCE_LABEL], link_dict[SOURCE_PORT_LABEL]
5656
if source in result_dict.keys():
5757
result = result_dict[source]
@@ -65,7 +65,7 @@ def _get_value(result_dict, nodes_new_dict, link_dict):
6565
return result[source_handle]
6666

6767

68-
def load_workflow_json(file_name):
68+
def load_workflow_json(file_name: str):
6969
with open(file_name, "r") as f:
7070
content = json.load(f)
7171

python_workflow_definition/src/python_workflow_definition/pyiron_base.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from importlib import import_module
22
from inspect import isfunction
33
import json
4+
from typing import Optional
45

56
import numpy as np
67
from pyiron_base import job, Project
@@ -19,7 +20,7 @@
1920
)
2021

2122

22-
def _resort_total_lst(total_lst, nodes_dict):
23+
def _resort_total_lst(total_lst: list, nodes_dict: dict) -> list:
2324
nodes_with_dep_lst = list(sorted([v[0] for v in total_lst]))
2425
nodes_without_dep_lst = [
2526
k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
@@ -37,7 +38,7 @@ def _resort_total_lst(total_lst, nodes_dict):
3738
return total_new_lst
3839

3940

40-
def _group_edges(edges_lst):
41+
def _group_edges(edges_lst: list) -> list:
4142
edges_sorted_lst = sorted(edges_lst, key=lambda x: x[TARGET_LABEL], reverse=True)
4243
total_lst, tmp_lst = [], []
4344
target_id = edges_sorted_lst[0][TARGET_LABEL]
@@ -52,18 +53,18 @@ def _group_edges(edges_lst):
5253
return total_lst
5354

5455

55-
def _get_source(nodes_dict, delayed_object_dict, source, sourceHandle):
56-
if source in delayed_object_dict.keys() and sourceHandle is not None:
56+
def _get_source(nodes_dict: dict, delayed_object_dict: dict, source: str, source_handle: str):
57+
if source in delayed_object_dict.keys() and source_handle is not None:
5758
return (
58-
delayed_object_dict[source].__getattr__("output").__getattr__(sourceHandle)
59+
delayed_object_dict[source].__getattr__("output").__getattr__(source_handle)
5960
)
6061
elif source in delayed_object_dict.keys():
6162
return delayed_object_dict[source]
6263
else:
6364
return nodes_dict[source]
6465

6566

66-
def _get_delayed_object_dict(total_lst, nodes_dict, source_handle_dict, pyiron_project):
67+
def _get_delayed_object_dict(total_lst: list, nodes_dict: dict, source_handle_dict: dict, pyiron_project: Project) -> dict:
6768
delayed_object_dict = {}
6869
for item in total_lst:
6970
key, input_dict = item
@@ -72,7 +73,7 @@ def _get_delayed_object_dict(total_lst, nodes_dict, source_handle_dict, pyiron_p
7273
nodes_dict=nodes_dict,
7374
delayed_object_dict=delayed_object_dict,
7475
source=v[SOURCE_LABEL],
75-
sourceHandle=v[SOURCE_PORT_LABEL],
76+
source_handle=v[SOURCE_PORT_LABEL],
7677
)
7778
for k, v in input_dict.items()
7879
}
@@ -83,30 +84,30 @@ def _get_delayed_object_dict(total_lst, nodes_dict, source_handle_dict, pyiron_p
8384
return delayed_object_dict
8485

8586

86-
def get_dict(**kwargs):
87+
def get_dict(**kwargs) -> dict:
8788
return {k: v for k, v in kwargs["kwargs"].items()}
8889

8990

90-
def get_list(**kwargs):
91+
def get_list(**kwargs) -> list:
9192
return list(kwargs["kwargs"].values())
9293

9394

94-
def _remove_server_obj(nodes_dict, edges_lst):
95+
def _remove_server_obj(nodes_dict: dict, edges_lst: list):
9596
server_lst = [k for k in nodes_dict.keys() if k.startswith("server_obj_")]
9697
for s in server_lst:
9798
del nodes_dict[s]
9899
edges_lst = [ep for ep in edges_lst if s not in ep]
99100
return nodes_dict, edges_lst
100101

101102

102-
def _get_nodes(connection_dict, delayed_object_updated_dict):
103+
def _get_nodes(connection_dict: dict, delayed_object_updated_dict: dict):
103104
return {
104105
connection_dict[k]: v._python_function if isinstance(v, DelayedObject) else v
105106
for k, v in delayed_object_updated_dict.items()
106107
}
107108

108109

109-
def _get_unique_objects(nodes_dict):
110+
def _get_unique_objects(nodes_dict: dict):
110111
delayed_object_dict = {}
111112
for k, v in nodes_dict.items():
112113
if isinstance(v, DelayedObject):
@@ -158,7 +159,7 @@ def _get_unique_objects(nodes_dict):
158159
return delayed_object_updated_dict, match_dict
159160

160161

161-
def _get_connection_dict(delayed_object_updated_dict, match_dict):
162+
def _get_connection_dict(delayed_object_updated_dict: dict, match_dict: dict):
162163
new_obj_dict = {}
163164
connection_dict = {}
164165
lookup_dict = {}
@@ -174,7 +175,7 @@ def _get_connection_dict(delayed_object_updated_dict, match_dict):
174175
return connection_dict, lookup_dict
175176

176177

177-
def _get_edges_dict(edges_lst, nodes_dict, connection_dict, lookup_dict):
178+
def _get_edges_dict(edges_lst: list, nodes_dict: dict, connection_dict: dict, lookup_dict: dict):
178179
edges_dict_lst = []
179180
existing_connection_lst = []
180181
for ep in edges_lst:
@@ -216,7 +217,7 @@ def _get_edges_dict(edges_lst, nodes_dict, connection_dict, lookup_dict):
216217
return edges_dict_lst
217218

218219

219-
def load_workflow_json(file_name, project=None):
220+
def load_workflow_json(file_name: str, project: Optional[Project]=None):
220221
if project is None:
221222
project = Project(".")
222223

@@ -247,7 +248,7 @@ def load_workflow_json(file_name, project=None):
247248
return list(delayed_object_dict.values())
248249

249250

250-
def write_workflow_json(delayed_object, file_name="workflow.json"):
251+
def write_workflow_json(delayed_object: DelayedObject, file_name: str="workflow.json"):
251252
nodes_dict, edges_lst = delayed_object.get_graph()
252253
nodes_dict, edges_lst = _remove_server_obj(
253254
nodes_dict=nodes_dict, edges_lst=edges_lst

python_workflow_definition/src/python_workflow_definition/shared.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@
66
TARGET_PORT_LABEL = "targetPort"
77

88

9-
def get_dict(**kwargs):
9+
def get_dict(**kwargs) -> dict:
1010
# NOTE: In WG, this will automatically be wrapped in a dict with the `result` key
1111
return {k: v for k, v in kwargs.items()}
1212
# return {'dict': {k: v for k, v in kwargs.items()}}
1313

1414

15-
def get_list(**kwargs):
15+
def get_list(**kwargs) -> list:
1616
return list(kwargs.values())
1717

1818

19-
def get_kwargs(lst):
19+
def get_kwargs(lst: list) -> dict:
2020
return {
2121
t[TARGET_PORT_LABEL]: {
2222
SOURCE_LABEL: t[SOURCE_LABEL],
@@ -26,7 +26,7 @@ def get_kwargs(lst):
2626
}
2727

2828

29-
def get_source_handles(edges_lst):
29+
def get_source_handles(edges_lst: list) -> dict:
3030
source_handle_dict = {}
3131
for ed in edges_lst:
3232
if ed[SOURCE_LABEL] not in source_handle_dict.keys():
@@ -38,7 +38,7 @@ def get_source_handles(edges_lst):
3838
}
3939

4040

41-
def convert_nodes_list_to_dict(nodes_list):
41+
def convert_nodes_list_to_dict(nodes_list: list) -> dict:
4242
return {
4343
str(el["id"]): el["value"] if "value" in el else el["function"]
4444
for el in sorted(nodes_list, key=lambda d: d["id"])

0 commit comments

Comments
 (0)