Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions python_workflow_definition/src/python_workflow_definition/aiida.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def load_workflow_json(file_name):
wg = WorkGraph()
task_name_mapping = {}

for id, identifier in convert_nodes_list_to_dict(nodes_list=data[NODES_LABEL]).items():
for id, identifier in convert_nodes_list_to_dict(
nodes_list=data[NODES_LABEL]
).items():
if isinstance(identifier, str) and "." in identifier:
p, m = identifier.rsplit(".", 1)
mod = import_module(p)
Expand All @@ -45,7 +47,7 @@ def load_workflow_json(file_name):
# if the input is not exit, it means we pass the data into to the kwargs
# in this case, we add the input socket
if link[TARGET_PORT_LABEL] not in to_task.inputs:
to_socket = to_task.add_input( "workgraph.any", name=link[TARGET_PORT_LABEL])
to_socket = to_task.add_input("workgraph.any", name=link[TARGET_PORT_LABEL])
else:
to_socket = to_task.inputs[link[TARGET_PORT_LABEL]]
from_task = task_name_mapping[str(link[SOURCE_LABEL])]
Expand All @@ -58,7 +60,7 @@ def load_workflow_json(file_name):
# because we are not define the outputs explicitly during the pythonjob creation
# we add it here, and assume the output exit
if link[SOURCE_PORT_LABEL] not in from_task.outputs:
# if str(link["sourcePort"]) not in from_task.outputs:
# if str(link["sourcePort"]) not in from_task.outputs:
from_socket = from_task.add_output(
"workgraph.any",
name=link[SOURCE_PORT_LABEL],
Expand Down Expand Up @@ -99,7 +101,7 @@ def write_workflow_json(wg, file_name):
link_data[SOURCE_LABEL] = node_name_mapping[link_data.pop("from_node")]
link_data[SOURCE_PORT_LABEL] = link_data.pop("from_socket")
data[EDGES_LABEL].append(link_data)

for node in wg.tasks:
for input in node.inputs:
# assume namespace is not used as input
Expand All @@ -121,12 +123,14 @@ def write_workflow_json(wg, file_name):
i += 1
else:
input_node_name = data_node_name_mapping[input.value.uuid]
data[EDGES_LABEL].append({
TARGET_LABEL: node_name_mapping[node.name],
TARGET_PORT_LABEL: input._name,
SOURCE_LABEL: input_node_name,
SOURCE_PORT_LABEL: None
})
data[EDGES_LABEL].append(
{
TARGET_LABEL: node_name_mapping[node.name],
TARGET_PORT_LABEL: input._name,
SOURCE_LABEL: input_node_name,
SOURCE_PORT_LABEL: None,
}
)
with open(file_name, "w") as f:
# json.dump({"nodes": data[], "edges": edges_new_lst}, f)
json.dump(data, f, indent=2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def load_workflow_json(file_name, exe):

for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
if isinstance(v, str) and "." in v:
p, m = v.rsplit('.', 1)
p, m = v.rsplit(".", 1)
mod = import_module(p)
nodes_new_dict[int(k)] = getattr(mod, m)
else:
Expand All @@ -59,7 +59,12 @@ def load_workflow_json(file_name, exe):
node = nodes_new_dict[lst[0]]
if isfunction(node):
kwargs = {
k: _get_value(result_dict=result_dict, nodes_new_dict=nodes_new_dict, link_dict=v, exe=exe)
k: _get_value(
result_dict=result_dict,
nodes_new_dict=nodes_new_dict,
link_dict=v,
exe=exe,
)
for k, v in lst[1].items()
}
result_dict[lst[0]] = exe.submit(node, **kwargs)
Expand Down
174 changes: 131 additions & 43 deletions python_workflow_definition/src/python_workflow_definition/jobflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@


def _get_function_dict(flow):
return {
job.uuid: job.function
for job in flow.jobs
}
return {job.uuid: job.function for job in flow.jobs}


def _get_nodes_dict(function_dict):
Expand All @@ -37,7 +34,7 @@ def _get_nodes_dict(function_dict):


def _get_edge_from_dict(target, key, value_dict, nodes_mapping_dict):
if len(value_dict['attributes']) == 1:
if len(value_dict["attributes"]) == 1:
return {
TARGET_LABEL: target,
TARGET_PORT_LABEL: key,
Expand All @@ -57,72 +54,152 @@ def _get_edges_and_extend_nodes(flow_dict, nodes_mapping_dict, nodes_dict):
edges_lst = []
for job in flow_dict["jobs"]:
for k, v in job["function_kwargs"].items():
if isinstance(v, dict) and "@module" in v and "@class" in v and "@version" in v:
edges_lst.append(_get_edge_from_dict(
target=nodes_mapping_dict[job["uuid"]],
key=k,
value_dict=v,
nodes_mapping_dict=nodes_mapping_dict,
))
elif isinstance(v, dict) and any([isinstance(el, dict) and "@module" in el and "@class" in el and "@version" in el for el in v.values()]):
if (
isinstance(v, dict)
and "@module" in v
and "@class" in v
and "@version" in v
):
edges_lst.append(
_get_edge_from_dict(
target=nodes_mapping_dict[job["uuid"]],
key=k,
value_dict=v,
nodes_mapping_dict=nodes_mapping_dict,
)
)
elif isinstance(v, dict) and any(
[
isinstance(el, dict)
and "@module" in el
and "@class" in el
and "@version" in el
for el in v.values()
]
):
node_dict_index = len(nodes_dict)
nodes_dict[node_dict_index] = get_dict
for kt, vt in v.items():
if isinstance(vt, dict) and "@module" in vt and "@class" in vt and "@version" in vt:
edges_lst.append(_get_edge_from_dict(
target=node_dict_index,
key=kt,
value_dict=vt,
nodes_mapping_dict=nodes_mapping_dict,
))
if (
isinstance(vt, dict)
and "@module" in vt
and "@class" in vt
and "@version" in vt
):
edges_lst.append(
_get_edge_from_dict(
target=node_dict_index,
key=kt,
value_dict=vt,
nodes_mapping_dict=nodes_mapping_dict,
)
)
else:
if vt not in nodes_dict.values():
node_index = len(nodes_dict)
nodes_dict[node_index] = vt
else:
node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[str(vt)]
edges_lst.append({TARGET_LABEL: node_dict_index, TARGET_PORT_LABEL: kt, SOURCE_LABEL: node_index, SOURCE_PORT_LABEL: None})
edges_lst.append({TARGET_LABEL: nodes_mapping_dict[job["uuid"]], TARGET_PORT_LABEL: k, SOURCE_LABEL: node_dict_index, SOURCE_PORT_LABEL: None})
elif isinstance(v, list) and any([isinstance(el, dict) and "@module" in el and "@class" in el and "@version" in el for el in v]):
node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[
str(vt)
]
edges_lst.append(
{
TARGET_LABEL: node_dict_index,
TARGET_PORT_LABEL: kt,
SOURCE_LABEL: node_index,
SOURCE_PORT_LABEL: None,
}
)
edges_lst.append(
{
TARGET_LABEL: nodes_mapping_dict[job["uuid"]],
TARGET_PORT_LABEL: k,
SOURCE_LABEL: node_dict_index,
SOURCE_PORT_LABEL: None,
}
)
elif isinstance(v, list) and any(
[
isinstance(el, dict)
and "@module" in el
and "@class" in el
and "@version" in el
for el in v
]
):
node_list_index = len(nodes_dict)
nodes_dict[node_list_index] = get_list
for kt, vt in enumerate(v):
if isinstance(vt, dict) and "@module" in vt and "@class" in vt and "@version" in vt:
edges_lst.append(_get_edge_from_dict(
target=node_list_index,
key=str(kt),
value_dict=vt,
nodes_mapping_dict=nodes_mapping_dict,
))
if (
isinstance(vt, dict)
and "@module" in vt
and "@class" in vt
and "@version" in vt
):
edges_lst.append(
_get_edge_from_dict(
target=node_list_index,
key=str(kt),
value_dict=vt,
nodes_mapping_dict=nodes_mapping_dict,
)
)
else:
if vt not in nodes_dict.values():
node_index = len(nodes_dict)
nodes_dict[node_index] = vt
else:
node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[str(vt)]
edges_lst.append({TARGET_LABEL: node_list_index, TARGET_PORT_LABEL: kt, SOURCE_LABEL: node_index, SOURCE_PORT_LABEL: None})
edges_lst.append({TARGET_LABEL: nodes_mapping_dict[job["uuid"]], TARGET_PORT_LABEL: k, SOURCE_LABEL: node_list_index, SOURCE_PORT_LABEL: None})
node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[
str(vt)
]
edges_lst.append(
{
TARGET_LABEL: node_list_index,
TARGET_PORT_LABEL: kt,
SOURCE_LABEL: node_index,
SOURCE_PORT_LABEL: None,
}
)
edges_lst.append(
{
TARGET_LABEL: nodes_mapping_dict[job["uuid"]],
TARGET_PORT_LABEL: k,
SOURCE_LABEL: node_list_index,
SOURCE_PORT_LABEL: None,
}
)
else:
if v not in nodes_dict.values():
node_index = len(nodes_dict)
nodes_dict[node_index] = v
else:
node_index = {tv: tk for tk, tv in nodes_dict.items()}[v]
edges_lst.append({TARGET_LABEL: nodes_mapping_dict[job["uuid"]], TARGET_PORT_LABEL: k, SOURCE_LABEL: node_index, SOURCE_PORT_LABEL: None})
edges_lst.append(
{
TARGET_LABEL: nodes_mapping_dict[job["uuid"]],
TARGET_PORT_LABEL: k,
SOURCE_LABEL: node_index,
SOURCE_PORT_LABEL: None,
}
)
return edges_lst, nodes_dict


def _resort_total_lst(total_dict, nodes_dict):
nodes_with_dep_lst = list(sorted(total_dict.keys()))
nodes_without_dep_lst = [k for k in nodes_dict.keys() if k not in nodes_with_dep_lst]
nodes_without_dep_lst = [
k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
]
ordered_lst = []
total_new_dict = {}
while len(total_new_dict) < len(total_dict):
for ind in sorted(total_dict.keys()):
connect = total_dict[ind]
if ind not in ordered_lst:
source_lst = [sd[SOURCE_LABEL] for sd in connect.values()]
if all([s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]):
if all(
[s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]
):
ordered_lst.append(ind)
total_new_dict[ind] = connect
return total_new_dict
Expand All @@ -142,7 +219,7 @@ def _group_edges(edges_lst):


def _get_input_dict(nodes_dict):
return {k:v for k, v in nodes_dict.items() if not isfunction(v)}
return {k: v for k, v in nodes_dict.items() if not isfunction(v)}


def _get_workflow(nodes_dict, input_dict, total_dict, source_handles_dict):
Expand All @@ -157,12 +234,21 @@ def get_attr_helper(obj, source_handle):
v = nodes_dict[k]
if isfunction(v):
if k in source_handles_dict.keys():
fn = job(method=v, data=[el for el in source_handles_dict[k] if el is not None])
fn = job(
method=v,
data=[el for el in source_handles_dict[k] if el is not None],
)
else:
fn = job(method=v)
kwargs = {
kw: input_dict[vw[SOURCE_LABEL]] if vw[SOURCE_LABEL] in input_dict else get_attr_helper(
obj=memory_dict[vw[SOURCE_LABEL]], source_handle=vw[SOURCE_PORT_LABEL])
kw: (
input_dict[vw[SOURCE_LABEL]]
if vw[SOURCE_LABEL] in input_dict
else get_attr_helper(
obj=memory_dict[vw[SOURCE_LABEL]],
source_handle=vw[SOURCE_PORT_LABEL],
)
)
for kw, vw in total_dict[k].items()
}
memory_dict[k] = fn(**kwargs)
Expand Down Expand Up @@ -197,7 +283,7 @@ def load_workflow_json(file_name):
nodes_new_dict = {}
for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
if isinstance(v, str) and "." in v:
p, m = v.rsplit('.', 1)
p, m = v.rsplit(".", 1)
mod = import_module(p)
nodes_new_dict[int(k)] = getattr(mod, m)
else:
Expand Down Expand Up @@ -229,7 +315,9 @@ def write_workflow_json(flow, file_name="workflow.json"):
nodes_store_lst = []
for k, v in nodes_dict.items():
if isfunction(v):
nodes_store_lst.append({"id": k, "function": v.__module__ + "." + v.__name__})
nodes_store_lst.append(
{"id": k, "function": v.__module__ + "." + v.__name__}
)
elif isinstance(v, np.ndarray):
nodes_store_lst.append({"id": k, "value": v.tolist()})
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def plot(file_name):
if v[SOURCE_PORT_LABEL] is None:
edge_label_dict[v[SOURCE_LABEL]].append(k)
else:
edge_label_dict[v[SOURCE_LABEL]].append(k + "=result[" + v[SOURCE_PORT_LABEL] + "]")
edge_label_dict[v[SOURCE_LABEL]].append(
k + "=result[" + v[SOURCE_PORT_LABEL] + "]"
)
for k, v in edge_label_dict.items():
graph.add_edge(str(k), str(target_node), label=", ".join(v))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@

def resort_total_lst(total_lst, nodes_dict):
nodes_with_dep_lst = list(sorted([v[0] for v in total_lst]))
nodes_without_dep_lst = [k for k in nodes_dict.keys() if k not in nodes_with_dep_lst]
nodes_without_dep_lst = [
k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
]
ordered_lst, total_new_lst = [], []
while len(total_new_lst) < len(total_lst):
for ind, connect in total_lst:
if ind not in ordered_lst:
source_lst = [sd[SOURCE_LABEL] for sd in connect.values()]
if all([s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]):
if all(
[s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]
):
ordered_lst.append(ind)
total_new_lst.append([ind, connect])
return total_new_lst
Expand Down Expand Up @@ -69,7 +73,7 @@ def load_workflow_json(file_name):
nodes_new_dict = {}
for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
if isinstance(v, str) and "." in v:
p, m = v.rsplit('.', 1)
p, m = v.rsplit(".", 1)
mod = import_module(p)
nodes_new_dict[int(k)] = getattr(mod, m)
else:
Expand All @@ -84,7 +88,9 @@ def load_workflow_json(file_name):
node = nodes_new_dict[lst[0]]
if isfunction(node):
kwargs = {
k: _get_value(result_dict=result_dict, nodes_new_dict=nodes_new_dict, link_dict=v)
k: _get_value(
result_dict=result_dict, nodes_new_dict=nodes_new_dict, link_dict=v
)
for k, v in lst[1].items()
}
result_dict[lst[0]] = node(**kwargs)
Expand Down
Loading
Loading