Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions dvc/commands/experiments/exec_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def run(self):
queue=None,
log_level=logger.getEffectiveLevel(),
infofile=self.args.infofile,
copy_paths=self.args.copy_paths,
)
return 0

Expand All @@ -36,4 +37,14 @@ def add_parser(experiments_subparsers, parent_parser):
help="Path to executor info file",
default=None,
)
exec_run_parser.add_argument(
"-C",
"--copy-paths",
action="append",
default=[],
help=(
"List of ignored or untracked paths to copy into the temp directory."
" Only used if `--temp` or `--queue` is specified."
),
)
exec_run_parser.set_defaults(func=CmdExecutorRun)
11 changes: 11 additions & 0 deletions dvc/commands/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def run(self):
reset=self.args.reset,
tmp_dir=self.args.tmp_dir,
machine=self.args.machine,
copy_paths=self.args.copy_paths,
**self._common_kwargs,
)

Expand Down Expand Up @@ -136,3 +137,13 @@ def _add_run_common(parser):
# )
# metavar="<name>",
)
parser.add_argument(
"-C",
"--copy-paths",
action="append",
default=[],
help=(
"List of ignored or untracked paths to copy into the temp directory."
" Only used if `--temp` or `--queue` is specified."
),
)
11 changes: 7 additions & 4 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
import re
from typing import TYPE_CHECKING, Dict, Iterable, Optional
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional

from funcy import chain, first

Expand Down Expand Up @@ -118,14 +118,15 @@ def stash_revs(self) -> Dict[str, "ExpStashEntry"]:
def reproduce_one(
self,
tmp_dir: bool = False,
copy_paths: Optional[List[str]] = None,
**kwargs,
):
"""Reproduce and checkout a single (standalone) experiment."""
exp_queue: "BaseStashQueue" = (
self.tempdir_queue if tmp_dir else self.workspace_queue
)
self.queue_one(exp_queue, **kwargs)
results = self._reproduce_queue(exp_queue)
results = self._reproduce_queue(exp_queue, copy_paths=copy_paths)
exp_rev = first(results)
if exp_rev is not None:
self._log_reproduced(results, tmp_dir=tmp_dir)
Expand Down Expand Up @@ -347,7 +348,9 @@ def reset_checkpoints(self):
self.scm.remove_ref(EXEC_APPLY)

@unlocked_repo
def _reproduce_queue(self, queue: "BaseStashQueue", **kwargs) -> Dict[str, str]:
def _reproduce_queue(
self, queue: "BaseStashQueue", copy_paths: Optional[List[str]] = None, **kwargs
) -> Dict[str, str]:
"""Reproduce queued experiments.

Arguments:
Expand All @@ -357,7 +360,7 @@ def _reproduce_queue(self, queue: "BaseStashQueue", **kwargs) -> Dict[str, str]:
dict mapping successfully reproduced experiment revs to their
results.
"""
exec_results = queue.reproduce()
exec_results = queue.reproduce(copy_paths=copy_paths)

results: Dict[str, str] = {}
for _, exp_result in exec_results.items():
Expand Down
22 changes: 22 additions & 0 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import pickle # nosec B403
import shutil
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import asdict, dataclass
Expand Down Expand Up @@ -451,6 +452,7 @@ def reproduce(
infofile: Optional[str] = None,
log_errors: bool = True,
log_level: Optional[int] = None,
copy_paths: Optional[List[str]] = None,
**kwargs,
) -> "ExecutorResult":
"""Run dvc repro and return the result.
Expand Down Expand Up @@ -487,6 +489,7 @@ def filter_pipeline(stages):
info,
infofile,
log_errors=log_errors,
copy_paths=copy_paths,
**kwargs,
) as dvc:
if auto_push:
Expand Down Expand Up @@ -609,6 +612,7 @@ def _repro_dvc( # noqa: C901
info: "ExecutorInfo",
infofile: Optional[str] = None,
log_errors: bool = True,
copy_paths: Optional[List[str]] = None,
**kwargs,
):
from dvc_studio_client.post_live_metrics import post_live_metrics
Expand All @@ -623,6 +627,10 @@ def _repro_dvc( # noqa: C901
if cls.QUIET:
dvc.scm_context.quiet = cls.QUIET
old_cwd = os.getcwd()

for path in copy_paths or []:
cls._copy_path(os.path.realpath(path), os.path.join(dvc.root_dir, path))

if info.wdir:
os.chdir(os.path.join(dvc.scm.root_dir, info.wdir))
else:
Expand Down Expand Up @@ -792,6 +800,20 @@ def _set_log_level(level):
if level is not None:
dvc_logger.setLevel(level)

@staticmethod
def _copy_path(src, dst):
try:
if os.path.isfile(src):
shutil.copy(src, dst)
elif os.path.isdir(src):
shutil.copytree(src, dst)
else:
raise DvcException(
f"Unable to copy '{src}'. It is not a file or directory."
)
except OSError as exc:
raise DvcException(f"Unable to copy '{src}' to '{dst}'.") from exc

@contextmanager
def set_temp_refs(self, scm: "Git", temp_dict: Dict[str, str]):
try:
Expand Down
3 changes: 2 additions & 1 deletion dvc/repo/experiments/executor/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import posixpath
import sys
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Iterable, Optional
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional

from dvc_ssh import SSHFileSystem
from funcy import first
Expand Down Expand Up @@ -242,6 +242,7 @@ def reproduce(
infofile: Optional[str] = None,
log_errors: bool = True,
log_level: Optional[int] = None,
copy_paths: Optional[List[str]] = None, # noqa: ARG003
**kwargs,
) -> "ExecutorResult":
"""Reproduce an experiment on a remote machine over SSH.
Expand Down
4 changes: 3 additions & 1 deletion dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ def iter_failed(self) -> Generator[QueueDoneResult, None, None]:
"""Iterate over items which been failed."""

