From 2853e82f8a1576c2beac0f2a85120a192164bb04 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Mon, 13 Sep 2021 22:00:54 +0800 Subject: [PATCH 1/4] machine: add `machine list`&`machine modify` 1. add new command `dvc machine list` 2. complete `dvc machine modify` 3. add new unit tests for this two commands to ensure the call. 4. add functionality tests for two commands --- dvc/command/machine.py | 37 +++++++ dvc/config_schema.py | 47 +++++++-- tests/func/machine/__init__.py | 0 tests/func/machine/test_machine_config.py | 115 ++++++++++++++++++++++ tests/unit/command/test_machine.py | 48 ++++++--- 5 files changed, 225 insertions(+), 22 deletions(-) create mode 100644 tests/func/machine/__init__.py create mode 100644 tests/func/machine/test_machine_config.py diff --git a/dvc/command/machine.py b/dvc/command/machine.py index 9d7c72b65e..3554ca9cd7 100644 --- a/dvc/command/machine.py +++ b/dvc/command/machine.py @@ -69,6 +69,22 @@ def run(self): return 0 +class CmdMachineList(CmdMachineConfig): + def run(self): + levels = [self.args.level] if self.args.level else self.config.LEVELS + for level in levels: + conf = self.config.read(level)["machine"] + if self.args.name: + conf = conf.get(self.args.name, {}) + prefix = self._config_file_prefix( + self.args.show_origin, self.config, level + ) + configs = list(self._format_config(conf, prefix)) + if configs: + ui.write("\n".join(configs)) + return 0 + + class CmdMachineModify(CmdMachineConfig): def run(self): from dvc.config import merge @@ -219,6 +235,27 @@ def add_parser(subparsers, parent_parser): ) machine_default_parser.set_defaults(func=CmdMachineDefault) + machine_LIST_HELP = "List the configuration of one/all machines." + machine_list_parser = machine_subparsers.add_parser( + "list", + parents=[parent_config_parser, parent_parser], + description=append_doc_link(machine_LIST_HELP, "machine/list"), + help=machine_LIST_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + machine_list_parser.add_argument( + "--show-origin", + default=False, + action="store_true", + help="Show the source file containing each config value.", + ) + machine_list_parser.add_argument( + "name", + nargs="?", + type=str, + help="name of machine to specify", + ) + machine_list_parser.set_defaults(func=CmdMachineList) machine_MODIFY_HELP = "Modify the configuration of an machine." machine_modify_parser = machine_subparsers.add_parser( "modify", diff --git a/dvc/config_schema.py b/dvc/config_schema.py index 1b55d45dc8..0d369cbb98 100644 --- a/dvc/config_schema.py +++ b/dvc/config_schema.py @@ -2,16 +2,9 @@ from urllib.parse import urlparse from funcy import walk_values -from voluptuous import ( - All, - Any, - Coerce, - Invalid, - Lower, - Optional, - Range, - Schema, -) +from voluptuous import All, Any, Coerce, Invalid, Lower +from voluptuous import Number as Number_ +from voluptuous import Optional, Range, Schema Bool = All( Lower, @@ -41,6 +34,30 @@ def supported_cache_type(types): return types +class Number(Number_): + """`<=` version of the official Number class""" + + def __call__(self, v): + precision, scale, decimal_num = self._get_precision_scale(v) + + error_msg = "" + if self.precision is not None and precision > self.precision: + error_msg += ( + "Precision must be not bigger than %s" % self.precision + ) + if self.scale is not None and self.scale != scale: + if error_msg: + error_msg += ", and " + error_msg += "Scale must be equal to %s" % self.scale + + if error_msg: + raise Invalid(self.msg or error_msg) + + if self.yield_decimal: + return decimal_num + return v + + def Choices(*choices): """Checks that value belongs to the specified set of values @@ -229,7 +246,17 @@ class RelPath(str): "machine": { str: { "cloud": All(Lower, Choices("aws", "azure")), + "region": All( + Lower, Choices("us-west", "us-east", "eu-west", "eu-north") + ), + "image": str, + "name": str, + "spot": Bool, + "spot_price": Number(precision=5), + "instance_hdd_size": Coerce(int), "instance_type": Lower, + "instance_gpu": Lower, + "ssh_private": str, "startup_script": str, }, }, diff --git a/tests/func/machine/__init__.py b/tests/func/machine/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/func/machine/test_machine_config.py b/tests/func/machine/test_machine_config.py new file mode 100644 index 0000000000..5f85ac858e --- /dev/null +++ b/tests/func/machine/test_machine_config.py @@ -0,0 +1,115 @@ +import textwrap + +import pytest + +from dvc.main import main + +config_text = textwrap.dedent( + """\ + [feature] + machine = true + ['machine \"foo\"'] + cloud = aws + """ +) + + +@pytest.mark.parametrize( + "slot,value", + [ + ("region", "us-west"), + ("image", "iterative-cml"), + ("name", "iterative_test"), + ("spot", "True"), + ("spot_price", "1.2345"), + ("spot_price", "123"), + ("instance_hdd_size", "10"), + ("instance_type", "l"), + ("instance_gpu", "tesla"), + ("ssh_private", "secret"), + ], +) +def test_machine_modify_susccess(tmp_dir, dvc, slot, value): + (tmp_dir / ".dvc" / "config").write_text(config_text) + assert main(["machine", "modify", "foo", slot, value]) == 0 + assert ( + tmp_dir / ".dvc" / "config" + ).read_text() == config_text + f" {slot} = {value}\n" + assert main(["machine", "modify", "--unset", "foo", slot]) == 0 + assert (tmp_dir / ".dvc" / "config").read_text() == config_text + + +def test_machine_modify_startup_script(tmp_dir, dvc): + slot, value = "startup_script", "start.sh" + (tmp_dir / ".dvc" / "config").write_text(config_text) + assert main(["machine", "modify", "foo", slot, value]) == 0 + assert ( + tmp_dir / ".dvc" / "config" + ).read_text() == config_text + f" {slot} = ../{value}\n" + assert main(["machine", "modify", "--unset", "foo", slot]) == 0 + assert (tmp_dir / ".dvc" / "config").read_text() == config_text + + +@pytest.mark.parametrize( + "slot,value,msg", + [ + ( + "region", + "other-west", + "expected one of us-west, us-east, eu-west, eu-north", + ), + ("spot_price", "123.4567", "Precision must be not bigger than 5"), + ("instance_hdd_size", "BIG", "expected int"), + ], +) +def test_machine_modify_fail(tmp_dir, dvc, caplog, slot, value, msg): + (tmp_dir / ".dvc" / "config").write_text(config_text) + + assert main(["machine", "modify", "foo", slot, value]) == 251 + assert (tmp_dir / ".dvc" / "config").read_text() == config_text + assert msg in caplog.text + + +full_config_text = textwrap.dedent( + """\ + [feature] + machine = true + ['machine \"bar\"'] + cloud = azure + ['machine \"foo\"'] + cloud = aws + region = us-west + image = iterative-cml + name = iterative_test + spot = True + spot_price = 1.2345 + instance_hdd_size = 10 + instance_type = l + instance_gpu = tesla + ssh_private = secret + startup_script = ../start.sh + """ +) + + +def test_machine_list(tmp_dir, dvc, capsys): + (tmp_dir / ".dvc" / "config").write_text(full_config_text) + + assert main(["machine", "list"]) == 0 + cap = capsys.readouterr() + assert "cloud=azure" in cap.out + + assert main(["machine", "list", "foo"]) == 0 + cap = capsys.readouterr() + assert "cloud=azure" not in cap.out + assert "cloud=aws" in cap.out + assert "region=us-west" in cap.out + assert "image=iterative-cml" in cap.out + assert "name=iterative_test" in cap.out + assert "spot=True" in cap.out + assert "spot_price=1.2345" in cap.out + assert "instance_hdd_size=10" in cap.out + assert "instance_type=l" in cap.out + assert "instance_gpu=tesla" in cap.out + assert "ssh_private=secret" in cap.out + assert f"startup_script={tmp_dir}/.dvc/../start.sh" in cap.out diff --git a/tests/unit/command/test_machine.py b/tests/unit/command/test_machine.py index 02af8feeb6..d5504b2a46 100644 --- a/tests/unit/command/test_machine.py +++ b/tests/unit/command/test_machine.py @@ -5,10 +5,23 @@ CmdMachineAdd, CmdMachineCreate, CmdMachineDestroy, + CmdMachineList, + CmdMachineModify, CmdMachineRemove, CmdMachineSsh, ) +data = { + ".dvc": { + "config": ( + "[feature]\n" + " machine = true\n" + "['machine \"foo\"']\n" + " cloud = aws" + ) + } +} + def test_add(tmp_dir): tmp_dir.gen({".dvc": {"config": "[feature]\n machine = true"}}) @@ -21,18 +34,7 @@ def test_add(tmp_dir): def test_remove(tmp_dir): - tmp_dir.gen( - { - ".dvc": { - "config": ( - "[feature]\n" - " machine = true\n" - "['machine \"foo\"']\n" - " cloud = aws" - ) - } - } - ) + tmp_dir.gen(data) cli_args = parse_args(["machine", "remove", "foo"]) assert cli_args.func == CmdMachineRemove cmd = cli_args.func(cli_args) @@ -78,3 +80,25 @@ def test_ssh(tmp_dir, dvc, mocker): assert cmd.run() == 0 m.assert_called_once_with("foo") + + +def test_list(tmp_dir, mocker): + from dvc.ui import ui + + tmp_dir.gen(data) + cli_args = parse_args(["machine", "list", "foo"]) + assert cli_args.func == CmdMachineList + cmd = cli_args.func(cli_args) + m = mocker.patch.object(ui, "write", autospec=True) + assert cmd.run() == 0 + m.assert_called_once_with("cloud=aws") + + +def test_modified(tmp_dir): + tmp_dir.gen(data) + cli_args = parse_args(["machine", "modify", "foo", "cloud", "azure"]) + assert cli_args.func == CmdMachineModify + cmd = cli_args.func(cli_args) + assert cmd.run() == 0 + config = configobj.ConfigObj(str(tmp_dir / ".dvc" / "config")) + assert config['machine "foo"']["cloud"] == "azure" From 6bda02ac932e78e3e67574f0c19f3dd59ee603f5 Mon Sep 17 00:00:00 2001 From: Gao Date: Tue, 14 Sep 2021 15:39:55 +0800 Subject: [PATCH 2/4] Update dvc/config_schema.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Peter Rowlands (변기호) --- dvc/config_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dvc/config_schema.py b/dvc/config_schema.py index 0d369cbb98..24d8cfe847 100644 --- a/dvc/config_schema.py +++ b/dvc/config_schema.py @@ -252,7 +252,7 @@ class RelPath(str): "image": str, "name": str, "spot": Bool, - "spot_price": Number(precision=5), + "spot_price": Coerce(float), "instance_hdd_size": Coerce(int), "instance_type": Lower, "instance_gpu": Lower, From f615cd295a93206b5f4c78a7ac7d13d0b4397e51 Mon Sep 17 00:00:00 2001 From: Gao Date: Tue, 14 Sep 2021 15:40:03 +0800 Subject: [PATCH 3/4] Update tests/func/machine/test_machine_config.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Peter Rowlands (변기호) --- tests/func/machine/test_machine_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/func/machine/test_machine_config.py b/tests/func/machine/test_machine_config.py index 5f85ac858e..cd93d5d88e 100644 --- a/tests/func/machine/test_machine_config.py +++ b/tests/func/machine/test_machine_config.py @@ -4,7 +4,7 @@ from dvc.main import main -config_text = textwrap.dedent( +CONFIG_TEXT = textwrap.dedent( """\ [feature] machine = true From 9f7cf7c9cae6602e27c59ae9d3d8b34217623c87 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Tue, 14 Sep 2021 16:46:02 +0800 Subject: [PATCH 4/4] Some reviewed problems 1. capitalize global variables. 2. remove precision restriction. 3. solve windows path seperator problem. --- dvc/config_schema.py | 37 ++++++---------------- tests/func/machine/test_machine_config.py | 38 ++++++++++++++--------- tests/unit/command/test_machine.py | 8 ++--- 3 files changed, 37 insertions(+), 46 deletions(-) diff --git a/dvc/config_schema.py b/dvc/config_schema.py index 24d8cfe847..cd445d3ff6 100644 --- a/dvc/config_schema.py +++ b/dvc/config_schema.py @@ -2,9 +2,16 @@ from urllib.parse import urlparse from funcy import walk_values -from voluptuous import All, Any, Coerce, Invalid, Lower -from voluptuous import Number as Number_ -from voluptuous import Optional, Range, Schema +from voluptuous import ( + All, + Any, + Coerce, + Invalid, + Lower, + Optional, + Range, + Schema, +) Bool = All( Lower, @@ -34,30 +41,6 @@ def supported_cache_type(types): return types -class Number(Number_): - """`<=` version of the official Number class""" - - def __call__(self, v): - precision, scale, decimal_num = self._get_precision_scale(v) - - error_msg = "" - if self.precision is not None and precision > self.precision: - error_msg += ( - "Precision must be not bigger than %s" % self.precision - ) - if self.scale is not None and self.scale != scale: - if error_msg: - error_msg += ", and " - error_msg += "Scale must be equal to %s" % self.scale - - if error_msg: - raise Invalid(self.msg or error_msg) - - if self.yield_decimal: - return decimal_num - return v - - def Choices(*choices): """Checks that value belongs to the specified set of values diff --git a/tests/func/machine/test_machine_config.py b/tests/func/machine/test_machine_config.py index cd93d5d88e..6b89d9d99b 100644 --- a/tests/func/machine/test_machine_config.py +++ b/tests/func/machine/test_machine_config.py @@ -1,3 +1,4 @@ +import os import textwrap import pytest @@ -22,7 +23,7 @@ ("name", "iterative_test"), ("spot", "True"), ("spot_price", "1.2345"), - ("spot_price", "123"), + ("spot_price", "12345"), ("instance_hdd_size", "10"), ("instance_type", "l"), ("instance_gpu", "tesla"), @@ -30,24 +31,24 @@ ], ) def test_machine_modify_susccess(tmp_dir, dvc, slot, value): - (tmp_dir / ".dvc" / "config").write_text(config_text) + (tmp_dir / ".dvc" / "config").write_text(CONFIG_TEXT) assert main(["machine", "modify", "foo", slot, value]) == 0 assert ( tmp_dir / ".dvc" / "config" - ).read_text() == config_text + f" {slot} = {value}\n" + ).read_text() == CONFIG_TEXT + f" {slot} = {value}\n" assert main(["machine", "modify", "--unset", "foo", slot]) == 0 - assert (tmp_dir / ".dvc" / "config").read_text() == config_text + assert (tmp_dir / ".dvc" / "config").read_text() == CONFIG_TEXT def test_machine_modify_startup_script(tmp_dir, dvc): slot, value = "startup_script", "start.sh" - (tmp_dir / ".dvc" / "config").write_text(config_text) + (tmp_dir / ".dvc" / "config").write_text(CONFIG_TEXT) assert main(["machine", "modify", "foo", slot, value]) == 0 assert ( tmp_dir / ".dvc" / "config" - ).read_text() == config_text + f" {slot} = ../{value}\n" + ).read_text() == CONFIG_TEXT + f" {slot} = ../{value}\n" assert main(["machine", "modify", "--unset", "foo", slot]) == 0 - assert (tmp_dir / ".dvc" / "config").read_text() == config_text + assert (tmp_dir / ".dvc" / "config").read_text() == CONFIG_TEXT @pytest.mark.parametrize( @@ -58,19 +59,19 @@ def test_machine_modify_startup_script(tmp_dir, dvc): "other-west", "expected one of us-west, us-east, eu-west, eu-north", ), - ("spot_price", "123.4567", "Precision must be not bigger than 5"), + ("spot_price", "NUM", "expected float"), ("instance_hdd_size", "BIG", "expected int"), ], ) def test_machine_modify_fail(tmp_dir, dvc, caplog, slot, value, msg): - (tmp_dir / ".dvc" / "config").write_text(config_text) + (tmp_dir / ".dvc" / "config").write_text(CONFIG_TEXT) assert main(["machine", "modify", "foo", slot, value]) == 251 - assert (tmp_dir / ".dvc" / "config").read_text() == config_text + assert (tmp_dir / ".dvc" / "config").read_text() == CONFIG_TEXT assert msg in caplog.text -full_config_text = textwrap.dedent( +FULL_CONFIG_TEXT = textwrap.dedent( """\ [feature] machine = true @@ -87,13 +88,15 @@ def test_machine_modify_fail(tmp_dir, dvc, caplog, slot, value, msg): instance_type = l instance_gpu = tesla ssh_private = secret - startup_script = ../start.sh - """ + startup_script = {} + """.format( + os.path.join("..", "start.sh") + ) ) def test_machine_list(tmp_dir, dvc, capsys): - (tmp_dir / ".dvc" / "config").write_text(full_config_text) + (tmp_dir / ".dvc" / "config").write_text(FULL_CONFIG_TEXT) assert main(["machine", "list"]) == 0 cap = capsys.readouterr() @@ -112,4 +115,9 @@ def test_machine_list(tmp_dir, dvc, capsys): assert "instance_type=l" in cap.out assert "instance_gpu=tesla" in cap.out assert "ssh_private=secret" in cap.out - assert f"startup_script={tmp_dir}/.dvc/../start.sh" in cap.out + assert ( + "startup_script={}".format( + os.path.join(tmp_dir, ".dvc", "..", "start.sh") + ) + in cap.out + ) diff --git a/tests/unit/command/test_machine.py b/tests/unit/command/test_machine.py index d5504b2a46..dd77b9bc5d 100644 --- a/tests/unit/command/test_machine.py +++ b/tests/unit/command/test_machine.py @@ -11,7 +11,7 @@ CmdMachineSsh, ) -data = { +DATA = { ".dvc": { "config": ( "[feature]\n" @@ -34,7 +34,7 @@ def test_add(tmp_dir): def test_remove(tmp_dir): - tmp_dir.gen(data) + tmp_dir.gen(DATA) cli_args = parse_args(["machine", "remove", "foo"]) assert cli_args.func == CmdMachineRemove cmd = cli_args.func(cli_args) @@ -85,7 +85,7 @@ def test_ssh(tmp_dir, dvc, mocker): def test_list(tmp_dir, mocker): from dvc.ui import ui - tmp_dir.gen(data) + tmp_dir.gen(DATA) cli_args = parse_args(["machine", "list", "foo"]) assert cli_args.func == CmdMachineList cmd = cli_args.func(cli_args) @@ -95,7 +95,7 @@ def test_list(tmp_dir, mocker): def test_modified(tmp_dir): - tmp_dir.gen(data) + tmp_dir.gen(DATA) cli_args = parse_args(["machine", "modify", "foo", "cloud", "azure"]) assert cli_args.func == CmdMachineModify cmd = cli_args.func(cli_args)