-
Notifications
You must be signed in to change notification settings - Fork 1.3k
exp show: cache collected experiments by git revision #9069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| import logging | ||
| import os | ||
| from typing import TYPE_CHECKING, Optional, Union | ||
|
|
||
| from dvc.fs import localfs | ||
| from dvc_objects.db import ObjectDB | ||
|
|
||
| from .serialize import DeserializeError, SerializableError, SerializableExp | ||
| from .utils import EXEC_TMP_DIR | ||
|
|
||
| if TYPE_CHECKING: | ||
| from dvc.repo import Repo | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class ExpCache: | ||
| """Serialized experiment state cache. | ||
| ODB with git SHAs as keys. Objects can be either SerializableExp or | ||
| SerializableError. | ||
| """ | ||
|
|
||
| CACHE_DIR = os.path.join(EXEC_TMP_DIR, "cache") | ||
|
|
||
| def __init__(self, repo: "Repo"): | ||
| path = os.path.join(repo.tmp_dir, self.CACHE_DIR) | ||
| self.odb = ObjectDB(localfs, path) | ||
|
|
||
| def delete(self, rev: str): | ||
| self.odb.delete(rev) | ||
|
|
||
| def put( | ||
| self, | ||
| exp: Union[SerializableExp, SerializableError], | ||
| rev: Optional[str] = None, | ||
| force: bool = False, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The PR uses |
||
| ): | ||
| rev = rev or getattr(exp, "rev", None) | ||
| assert rev | ||
| assert rev != "workspace" | ||
| if force or not self.odb.exists(rev): | ||
| try: | ||
| self.delete(rev) | ||
| except FileNotFoundError: | ||
| pass | ||
|
Comment on lines
+43
to
+46
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a need to delete this? IIRC
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it overwrites on linux/mac but not windows. We end up hitting |
||
| self.odb.add_bytes(rev, exp.as_bytes()) | ||
| logger.trace( # type: ignore[attr-defined] | ||
| "ExpCache: cache put '%s'", rev[:7] | ||
| ) | ||
|
|
||
| def get(self, rev: str) -> Optional[Union[SerializableExp, SerializableError]]: | ||
| obj = self.odb.get(rev) | ||
| try: | ||
| with obj.fs.open(obj.path, "rb") as fobj: | ||
| data = fobj.read() | ||
| except FileNotFoundError: | ||
| logger.trace( # type: ignore[attr-defined] | ||
| "ExpCache: cache miss '%s'", rev[:7] | ||
| ) | ||
| return None | ||
| for typ in (SerializableExp, SerializableError): | ||
| try: | ||
| exp = typ.from_bytes(data) # type: ignore[attr-defined] | ||
| logger.trace( # type: ignore[attr-defined] | ||
| "ExpCache: cache load '%s'", rev[:7] | ||
| ) | ||
| return exp | ||
| except DeserializeError: | ||
| continue | ||
| logger.debug("ExpCache: unknown object type for '%s'", rev) | ||
| return None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| import json | ||
| from dataclasses import asdict, dataclass, field | ||
| from datetime import datetime | ||
| from typing import TYPE_CHECKING, Any, Callable, Dict, Optional | ||
|
|
||
| from dvc.exceptions import DvcException | ||
| from dvc.repo.metrics.show import _gather_metrics | ||
| from dvc.repo.params.show import _gather_params | ||
| from dvc.utils import onerror_collect, relpath | ||
|
|
||
| if TYPE_CHECKING: | ||
| from dvc.repo import Repo | ||
|
|
||
|
|
||
| class DeserializeError(DvcException): | ||
| pass | ||
|
|
||
|
|
||
| class _ISOEncoder(json.JSONEncoder): | ||
| def default(self, o: object) -> Any: | ||
| if isinstance(o, datetime): | ||
| return o.isoformat() | ||
| return super().default(o) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class SerializableExp: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The nature of our JSON format is arbitrary. It might be easier to just save json format directly and return that directly.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (I do like structure, but our format is arbitrary)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is one of those places I'd like to move away from using json/yaml as our default serialization format (when it doesn't need to be human readable) in favor of something that's not text based and is faster. Keeping it structured makes it easier to do that, and also dealing with nested dicts of dicts everywhere in params/metrics/exp show is a lot harder to follow vs having proper dataclasses. I am wondering if we want to start moving to |
||
| """Serializable experiment state.""" | ||
|
|
||
| rev: str | ||
| timestamp: Optional[datetime] = None | ||
| params: Dict[str, Any] = field(default_factory=dict) | ||
| metrics: Dict[str, Any] = field(default_factory=dict) | ||
| deps: Dict[str, "_ExpDep"] = field(default_factory=dict) | ||
| outs: Dict[str, "_ExpOut"] = field(default_factory=dict) | ||
| status: Optional[str] = None | ||
| executor: Optional[str] = None | ||
| error: Optional["SerializableError"] = None | ||
|
|
||
| @classmethod | ||
| def from_repo( | ||
| cls, | ||
| repo: "Repo", | ||
| rev: Optional[str] = None, | ||
| onerror: Optional[Callable] = None, | ||
| **kwargs, | ||
| ) -> "SerializableExp": | ||
| """Returns a SerializableExp from the current repo state. | ||
| Params, metrics, deps, outs are filled via repo fs/index, all other fields | ||
| should be passed via kwargs. | ||
| """ | ||
| from dvc.dependency import ParamsDependency, RepoDependency | ||
|
|
||
| if not onerror: | ||
| onerror = onerror_collect | ||
|
|
||
| rev = rev or repo.get_rev() | ||
| assert rev | ||
| status: Optional[str] = kwargs.get("status") | ||
| # NOTE: _gather_params/_gather_metrics return defaultdict which is not | ||
| # supported in dataclasses.asdict() on all python releases | ||
| # see https://bugs.python.org/issue35540 | ||
| params = dict(_gather_params(repo, onerror=onerror)) | ||
| if status and status.lower() in ("queued", "failed"): | ||
| metrics: Dict[str, Any] = {} | ||
| else: | ||
| metrics = dict( | ||
| _gather_metrics( | ||
| repo, | ||
| targets=None, | ||
| rev=rev[:7], | ||
| recursive=False, | ||
| onerror=onerror_collect, | ||
| ) | ||
| ) | ||
| return cls( | ||
| rev=rev, | ||
| params=params, | ||
| metrics=metrics, | ||
| deps={ | ||
| relpath(dep.fs_path, repo.root_dir): _ExpDep( | ||
| hash=dep.hash_info.value if dep.hash_info else None, | ||
| size=dep.meta.size if dep.meta else None, | ||
| nfiles=dep.meta.nfiles if dep.meta else None, | ||
| ) | ||
| for dep in repo.index.deps | ||
| if not isinstance(dep, (ParamsDependency, RepoDependency)) | ||
| }, | ||
| outs={ | ||
| relpath(out.fs_path, repo.root_dir): _ExpOut( | ||
| hash=out.hash_info.value if out.hash_info else None, | ||
| size=out.meta.size if out.meta else None, | ||
| nfiles=out.meta.nfiles if out.meta else None, | ||
| use_cache=out.use_cache, | ||
| is_data_source=out.stage.is_data_source, | ||
| ) | ||
| for out in repo.index.outs | ||
| if not (out.is_metric or out.is_plot) | ||
| }, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| def dumpd(self) -> Dict[str, Any]: | ||
| return asdict(self) | ||
|
|
||
| def as_bytes(self) -> bytes: | ||
| return _ISOEncoder().encode(self.dumpd()).encode("utf-8") | ||
|
|
||
| @classmethod | ||
| def from_bytes(cls, data: bytes): | ||
| try: | ||
| parsed = json.loads(data) | ||
| if "timestamp" in parsed: | ||
| parsed["timestamp"] = datetime.fromisoformat(parsed["timestamp"]) | ||
| return cls(**parsed) | ||
| except (TypeError, json.JSONDecodeError) as exc: | ||
| raise DeserializeError("failed to load SerializableExp") from exc | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class _ExpDep: | ||
| hash: Optional[str] # noqa: A003 | ||
| size: Optional[int] | ||
| nfiles: Optional[int] | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class _ExpOut: | ||
| hash: Optional[str] # noqa: A003 | ||
| size: Optional[int] | ||
| nfiles: Optional[int] | ||
| use_cache: bool | ||
| is_data_source: bool | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class SerializableError: | ||
| msg: str | ||
| type: str = "" # noqa: A003 | ||
|
|
||
| def dumpd(self) -> Dict[str, Any]: | ||
| return asdict(self) | ||
|
|
||
| def as_bytes(self) -> bytes: | ||
| return json.dumps(self.dumpd()).encode("utf-8") | ||
|
|
||
| @classmethod | ||
| def from_bytes(cls, data: bytes): | ||
| try: | ||
| parsed = json.loads(data) | ||
| return cls(**parsed) | ||
| except (TypeError, json.JSONDecodeError) as exc: | ||
| raise DeserializeError("failed to load SerializableError") from exc | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we version the cache? In case we add some fields, or make backward incompatible change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is something that I considered, but the current implementation just invalidates the entry if it contains something unexpected, so in the event that the serialized fields change, the newer version of dvc will just re-collect the commit and overwrite the old cache entry.