diff --git a/dvc/command/machine.py b/dvc/command/machine.py index 008104420b..a8e5526854 100644 --- a/dvc/command/machine.py +++ b/dvc/command/machine.py @@ -2,8 +2,10 @@ from dvc.command.base import CmdBase, append_doc_link, fix_subparsers from dvc.command.config import CmdConfig +from dvc.compare import TabularData from dvc.config import ConfigError from dvc.exceptions import DvcException +from dvc.types import Dict, List from dvc.ui import ui from dvc.utils import format_link @@ -71,18 +73,56 @@ def run(self): class CmdMachineList(CmdMachineConfig): - def run(self): + TABLE_COLUMNS = [ + "name", + "cloud", + "region", + "image", + "spot", + "spot_price", + "instance_hdd_size", + "instance_type", + "ssh_private", + "startup_script", + ] + + PRIVATE_COLUMNS = ["ssh_private", "startup_script"] + + def _hide_private(self, conf): + for machine in conf: + for column in self.PRIVATE_COLUMNS: + if column in conf[machine]: + conf[machine][column] = "***" + + def _show_origin(self): levels = [self.args.level] if self.args.level else self.config.LEVELS for level in levels: conf = self.config.read(level)["machine"] if self.args.name: conf = conf.get(self.args.name, {}) - prefix = self._config_file_prefix( - self.args.show_origin, self.config, level - ) + self._hide_private(conf) + prefix = self._config_file_prefix(True, self.config, level) configs = list(self._format_config(conf, prefix)) if configs: ui.write("\n".join(configs)) + + def _show_table(self): + td = TabularData(self.TABLE_COLUMNS, fill_value="-") + conf = self.config.read()["machine"] + if self.args.name: + conf = {self.args.name: conf.get(self.args.name, {})} + self._hide_private(conf) + for machine, machine_config in conf.items(): + machine_config["name"] = machine + td.row_from_dict(machine_config) + td.dropna("cols", "all") + td.render() + + def run(self): + if self.args.show_origin: + self._show_origin() + else: + self._show_table() return 0 @@ -193,8 +233,8 @@ def run(self): class CmdMachineStatus(CmdBase): + INSTANCE_FIELD = ["name", "instance", "status"] SHOWN_FIELD = [ - "name", "cloud", "instance_ip", "instance_type", @@ -202,23 +242,34 @@ class CmdMachineStatus(CmdBase): "instance_gpu", ] - def _show_machine_status(self, name: str): - ui.write(f"machine '{name}':") - all_status = list(self.repo.machine.status(name)) + def _add_row( + self, + name: str, + all_status: List[Dict], + td: TabularData, + ): + if not all_status: - ui.write("\toffline") + row = [name, None, "offline"] + td.append(row) for i, status in enumerate(all_status, start=1): - ui.write(f"\tinstance_num_{i}:") + row = [name, f"num_{i}", "running" if status else "offline"] for field in self.SHOWN_FIELD: - value = status.get(field, None) - ui.write(f"\t\t{field:20}: {value}") + value = str(status.get(field, "")) + row.append(value) + td.append(row) def run(self): if self.repo.machine is None: raise MachineDisabledError + td = TabularData( + self.INSTANCE_FIELD + self.SHOWN_FIELD, fill_value="-" + ) + if self.args.name: - self._show_machine_status(self.args.name) + all_status = list(self.repo.machine.status(self.args.name)) + self._add_row(self.args.name, all_status, td) else: name_set = set() for level in self.repo.config.LEVELS: @@ -226,8 +277,11 @@ def run(self): name_set.update(conf.keys()) name_list = list(name_set) for name in sorted(name_list): - self._show_machine_status(name) + all_status = list(self.repo.machine.status(name)) + self._add_row(name, all_status, td) + td.dropna("cols", "all") + td.render() return 0 diff --git a/dvc/compare.py b/dvc/compare.py index e9d742d3bb..0bf8c54da0 100644 --- a/dvc/compare.py +++ b/dvc/compare.py @@ -116,6 +116,8 @@ def __delitem__(self, item: Union[int, slice]) -> None: del col[item] def __len__(self) -> int: + if not self._columns: + return 0 return len(self.columns[0]) @property @@ -182,21 +184,38 @@ def as_dict( {k: self._columns[k][i] for k in keys} for i in range(len(self)) ] - def dropna(self, axis: str = "rows"): + def dropna(self, axis: str = "rows", how="any"): if axis not in ["rows", "cols"]: raise ValueError( f"Invalid 'axis' value {axis}." "Choose one of ['rows', 'cols']" ) - to_drop: Set = set() + if how not in ["any", "all"]: + raise ValueError( + f"Invalid 'how' value {how}." "Choose one of ['any', 'all']" + ) + + match_line: Set = set() + match = True + if how == "all": + match = False + for n_row, row in enumerate(self): for n_col, col in enumerate(row): - if col == self._fill_value: + if (col == self._fill_value) is match: if axis == "rows": - to_drop.add(n_row) + match_line.add(n_row) break else: - to_drop.add(self.keys()[n_col]) + match_line.add(self.keys()[n_col]) + + to_drop = match_line + if how == "all": + if axis == "rows": + to_drop = set(range(len(self))) + else: + to_drop = set(self.keys()) + to_drop -= match_line if axis == "rows": for name in self.keys(): diff --git a/dvc/config_schema.py b/dvc/config_schema.py index 7427825ca1..24bb165ee3 100644 --- a/dvc/config_schema.py +++ b/dvc/config_schema.py @@ -241,7 +241,6 @@ class RelPath(str): Lower, Choices("us-west", "us-east", "eu-west", "eu-north") ), "image": str, - "name": str, "spot": Bool, "spot_price": Coerce(float), "instance_hdd_size": Coerce(int), diff --git a/dvc/ui/table.py b/dvc/ui/table.py index ae17613df1..a08216f1c2 100644 --- a/dvc/ui/table.py +++ b/dvc/ui/table.py @@ -14,7 +14,7 @@ SHOW_MAX_WIDTH = 1024 -CellT = Union[str, "RichText"] # RichText is mostly compatible with str +CellT = Union[str, "RichText", None] # RichText is mostly compatible with str Row = Sequence[CellT] TableData = Sequence[Row] Headers = Sequence[str] diff --git a/tests/func/experiments/test_show.py b/tests/func/experiments/test_show.py index 5a347f7184..1aa6cdcd81 100644 --- a/tests/func/experiments/test_show.py +++ b/tests/func/experiments/test_show.py @@ -13,6 +13,7 @@ from dvc.utils.fs import makedirs from dvc.utils.serialize import YAMLFileCorruptedError from tests.func.test_repro_multistage import COPY_SCRIPT +from tests.utils import console_width def test_show_simple(tmp_dir, scm, dvc, exp_stage): @@ -270,8 +271,6 @@ def test_show_filter( included, excluded, ): - from contextlib import contextmanager - from dvc.ui import ui capsys.readouterr() @@ -317,21 +316,7 @@ def test_show_filter( if e_params is not None: command.append(f"--exclude-params={e_params}") - @contextmanager - def console_with(console, width): - console_options = console.options - original = console_options.max_width - con_width = console._width - - try: - console_options.max_width = width - console._width = width - yield - finally: - console_options.max_width = original - console._width = con_width - - with console_with(ui.rich_console, 255): + with console_width(ui.rich_console, 255): assert main(command) == 0 cap = capsys.readouterr() diff --git a/tests/func/machine/conftest.py b/tests/func/machine/conftest.py index 2795a62b53..6ea95ae713 100644 --- a/tests/func/machine/conftest.py +++ b/tests/func/machine/conftest.py @@ -40,14 +40,16 @@ def machine_config(tmp_dir): @pytest.fixture def machine_instance(tmp_dir, dvc, mocker): - from tpi.terraform import TerraformBackend - with dvc.config.edit() as conf: conf["machine"]["foo"] = {"cloud": "aws"} - mocker.patch.object( - TerraformBackend, - "instances", - autospec=True, - return_value=[TEST_INSTANCE], + + def mock_instances(name=None, **kwargs): + if name == "foo": + return iter([TEST_INSTANCE]) + return iter([]) + + mocker.patch( + "tpi.terraform.TerraformBackend.instances", + mocker.MagicMock(side_effect=mock_instances), ) yield TEST_INSTANCE diff --git a/tests/func/machine/test_machine_config.py b/tests/func/machine/test_machine_config.py index a788088d69..5426ffc83d 100644 --- a/tests/func/machine/test_machine_config.py +++ b/tests/func/machine/test_machine_config.py @@ -5,6 +5,8 @@ import tpi from dvc.main import main +from dvc.ui import ui +from tests.utils import console_width from .conftest import BASIC_CONFIG @@ -14,7 +16,6 @@ [ ("region", "us-west"), ("image", "iterative-cml"), - ("name", "iterative_test"), ("spot", "True"), ("spot_price", "1.2345"), ("spot_price", "12345"), @@ -73,7 +74,6 @@ def test_machine_modify_fail( cloud = aws region = us-west image = iterative-cml - name = iterative_test spot = True spot_price = 1.2345 instance_hdd_size = 10 @@ -88,31 +88,26 @@ def test_machine_modify_fail( def test_machine_list(tmp_dir, dvc, capsys): - (tmp_dir / ".dvc" / "config").write_text(FULL_CONFIG_TEXT) + from dvc.command.machine import CmdMachineList - assert main(["machine", "list"]) == 0 - cap = capsys.readouterr() - assert "cloud=azure" in cap.out + (tmp_dir / ".dvc" / "config").write_text(FULL_CONFIG_TEXT) - assert main(["machine", "list", "foo"]) == 0 - cap = capsys.readouterr() - assert "cloud=azure" not in cap.out - assert "cloud=aws" in cap.out - assert "region=us-west" in cap.out - assert "image=iterative-cml" in cap.out - assert "name=iterative_test" in cap.out - assert "spot=True" in cap.out - assert "spot_price=1.2345" in cap.out - assert "instance_hdd_size=10" in cap.out - assert "instance_type=l" in cap.out - assert "instance_gpu=tesla" in cap.out - assert "ssh_private=secret" in cap.out - assert ( - "startup_script={}".format( - os.path.join(tmp_dir, ".dvc", "..", "start.sh") - ) - in cap.out - ) + with console_width(ui.rich_console, 255): + assert main(["machine", "list"]) == 0 + out, _ = capsys.readouterr() + for key in CmdMachineList.TABLE_COLUMNS: + assert f"{key}" in out + assert "bar azure - -" in out + assert "foo aws us-west iterative-cml True 1.2345" in out + assert "10 l *** ***" in out + assert "tesla" in out + + with console_width(ui.rich_console, 255): + assert main(["machine", "list", "bar"]) == 0 + out, _ = capsys.readouterr() + assert "foo" not in out + assert "name cloud" in out + assert "bar azure" in out def test_machine_rename_success( diff --git a/tests/func/machine/test_machine_status.py b/tests/func/machine/test_machine_status.py index 33b474be78..c1e3056e23 100644 --- a/tests/func/machine/test_machine_status.py +++ b/tests/func/machine/test_machine_status.py @@ -1,14 +1,23 @@ -from dvc.command.machine import CmdMachineStatus from dvc.main import main +from dvc.ui import ui +from tests.utils import console_width -def test_status( - tmp_dir, scm, dvc, machine_config, machine_instance, mocker, capsys -): - status = machine_instance - assert main(["machine", "status", "foo"]) == 0 +def test_status(tmp_dir, scm, dvc, machine_config, machine_instance, capsys): + + assert main(["machine", "add", "bar", "aws"]) == 0 + with console_width(ui.rich_console, 255): + assert main(["machine", "status"]) == 0 cap = capsys.readouterr() - assert "machine 'foo':\n" in cap.out - assert "\tinstance_num_1:\n" in cap.out - for key in CmdMachineStatus.SHOWN_FIELD: - assert f"\t\t{key:20}: {status[key]}\n" in cap.out + assert ( + "name instance status cloud instance_ip " + "instance_type instance_hdd_size instance_gpu" + ) in cap.out + assert ( + "bar - offline - - " + "- - -" + ) in cap.out + assert ( + "foo num_1 running aws 123.123.123.123 " + "m 35 None" + ) in cap.out diff --git a/tests/unit/command/test_machine.py b/tests/unit/command/test_machine.py index 058d83986e..ca636b38f6 100644 --- a/tests/unit/command/test_machine.py +++ b/tests/unit/command/test_machine.py @@ -1,3 +1,5 @@ +import os + import configobj import pytest from mock import call @@ -110,16 +112,24 @@ def test_ssh(tmp_dir, dvc, mocker): m.assert_called_once_with("foo") -def test_list(tmp_dir, mocker): +@pytest.mark.parametrize("show_origin", [["--show-origin"], []]) +def test_list(tmp_dir, mocker, show_origin): + from dvc.compare import TabularData from dvc.ui import ui tmp_dir.gen(DATA) - cli_args = parse_args(["machine", "list", "foo"]) + cli_args = parse_args(["machine", "list"] + show_origin + ["foo"]) assert cli_args.func == CmdMachineList cmd = cli_args.func(cli_args) - m = mocker.patch.object(ui, "write", autospec=True) + if show_origin: + m = mocker.patch.object(ui, "write", autospec=True) + else: + m = mocker.patch.object(TabularData, "render", autospec=True) assert cmd.run() == 0 - m.assert_called_once_with("cloud=aws") + if show_origin: + m.assert_called_once_with(f".dvc{os.sep}config cloud=aws") + else: + m.assert_called_once() def test_modified(tmp_dir): diff --git a/tests/unit/test_tabular_data.py b/tests/unit/test_tabular_data.py index 43c9840b53..a8439d2815 100644 --- a/tests/unit/test_tabular_data.py +++ b/tests/unit/test_tabular_data.py @@ -180,28 +180,43 @@ def test_row_from_dict(): @pytest.mark.parametrize( - "axis,expected", + "axis,how,data,expected", [ ( "rows", + "any", + [["foo"], ["foo", "bar"], ["foo", "bar", "foobar"]], [ ["foo", "bar", "foobar"], ], ), - ("cols", [["foo"], ["foo"], ["foo"]]), + ( + "rows", + "all", + [["foo"], ["foo", "bar"], ["", "", ""]], + [ + ["foo", "", ""], + ["foo", "bar", ""], + ], + ), + ( + "cols", + "any", + [["foo"], ["foo", "bar"], ["foo", "bar", "foobar"]], + [["foo"], ["foo"], ["foo"]], + ), + ( + "cols", + "all", + [["foo"], ["foo", "bar"], ["", "", ""]], + [["foo", ""], ["foo", "bar"], ["", ""]], + ), ], ) -def test_dropna(axis, expected): +def test_dropna(axis, how, data, expected): td = TabularData(["col-1", "col-2", "col-3"]) - td.extend([["foo"], ["foo", "bar"], ["foo", "bar", "foobar"]]) - assert list(td) == [ - ["foo", "", ""], - ["foo", "bar", ""], - ["foo", "bar", "foobar"], - ] - - td.dropna(axis) - + td.extend(data) + td.dropna(axis, how) assert list(td) == expected diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index fd534c768c..579ce1fb5c 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -54,3 +54,18 @@ def clean_staging(): ) except FileNotFoundError: pass + + +@contextmanager +def console_width(console, width): + console_options = console.options + original = console_options.max_width + con_width = console._width + + try: + console_options.max_width = width + console._width = width + yield + finally: + console_options.max_width = original + console._width = con_width