Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)


def load_workflow_json(file_name):
def load_workflow_json(file_name: str) -> WorkGraph:
with open(file_name) as f:
data = json.load(f)

Expand Down Expand Up @@ -77,7 +77,7 @@ def load_workflow_json(file_name):
return wg


def write_workflow_json(wg, file_name):
def write_workflow_json(wg: WorkGraph, file_name: str) -> dict:
data = {NODES_LABEL: [], EDGES_LABEL: []}
node_name_mapping = {}
data_node_name_mapping = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from concurrent.futures import Executor
from importlib import import_module
from inspect import isfunction
import json


from python_workflow_definition.shared import (
Expand All @@ -21,7 +22,7 @@ def get_item(obj, key):
return obj[key]


def _get_value(result_dict, nodes_new_dict, link_dict, exe):
def _get_value(result_dict: dict, nodes_new_dict: dict, link_dict: dict, exe: Executor):
source, source_handle = link_dict[SOURCE_LABEL], link_dict[SOURCE_PORT_LABEL]
if source in result_dict.keys():
result = result_dict[source]
Expand All @@ -35,7 +36,7 @@ def _get_value(result_dict, nodes_new_dict, link_dict, exe):
return exe.submit(fn=get_item, obj=result, key=source_handle)


def load_workflow_json(file_name, exe):
def load_workflow_json(file_name: str, exe: Executor):
with open(file_name, "r") as f:
content = json.load(f)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
)


def _get_function_dict(flow):
def _get_function_dict(flow: Flow):
return {job.uuid: job.function for job in flow.jobs}


def _get_nodes_dict(function_dict):
def _get_nodes_dict(function_dict: dict):
nodes_dict, nodes_mapping_dict = {}, {}
for i, [k, v] in enumerate(function_dict.items()):
nodes_dict[i] = v
Expand All @@ -33,7 +33,7 @@ def _get_nodes_dict(function_dict):
return nodes_dict, nodes_mapping_dict


def _get_edge_from_dict(target, key, value_dict, nodes_mapping_dict):
def _get_edge_from_dict(target: str, key: str, value_dict: dict, nodes_mapping_dict: dict) -> dict:
if len(value_dict["attributes"]) == 1:
return {
TARGET_LABEL: target,
Expand All @@ -50,7 +50,7 @@ def _get_edge_from_dict(target, key, value_dict, nodes_mapping_dict):
}


def _get_edges_and_extend_nodes(flow_dict, nodes_mapping_dict, nodes_dict):
def _get_edges_and_extend_nodes(flow_dict: dict, nodes_mapping_dict: dict, nodes_dict: dict):
edges_lst = []
for job in flow_dict["jobs"]:
for k, v in job["function_kwargs"].items():
Expand Down Expand Up @@ -185,7 +185,7 @@ def _get_edges_and_extend_nodes(flow_dict, nodes_mapping_dict, nodes_dict):
return edges_lst, nodes_dict


def _resort_total_lst(total_dict, nodes_dict):
def _resort_total_lst(total_dict: dict, nodes_dict: dict) -> dict:
nodes_with_dep_lst = list(sorted(total_dict.keys()))
nodes_without_dep_lst = [
k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
Expand All @@ -205,7 +205,7 @@ def _resort_total_lst(total_dict, nodes_dict):
return total_new_dict


def _group_edges(edges_lst):
def _group_edges(edges_lst: list) -> dict:
total_dict = {}
for ed_major in edges_lst:
target_id = ed_major[TARGET_LABEL]
Expand All @@ -218,11 +218,11 @@ def _group_edges(edges_lst):
return total_dict


def _get_input_dict(nodes_dict):
def _get_input_dict(nodes_dict: dict) -> dict:
return {k: v for k, v in nodes_dict.items() if not isfunction(v)}


def _get_workflow(nodes_dict, input_dict, total_dict, source_handles_dict):
def _get_workflow(nodes_dict: dict, input_dict: dict, total_dict: dict, source_handles_dict: dict) -> list:
def get_attr_helper(obj, source_handle):
if source_handle is None:
return getattr(obj, "output")
Expand Down Expand Up @@ -262,7 +262,7 @@ def _get_item_from_tuple(input_obj, index, index_lst):
return list(input_obj)[index_lst.index(index)]


def load_workflow_json(file_name):
def load_workflow_json(file_name: str) -> Flow:
with open(file_name, "r") as f:
content = json.load(f)

Expand Down Expand Up @@ -302,7 +302,7 @@ def load_workflow_json(file_name):
return Flow(task_lst)


def write_workflow_json(flow, file_name="workflow.json"):
def write_workflow_json(flow: Flow, file_name: str = "workflow.json"):
flow_dict = flow.as_dict()
function_dict = _get_function_dict(flow=flow)
nodes_dict, nodes_mapping_dict = _get_nodes_dict(function_dict=function_dict)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)