@abstractmethod
def reproduce(self) -> Mapping[str, Mapping[str, str]]:
def reproduce(
self, copy_paths: Optional[List[str]] = None
) -> Mapping[str, Mapping[str, str]]:
"""Reproduce queued experiments sequentially."""

@abstractmethod
Expand Down
10 changes: 7 additions & 3 deletions dvc/repo/experiments/queue/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,13 @@ def start_workers(self, count: int) -> int:

return started

def put(self, *args, **kwargs) -> QueueEntry:
def put(
self, *args, copy_paths: Optional[List[str]] = None, **kwargs
) -> QueueEntry:
"""Stash an experiment and add it to the queue."""
with get_exp_rwlock(self.repo, writes=["workspace", CELERY_STASH]):
entry = self._stash_exp(*args, **kwargs)
self.celery.signature(run_exp.s(entry.asdict())).delay()
self.celery.signature(run_exp.s(entry.asdict(), copy_paths=copy_paths)).delay()
return entry

# NOTE: Queue consumption should not be done directly. Celery worker(s)
Expand Down Expand Up @@ -250,7 +252,9 @@ def iter_failed(self) -> Generator[QueueDoneResult, None, None]:
if exp_result is None:
yield QueueDoneResult(queue_entry, exp_result)

def reproduce(self) -> Mapping[str, Mapping[str, str]]:
def reproduce(
self, copy_paths: Optional[List[str]] = None
) -> Mapping[str, Mapping[str, str]]:
raise NotImplementedError

def _load_info(self, rev: str) -> ExecutorInfo:
Expand Down
7 changes: 5 additions & 2 deletions dvc/repo/experiments/queue/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from celery import shared_task
from celery.utils.log import get_task_logger
Expand Down Expand Up @@ -91,7 +91,7 @@ def cleanup_exp(executor: TempDirExecutor, infofile: str) -> None:


@shared_task
def run_exp(entry_dict: Dict[str, Any]) -> None:
def run_exp(entry_dict: Dict[str, Any], copy_paths: Optional[List[str]] = None) -> None:
"""Run a full experiment.

Experiment subtasks are executed inline as one atomic operation.
Expand All @@ -108,6 +108,9 @@ def run_exp(entry_dict: Dict[str, Any]) -> None:
executor = setup_exp.s(entry_dict)()
try:
cmd = ["dvc", "exp", "exec-run", "--infofile", infofile]
if copy_paths:
for path in copy_paths:
cmd.extend(["--copy-paths", path])
proc_dict = queue.proc.run_signature(cmd, name=entry.stash_rev)()
collect_exp.s(proc_dict, entry_dict)()
finally:
Expand Down
9 changes: 7 additions & 2 deletions dvc/repo/experiments/queue/tempdir.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Generator, Optional
from typing import TYPE_CHECKING, Dict, Generator, List, Optional

from funcy import first

Expand Down Expand Up @@ -92,7 +92,11 @@ def iter_active(self) -> Generator[QueueEntry, None, None]:
)

def _reproduce_entry(
self, entry: QueueEntry, executor: "BaseExecutor"
self,
entry: QueueEntry,
executor: "BaseExecutor",
copy_paths: Optional[List[str]] = None,
**kwargs,
) -> Dict[str, Dict[str, str]]:
from dvc.stage.monitor import CheckpointKilledError

Expand All @@ -107,6 +111,7 @@ def _reproduce_entry(
infofile=infofile,
log_level=logger.getEffectiveLevel(),
log_errors=True,
copy_paths=copy_paths,
)
if not exec_result.exp_hash:
raise DvcException(f"Failed to reproduce experiment '{rev[:7]}'")
Expand Down
12 changes: 9 additions & 3 deletions dvc/repo/experiments/queue/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class WorkspaceQueue(BaseStashQueue):
_EXEC_NAME: Optional[str] = "workspace"

def put(self, *args, **kwargs) -> QueueEntry:
kwargs.pop("copy_paths", None)
with get_exp_rwlock(self.repo, writes=["workspace", WORKSPACE_STASH]):
return self._stash_exp(*args, **kwargs)

Expand Down Expand Up @@ -81,19 +82,24 @@ def iter_failed(self) -> Generator["QueueDoneResult", None, None]:
def iter_success(self) -> Generator["QueueDoneResult", None, None]:
raise NotImplementedError

def reproduce(self) -> Dict[str, Dict[str, str]]:
def reproduce(
self, copy_paths: Optional[List[str]] = None
) -> Dict[str, Dict[str, str]]:
results: Dict[str, Dict[str, str]] = defaultdict(dict)
try:
while True:
entry, executor = self.get()
results.update(self._reproduce_entry(entry, executor))
results.update(
self._reproduce_entry(entry, executor, copy_paths=copy_paths)
)
except ExpQueueEmptyError:
pass
return results

def _reproduce_entry(
self, entry: QueueEntry, executor: "BaseExecutor"
self, entry: QueueEntry, executor: "BaseExecutor", **kwargs
) -> Dict[str, Dict[str, str]]:
kwargs.pop("copy_paths", None)
from dvc.stage.monitor import CheckpointKilledError

results: Dict[str, Dict[str, str]] = defaultdict(dict)
Expand Down
8 changes: 7 additions & 1 deletion dvc/repo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def run( # noqa: C901, PLR0912
jobs: int = 1,
tmp_dir: bool = False,
queue: bool = False,
copy_paths: Optional[Iterable[str]] = None,
**kwargs,
) -> Dict[str, str]:
"""Reproduce the specified targets as an experiment.
Expand Down Expand Up @@ -69,7 +70,11 @@ def run( # noqa: C901, PLR0912

if not queue:
return repo.experiments.reproduce_one(
targets=targets, params=path_overrides, tmp_dir=tmp_dir, **kwargs
targets=targets,
params=path_overrides,
tmp_dir=tmp_dir,
copy_paths=copy_paths,
**kwargs,
)

if hydra_sweep:
Expand All @@ -90,6 +95,7 @@ def run( # noqa: C901, PLR0912
repo.experiments.celery_queue,
targets=targets,
params=sweep_overrides,
copy_paths=copy_paths,
**kwargs,
)
if sweep_overrides:
Expand Down
42 changes: 41 additions & 1 deletion tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from funcy import first

from dvc.dvcfile import PROJECT_FILE
from dvc.exceptions import ReproductionError
from dvc.exceptions import DvcException, ReproductionError
from dvc.repo.experiments.exceptions import ExperimentExistsError
from dvc.repo.experiments.queue.base import BaseStashQueue
from dvc.repo.experiments.refs import CELERY_STASH
Expand Down Expand Up @@ -704,3 +704,43 @@ def test_untracked_top_level_files_are_included_in_exp(tmp_dir, scm, dvc, tmp):
fs = scm.get_fs(exp)
for file in ["metrics.json", "params.yaml", "plots.csv"]:
assert fs.exists(file)


@pytest.mark.parametrize("tmp", [True, False])
def test_copy_paths(tmp_dir, scm, dvc, tmp):
stage = dvc.stage.add(
cmd="cat file && ls dir",
name="foo",
)
scm.add_commit(["dvc.yaml"], message="add dvc.yaml")

(tmp_dir / "dir").mkdir()
(tmp_dir / "dir" / "file").write_text("dir/file")
scm.ignore(tmp_dir / "dir")
(tmp_dir / "file").write_text("file")
scm.ignore(tmp_dir / "file")

results = dvc.experiments.run(
stage.addressing, tmp_dir=tmp, copy_paths=["dir", "file"]
)
exp = first(results)
fs = scm.get_fs(exp)
assert not fs.exists("dir")
assert not fs.exists("file")


def test_copy_paths_errors(tmp_dir, scm, dvc, mocker):
stage = dvc.stage.add(
cmd="echo foo",
name="foo",
)
scm.add_commit(["dvc.yaml"], message="add dvc.yaml")

with pytest.raises(DvcException, match="Unable to copy"):
dvc.experiments.run(stage.addressing, tmp_dir=True, copy_paths=["foo"])

(tmp_dir / "foo").write_text("foo")
mocker.patch("shutil.copy", side_effect=OSError)

with pytest.raises(DvcException, match="Unable to copy"):
dvc.experiments.run(stage.addressing, tmp_dir=True, copy_paths=["foo"])
Loading