diff --git a/dvc/commands/experiments/show.py b/dvc/commands/experiments/show.py index b85d63ff0e..555f12e93a 100644 --- a/dvc/commands/experiments/show.py +++ b/dvc/commands/experiments/show.py @@ -453,8 +453,12 @@ def _normalize_headers(names, count): def _format_json(item): + from dvc.repo.experiments.show import _CachedError + if isinstance(item, (date, datetime)): return item.isoformat() + if isinstance(item, _CachedError): + return {"type": getattr(item, "typ", "_CachedError"), "msg": str(item)} return encode_exception(item) @@ -472,6 +476,7 @@ def run(self): sha_only=self.args.sha, param_deps=self.args.param_deps, fetch_running=self.args.fetch_running, + force=self.args.force, ) except DvcException: logger.exception("failed to show experiments") @@ -650,4 +655,10 @@ def add_parser(experiments_subparsers, parent_parser): action="store_false", help=argparse.SUPPRESS, ) + experiments_show_parser.add_argument( + "-f", + "--force", + action="store_true", + help="Force re-collection of experiments instead of loading from exp cache.", + ) experiments_show_parser.set_defaults(func=CmdExperimentsShow) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 20c26f7844..c4c74d95aa 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -11,6 +11,7 @@ from dvc.utils import relpath from dvc.utils.objects import cached_property +from .cache import ExpCache from .exceptions import ( BaselineMismatchError, ExperimentExistsError, @@ -92,6 +93,10 @@ def celery_queue(self) -> LocalCeleryQueue: def apply_stash(self) -> ApplyStash: return ApplyStash(self.scm, APPLY_STASH) + @cached_property + def cache(self) -> ExpCache: + return ExpCache(self.repo) + @property def stash_revs(self) -> Dict[str, "ExpStashEntry"]: revs = {} diff --git a/dvc/repo/experiments/cache.py b/dvc/repo/experiments/cache.py new file mode 100644 index 0000000000..ce986d8653 --- /dev/null +++ b/dvc/repo/experiments/cache.py @@ -0,0 +1,72 @@ +import logging +import os +from typing import TYPE_CHECKING, Optional, Union + +from dvc.fs import localfs +from dvc_objects.db import ObjectDB + +from .serialize import DeserializeError, SerializableError, SerializableExp +from .utils import EXEC_TMP_DIR + +if TYPE_CHECKING: + from dvc.repo import Repo + +logger = logging.getLogger(__name__) + + +class ExpCache: + """Serialized experiment state cache. + + ODB with git SHAs as keys. Objects can be either SerializableExp or + SerializableError. + """ + + CACHE_DIR = os.path.join(EXEC_TMP_DIR, "cache") + + def __init__(self, repo: "Repo"): + path = os.path.join(repo.tmp_dir, self.CACHE_DIR) + self.odb = ObjectDB(localfs, path) + + def delete(self, rev: str): + self.odb.delete(rev) + + def put( + self, + exp: Union[SerializableExp, SerializableError], + rev: Optional[str] = None, + force: bool = False, + ): + rev = rev or getattr(exp, "rev", None) + assert rev + assert rev != "workspace" + if force or not self.odb.exists(rev): + try: + self.delete(rev) + except FileNotFoundError: + pass + self.odb.add_bytes(rev, exp.as_bytes()) + logger.trace( # type: ignore[attr-defined] + "ExpCache: cache put '%s'", rev[:7] + ) + + def get(self, rev: str) -> Optional[Union[SerializableExp, SerializableError]]: + obj = self.odb.get(rev) + try: + with obj.fs.open(obj.path, "rb") as fobj: + data = fobj.read() + except FileNotFoundError: + logger.trace( # type: ignore[attr-defined] + "ExpCache: cache miss '%s'", rev[:7] + ) + return None + for typ in (SerializableExp, SerializableError): + try: + exp = typ.from_bytes(data) # type: ignore[attr-defined] + logger.trace( # type: ignore[attr-defined] + "ExpCache: cache load '%s'", rev[:7] + ) + return exp + except DeserializeError: + continue + logger.debug("ExpCache: unknown object type for '%s'", rev) + return None diff --git a/dvc/repo/experiments/serialize.py b/dvc/repo/experiments/serialize.py new file mode 100644 index 0000000000..06c0238e64 --- /dev/null +++ b/dvc/repo/experiments/serialize.py @@ -0,0 +1,154 @@ +import json +from dataclasses import asdict, dataclass, field +from datetime import datetime +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional + +from dvc.exceptions import DvcException +from dvc.repo.metrics.show import _gather_metrics +from dvc.repo.params.show import _gather_params +from dvc.utils import onerror_collect, relpath + +if TYPE_CHECKING: + from dvc.repo import Repo + + +class DeserializeError(DvcException): + pass + + +class _ISOEncoder(json.JSONEncoder): + def default(self, o: object) -> Any: + if isinstance(o, datetime): + return o.isoformat() + return super().default(o) + + +@dataclass(frozen=True) +class SerializableExp: + """Serializable experiment state.""" + + rev: str + timestamp: Optional[datetime] = None + params: Dict[str, Any] = field(default_factory=dict) + metrics: Dict[str, Any] = field(default_factory=dict) + deps: Dict[str, "_ExpDep"] = field(default_factory=dict) + outs: Dict[str, "_ExpOut"] = field(default_factory=dict) + status: Optional[str] = None + executor: Optional[str] = None + error: Optional["SerializableError"] = None + + @classmethod + def from_repo( + cls, + repo: "Repo", + rev: Optional[str] = None, + onerror: Optional[Callable] = None, + **kwargs, + ) -> "SerializableExp": + """Returns a SerializableExp from the current repo state. + + Params, metrics, deps, outs are filled via repo fs/index, all other fields + should be passed via kwargs. + """ + from dvc.dependency import ParamsDependency, RepoDependency + + if not onerror: + onerror = onerror_collect + + rev = rev or repo.get_rev() + assert rev + status: Optional[str] = kwargs.get("status") + # NOTE: _gather_params/_gather_metrics return defaultdict which is not + # supported in dataclasses.asdict() on all python releases + # see https://bugs.python.org/issue35540 + params = dict(_gather_params(repo, onerror=onerror)) + if status and status.lower() in ("queued", "failed"): + metrics: Dict[str, Any] = {} + else: + metrics = dict( + _gather_metrics( + repo, + targets=None, + rev=rev[:7], + recursive=False, + onerror=onerror_collect, + ) + ) + return cls( + rev=rev, + params=params, + metrics=metrics, + deps={ + relpath(dep.fs_path, repo.root_dir): _ExpDep( + hash=dep.hash_info.value if dep.hash_info else None, + size=dep.meta.size if dep.meta else None, + nfiles=dep.meta.nfiles if dep.meta else None, + ) + for dep in repo.index.deps + if not isinstance(dep, (ParamsDependency, RepoDependency)) + }, + outs={ + relpath(out.fs_path, repo.root_dir): _ExpOut( + hash=out.hash_info.value if out.hash_info else None, + size=out.meta.size if out.meta else None, + nfiles=out.meta.nfiles if out.meta else None, + use_cache=out.use_cache, + is_data_source=out.stage.is_data_source, + ) + for out in repo.index.outs + if not (out.is_metric or out.is_plot) + }, + **kwargs, + ) + + def dumpd(self) -> Dict[str, Any]: + return asdict(self) + + def as_bytes(self) -> bytes: + return _ISOEncoder().encode(self.dumpd()).encode("utf-8") + + @classmethod + def from_bytes(cls, data: bytes): + try: + parsed = json.loads(data) + if "timestamp" in parsed: + parsed["timestamp"] = datetime.fromisoformat(parsed["timestamp"]) + return cls(**parsed) + except (TypeError, json.JSONDecodeError) as exc: + raise DeserializeError("failed to load SerializableExp") from exc + + +@dataclass(frozen=True) +class _ExpDep: + hash: Optional[str] # noqa: A003 + size: Optional[int] + nfiles: Optional[int] + + +@dataclass(frozen=True) +class _ExpOut: + hash: Optional[str] # noqa: A003 + size: Optional[int] + nfiles: Optional[int] + use_cache: bool + is_data_source: bool + + +@dataclass(frozen=True) +class SerializableError: + msg: str + type: str = "" # noqa: A003 + + def dumpd(self) -> Dict[str, Any]: + return asdict(self) + + def as_bytes(self) -> bytes: + return json.dumps(self.dumpd()).encode("utf-8") + + @classmethod + def from_bytes(cls, data: bytes): + try: + parsed = json.loads(data) + return cls(**parsed) + except (TypeError, json.JSONDecodeError) as exc: + raise DeserializeError("failed to load SerializableError") from exc diff --git a/dvc/repo/experiments/show.py b/dvc/repo/experiments/show.py index 9917e4f709..ab1052a883 100644 --- a/dvc/repo/experiments/show.py +++ b/dvc/repo/experiments/show.py @@ -18,12 +18,11 @@ from scmrepo.exceptions import SCMError as InnerScmError -from dvc.repo.metrics.show import _gather_metrics -from dvc.repo.params.show import _gather_params -from dvc.scm import Git, SCMError, iter_revs, resolve_rev -from dvc.utils import error_handler, onerror_collect, relpath +from dvc.exceptions import DvcException +from dvc.scm import Git, RevError, SCMError, iter_revs, resolve_rev from .refs import ExpRefInfo +from .serialize import SerializableError, SerializableExp if TYPE_CHECKING: from scmrepo.git.objects import GitCommit @@ -40,85 +39,113 @@ class ExpStatus(Enum): Failed = 3 +class _CachedError(DvcException): + def __init__(self, msg, typ, *args): + super().__init__(msg, *args) + self.typ = typ + + def _is_scm_error(collected_exp: Dict[str, Any]) -> bool: if "error" in collected_exp and ( - isinstance(collected_exp["error"], (SCMError, InnerScmError)) + isinstance(collected_exp["error"], (_CachedError, SCMError, InnerScmError)) ): return True return False -def _show_onerror_collect(result: Dict, exception: Exception, *args, **kwargs): - onerror_collect(result, exception, *args, **kwargs) - result["data"] = {} +def _format_exp(exp: SerializableExp) -> Dict[str, Any]: + # SerializableExp always includes error but we need to strip it from show + # output when it is false-y to maintain compatibility w/tools that consume + # json output and assume that "error" key presence means there was an error + exp_dict = exp.dumpd() + if "error" in exp_dict and not exp_dict["error"]: + del exp_dict["error"] + return {"data": exp_dict} + + +def _format_error(error: SerializableError): + msg = error.msg or "None" + return {"data": {}, "error": _CachedError(msg, error.type)} -@error_handler def collect_experiment_commit( repo: "Repo", exp_rev: str, status: ExpStatus = ExpStatus.Success, - param_deps=False, - running: Optional[Dict[str, Any]] = None, - onerror: Optional[Callable] = None, + param_deps: bool = False, + force: bool = False, + **kwargs, ) -> Dict[str, Any]: - from dvc.dependency import ParamsDependency, RepoDependency - - result: Dict[str, Any] = defaultdict(dict) + cache = repo.experiments.cache + # TODO: support filtering serialized exp when param_deps is set + if exp_rev != "workspace" and not (force or param_deps): + cached_exp = cache.get(exp_rev) + if cached_exp: + if status == ExpStatus.Running: + # expire cached queued exp entry once we start running it + cache.delete(exp_rev) + elif isinstance(cached_exp, SerializableError): + return _format_error(cached_exp) + else: + return _format_exp(cached_exp) + try: + exp = _collect_from_repo( + repo, + exp_rev, + status=status, + param_deps=param_deps, + force=force, + **kwargs, + ) + if exp_rev != "workspace" and not param_deps: + cache.put(exp, force=True) + return _format_exp(exp) + except Exception as exc: # noqa: BLE001, pylint: disable=broad-except + logger.debug("", exc_info=True) + error = SerializableError(str(exc), type(exc).__name__) + if not (exp_rev == "workspace" or param_deps or status == ExpStatus.Running): + cache.put(error, rev=exp_rev, force=True) + return _format_error(error) + + +def _collect_from_repo( + repo: "Repo", + exp_rev: str, + status: ExpStatus = ExpStatus.Success, + running: Optional[Dict[str, Any]] = None, + **kwargs, +) -> "SerializableExp": running = running or {} for rev in repo.brancher(revs=[exp_rev]): if rev == "workspace": if exp_rev != "workspace": continue - result["timestamp"] = None + timestamp: Optional[datetime] = None else: commit = repo.scm.resolve_commit(rev) - result["timestamp"] = datetime.fromtimestamp(commit.commit_time) - - params = _gather_params(repo, targets=None, deps=param_deps, onerror=onerror) - if params: - result["params"] = params - - result["deps"] = { - relpath(dep.fs_path, repo.root_dir): { - "hash": dep.hash_info.value, - "size": dep.meta.size, - "nfiles": dep.meta.nfiles, - } - for dep in repo.index.deps - if not isinstance(dep, (ParamsDependency, RepoDependency)) - } - result["outs"] = { - relpath(out.fs_path, repo.root_dir): { - "hash": out.hash_info.value, - "size": out.meta.size, - "nfiles": out.meta.nfiles, - "use_cache": out.use_cache, - "is_data_source": out.stage.is_data_source, - } - for out in repo.index.outs - if not (out.is_metric or out.is_plot) - } - - result["status"] = status.name + timestamp = datetime.fromtimestamp(commit.commit_time) + if status == ExpStatus.Running: - result["executor"] = running.get(exp_rev, {}).get("location", None) + executor: Optional[str] = running.get(exp_rev, {}).get("location", None) else: - result["executor"] = None + executor = None if status == ExpStatus.Failed: - result["error"] = { - "msg": "Experiment run failed.", - "type": "", - } - - if status not in {ExpStatus.Queued, ExpStatus.Failed}: - vals = _gather_metrics( - repo, targets=None, rev=rev, recursive=False, onerror=onerror + error: Optional["SerializableError"] = SerializableError( + "Experiment run failed." ) - result["metrics"] = vals + else: + error = None - return result + return SerializableExp.from_repo( + repo, + rev=exp_rev, + timestamp=timestamp, + status=status.name, + executor=executor, + error=error, + ) + raise RevError(f"nonexistent exp rev: '{exp_rev}'") def _collect_complete_experiment( @@ -396,12 +423,11 @@ def show( # noqa: PLR0913 param_deps=False, onerror: Optional[Callable] = None, fetch_running: bool = True, + force: bool = False, ): if repo.scm.no_commits: return {} - onerror = onerror or _show_onerror_collect - res: Dict[str, Dict] = defaultdict(OrderedDict) if not any([revs, all_branches, all_tags, all_commits]): @@ -427,7 +453,7 @@ def show( # noqa: PLR0913 found_revs, running, param_deps=param_deps, - onerror=onerror, + force=force, ) if not hide_queued else {} @@ -439,6 +465,7 @@ def show( # noqa: PLR0913 running, param_deps=param_deps, onerror=onerror, + force=force, ) failed_experiments = ( @@ -448,6 +475,7 @@ def show( # noqa: PLR0913 running, param_deps=param_deps, onerror=onerror, + force=force, ) if not hide_failed else {} @@ -460,6 +488,7 @@ def show( # noqa: PLR0913 running=running, param_deps=param_deps, onerror=onerror, + force=force, ) update_new(res, failed_experiments) diff --git a/tests/func/experiments/test_show.py b/tests/func/experiments/test_show.py index 12057878eb..7949001be9 100644 --- a/tests/func/experiments/test_show.py +++ b/tests/func/experiments/test_show.py @@ -12,6 +12,7 @@ from dvc.repo.experiments.executor.base import BaseExecutor, ExecutorInfo, TaskStatus from dvc.repo.experiments.queue.base import QueueEntry from dvc.repo.experiments.refs import CELERY_STASH, ExpRefInfo +from dvc.repo.experiments.show import _CachedError from dvc.repo.experiments.utils import EXEC_PID_DIR, EXEC_TMP_DIR, exp_refs_by_rev from dvc.utils import relpath from dvc.utils.serialize import YAMLFileCorruptedError @@ -71,6 +72,7 @@ def test_show_simple(tmp_dir, scm, dvc, exp_stage): assert dvc.experiments.show()["workspace"] == { "baseline": { "data": { + "rev": "workspace", "deps": { "copy.py": { "hash": ANY, @@ -102,6 +104,7 @@ def test_show_experiment(tmp_dir, scm, dvc, exp_stage, workspace): expected_baseline = { "data": { + "rev": ANY, "deps": { "copy.py": { "hash": ANY, @@ -177,6 +180,7 @@ def test_show_failed_experiment(tmp_dir, scm, dvc, failed_exp_stage): expected_baseline = { "data": { + "rev": ANY, "deps": { "copy.py": { "hash": ANY, @@ -196,9 +200,11 @@ def test_show_failed_experiment(tmp_dir, scm, dvc, failed_exp_stage): expected_failed = { "data": { + "rev": ANY, "timestamp": ANY, "params": {"params.yaml": {"data": {"foo": 2}}}, "deps": {"copy.py": {"hash": None, "size": None, "nfiles": None}}, + "metrics": {}, "outs": {}, "status": "Failed", "executor": None, @@ -432,6 +438,7 @@ def test_show_running_workspace( assert dvc.experiments.show().get("workspace") == { "baseline": { "data": { + "rev": "workspace", "deps": { "copy.py": { "hash": ANY, @@ -608,7 +615,8 @@ def test_show_with_broken_repo(tmp_dir, scm, dvc, exp_stage, caplog): assert get_in(baseline[rev2], paths) == {"data": {"foo": 3}} paths = ["workspace", "baseline", "error"] - assert isinstance(get_in(result, paths), YAMLFileCorruptedError) + assert isinstance(get_in(result, paths), (_CachedError, YAMLFileCorruptedError)) + assert "YAML file structure is corrupted" in str(get_in(result, paths)) def test_show_csv(tmp_dir, scm, dvc, exp_stage, capsys): @@ -1040,7 +1048,7 @@ def resolve_commit(rev): "resolve_commit", side_effect=mocker.MagicMock(side_effect=resolve_commit), ) - results = dvc.experiments.show()[baseline_rev] + results = dvc.experiments.show(force=True)[baseline_rev] assert len(results) == 1 @@ -1066,4 +1074,4 @@ def resolve_commit(rev): experiments = dvc.experiments.show()[baseline_rev] assert len(experiments) == 1 assert experiments["baseline"]["data"] == {"name": branch} - assert isinstance(experiments["baseline"]["error"], SCMError) + assert isinstance(experiments["baseline"]["error"], (SCMError, _CachedError)) diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 3233920d0a..be667c5fd5 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -100,6 +100,7 @@ def test_experiments_show(dvc, scm, mocker): "1", "--rev", "foo", + "--force", ] ) assert cli_args.func == CmdExperimentsShow @@ -121,6 +122,7 @@ def test_experiments_show(dvc, scm, mocker): sha_only=True, param_deps=True, fetch_running=True, + force=True, )