def plot(file_name):
def plot(file_name: str):
with open(file_name, "r") as f:
content = json.load(f)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)


def resort_total_lst(total_lst, nodes_dict):
def resort_total_lst(total_lst: list, nodes_dict: dict) -> list:
nodes_with_dep_lst = list(sorted([v[0] for v in total_lst]))
nodes_without_dep_lst = [
k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
Expand All @@ -36,7 +36,7 @@ def resort_total_lst(total_lst, nodes_dict):
return total_new_lst


def group_edges(edges_lst):
def group_edges(edges_lst: list) -> list:
edges_sorted_lst = sorted(edges_lst, key=lambda x: x[TARGET_LABEL], reverse=True)
total_lst, tmp_lst = [], []
target_id = edges_sorted_lst[0][TARGET_LABEL]
Expand All @@ -51,7 +51,7 @@ def group_edges(edges_lst):
return total_lst


def _get_value(result_dict, nodes_new_dict, link_dict):
def _get_value(result_dict: dict, nodes_new_dict: dict, link_dict: dict):
source, source_handle = link_dict[SOURCE_LABEL], link_dict[SOURCE_PORT_LABEL]
if source in result_dict.keys():
result = result_dict[source]
Expand All @@ -65,7 +65,7 @@ def _get_value(result_dict, nodes_new_dict, link_dict):
return result[source_handle]


def load_workflow_json(file_name):
def load_workflow_json(file_name: str):
with open(file_name, "r") as f:
content = json.load(f)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from importlib import import_module
from inspect import isfunction
import json
from typing import Optional

import numpy as np
from pyiron_base import job, Project
Expand All @@ -19,7 +20,7 @@
)


