From 6a45b3082997c3b4264c5178355500a2a55ba10b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Jan=C3=9Fen?= Date: Tue, 15 Apr 2025 21:03:03 +0200 Subject: [PATCH] Add type hints --- .../src/python_workflow_definition/aiida.py | 4 +-- .../python_workflow_definition/executorlib.py | 7 ++-- .../src/python_workflow_definition/jobflow.py | 20 +++++------ .../src/python_workflow_definition/plot.py | 2 +- .../python_workflow_definition/purepython.py | 8 ++--- .../python_workflow_definition/pyiron_base.py | 33 ++++++++++--------- .../src/python_workflow_definition/shared.py | 10 +++--- 7 files changed, 43 insertions(+), 41 deletions(-) diff --git a/python_workflow_definition/src/python_workflow_definition/aiida.py b/python_workflow_definition/src/python_workflow_definition/aiida.py index fd49587..c3fbf12 100644 --- a/python_workflow_definition/src/python_workflow_definition/aiida.py +++ b/python_workflow_definition/src/python_workflow_definition/aiida.py @@ -18,7 +18,7 @@ ) -def load_workflow_json(file_name): +def load_workflow_json(file_name: str) -> WorkGraph: with open(file_name) as f: data = json.load(f) @@ -77,7 +77,7 @@ def load_workflow_json(file_name): return wg -def write_workflow_json(wg, file_name): +def write_workflow_json(wg: WorkGraph, file_name: str) -> dict: data = {NODES_LABEL: [], EDGES_LABEL: []} node_name_mapping = {} data_node_name_mapping = {} diff --git a/python_workflow_definition/src/python_workflow_definition/executorlib.py b/python_workflow_definition/src/python_workflow_definition/executorlib.py index eef7da9..8449694 100644 --- a/python_workflow_definition/src/python_workflow_definition/executorlib.py +++ b/python_workflow_definition/src/python_workflow_definition/executorlib.py @@ -1,6 +1,7 @@ -import json +from concurrent.futures import Executor from importlib import import_module from inspect import isfunction +import json from python_workflow_definition.shared import ( @@ -21,7 +22,7 @@ def get_item(obj, key): return obj[key] -def _get_value(result_dict, nodes_new_dict, link_dict, exe): +def _get_value(result_dict: dict, nodes_new_dict: dict, link_dict: dict, exe: Executor): source, source_handle = link_dict[SOURCE_LABEL], link_dict[SOURCE_PORT_LABEL] if source in result_dict.keys(): result = result_dict[source] @@ -35,7 +36,7 @@ def _get_value(result_dict, nodes_new_dict, link_dict, exe): return exe.submit(fn=get_item, obj=result, key=source_handle) -def load_workflow_json(file_name, exe): +def load_workflow_json(file_name: str, exe: Executor): with open(file_name, "r") as f: content = json.load(f) diff --git a/python_workflow_definition/src/python_workflow_definition/jobflow.py b/python_workflow_definition/src/python_workflow_definition/jobflow.py index 20066ee..3daee26 100644 --- a/python_workflow_definition/src/python_workflow_definition/jobflow.py +++ b/python_workflow_definition/src/python_workflow_definition/jobflow.py @@ -20,11 +20,11 @@ ) -def _get_function_dict(flow): +def _get_function_dict(flow: Flow): return {job.uuid: job.function for job in flow.jobs} -def _get_nodes_dict(function_dict): +def _get_nodes_dict(function_dict: dict): nodes_dict, nodes_mapping_dict = {}, {} for i, [k, v] in enumerate(function_dict.items()): nodes_dict[i] = v @@ -33,7 +33,7 @@ def _get_nodes_dict(function_dict): return nodes_dict, nodes_mapping_dict -def _get_edge_from_dict(target, key, value_dict, nodes_mapping_dict): +def _get_edge_from_dict(target: str, key: str, value_dict: dict, nodes_mapping_dict: dict) -> dict: if len(value_dict["attributes"]) == 1: return { TARGET_LABEL: target, @@ -50,7 +50,7 @@ def _get_edge_from_dict(target, key, value_dict, nodes_mapping_dict): } -def _get_edges_and_extend_nodes(flow_dict, nodes_mapping_dict, nodes_dict): +def _get_edges_and_extend_nodes(flow_dict: dict, nodes_mapping_dict: dict, nodes_dict: dict): edges_lst = [] for job in flow_dict["jobs"]: for k, v in job["function_kwargs"].items(): @@ -185,7 +185,7 @@ def _get_edges_and_extend_nodes(flow_dict, nodes_mapping_dict, nodes_dict): return edges_lst, nodes_dict -def _resort_total_lst(total_dict, nodes_dict): +def _resort_total_lst(total_dict: dict, nodes_dict: dict) -> dict: nodes_with_dep_lst = list(sorted(total_dict.keys())) nodes_without_dep_lst = [ k for k in nodes_dict.keys() if k not in nodes_with_dep_lst @@ -205,7 +205,7 @@ def _resort_total_lst(total_dict, nodes_dict): return total_new_dict -def _group_edges(edges_lst): +def _group_edges(edges_lst: list) -> dict: total_dict = {} for ed_major in edges_lst: target_id = ed_major[TARGET_LABEL] @@ -218,11 +218,11 @@ def _group_edges(edges_lst): return total_dict -def _get_input_dict(nodes_dict): +def _get_input_dict(nodes_dict: dict) -> dict: return {k: v for k, v in nodes_dict.items() if not isfunction(v)} -def _get_workflow(nodes_dict, input_dict, total_dict, source_handles_dict): +def _get_workflow(nodes_dict: dict, input_dict: dict, total_dict: dict, source_handles_dict: dict) -> list: def get_attr_helper(obj, source_handle): if source_handle is None: return getattr(obj, "output") @@ -262,7 +262,7 @@ def _get_item_from_tuple(input_obj, index, index_lst): return list(input_obj)[index_lst.index(index)] -def load_workflow_json(file_name): +def load_workflow_json(file_name: str) -> Flow: with open(file_name, "r") as f: content = json.load(f) @@ -302,7 +302,7 @@ def load_workflow_json(file_name): return Flow(task_lst) -def write_workflow_json(flow, file_name="workflow.json"): +def write_workflow_json(flow: Flow, file_name: str = "workflow.json"): flow_dict = flow.as_dict() function_dict = _get_function_dict(flow=flow) nodes_dict, nodes_mapping_dict = _get_nodes_dict(function_dict=function_dict) diff --git a/python_workflow_definition/src/python_workflow_definition/plot.py b/python_workflow_definition/src/python_workflow_definition/plot.py index d5297ff..26e89a5 100644 --- a/python_workflow_definition/src/python_workflow_definition/plot.py +++ b/python_workflow_definition/src/python_workflow_definition/plot.py @@ -15,7 +15,7 @@ ) -def plot(file_name): +def plot(file_name: str): with open(file_name, "r") as f: content = json.load(f) diff --git a/python_workflow_definition/src/python_workflow_definition/purepython.py b/python_workflow_definition/src/python_workflow_definition/purepython.py index ea24945..1778c04 100644 --- a/python_workflow_definition/src/python_workflow_definition/purepython.py +++ b/python_workflow_definition/src/python_workflow_definition/purepython.py @@ -18,7 +18,7 @@ ) -def resort_total_lst(total_lst, nodes_dict): +def resort_total_lst(total_lst: list, nodes_dict: dict) -> list: nodes_with_dep_lst = list(sorted([v[0] for v in total_lst])) nodes_without_dep_lst = [ k for k in nodes_dict.keys() if k not in nodes_with_dep_lst @@ -36,7 +36,7 @@ def resort_total_lst(total_lst, nodes_dict): return total_new_lst -def group_edges(edges_lst): +def group_edges(edges_lst: list) -> list: edges_sorted_lst = sorted(edges_lst, key=lambda x: x[TARGET_LABEL], reverse=True) total_lst, tmp_lst = [], [] target_id = edges_sorted_lst[0][TARGET_LABEL] @@ -51,7 +51,7 @@ def group_edges(edges_lst): return total_lst -def _get_value(result_dict, nodes_new_dict, link_dict): +def _get_value(result_dict: dict, nodes_new_dict: dict, link_dict: dict): source, source_handle = link_dict[SOURCE_LABEL], link_dict[SOURCE_PORT_LABEL] if source in result_dict.keys(): result = result_dict[source] @@ -65,7 +65,7 @@ def _get_value(result_dict, nodes_new_dict, link_dict): return result[source_handle] -def load_workflow_json(file_name): +def load_workflow_json(file_name: str): with open(file_name, "r") as f: content = json.load(f) diff --git a/python_workflow_definition/src/python_workflow_definition/pyiron_base.py b/python_workflow_definition/src/python_workflow_definition/pyiron_base.py index 1fbf8ae..f3f826c 100644 --- a/python_workflow_definition/src/python_workflow_definition/pyiron_base.py +++ b/python_workflow_definition/src/python_workflow_definition/pyiron_base.py @@ -1,6 +1,7 @@ from importlib import import_module from inspect import isfunction import json +from typing import Optional import numpy as np from pyiron_base import job, Project @@ -19,7 +20,7 @@ ) -def _resort_total_lst(total_lst, nodes_dict): +def _resort_total_lst(total_lst: list, nodes_dict: dict) -> list: nodes_with_dep_lst = list(sorted([v[0] for v in total_lst])) nodes_without_dep_lst = [ k for k in nodes_dict.keys() if k not in nodes_with_dep_lst @@ -37,7 +38,7 @@ def _resort_total_lst(total_lst, nodes_dict): return total_new_lst -def _group_edges(edges_lst): +def _group_edges(edges_lst: list) -> list: edges_sorted_lst = sorted(edges_lst, key=lambda x: x[TARGET_LABEL], reverse=True) total_lst, tmp_lst = [], [] target_id = edges_sorted_lst[0][TARGET_LABEL] @@ -52,10 +53,10 @@ def _group_edges(edges_lst): return total_lst -def _get_source(nodes_dict, delayed_object_dict, source, sourceHandle): - if source in delayed_object_dict.keys() and sourceHandle is not None: +def _get_source(nodes_dict: dict, delayed_object_dict: dict, source: str, source_handle: str): + if source in delayed_object_dict.keys() and source_handle is not None: return ( - delayed_object_dict[source].__getattr__("output").__getattr__(sourceHandle) + delayed_object_dict[source].__getattr__("output").__getattr__(source_handle) ) elif source in delayed_object_dict.keys(): return delayed_object_dict[source] @@ -63,7 +64,7 @@ def _get_source(nodes_dict, delayed_object_dict, source, sourceHandle): return nodes_dict[source] -def _get_delayed_object_dict(total_lst, nodes_dict, source_handle_dict, pyiron_project): +def _get_delayed_object_dict(total_lst: list, nodes_dict: dict, source_handle_dict: dict, pyiron_project: Project) -> dict: delayed_object_dict = {} for item in total_lst: key, input_dict = item @@ -72,7 +73,7 @@ def _get_delayed_object_dict(total_lst, nodes_dict, source_handle_dict, pyiron_p nodes_dict=nodes_dict, delayed_object_dict=delayed_object_dict, source=v[SOURCE_LABEL], - sourceHandle=v[SOURCE_PORT_LABEL], + source_handle=v[SOURCE_PORT_LABEL], ) for k, v in input_dict.items() } @@ -83,15 +84,15 @@ def _get_delayed_object_dict(total_lst, nodes_dict, source_handle_dict, pyiron_p return delayed_object_dict -def get_dict(**kwargs): +def get_dict(**kwargs) -> dict: return {k: v for k, v in kwargs["kwargs"].items()} -def get_list(**kwargs): +def get_list(**kwargs) -> list: return list(kwargs["kwargs"].values()) -def _remove_server_obj(nodes_dict, edges_lst): +def _remove_server_obj(nodes_dict: dict, edges_lst: list): server_lst = [k for k in nodes_dict.keys() if k.startswith("server_obj_")] for s in server_lst: del nodes_dict[s] @@ -99,14 +100,14 @@ def _remove_server_obj(nodes_dict, edges_lst): return nodes_dict, edges_lst -def _get_nodes(connection_dict, delayed_object_updated_dict): +def _get_nodes(connection_dict: dict, delayed_object_updated_dict: dict): return { connection_dict[k]: v._python_function if isinstance(v, DelayedObject) else v for k, v in delayed_object_updated_dict.items() } -def _get_unique_objects(nodes_dict): +def _get_unique_objects(nodes_dict: dict): delayed_object_dict = {} for k, v in nodes_dict.items(): if isinstance(v, DelayedObject): @@ -158,7 +159,7 @@ def _get_unique_objects(nodes_dict): return delayed_object_updated_dict, match_dict -def _get_connection_dict(delayed_object_updated_dict, match_dict): +def _get_connection_dict(delayed_object_updated_dict: dict, match_dict: dict): new_obj_dict = {} connection_dict = {} lookup_dict = {} @@ -174,7 +175,7 @@ def _get_connection_dict(delayed_object_updated_dict, match_dict): return connection_dict, lookup_dict -def _get_edges_dict(edges_lst, nodes_dict, connection_dict, lookup_dict): +def _get_edges_dict(edges_lst: list, nodes_dict: dict, connection_dict: dict, lookup_dict: dict): edges_dict_lst = [] existing_connection_lst = [] for ep in edges_lst: @@ -216,7 +217,7 @@ def _get_edges_dict(edges_lst, nodes_dict, connection_dict, lookup_dict): return edges_dict_lst -def load_workflow_json(file_name, project=None): +def load_workflow_json(file_name: str, project: Optional[Project]=None): if project is None: project = Project(".") @@ -247,7 +248,7 @@ def load_workflow_json(file_name, project=None): return list(delayed_object_dict.values()) -def write_workflow_json(delayed_object, file_name="workflow.json"): +def write_workflow_json(delayed_object: DelayedObject, file_name: str="workflow.json"): nodes_dict, edges_lst = delayed_object.get_graph() nodes_dict, edges_lst = _remove_server_obj( nodes_dict=nodes_dict, edges_lst=edges_lst diff --git a/python_workflow_definition/src/python_workflow_definition/shared.py b/python_workflow_definition/src/python_workflow_definition/shared.py index 8017cd7..ebe34c7 100644 --- a/python_workflow_definition/src/python_workflow_definition/shared.py +++ b/python_workflow_definition/src/python_workflow_definition/shared.py @@ -6,17 +6,17 @@ TARGET_PORT_LABEL = "targetPort" -def get_dict(**kwargs): +def get_dict(**kwargs) -> dict: # NOTE: In WG, this will automatically be wrapped in a dict with the `result` key return {k: v for k, v in kwargs.items()} # return {'dict': {k: v for k, v in kwargs.items()}} -def get_list(**kwargs): +def get_list(**kwargs) -> list: return list(kwargs.values()) -def get_kwargs(lst): +def get_kwargs(lst: list) -> dict: return { t[TARGET_PORT_LABEL]: { SOURCE_LABEL: t[SOURCE_LABEL], @@ -26,7 +26,7 @@ def get_kwargs(lst): } -def get_source_handles(edges_lst): +def get_source_handles(edges_lst: list) -> dict: source_handle_dict = {} for ed in edges_lst: if ed[SOURCE_LABEL] not in source_handle_dict.keys(): @@ -38,7 +38,7 @@ def get_source_handles(edges_lst): } -def convert_nodes_list_to_dict(nodes_list): +def convert_nodes_list_to_dict(nodes_list: list) -> dict: return { str(el["id"]): el["value"] if "value" in el else el["function"] for el in sorted(nodes_list, key=lambda d: d["id"])