Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions flowrep/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,12 @@ def _parse_function_call(self, value, control_flow: str | None = None) -> str:
if control_flow is not None:
self.function_defs[unique_func_name]["control_flow"] = control_flow

arg_list = _get_function_keywords(self.scope[func_name])

# Parse inputs (positional + keyword)
for i, arg in enumerate(value.get("args", [])):
self._add_input_edge(
arg, unique_func_name, input_index=i, control_flow=control_flow
arg, unique_func_name, input_name=arg_list[i], control_flow=control_flow
)
for kw in value.get("keywords", []):
self._add_input_edge(
Expand Down Expand Up @@ -416,8 +418,8 @@ def _detect_io_variables_from_control_flow(
inputs = var_inp_1.intersection(var_inp_2)
input_stem = [inp.rsplit("_", 1)[0] for inp in inputs]
outputs = var_out_1.intersection(var_out_2)
# This is needed in order to add those outputs which are updated during
# the control flow to the outputs. For example, the variable x in the
# Lines below are needed in order to add those outputs which are updated
# during the control flow to the outputs. For example, the variable x in the
# following workflow has to be in the outputs because it is updated in the
# while loop, even though it is not used subsequently.
# def f(x):
Expand Down Expand Up @@ -1213,3 +1215,14 @@ def _get_function_metadata(cls: Callable | dict[str, str]) -> dict[str, str]:
"qualname": qualname,
"version": version,
}


def _get_function_keywords(function: Callable) -> list[str | int]:
signature = inspect.signature(function)
items = []
for ii, (name, param) in enumerate(signature.parameters.items()):
if param.kind == inspect.Parameter.POSITIONAL_ONLY:
items.append(ii)
else:
items.append(name)
return items
Copy link

Copilot AI Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function returns a list but the type annotation indicates dict[str, Any]. This mismatch will cause type checking issues. Change the return type to list[Union[int, str]] to match the actual return value.

Copilot uses AI. Check for mistakes.
108 changes: 57 additions & 51 deletions tests/unit/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@


def workflow_to_use_undefined_variable(a=10, b=20):
result = add(a, u)

Check failure on line 79 in tests/unit/test_workflow.py

View workflow job for this annotation

GitHub Actions / pyiron / ruff-check

Ruff (F821)

tests/unit/test_workflow.py:79:21: F821 Undefined name `u`
return result


Expand Down Expand Up @@ -189,12 +189,12 @@
all_data = [
("operation_0", "c_0", {"type": "output", "output_index": 0}),
("operation_0", "d_0", {"type": "output", "output_index": 1}),
("c_0", "add_0", {"type": "input", "input_index": 0}),
("c_0", "add_0", {"type": "input", "input_name": "x"}),
("d_0", "add_0", {"type": "input", "input_name": "y"}),
("a_0", "operation_0", {"type": "input", "input_index": 0}),
("b_0", "operation_0", {"type": "input", "input_index": 1}),
("a_0", "operation_0", {"type": "input", "input_name": "x"}),
("b_0", "operation_0", {"type": "input", "input_name": "y"}),
("add_0", "e_0", {"type": "output"}),
("e_0", "multiply_0", {"type": "input", "input_index": 0}),
("e_0", "multiply_0", {"type": "input", "input_name": "x"}),
("multiply_0", "f_0", {"type": "output"}),
("f_0", "output", {"type": "input"}),
("input", "a_0", {"type": "output"}),
Expand Down Expand Up @@ -252,11 +252,11 @@
},
},
"edges": [
("inputs.a", "operation_0.inputs.0"),
("inputs.b", "operation_0.inputs.1"),
("operation_0.outputs.0", "add_0.inputs.0"),
("inputs.a", "operation_0.inputs.x"),
("inputs.b", "operation_0.inputs.y"),
("operation_0.outputs.0", "add_0.inputs.x"),
("operation_0.outputs.1", "add_0.inputs.y"),
("add_0.outputs.output", "multiply_0.inputs.0"),
("add_0.outputs.output", "multiply_0.inputs.x"),
("multiply_0.outputs.output", "outputs.f"),
],
"label": "example_macro",
Expand Down Expand Up @@ -302,11 +302,11 @@
},
},
"edges": [
("inputs.a", "operation_0.inputs.0"),
("inputs.b", "operation_0.inputs.1"),
("operation_0.outputs.0", "add_0.inputs.0"),
("inputs.a", "operation_0.inputs.x"),
("inputs.b", "operation_0.inputs.y"),
("operation_0.outputs.0", "add_0.inputs.x"),
("operation_0.outputs.1", "add_0.inputs.y"),
("add_0.outputs.output", "multiply_0.inputs.0"),
("add_0.outputs.output", "multiply_0.inputs.x"),
("multiply_0.outputs.output", "outputs.f"),
],
"label": "example_macro_0",
Expand All @@ -322,10 +322,10 @@
},
},
"edges": [
("inputs.a", "example_macro_0.inputs.0"),
("inputs.b", "example_macro_0.inputs.1"),
("inputs.b", "add_0.inputs.1"),
("example_macro_0.outputs.output", "add_0.inputs.0"),
("inputs.a", "example_macro_0.inputs.a"),
("inputs.b", "example_macro_0.inputs.b"),
("inputs.b", "add_0.inputs.y"),
("example_macro_0.outputs.output", "add_0.inputs.x"),
("add_0.outputs.output", "outputs.z"),
],
"label": "example_workflow",
Expand Down Expand Up @@ -415,12 +415,12 @@
sorted(wf["nodes"]["injected_While_0"]["edges"]),
sorted(
[
("inputs.x", "test.inputs.0"),
("inputs.b", "test.inputs.1"),
("inputs.b", "add_1.inputs.1"),
("inputs.a", "add_1.inputs.0"),
("inputs.a", "multiply_0.inputs.0"),
("add_1.outputs.output", "multiply_0.inputs.1"),
("inputs.x", "test.inputs.a"),
("inputs.b", "test.inputs.b"),
("inputs.b", "add_1.inputs.y"),
("inputs.a", "add_1.inputs.x"),
("inputs.a", "multiply_0.inputs.x"),
("add_1.outputs.output", "multiply_0.inputs.y"),
("multiply_0.outputs.output", "outputs.z"),
("add_1.outputs.output", "outputs.x"),
]
Expand Down Expand Up @@ -460,7 +460,7 @@
self.assertIn("add_0", data["nodes"])
self.assertIn("y", data["outputs"])
self.assertIn(
("add_0.outputs.output", "check_positive_0.inputs.0"), data["edges"]
("add_0.outputs.output", "check_positive_0.inputs.x"), data["edges"]
)
self.assertNotIn("outputs", data["nodes"]["check_positive_0"])

Expand Down Expand Up @@ -568,8 +568,8 @@
[
("inputs.a", "injected_For_0.inputs.a"),
("inputs.b", "injected_For_0.inputs.b"),
("inputs.a", "add_0.inputs.0"),
("inputs.b", "add_0.inputs.1"),
("inputs.a", "add_0.inputs.x"),
("inputs.b", "add_0.inputs.y"),
("add_0.outputs.output", "injected_For_0.inputs.x"),
("injected_For_0.outputs.z", "outputs.z"),
]
Expand All @@ -579,12 +579,12 @@
sorted(data["nodes"]["injected_For_0"]["edges"]),
sorted(
[
("inputs.x", "iter.inputs.0"),
("inputs.b", "iter.inputs.1"),
("iter.outputs.output", "add_1.inputs.1"),
("inputs.a", "add_1.inputs.0"),
("inputs.a", "multiply_0.inputs.0"),
("add_1.outputs.output", "multiply_0.inputs.1"),
("inputs.x", "iter.inputs.a"),
("inputs.b", "iter.inputs.b"),
("iter.outputs.output", "add_1.inputs.y"),
("inputs.a", "add_1.inputs.x"),
("inputs.a", "multiply_0.inputs.x"),
("add_1.outputs.output", "multiply_0.inputs.y"),
("multiply_0.outputs.output", "outputs.z"),
("add_1.outputs.output", "outputs.x"),
]
Expand All @@ -610,8 +610,8 @@
sorted(
[
("inputs.b", "injected_If_0.inputs.b"),
("inputs.a", "add_0.inputs.0"),
("inputs.b", "add_0.inputs.1"),
("inputs.a", "add_0.inputs.x"),
("inputs.b", "add_0.inputs.y"),
("add_0.outputs.output", "injected_If_0.inputs.x"),
("injected_If_0.outputs.x", "outputs.x"),
]
Expand All @@ -621,10 +621,10 @@
sorted(data["nodes"]["injected_If_0"]["edges"]),
sorted(
[
("inputs.x", "multiply_0.inputs.0"),
("inputs.b", "multiply_0.inputs.1"),
("inputs.x", "test.inputs.0"),
("inputs.b", "test.inputs.1"),
("inputs.x", "multiply_0.inputs.x"),
("inputs.b", "multiply_0.inputs.y"),
("inputs.x", "test.inputs.a"),
("inputs.b", "test.inputs.b"),
("multiply_0.outputs.output", "outputs.x"),
]
),
Expand All @@ -639,8 +639,8 @@
[
("inputs.b", "injected_If_0.inputs.b"),
("inputs.a", "injected_Else_0.inputs.a"),
("inputs.a", "add_0.inputs.0"),
("inputs.b", "add_0.inputs.1"),
("inputs.a", "add_0.inputs.x"),
("inputs.b", "add_0.inputs.y"),
("add_0.outputs.output", "injected_If_0.inputs.x"),
("add_0.outputs.output", "injected_Else_0.inputs.x"),
("injected_If_0.outputs.x", "outputs.x"),
Expand All @@ -652,10 +652,10 @@
sorted(data["nodes"]["injected_If_0"]["edges"]),
sorted(
[
("inputs.x", "multiply_0.inputs.0"),
("inputs.b", "multiply_0.inputs.1"),
("inputs.x", "test.inputs.0"),
("inputs.b", "test.inputs.1"),
("inputs.x", "multiply_0.inputs.x"),
("inputs.b", "multiply_0.inputs.y"),
("inputs.x", "test.inputs.a"),
("inputs.b", "test.inputs.b"),
("multiply_0.outputs.output", "outputs.x"),
]
),
Expand All @@ -664,11 +664,11 @@
sorted(data["nodes"]["injected_Else_0"]["edges"]),
sorted(
[
("inputs.x", "multiply_1.inputs.0"),
("inputs.x", "multiply_2.inputs.0"), # This must not be here
("inputs.a", "multiply_1.inputs.1"),
("multiply_1.outputs.output", "multiply_2.inputs.0"),
("inputs.a", "multiply_2.inputs.1"),
("inputs.x", "multiply_1.inputs.x"),
("inputs.x", "multiply_2.inputs.x"), # This must not be here
("inputs.a", "multiply_1.inputs.y"),
("multiply_1.outputs.output", "multiply_2.inputs.x"),
("inputs.a", "multiply_2.inputs.y"),
("multiply_2.outputs.output", "outputs.x"),
]
),
Expand Down Expand Up @@ -703,7 +703,7 @@
"version": "not_defined",
"connected_inputs": [],
},
"inputs": {"0": 10, "1": 20},
"inputs": {"x": 10, "y": 20},
"outputs": ["output"],
},
)
Expand All @@ -718,9 +718,9 @@
"module": multiply.__module__,
"qualname": "multiply",
"version": "not_defined",
"connected_inputs": ["0"],
"connected_inputs": ["x"],
},
"inputs": {"0": add_hashed + "@output", "1": 20},
"inputs": {"x": add_hashed + "@output", "y": 20},
"outputs": ["output"],
},
)
Expand Down Expand Up @@ -760,6 +760,12 @@
fwf._get_function_metadata(operation),
)

def test_get_function_keyword(self):
def my_test_function(x, /, y, *, z):
return x + y + z

self.assertEqual(fwf._get_function_keywords(my_test_function), [0, "y", "z"])


if __name__ == "__main__":
unittest.main()
Loading