diff --git a/src/firewheel/cli/configure_firewheel.py b/src/firewheel/cli/configure_firewheel.py index 57732a7f..7101f7cb 100644 --- a/src/firewheel/cli/configure_firewheel.py +++ b/src/firewheel/cli/configure_firewheel.py @@ -1,5 +1,6 @@ import os import cmd +import shlex import pprint import argparse import operator @@ -43,7 +44,7 @@ def _handle_parsing( # Print the full help message on an error # (see: https://stackoverflow.com/a/29293080) try: - cmd_args = parser.parse_args(args.split()) + cmd_args = parser.parse_args(shlex.split(args)) except SystemExit as err: if err.code == 2: parser.print_help() @@ -202,7 +203,7 @@ def do_set(self, args: str) -> None: # noqa: DOC502 if cmd_args.single is not None: key = cmd_args.single[0] - value = " ".join(cmd_args.single[1:]) + value = shlex.join(cmd_args.single[1:]) self.log.debug( "Setting the FIREWHEEL config value for `%s` to `%s`.", key, value ) diff --git a/src/firewheel/cli/executors/shell.py b/src/firewheel/cli/executors/shell.py index d6ae2004..d03ccb49 100644 --- a/src/firewheel/cli/executors/shell.py +++ b/src/firewheel/cli/executors/shell.py @@ -65,12 +65,16 @@ def execute( "MM_FORCE", "MM_RECOVER", "MM_CGROUP", - "MM_APPEND" + "MM_APPEND", } # Concatenate minimega environment variables env_vars = [ - *(f"{env}={os.environ[env]}" for env in minimega_vars if env in os.environ), + *( + f"{env}={os.environ[env]}" + for env in minimega_vars + if env in os.environ + ), f"FIREWHEEL={fw_path}", f"FIREWHEEL_PYTHON={sys.executable}", f"FIREWHEEL_GRPC_SERVER={grpc_path}", diff --git a/src/firewheel/config/_config.py b/src/firewheel/config/_config.py index 9fc328e5..9d23774d 100644 --- a/src/firewheel/config/_config.py +++ b/src/firewheel/config/_config.py @@ -20,6 +20,7 @@ firewheel.log """ +import shlex import shutil from typing import Any, Set, Dict, List, Final, Tuple, Union from pathlib import Path @@ -315,8 +316,8 @@ def resolve_get( This helper method enables getting the value for a specific configuration key. If a nested key is requested it should be represented using periods to indicate the nesting. This function will return the Python object - of the key. Alternatively, if the value if a list, the user can return - a space separated string. + of the key. Alternatively, if the value is a list, the user can return + a space separated string. Args: key (str): The input key to get in *dot* notation. This means that @@ -389,7 +390,7 @@ def resolve_set( # Set the correct type of the input data try: if isinstance(cur_value, list): - value = value.split() + value = shlex.split(value) elif isinstance(cur_value, bool): value = bool(strtobool(value)) elif cur_value is not None: diff --git a/src/firewheel/tests/unit/cli/test_cli_configure.py b/src/firewheel/tests/unit/cli/test_cli_configure.py index c478daec..a1c28961 100644 --- a/src/firewheel/tests/unit/cli/test_cli_configure.py +++ b/src/firewheel/tests/unit/cli/test_cli_configure.py @@ -45,7 +45,7 @@ def test_do_reset(self): new_config = Config().get_config() self.assertEqual(new_config["logging"]["cli_log"], default_setting) - def test_do_set_single(self): + def test_do_set_single_string(self): new_log_name = "new_cli.log" old_setting = self.old_config["logging"]["cli_log"] self.assertNotEqual(old_setting, new_log_name) @@ -56,6 +56,50 @@ def test_do_set_single(self): self.assertEqual(new_config["logging"]["cli_log"], new_log_name) + def test_do_set_single_list_one_element(self): + new_nodes_string = "test_node" + old_setting = self.old_config["cluster"]["compute"] + self.assertNotEqual(old_setting, new_nodes_string) + args = f"-s cluster.compute {new_nodes_string}" + self.cli.do_set(args) + + new_config = Config().get_config() + + self.assertEqual(new_config["cluster"]["compute"], [new_nodes_string]) + + def test_do_set_single_list_one_element_with_space(self): + new_nodes_string = "test node" + old_setting = self.old_config["cluster"]["compute"] + self.assertNotEqual(old_setting, new_nodes_string) + args = f"-s cluster.compute '{new_nodes_string}'" + self.cli.do_set(args) + + new_config = Config().get_config() + + self.assertEqual(new_config["cluster"]["compute"], [new_nodes_string]) + + def test_do_set_single_list_multiple_elements(self): + new_nodes_string = "test_node0,test_node1" + old_setting = self.old_config["cluster"]["compute"] + self.assertNotEqual(old_setting, new_nodes_string) + args = f"-s cluster.compute {new_nodes_string}" + self.cli.do_set(args) + + new_config = Config().get_config() + + self.assertEqual(new_config["cluster"]["compute"], [new_nodes_string]) + + def test_do_set_single_list_multiple_elements_space(self): + new_nodes_string = "test_node0 test_node1" + old_setting = self.old_config["cluster"]["compute"] + self.assertNotEqual(old_setting, new_nodes_string) + args = f"-s cluster.compute {new_nodes_string}" + self.cli.do_set(args) + + new_config = Config().get_config() + + self.assertEqual(new_config["cluster"]["compute"], new_nodes_string.split(" ")) + @unittest.mock.patch("sys.stderr", new_callable=io.StringIO) @unittest.mock.patch("sys.stdout", new_callable=io.StringIO) def test_do_set_incorrect(self, mock_stdout, mock_stderr): diff --git a/tox.ini b/tox.ini index dff8b683..88c1f7f7 100644 --- a/tox.ini +++ b/tox.ini @@ -27,6 +27,7 @@ commands = coverage erase basepython = python3 extras = format commands = + ruff check --select I --fix {toxinidir}/src/firewheel ruff format {toxinidir}/src/firewheel # Linters