diff --git a/flowrep/workflow.py b/flowrep/workflow.py index 8054ba8..ee286ff 100644 --- a/flowrep/workflow.py +++ b/flowrep/workflow.py @@ -207,10 +207,12 @@ def _parse_function_call(self, value, control_flow: str | None = None) -> str: if control_flow is not None: self.function_defs[unique_func_name]["control_flow"] = control_flow + arg_list = _get_function_keywords(self.scope[func_name]) + # Parse inputs (positional + keyword) for i, arg in enumerate(value.get("args", [])): self._add_input_edge( - arg, unique_func_name, input_index=i, control_flow=control_flow + arg, unique_func_name, input_name=arg_list[i], control_flow=control_flow ) for kw in value.get("keywords", []): self._add_input_edge( @@ -416,8 +418,8 @@ def _detect_io_variables_from_control_flow( inputs = var_inp_1.intersection(var_inp_2) input_stem = [inp.rsplit("_", 1)[0] for inp in inputs] outputs = var_out_1.intersection(var_out_2) - # This is needed in order to add those outputs which are updated during - # the control flow to the outputs. For example, the variable x in the + # Lines below are needed in order to add those outputs which are updated + # during the control flow to the outputs. For example, the variable x in the # following workflow has to be in the outputs because it is updated in the # while loop, even though it is not used subsequently. # def f(x): @@ -1213,3 +1215,14 @@ def _get_function_metadata(cls: Callable | dict[str, str]) -> dict[str, str]: "qualname": qualname, "version": version, } + + +def _get_function_keywords(function: Callable) -> list[str | int]: + signature = inspect.signature(function) + items = [] + for ii, (name, param) in enumerate(signature.parameters.items()): + if param.kind == inspect.Parameter.POSITIONAL_ONLY: + items.append(ii) + else: + items.append(name) + return items diff --git a/tests/unit/test_workflow.py b/tests/unit/test_workflow.py index 855c3eb..5b8eaad 100644 --- a/tests/unit/test_workflow.py +++ b/tests/unit/test_workflow.py @@ -189,12 +189,12 @@ def test_analyzer(self): all_data = [ ("operation_0", "c_0", {"type": "output", "output_index": 0}), ("operation_0", "d_0", {"type": "output", "output_index": 1}), - ("c_0", "add_0", {"type": "input", "input_index": 0}), + ("c_0", "add_0", {"type": "input", "input_name": "x"}), ("d_0", "add_0", {"type": "input", "input_name": "y"}), - ("a_0", "operation_0", {"type": "input", "input_index": 0}), - ("b_0", "operation_0", {"type": "input", "input_index": 1}), + ("a_0", "operation_0", {"type": "input", "input_name": "x"}), + ("b_0", "operation_0", {"type": "input", "input_name": "y"}), ("add_0", "e_0", {"type": "output"}), - ("e_0", "multiply_0", {"type": "input", "input_index": 0}), + ("e_0", "multiply_0", {"type": "input", "input_name": "x"}), ("multiply_0", "f_0", {"type": "output"}), ("f_0", "output", {"type": "input"}), ("input", "a_0", {"type": "output"}), @@ -252,11 +252,11 @@ def test_get_workflow_dict(self): }, }, "edges": [ - ("inputs.a", "operation_0.inputs.0"), - ("inputs.b", "operation_0.inputs.1"), - ("operation_0.outputs.0", "add_0.inputs.0"), + ("inputs.a", "operation_0.inputs.x"), + ("inputs.b", "operation_0.inputs.y"), + ("operation_0.outputs.0", "add_0.inputs.x"), ("operation_0.outputs.1", "add_0.inputs.y"), - ("add_0.outputs.output", "multiply_0.inputs.0"), + ("add_0.outputs.output", "multiply_0.inputs.x"), ("multiply_0.outputs.output", "outputs.f"), ], "label": "example_macro", @@ -302,11 +302,11 @@ def test_get_workflow_dict_macro(self): }, }, "edges": [ - ("inputs.a", "operation_0.inputs.0"), - ("inputs.b", "operation_0.inputs.1"), - ("operation_0.outputs.0", "add_0.inputs.0"), + ("inputs.a", "operation_0.inputs.x"), + ("inputs.b", "operation_0.inputs.y"), + ("operation_0.outputs.0", "add_0.inputs.x"), ("operation_0.outputs.1", "add_0.inputs.y"), - ("add_0.outputs.output", "multiply_0.inputs.0"), + ("add_0.outputs.output", "multiply_0.inputs.x"), ("multiply_0.outputs.output", "outputs.f"), ], "label": "example_macro_0", @@ -322,10 +322,10 @@ def test_get_workflow_dict_macro(self): }, }, "edges": [ - ("inputs.a", "example_macro_0.inputs.0"), - ("inputs.b", "example_macro_0.inputs.1"), - ("inputs.b", "add_0.inputs.1"), - ("example_macro_0.outputs.output", "add_0.inputs.0"), + ("inputs.a", "example_macro_0.inputs.a"), + ("inputs.b", "example_macro_0.inputs.b"), + ("inputs.b", "add_0.inputs.y"), + ("example_macro_0.outputs.output", "add_0.inputs.x"), ("add_0.outputs.output", "outputs.z"), ], "label": "example_workflow", @@ -415,12 +415,12 @@ def test_workflow_with_while(self): sorted(wf["nodes"]["injected_While_0"]["edges"]), sorted( [ - ("inputs.x", "test.inputs.0"), - ("inputs.b", "test.inputs.1"), - ("inputs.b", "add_1.inputs.1"), - ("inputs.a", "add_1.inputs.0"), - ("inputs.a", "multiply_0.inputs.0"), - ("add_1.outputs.output", "multiply_0.inputs.1"), + ("inputs.x", "test.inputs.a"), + ("inputs.b", "test.inputs.b"), + ("inputs.b", "add_1.inputs.y"), + ("inputs.a", "add_1.inputs.x"), + ("inputs.a", "multiply_0.inputs.x"), + ("add_1.outputs.output", "multiply_0.inputs.y"), ("multiply_0.outputs.output", "outputs.z"), ("add_1.outputs.output", "outputs.x"), ] @@ -460,7 +460,7 @@ def test_workflow_with_leaf(self): self.assertIn("add_0", data["nodes"]) self.assertIn("y", data["outputs"]) self.assertIn( - ("add_0.outputs.output", "check_positive_0.inputs.0"), data["edges"] + ("add_0.outputs.output", "check_positive_0.inputs.x"), data["edges"] ) self.assertNotIn("outputs", data["nodes"]["check_positive_0"]) @@ -568,8 +568,8 @@ def test_for_loop(self): [ ("inputs.a", "injected_For_0.inputs.a"), ("inputs.b", "injected_For_0.inputs.b"), - ("inputs.a", "add_0.inputs.0"), - ("inputs.b", "add_0.inputs.1"), + ("inputs.a", "add_0.inputs.x"), + ("inputs.b", "add_0.inputs.y"), ("add_0.outputs.output", "injected_For_0.inputs.x"), ("injected_For_0.outputs.z", "outputs.z"), ] @@ -579,12 +579,12 @@ def test_for_loop(self): sorted(data["nodes"]["injected_For_0"]["edges"]), sorted( [ - ("inputs.x", "iter.inputs.0"), - ("inputs.b", "iter.inputs.1"), - ("iter.outputs.output", "add_1.inputs.1"), - ("inputs.a", "add_1.inputs.0"), - ("inputs.a", "multiply_0.inputs.0"), - ("add_1.outputs.output", "multiply_0.inputs.1"), + ("inputs.x", "iter.inputs.a"), + ("inputs.b", "iter.inputs.b"), + ("iter.outputs.output", "add_1.inputs.y"), + ("inputs.a", "add_1.inputs.x"), + ("inputs.a", "multiply_0.inputs.x"), + ("add_1.outputs.output", "multiply_0.inputs.y"), ("multiply_0.outputs.output", "outputs.z"), ("add_1.outputs.output", "outputs.x"), ] @@ -610,8 +610,8 @@ def test_if_statement(self): sorted( [ ("inputs.b", "injected_If_0.inputs.b"), - ("inputs.a", "add_0.inputs.0"), - ("inputs.b", "add_0.inputs.1"), + ("inputs.a", "add_0.inputs.x"), + ("inputs.b", "add_0.inputs.y"), ("add_0.outputs.output", "injected_If_0.inputs.x"), ("injected_If_0.outputs.x", "outputs.x"), ] @@ -621,10 +621,10 @@ def test_if_statement(self): sorted(data["nodes"]["injected_If_0"]["edges"]), sorted( [ - ("inputs.x", "multiply_0.inputs.0"), - ("inputs.b", "multiply_0.inputs.1"), - ("inputs.x", "test.inputs.0"), - ("inputs.b", "test.inputs.1"), + ("inputs.x", "multiply_0.inputs.x"), + ("inputs.b", "multiply_0.inputs.y"), + ("inputs.x", "test.inputs.a"), + ("inputs.b", "test.inputs.b"), ("multiply_0.outputs.output", "outputs.x"), ] ), @@ -639,8 +639,8 @@ def test_if_else_statement(self): [ ("inputs.b", "injected_If_0.inputs.b"), ("inputs.a", "injected_Else_0.inputs.a"), - ("inputs.a", "add_0.inputs.0"), - ("inputs.b", "add_0.inputs.1"), + ("inputs.a", "add_0.inputs.x"), + ("inputs.b", "add_0.inputs.y"), ("add_0.outputs.output", "injected_If_0.inputs.x"), ("add_0.outputs.output", "injected_Else_0.inputs.x"), ("injected_If_0.outputs.x", "outputs.x"), @@ -652,10 +652,10 @@ def test_if_else_statement(self): sorted(data["nodes"]["injected_If_0"]["edges"]), sorted( [ - ("inputs.x", "multiply_0.inputs.0"), - ("inputs.b", "multiply_0.inputs.1"), - ("inputs.x", "test.inputs.0"), - ("inputs.b", "test.inputs.1"), + ("inputs.x", "multiply_0.inputs.x"), + ("inputs.b", "multiply_0.inputs.y"), + ("inputs.x", "test.inputs.a"), + ("inputs.b", "test.inputs.b"), ("multiply_0.outputs.output", "outputs.x"), ] ), @@ -664,11 +664,11 @@ def test_if_else_statement(self): sorted(data["nodes"]["injected_Else_0"]["edges"]), sorted( [ - ("inputs.x", "multiply_1.inputs.0"), - ("inputs.x", "multiply_2.inputs.0"), # This must not be here - ("inputs.a", "multiply_1.inputs.1"), - ("multiply_1.outputs.output", "multiply_2.inputs.0"), - ("inputs.a", "multiply_2.inputs.1"), + ("inputs.x", "multiply_1.inputs.x"), + ("inputs.x", "multiply_2.inputs.x"), # This must not be here + ("inputs.a", "multiply_1.inputs.y"), + ("multiply_1.outputs.output", "multiply_2.inputs.x"), + ("inputs.a", "multiply_2.inputs.y"), ("multiply_2.outputs.output", "outputs.x"), ] ), @@ -703,7 +703,7 @@ def workflow_with_data(a=10, b=20): "version": "not_defined", "connected_inputs": [], }, - "inputs": {"0": 10, "1": 20}, + "inputs": {"x": 10, "y": 20}, "outputs": ["output"], }, ) @@ -718,9 +718,9 @@ def workflow_with_data(a=10, b=20): "module": multiply.__module__, "qualname": "multiply", "version": "not_defined", - "connected_inputs": ["0"], + "connected_inputs": ["x"], }, - "inputs": {"0": add_hashed + "@output", "1": 20}, + "inputs": {"x": add_hashed + "@output", "y": 20}, "outputs": ["output"], }, ) @@ -760,6 +760,12 @@ def test_get_function_metadata(self): fwf._get_function_metadata(operation), ) + def test_get_function_keyword(self): + def my_test_function(x, /, y, *, z): + return x + y + z + + self.assertEqual(fwf._get_function_keywords(my_test_function), [0, "y", "z"]) + if __name__ == "__main__": unittest.main()