diff --git a/dvc/stage/run.py b/dvc/stage/run.py index 23a683c23e..b7660bfd05 100644 --- a/dvc/stage/run.py +++ b/dvc/stage/run.py @@ -45,61 +45,90 @@ def warn_if_fish(executable): ) -@unlocked_repo -def cmd_run(stage, *args, checkpoint_func=None, **kwargs): +def _enforce_cmd_list(cmd): + assert cmd + return cmd if isinstance(cmd, list) else [cmd] + + +def prepare_kwargs(stage, checkpoint_func=None): kwargs = {"cwd": stage.wdir, "env": fix_env(None), "close_fds": True} - cmd = stage.cmd if isinstance(stage.cmd, list) else [stage.cmd] if checkpoint_func: # indicate that checkpoint cmd is being run inside DVC kwargs["env"].update(_checkpoint_env(stage)) - if os.name == "nt": - kwargs["shell"] = True - executable = None - else: - # NOTE: when you specify `shell=True`, `Popen` [1] will default to - # `/bin/sh` on *nix and will add ["/bin/sh", "-c"] to your command. - # But we actually want to run the same shell that we are running - # from right now, which is usually determined by the `SHELL` env - # var. So instead, we compose our command on our own, making sure - # to include special flags to prevent shell from reading any - # configs and modifying env, which may change the behavior or the - # command we are running. See [2] for more info. - # - # [1] https://github.com/python/cpython/blob/3.7/Lib/subprocess.py - # #L1426 - # [2] https://github.com/iterative/dvc/issues/2506 - # #issuecomment-535396799 - kwargs["shell"] = False - executable = os.getenv("SHELL") or "/bin/sh" - warn_if_fish(executable) + # NOTE: when you specify `shell=True`, `Popen` [1] will default to + # `/bin/sh` on *nix and will add ["/bin/sh", "-c"] to your command. + # But we actually want to run the same shell that we are running + # from right now, which is usually determined by the `SHELL` env + # var. So instead, we compose our command on our own, making sure + # to include special flags to prevent shell from reading any + # configs and modifying env, which may change the behavior or the + # command we are running. See [2] for more info. + # + # [1] https://github.com/python/cpython/blob/3.7/Lib/subprocess.py + # #L1426 + # [2] https://github.com/iterative/dvc/issues/2506 + # #issuecomment-535396799 + kwargs["shell"] = True if os.name == "nt" else False + return kwargs + + +def display_command(cmd): + logger.info("%s %s", ">", cmd) + +def get_executable(): + return (os.getenv("SHELL") or "/bin/sh") if os.name != "nt" else None + + +def _run(stage, executable, cmd, checkpoint_func, **kwargs): main_thread = isinstance( threading.current_thread(), threading._MainThread, # pylint: disable=protected-access ) - for _cmd in cmd: - logger.info("$ %s", _cmd) - old_handler = None - p = None - try: - p = subprocess.Popen(_make_cmd(executable, _cmd), **kwargs) - if main_thread: - old_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) + exec_cmd = _make_cmd(executable, cmd) + old_handler = None + p = None - killed = threading.Event() - with checkpoint_monitor(stage, checkpoint_func, p, killed): - p.communicate() - finally: - if old_handler: - signal.signal(signal.SIGINT, old_handler) + try: + p = subprocess.Popen(exec_cmd, **kwargs) + if main_thread: + old_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) + + killed = threading.Event() + with checkpoint_monitor(stage, checkpoint_func, p, killed): + p.communicate() + finally: + if old_handler: + signal.signal(signal.SIGINT, old_handler) - retcode = None if not p else p.returncode - if retcode != 0: - if killed.is_set(): - raise CheckpointKilledError(_cmd, retcode) - raise StageCmdFailedError(_cmd, retcode) + retcode = None if not p else p.returncode + if retcode != 0: + if killed.is_set(): + raise CheckpointKilledError(cmd, retcode) + raise StageCmdFailedError(cmd, retcode) + + +def cmd_run(stage, dry=False, checkpoint_func=None): + logger.info( + "Running %s" "stage '%s':", + "callback " if stage.is_callback else "", + stage.addressing, + ) + commands = _enforce_cmd_list(stage.cmd) + kwargs = prepare_kwargs(stage, checkpoint_func=checkpoint_func) + executable = get_executable() + + if not dry: + warn_if_fish(executable) + + for cmd in commands: + display_command(cmd) + if dry: + continue + + _run(stage, executable, cmd, checkpoint_func=checkpoint_func, **kwargs) def run_stage(stage, dry=False, force=False, checkpoint_func=None, **kwargs): @@ -112,12 +141,8 @@ def run_stage(stage, dry=False, force=False, checkpoint_func=None, **kwargs): except RunCacheNotFoundError: pass - callback_str = "callback " if stage.is_callback else "" - logger.info( - "Running %s" "stage '%s':", callback_str, stage.addressing, - ) - if not dry: - cmd_run(stage, checkpoint_func=checkpoint_func) + run = cmd_run if dry else unlocked_repo(cmd_run) + run(stage, dry=dry, checkpoint_func=checkpoint_func) def _checkpoint_env(stage): diff --git a/tests/func/test_repro.py b/tests/func/test_repro.py index 146aec7702..12ad2df016 100644 --- a/tests/func/test_repro.py +++ b/tests/func/test_repro.py @@ -1307,4 +1307,4 @@ def test_repro_when_cmd_changes(tmp_dir, dvc, run_copy, mocker): stage.addressing: ["changed checksum"] } assert dvc.reproduce(stage.addressing)[0] == stage - m.assert_called_once_with(stage, checkpoint_func=None) + m.assert_called_once_with(stage, checkpoint_func=None, dry=False) diff --git a/tests/func/test_repro_multistage.py b/tests/func/test_repro_multistage.py index e51e577441..743c8765aa 100644 --- a/tests/func/test_repro_multistage.py +++ b/tests/func/test_repro_multistage.py @@ -286,7 +286,7 @@ def test_repro_when_cmd_changes(tmp_dir, dvc, run_copy, mocker): assert dvc.status([target]) == {target: ["changed command"]} assert dvc.reproduce(target)[0] == stage - m.assert_called_once_with(stage, checkpoint_func=None) + m.assert_called_once_with(stage, checkpoint_func=None, dry=False) def test_repro_when_new_deps_is_added_in_dvcfile(tmp_dir, dvc, run_copy): diff --git a/tests/func/test_stage.py b/tests/func/test_stage.py index 28435c1adb..26ebc3ce45 100644 --- a/tests/func/test_stage.py +++ b/tests/func/test_stage.py @@ -6,7 +6,7 @@ from dvc.dvcfile import SingleStageFile from dvc.main import main from dvc.output.local import LocalOutput -from dvc.repo import Repo +from dvc.repo import Repo, lock_repo from dvc.stage import PipelineStage, Stage from dvc.stage.exceptions import StageFileFormatError from dvc.stage.run import run_stage @@ -306,5 +306,7 @@ def test_stage_run_checkpoint(tmp_dir, dvc, mocker, checkpoint): callback = mocker.Mock() else: callback = None - run_stage(stage, checkpoint_func=callback) - mock_cmd_run.assert_called_with(stage, checkpoint_func=callback) + + with lock_repo(dvc): + run_stage(stage, checkpoint_func=callback) + mock_cmd_run.assert_called_with(stage, checkpoint_func=callback, dry=False) diff --git a/tests/unit/stage/test_run.py b/tests/unit/stage/test_run.py index bfc607cf57..7128aab2a1 100644 --- a/tests/unit/stage/test_run.py +++ b/tests/unit/stage/test_run.py @@ -1,13 +1,24 @@ import logging +import pytest + from dvc.stage import Stage from dvc.stage.run import run_stage -def test_run_stage_dry(caplog): +@pytest.mark.parametrize( + "cmd, expected", + [ + ("mycmd arg1 arg2", ["> mycmd arg1 arg2"]), + (["mycmd1 arg1", "mycmd2 arg2"], ["> mycmd1 arg1", "> mycmd2 arg2"]), + ], +) +def test_run_stage_dry(caplog, cmd, expected): with caplog.at_level(level=logging.INFO, logger="dvc"): - stage = Stage(None, "stage.dvc", cmd="mycmd arg1 arg2") + stage = Stage(None, "stage.dvc", cmd=cmd) run_stage(stage, dry=True) - assert caplog.messages == [ - "Running callback stage 'stage.dvc':", - ] + + expected.insert( + 0, "Running callback stage 'stage.dvc':", + ) + assert caplog.messages == expected