Skip to content

Commit 91ee493

Browse files
authored
Simplify updating the format by defining the labels only once (#55) (#56)
1 parent e13b66b commit 91ee493

File tree

6 files changed

+162
-88
lines changed

6 files changed

+162
-88
lines changed

python_workflow_definition/src/python_workflow_definition/aiida.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@
77
from aiida_workgraph import WorkGraph, task
88
from aiida_workgraph.socket import TaskSocketNamespace
99

10-
from python_workflow_definition.shared import convert_nodes_list_to_dict
10+
from python_workflow_definition.shared import (
11+
convert_nodes_list_to_dict,
12+
NODES_LABEL,
13+
EDGES_LABEL,
14+
SOURCE_LABEL,
15+
SOURCE_PORT_LABEL,
16+
TARGET_LABEL,
17+
TARGET_PORT_LABEL,
18+
)
1119

1220

1321
def load_workflow_json(file_name):
@@ -17,7 +25,7 @@ def load_workflow_json(file_name):
1725
wg = WorkGraph()
1826
task_name_mapping = {}
1927

20-
for id, identifier in convert_nodes_list_to_dict(nodes_list=data["nodes"]).items():
28+
for id, identifier in convert_nodes_list_to_dict(nodes_list=data[NODES_LABEL]).items():
2129
if isinstance(identifier, str) and "." in identifier:
2230
p, m = identifier.rsplit(".", 1)
2331
mod = import_module(p)
@@ -32,33 +40,33 @@ def load_workflow_json(file_name):
3240
task_name_mapping[id] = data_node
3341

3442
# add links
35-
for link in data["edges"]:
36-
to_task = task_name_mapping[str(link["target"])]
43+
for link in data[EDGES_LABEL]:
44+
to_task = task_name_mapping[str(link[TARGET_LABEL])]
3745
# if the input is not exit, it means we pass the data into to the kwargs
3846
# in this case, we add the input socket
39-
if link["targetPort"] not in to_task.inputs:
40-
to_socket = to_task.add_input( "workgraph.any", name=link["targetPort"])
47+
if link[TARGET_PORT_LABEL] not in to_task.inputs:
48+
to_socket = to_task.add_input( "workgraph.any", name=link[TARGET_PORT_LABEL])
4149
else:
42-
to_socket = to_task.inputs[link["targetPort"]]
43-
from_task = task_name_mapping[str(link["source"])]
50+
to_socket = to_task.inputs[link[TARGET_PORT_LABEL]]
51+
from_task = task_name_mapping[str(link[SOURCE_LABEL])]
4452
if isinstance(from_task, orm.Data):
4553
to_socket.value = from_task
4654
else:
4755
try:
48-
if link["sourcePort"] is None:
49-
link["sourcePort"] = "result"
56+
if link[SOURCE_PORT_LABEL] is None:
57+
link[SOURCE_PORT_LABEL] = "result"
5058
# because we are not define the outputs explicitly during the pythonjob creation
5159
# we add it here, and assume the output exit
52-
if link["sourcePort"] not in from_task.outputs:
60+
if link[SOURCE_PORT_LABEL] not in from_task.outputs:
5361
# if str(link["sourcePort"]) not in from_task.outputs:
5462
from_socket = from_task.add_output(
5563
"workgraph.any",
56-
name=link["sourcePort"],
64+
name=link[SOURCE_PORT_LABEL],
5765
# name=str(link["sourcePort"]),
5866
metadata={"is_function_output": True},
5967
)
6068
else:
61-
from_socket = from_task.outputs[link["sourcePort"]]
69+
from_socket = from_task.outputs[link[SOURCE_PORT_LABEL]]
6270

6371
wg.add_link(from_socket, to_socket)
6472
except Exception as e:
@@ -68,7 +76,7 @@ def load_workflow_json(file_name):
6876

6977

7078
def write_workflow_json(wg, file_name):
71-
data = {"nodes": [], "edges": []}
79+
data = {NODES_LABEL: [], EDGES_LABEL: []}
7280
node_name_mapping = {}
7381
data_node_name_mapping = {}
7482
i = 0
@@ -78,19 +86,19 @@ def write_workflow_json(wg, file_name):
7886

7987
callable_name = executor["callable_name"]
8088
callable_name = f"{executor['module_path']}.{callable_name}"
81-
data["nodes"].append({"id": i, "function": callable_name})
89+
data[NODES_LABEL].append({"id": i, "function": callable_name})
8290
i += 1
8391

8492
for link in wg.links:
8593
link_data = link.to_dict()
8694
# if the from socket is the default result, we set it to None
8795
if link_data["from_socket"] == "result":
8896
link_data["from_socket"] = None
89-
link_data["target"] = node_name_mapping[link_data.pop("to_node")]
90-
link_data["targetPort"] = link_data.pop("to_socket")
91-
link_data["source"] = node_name_mapping[link_data.pop("from_node")]
92-
link_data["sourcePort"] = link_data.pop("from_socket")
93-
data["edges"].append(link_data)
97+
link_data[TARGET_LABEL] = node_name_mapping[link_data.pop("to_node")]
98+
link_data[TARGET_PORT_LABEL] = link_data.pop("to_socket")
99+
link_data[SOURCE_LABEL] = node_name_mapping[link_data.pop("from_node")]
100+
link_data[SOURCE_PORT_LABEL] = link_data.pop("from_socket")
101+
data[EDGES_LABEL].append(link_data)
94102

95103
for node in wg.tasks:
96104
for input in node.inputs:
@@ -107,17 +115,17 @@ def write_workflow_json(wg, file_name):
107115
raw_value.pop("node_type", None)
108116
else:
109117
raw_value = input.value.value
110-
data["nodes"].append({"id": i, "value": raw_value})
118+
data[NODES_LABEL].append({"id": i, "value": raw_value})
111119
input_node_name = i
112120
data_node_name_mapping[input.value.uuid] = input_node_name
113121
i += 1
114122
else:
115123
input_node_name = data_node_name_mapping[input.value.uuid]
116-
data["edges"].append({
117-
"target": node_name_mapping[node.name],
118-
"targetPort": input._name,
119-
"source": input_node_name,
120-
"sourcePort": None
124+
data[EDGES_LABEL].append({
125+
TARGET_LABEL: node_name_mapping[node.name],
126+
TARGET_PORT_LABEL: input._name,
127+
SOURCE_LABEL: input_node_name,
128+
SOURCE_PORT_LABEL: None
121129
})
122130
with open(file_name, "w") as f:
123131
# json.dump({"nodes": data[], "edges": edges_new_lst}, f)

python_workflow_definition/src/python_workflow_definition/executorlib.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,17 @@
33
from inspect import isfunction
44

55

6-
from python_workflow_definition.shared import get_dict, get_list, get_kwargs, get_source_handles, convert_nodes_list_to_dict
6+
from python_workflow_definition.shared import (
7+
get_dict,
8+
get_list,
9+
get_kwargs,
10+
get_source_handles,
11+
convert_nodes_list_to_dict,
12+
NODES_LABEL,
13+
EDGES_LABEL,
14+
SOURCE_LABEL,
15+
SOURCE_PORT_LABEL,
16+
)
717
from python_workflow_definition.purepython import resort_total_lst, group_edges
818

919

@@ -12,7 +22,7 @@ def get_item(obj, key):
1222

1323

1424
def _get_value(result_dict, nodes_new_dict, link_dict, exe):
15-
source, source_handle = link_dict["source"], link_dict["sourcePort"]
25+
source, source_handle = link_dict[SOURCE_LABEL], link_dict[SOURCE_PORT_LABEL]
1626
if source in result_dict.keys():
1727
result = result_dict[source]
1828
elif source in nodes_new_dict.keys():
@@ -29,10 +39,10 @@ def load_workflow_json(file_name, exe):
2939
with open(file_name, "r") as f:
3040
content = json.load(f)
3141

32-
edges_new_lst = content["edges"]
42+
edges_new_lst = content[EDGES_LABEL]
3343
nodes_new_dict = {}
3444

35-
for k, v in convert_nodes_list_to_dict(nodes_list=content["nodes"]).items():
45+
for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
3646
if isinstance(v, str) and "." in v:
3747
p, m = v.rsplit('.', 1)
3848
mod = import_module(p)

python_workflow_definition/src/python_workflow_definition/jobflow.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,19 @@
55
import numpy as np
66
from jobflow import job, Flow
77

8-
from python_workflow_definition.shared import get_dict, get_list, get_kwargs, get_source_handles, convert_nodes_list_to_dict
8+
from python_workflow_definition.shared import (
9+
get_dict,
10+
get_list,
11+
get_kwargs,
12+
get_source_handles,
13+
convert_nodes_list_to_dict,
14+
NODES_LABEL,
15+
EDGES_LABEL,
16+
SOURCE_LABEL,
17+
SOURCE_PORT_LABEL,
18+
TARGET_LABEL,
19+
TARGET_PORT_LABEL,
20+
)
921

1022

1123
def _get_function_dict(flow):
@@ -26,9 +38,19 @@ def _get_nodes_dict(function_dict):
2638

2739
def _get_edge_from_dict(target, key, value_dict, nodes_mapping_dict):
2840
if len(value_dict['attributes']) == 1:
29-
return {"target": target, "targetPort": key, "source": nodes_mapping_dict[value_dict["uuid"]], "sourcePort": value_dict["attributes"][0][1]}
41+
return {
42+
TARGET_LABEL: target,
43+
TARGET_PORT_LABEL: key,
44+
SOURCE_LABEL: nodes_mapping_dict[value_dict["uuid"]],
45+
SOURCE_PORT_LABEL: value_dict["attributes"][0][1],
46+
}
3047
else:
31-
return {"target": target, "targetPort": key, "source": nodes_mapping_dict[value_dict["uuid"]], "sourcePort": None}
48+
return {
49+
TARGET_LABEL: target,
50+
TARGET_PORT_LABEL: key,
51+
SOURCE_LABEL: nodes_mapping_dict[value_dict["uuid"]],
52+
SOURCE_PORT_LABEL: None,
53+
}
3254

3355

3456
def _get_edges_and_extend_nodes(flow_dict, nodes_mapping_dict, nodes_dict):
@@ -59,8 +81,8 @@ def _get_edges_and_extend_nodes(flow_dict, nodes_mapping_dict, nodes_dict):
5981
nodes_dict[node_index] = vt
6082
else:
6183
node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[str(vt)]
62-
edges_lst.append({"target": node_dict_index, "targetPort": kt, "source": node_index, "sourcePort": None})
63-
edges_lst.append({"target": nodes_mapping_dict[job["uuid"]], "targetPort": k, "source": node_dict_index, "sourcePort": None})
84+
edges_lst.append({TARGET_LABEL: node_dict_index, TARGET_PORT_LABEL: kt, SOURCE_LABEL: node_index, SOURCE_PORT_LABEL: None})
85+
edges_lst.append({TARGET_LABEL: nodes_mapping_dict[job["uuid"]], TARGET_PORT_LABEL: k, SOURCE_LABEL: node_dict_index, SOURCE_PORT_LABEL: None})
6486
elif isinstance(v, list) and any([isinstance(el, dict) and "@module" in el and "@class" in el and "@version" in el for el in v]):
6587
node_list_index = len(nodes_dict)
6688
nodes_dict[node_list_index] = get_list
@@ -78,15 +100,15 @@ def _get_edges_and_extend_nodes(flow_dict, nodes_mapping_dict, nodes_dict):
78100
nodes_dict[node_index] = vt
79101
else:
80102
node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[str(vt)]
81-
edges_lst.append({"target": node_list_index, "targetPort": kt, "source": node_index, "sourcePort": None})
82-
edges_lst.append({"target": nodes_mapping_dict[job["uuid"]], "targetPort": k, "source": node_list_index, "sourcePort": None})
103+
edges_lst.append({TARGET_LABEL: node_list_index, TARGET_PORT_LABEL: kt, SOURCE_LABEL: node_index, SOURCE_PORT_LABEL: None})
104+
edges_lst.append({TARGET_LABEL: nodes_mapping_dict[job["uuid"]], TARGET_PORT_LABEL: k, SOURCE_LABEL: node_list_index, SOURCE_PORT_LABEL: None})
83105
else:
84106
if v not in nodes_dict.values():
85107
node_index = len(nodes_dict)
86108
nodes_dict[node_index] = v
87109
else:
88110
node_index = {tv: tk for tk, tv in nodes_dict.items()}[v]
89-
edges_lst.append({"target": nodes_mapping_dict[job["uuid"]], "targetPort": k, "source": node_index, "sourcePort": None})
111+
edges_lst.append({TARGET_LABEL: nodes_mapping_dict[job["uuid"]], TARGET_PORT_LABEL: k, SOURCE_LABEL: node_index, SOURCE_PORT_LABEL: None})
90112
return edges_lst, nodes_dict
91113

92114

@@ -99,7 +121,7 @@ def _resort_total_lst(total_dict, nodes_dict):
99121
for ind in sorted(total_dict.keys()):
100122
connect = total_dict[ind]
101123
if ind not in ordered_lst:
102-
source_lst = [sd["source"] for sd in connect.values()]
124+
source_lst = [sd[SOURCE_LABEL] for sd in connect.values()]
103125
if all([s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]):
104126
ordered_lst.append(ind)
105127
total_new_dict[ind] = connect
@@ -109,11 +131,11 @@ def _resort_total_lst(total_dict, nodes_dict):
109131
def _group_edges(edges_lst):
110132
total_dict = {}
111133
for ed_major in edges_lst:
112-
target_id = ed_major["target"]
134+
target_id = ed_major[TARGET_LABEL]
113135
tmp_lst = []
114136
if target_id not in total_dict.keys():
115137
for ed in edges_lst:
116-
if target_id == ed["target"]:
138+
if target_id == ed[TARGET_LABEL]:
117139
tmp_lst.append(ed)
118140
total_dict[target_id] = get_kwargs(lst=tmp_lst)
119141
return total_dict
@@ -139,8 +161,8 @@ def get_attr_helper(obj, source_handle):
139161
else:
140162
fn = job(method=v)
141163
kwargs = {
142-
kw: input_dict[vw["source"]] if vw["source"] in input_dict else get_attr_helper(
143-
obj=memory_dict[vw["source"]], source_handle=vw["sourcePort"])
164+
kw: input_dict[vw[SOURCE_LABEL]] if vw[SOURCE_LABEL] in input_dict else get_attr_helper(
165+
obj=memory_dict[vw[SOURCE_LABEL]], source_handle=vw[SOURCE_PORT_LABEL])
144166
for kw, vw in total_dict[k].items()
145167
}
146168
memory_dict[k] = fn(**kwargs)
@@ -159,21 +181,21 @@ def load_workflow_json(file_name):
159181
content = json.load(f)
160182

161183
edges_new_lst = []
162-
for edge in content["edges"]:
163-
if edge["sourcePort"] is None:
184+
for edge in content[EDGES_LABEL]:
185+
if edge[SOURCE_PORT_LABEL] is None:
164186
edges_new_lst.append(edge)
165187
else:
166188
edges_new_lst.append(
167189
{
168-
"target": edge["target"],
169-
"targetPort": edge["targetPort"],
170-
"source": edge["source"],
171-
"sourcePort": str(edge["sourcePort"]),
190+
TARGET_LABEL: edge[TARGET_LABEL],
191+
TARGET_PORT_LABEL: edge[TARGET_PORT_LABEL],
192+
SOURCE_LABEL: edge[SOURCE_LABEL],
193+
SOURCE_PORT_LABEL: str(edge[SOURCE_PORT_LABEL]),
172194
}
173195
)
174196

175197
nodes_new_dict = {}
176-
for k, v in convert_nodes_list_to_dict(nodes_list=content["nodes"]).items():
198+
for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
177199
if isinstance(v, str) and "." in v:
178200
p, m = v.rsplit('.', 1)
179201
mod = import_module(p)
@@ -214,4 +236,4 @@ def write_workflow_json(flow, file_name="workflow.json"):
214236
nodes_store_lst.append({"id": k, "value": v})
215237

216238
with open(file_name, "w") as f:
217-
json.dump({"nodes": nodes_store_lst, "edges": edges_lst}, f)
239+
json.dump({NODES_LABEL: nodes_store_lst, EDGES_LABEL: edges_lst}, f)

python_workflow_definition/src/python_workflow_definition/purepython.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,19 @@
33
from inspect import isfunction
44

55

6-
from python_workflow_definition.shared import get_dict, get_list, get_kwargs, get_source_handles, convert_nodes_list_to_dict
6+
from python_workflow_definition.shared import (
7+
get_dict,
8+
get_list,
9+
get_kwargs,
10+
get_source_handles,
11+
convert_nodes_list_to_dict,
12+
NODES_LABEL,
13+
EDGES_LABEL,
14+
SOURCE_LABEL,
15+
SOURCE_PORT_LABEL,
16+
TARGET_LABEL,
17+
TARGET_PORT_LABEL,
18+
)
719

820

921
def resort_total_lst(total_lst, nodes_dict):
@@ -13,30 +25,30 @@ def resort_total_lst(total_lst, nodes_dict):
1325
while len(total_new_lst) < len(total_lst):
1426
for ind, connect in total_lst:
1527
if ind not in ordered_lst:
16-
source_lst = [sd["source"] for sd in connect.values()]
28+
source_lst = [sd[SOURCE_LABEL] for sd in connect.values()]
1729
if all([s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]):
1830
ordered_lst.append(ind)
1931
total_new_lst.append([ind, connect])
2032
return total_new_lst
2133

2234

2335
def group_edges(edges_lst):
24-
edges_sorted_lst = sorted(edges_lst, key=lambda x: x["target"], reverse=True)
36+
edges_sorted_lst = sorted(edges_lst, key=lambda x: x[TARGET_LABEL], reverse=True)
2537
total_lst, tmp_lst = [], []
26-
target_id = edges_sorted_lst[0]["target"]
38+
target_id = edges_sorted_lst[0][TARGET_LABEL]
2739
for ed in edges_sorted_lst:
28-
if target_id == ed["target"]:
40+
if target_id == ed[TARGET_LABEL]:
2941
tmp_lst.append(ed)
3042
else:
3143
total_lst.append((target_id, get_kwargs(lst=tmp_lst)))
32-
target_id = ed["target"]
44+
target_id = ed[TARGET_LABEL]
3345
tmp_lst = [ed]
3446
total_lst.append((target_id, get_kwargs(lst=tmp_lst)))
3547
return total_lst
3648

3749

3850
def _get_value(result_dict, nodes_new_dict, link_dict):
39-
source, source_handle = link_dict["source"], link_dict["sourcePort"]
51+
source, source_handle = link_dict[SOURCE_LABEL], link_dict[SOURCE_PORT_LABEL]
4052
if source in result_dict.keys():
4153
result = result_dict[source]
4254
elif source in nodes_new_dict.keys():
@@ -53,9 +65,9 @@ def load_workflow_json(file_name):
5365
with open(file_name, "r") as f:
5466
content = json.load(f)
5567

56-
edges_new_lst = content["edges"]
68+
edges_new_lst = content[EDGES_LABEL]
5769
nodes_new_dict = {}
58-
for k, v in convert_nodes_list_to_dict(nodes_list=content["nodes"]).items():
70+
for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
5971
if isinstance(v, str) and "." in v:
6072
p, m = v.rsplit('.', 1)
6173
mod = import_module(p)

0 commit comments

Comments
 (0)