def _resort_total_lst(total_lst, nodes_dict):
def _resort_total_lst(total_lst: list, nodes_dict: dict) -> list:
nodes_with_dep_lst = list(sorted([v[0] for v in total_lst]))
nodes_without_dep_lst = [
k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
Expand All @@ -37,7 +38,7 @@ def _resort_total_lst(total_lst, nodes_dict):
return total_new_lst


def _group_edges(edges_lst):
def _group_edges(edges_lst: list) -> list:
edges_sorted_lst = sorted(edges_lst, key=lambda x: x[TARGET_LABEL], reverse=True)
total_lst, tmp_lst = [], []
target_id = edges_sorted_lst[0][TARGET_LABEL]
Expand All @@ -52,18 +53,18 @@ def _group_edges(edges_lst):
return total_lst


def _get_source(nodes_dict, delayed_object_dict, source, sourceHandle):
if source in delayed_object_dict.keys() and sourceHandle is not None:
def _get_source(nodes_dict: dict, delayed_object_dict: dict, source: str, source_handle: str):
if source in delayed_object_dict.keys() and source_handle is not None:
return (
delayed_object_dict[source].__getattr__("output").__getattr__(sourceHandle)
delayed_object_dict[source].__getattr__("output").__getattr__(source_handle)
)
elif source in delayed_object_dict.keys():
return delayed_object_dict[source]
else:
return nodes_dict[source]


def _get_delayed_object_dict(total_lst, nodes_dict, source_handle_dict, pyiron_project):
def _get_delayed_object_dict(total_lst: list, nodes_dict: dict, source_handle_dict: dict, pyiron_project: Project) -> dict:
delayed_object_dict = {}
for item in total_lst:
key, input_dict = item
Expand All @@ -72,7 +73,7 @@ def _get_delayed_object_dict(total_lst, nodes_dict, source_handle_dict, pyiron_p
nodes_dict=nodes_dict,
delayed_object_dict=delayed_object_dict,
source=v[SOURCE_LABEL],
sourceHandle=v[SOURCE_PORT_LABEL],
source_handle=v[SOURCE_PORT_LABEL],
)
for k, v in input_dict.items()
}
Expand All @@ -83,30 +84,30 @@ def _get_delayed_object_dict(total_lst, nodes_dict, source_handle_dict, pyiron_p
return delayed_object_dict


def get_dict(**kwargs):
def get_dict(**kwargs) -> dict:
return {k: v for k, v in kwargs["kwargs"].items()}


def get_list(**kwargs):
def get_list(**kwargs) -> list:
return list(kwargs["kwargs"].values())


def _remove_server_obj(nodes_dict, edges_lst):
def _remove_server_obj(nodes_dict: dict, edges_lst: list):
server_lst = [k for k in nodes_dict.keys() if k.startswith("server_obj_")]
for s in server_lst:
del nodes_dict[s]
edges_lst = [ep for ep in edges_lst if s not in ep]
return nodes_dict, edges_lst


def _get_nodes(connection_dict, delayed_object_updated_dict):
def _get_nodes(connection_dict: dict, delayed_object_updated_dict: dict):
return {
connection_dict[k]: v._python_function if isinstance(v, DelayedObject) else v
for k, v in delayed_object_updated_dict.items()
}


def _get_unique_objects(nodes_dict):
def _get_unique_objects(nodes_dict: dict):
delayed_object_dict = {}
for k, v in nodes_dict.items():
if isinstance(v, DelayedObject):
Expand Down Expand Up @@ -158,7 +159,7 @@ def _get_unique_objects(nodes_dict):
return delayed_object_updated_dict, match_dict


def _get_connection_dict(delayed_object_updated_dict, match_dict):
def _get_connection_dict(delayed_object_updated_dict: dict, match_dict: dict):
new_obj_dict = {}
connection_dict = {}
lookup_dict = {}
Expand All @@ -174,7 +175,7 @@ def _get_connection_dict(delayed_object_updated_dict, match_dict):
return connection_dict, lookup_dict


def _get_edges_dict(edges_lst, nodes_dict, connection_dict, lookup_dict):
def _get_edges_dict(edges_lst: list, nodes_dict: dict, connection_dict: dict, lookup_dict: dict):
edges_dict_lst = []
existing_connection_lst = []
for ep in edges_lst:
Expand Down Expand Up @@ -216,7 +217,7 @@ def _get_edges_dict(edges_lst, nodes_dict, connection_dict, lookup_dict):
return edges_dict_lst


def load_workflow_json(file_name, project=None):
def load_workflow_json(file_name: str, project: Optional[Project]=None):
if project is None:
project = Project(".")

Expand Down Expand Up @@ -247,7 +248,7 @@ def load_workflow_json(file_name, project=None):
return list(delayed_object_dict.values())


def write_workflow_json(delayed_object, file_name="workflow.json"):
def write_workflow_json(delayed_object: DelayedObject, file_name: str="workflow.json"):
nodes_dict, edges_lst = delayed_object.get_graph()
nodes_dict, edges_lst = _remove_server_obj(
nodes_dict=nodes_dict, edges_lst=edges_lst
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
TARGET_PORT_LABEL = "targetPort"


def get_dict(**kwargs):
def get_dict(**kwargs) -> dict:
# NOTE: In WG, this will automatically be wrapped in a dict with the `result` key
return {k: v for k, v in kwargs.items()}
# return {'dict': {k: v for k, v in kwargs.items()}}


def get_list(**kwargs):
def get_list(**kwargs) -> list:
return list(kwargs.values())


def get_kwargs(lst):
def get_kwargs(lst: list) -> dict:
return {
t[TARGET_PORT_LABEL]: {
SOURCE_LABEL: t[SOURCE_LABEL],
Expand All @@ -26,7 +26,7 @@ def get_kwargs(lst):
}


def get_source_handles(edges_lst):
def get_source_handles(edges_lst: list) -> dict:
source_handle_dict = {}
for ed in edges_lst:
if ed[SOURCE_LABEL] not in source_handle_dict.keys():
Expand All @@ -38,7 +38,7 @@ def get_source_handles(edges_lst):
}


def convert_nodes_list_to_dict(nodes_list):
def convert_nodes_list_to_dict(nodes_list: list) -> dict:
return {
str(el["id"]): el["value"] if "value" in el else el["function"]
for el in sorted(nodes_list, key=lambda d: d["id"])
Expand Down
Loading