Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions dvc/commands/experiments/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,45 @@

from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link
from dvc.exceptions import InvalidArgumentError
from dvc.ui import ui

logger = logging.getLogger(__name__)


class CmdExperimentsPull(CmdBase):
def raise_error_if_all_disabled(self):
if not any(
[self.args.experiment, self.args.all_commits, self.args.rev]
):
raise InvalidArgumentError(
"Either provide an `experiment` argument, or use the "
"`--rev` or `--all-commits` flag."
)

def run(self):
self.raise_error_if_all_disabled()

self.repo.experiments.pull(
pulled_exps = self.repo.experiments.pull(
self.args.git_remote,
self.args.experiment,
all_commits=self.args.all_commits,
rev=self.args.rev,
num=self.args.num,
force=self.args.force,
pull_cache=self.args.pull_cache,
dvc_remote=self.args.dvc_remote,
jobs=self.args.jobs,
run_cache=self.args.run_cache,
)

ui.write(
f"Pulled experiment '{self.args.experiment}'",
f"from Git remote '{self.args.git_remote}'.",
)
if pulled_exps:
ui.write(
f"Pulled experiment '{pulled_exps}'",
f"from Git remote '{self.args.git_remote}'.",
)
else:
ui.write("No experiments to pull.")
if not self.args.pull_cache:
ui.write(
"To pull cached outputs for this experiment"
Expand All @@ -36,6 +53,8 @@ def run(self):


def add_parser(experiments_subparsers, parent_parser):
from . import add_rev_selection_flags

EXPERIMENTS_PULL_HELP = "Pull an experiment from a Git remote."
experiments_pull_parser = experiments_subparsers.add_parser(
"pull",
Expand All @@ -44,6 +63,7 @@ def add_parser(experiments_subparsers, parent_parser):
help=EXPERIMENTS_PULL_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
add_rev_selection_flags(experiments_pull_parser, "Pull", False)
experiments_pull_parser.add_argument(
"-f",
"--force",
Expand Down Expand Up @@ -89,7 +109,8 @@ def add_parser(experiments_subparsers, parent_parser):
)
experiments_pull_parser.add_argument(
"experiment",
nargs="+",
nargs="*",
default=None,
help="Experiments to pull.",
metavar="<experiment>",
)
Expand Down
34 changes: 28 additions & 6 deletions dvc/commands/experiments/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,46 @@
from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link
from dvc.commands import completion
from dvc.exceptions import InvalidArgumentError
from dvc.ui import ui

logger = logging.getLogger(__name__)


class CmdExperimentsPush(CmdBase):
def raise_error_if_all_disabled(self):
if not any(
[self.args.experiment, self.args.all_commits, self.args.rev]
):
raise InvalidArgumentError(
"Either provide an `experiment` argument, or use the "
"`--rev` or `--all-commits` flag."
)

def run(self):

self.repo.experiments.push(
self.raise_error_if_all_disabled()

pushed_exps = self.repo.experiments.push(
self.args.git_remote,
self.args.experiment,
all_commits=self.args.all_commits,
rev=self.args.rev,
num=self.args.num,
force=self.args.force,
push_cache=self.args.push_cache,
dvc_remote=self.args.dvc_remote,
jobs=self.args.jobs,
run_cache=self.args.run_cache,
)

ui.write(
f"Pushed experiment '{self.args.experiment}'"
f"to Git remote '{self.args.git_remote}'."
)
if pushed_exps:
ui.write(
f"Pushed experiment '{pushed_exps}'"
f"to Git remote '{self.args.git_remote}'."
)
else:
ui.write("No experiments to pull.")
if not self.args.push_cache:
ui.write(
"To push cached outputs",
Expand All @@ -37,6 +55,8 @@ def run(self):


def add_parser(experiments_subparsers, parent_parser):
from . import add_rev_selection_flags

EXPERIMENTS_PUSH_HELP = "Push a local experiment to a Git remote."
experiments_push_parser = experiments_subparsers.add_parser(
"push",
Expand All @@ -45,6 +65,7 @@ def add_parser(experiments_subparsers, parent_parser):
help=EXPERIMENTS_PUSH_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
add_rev_selection_flags(experiments_push_parser, "Push", False)
experiments_push_parser.add_argument(
"-f",
"--force",
Expand Down Expand Up @@ -90,7 +111,8 @@ def add_parser(experiments_subparsers, parent_parser):
)
experiments_push_parser.add_argument(
"experiment",
nargs="+",
nargs="*",
default=None,
help="Experiments to push.",
metavar="<experiment>",
).complete = completion.EXPERIMENT
Expand Down
71 changes: 44 additions & 27 deletions dvc/repo/experiments/pull.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
from typing import Iterable, Union
from typing import Iterable, Optional, Set, Union

from dvc.exceptions import DvcException
from dvc.repo import locked
from dvc.repo.scm_context import scm_context
from dvc.scm import TqdmGit
from dvc.scm import TqdmGit, iter_revs

from .base import ExpRefInfo
from .exceptions import UnresolvedExpNamesError
from .utils import exp_commits, resolve_name
from .utils import exp_commits, exp_refs, exp_refs_by_baseline, resolve_name

logger = logging.getLogger(__name__)

Expand All @@ -18,33 +19,50 @@ def pull(
repo,
git_remote: str,
exp_names: Union[Iterable[str], str],
*args,
all_commits=False,
rev: Optional[str] = None,
num=1,
force: bool = False,
pull_cache: bool = False,
**kwargs,
):
if isinstance(exp_names, str):
exp_names = [exp_names]
exp_ref_dict = resolve_name(repo.scm, exp_names, git_remote)
unresolved_exp_names = [
exp_name
for exp_name, exp_ref in exp_ref_dict.items()
if exp_ref is None
]
if unresolved_exp_names:
raise UnresolvedExpNamesError(unresolved_exp_names)
) -> Iterable[str]:
exp_ref_set: Set["ExpRefInfo"] = set()
if all_commits:
exp_ref_set.update(exp_refs(repo.scm, git_remote))
else:
if exp_names:
if isinstance(exp_names, str):
exp_names = [exp_names]
exp_ref_dict = resolve_name(repo.scm, exp_names, git_remote)

unresolved_exp_names = []
for exp_name, exp_ref in exp_ref_dict.items():
if exp_ref is None:
unresolved_exp_names.append(exp_name)
else:
exp_ref_set.add(exp_ref)

if unresolved_exp_names:
raise UnresolvedExpNamesError(unresolved_exp_names)

if rev:
rev_dict = iter_revs(repo.scm, [rev], num)
rev_set = set(rev_dict.keys())
ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)
for _, ref_info_list in ref_info_dict.items():
exp_ref_set.update(ref_info_list)

exp_ref_set = exp_ref_dict.values()
_pull(repo, git_remote, exp_ref_set, force, pull_cache, **kwargs)
_pull(repo, git_remote, exp_ref_set, force)
if pull_cache:
_pull_cache(repo, exp_ref_set, **kwargs)
return [ref.name for ref in exp_ref_set]


def _pull(
repo,
git_remote: str,
exp_refs,
refs,
force: bool,
pull_cache: bool,
**kwargs,
):
def on_diverged(refname: str, rev: str) -> bool:
if repo.scm.get_ref(refname) == rev:
Expand All @@ -56,7 +74,7 @@ def on_diverged(refname: str, rev: str) -> bool:
"re-run with '--force'."
)

refspec_list = [f"{exp_ref}:{exp_ref}" for exp_ref in exp_refs]
refspec_list = [f"{exp_ref}:{exp_ref}" for exp_ref in refs]
logger.debug(f"git pull experiment '{git_remote}' -> '{refspec_list}'")

with TqdmGit(desc="Fetching git refs") as pbar:
Expand All @@ -68,20 +86,19 @@ def on_diverged(refname: str, rev: str) -> bool:
progress=pbar.update_git,
)

if pull_cache:
_pull_cache(repo, exp_refs, **kwargs)


def _pull_cache(
repo,
exp_ref,
refs: Union[ExpRefInfo, Iterable["ExpRefInfo"]],
dvc_remote=None,
jobs=None,
run_cache=False,
odb=None,
):
revs = list(exp_commits(repo.scm, exp_ref))
logger.debug(f"dvc fetch experiment '{exp_ref}'")
if isinstance(refs, ExpRefInfo):
refs = [refs]
revs = list(exp_commits(repo.scm, refs))
logger.debug(f"dvc fetch experiment '{refs}'")
repo.fetch(
jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs, odb=odb
)
Loading