Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions dvc/commands/experiments/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,12 @@ def _normalize_headers(names, count):


def _format_json(item):
from dvc.repo.experiments.show import _CachedError

if isinstance(item, (date, datetime)):
return item.isoformat()
if isinstance(item, _CachedError):
return {"type": getattr(item, "typ", "_CachedError"), "msg": str(item)}
return encode_exception(item)


Expand All @@ -472,6 +476,7 @@ def run(self):
sha_only=self.args.sha,
param_deps=self.args.param_deps,
fetch_running=self.args.fetch_running,
force=self.args.force,
)
except DvcException:
logger.exception("failed to show experiments")
Expand Down Expand Up @@ -650,4 +655,10 @@ def add_parser(experiments_subparsers, parent_parser):
action="store_false",
help=argparse.SUPPRESS,
)
experiments_show_parser.add_argument(
"-f",
"--force",
action="store_true",
help="Force re-collection of experiments instead of loading from exp cache.",
)
experiments_show_parser.set_defaults(func=CmdExperimentsShow)
5 changes: 5 additions & 0 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dvc.utils import relpath
from dvc.utils.objects import cached_property

from .cache import ExpCache
from .exceptions import (
BaselineMismatchError,
ExperimentExistsError,
Expand Down Expand Up @@ -92,6 +93,10 @@ def celery_queue(self) -> LocalCeleryQueue:
def apply_stash(self) -> ApplyStash:
return ApplyStash(self.scm, APPLY_STASH)

@cached_property
def cache(self) -> ExpCache:
return ExpCache(self.repo)

@property
def stash_revs(self) -> Dict[str, "ExpStashEntry"]:
revs = {}
Expand Down
72 changes: 72 additions & 0 deletions dvc/repo/experiments/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import logging
import os
from typing import TYPE_CHECKING, Optional, Union

from dvc.fs import localfs
from dvc_objects.db import ObjectDB

from .serialize import DeserializeError, SerializableError, SerializableExp
from .utils import EXEC_TMP_DIR

if TYPE_CHECKING:
from dvc.repo import Repo

logger = logging.getLogger(__name__)


class ExpCache:
"""Serialized experiment state cache.
ODB with git SHAs as keys. Objects can be either SerializableExp or
SerializableError.
"""

CACHE_DIR = os.path.join(EXEC_TMP_DIR, "cache")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we version the cache? In case we add some fields, or make backward incompatible change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something that I considered, but the current implementation just invalidates the entry if it contains something unexpected, so in the event that the serialized fields change, the newer version of dvc will just re-collect the commit and overwrite the old cache entry.


def __init__(self, repo: "Repo"):
path = os.path.join(repo.tmp_dir, self.CACHE_DIR)
self.odb = ObjectDB(localfs, path)

def delete(self, rev: str):
self.odb.delete(rev)

def put(
self,
exp: Union[SerializableExp, SerializableError],
rev: Optional[str] = None,
force: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need force?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR uses force=True all the time right now in show but there's some additional scenarios that aren't covered yet where we won't want force (when collecting checkpoints from active task queue runs)

):
rev = rev or getattr(exp, "rev", None)
assert rev
assert rev != "workspace"
if force or not self.odb.exists(rev):
try:
self.delete(rev)
except FileNotFoundError:
pass
Comment on lines +43 to +46
Copy link
Collaborator

@skshetry skshetry Feb 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a need to delete this? IIRC add_bytes overwrites.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it overwrites on linux/mac but not windows. We end up hitting LocalFileSystem.upload_fobj which uses os.rename
https://github.com/iterative/dvc-objects/blob/00ec978f5c55944471fcbf35e47272e4401c5193/src/dvc_objects/fs/local.py#L207

self.odb.add_bytes(rev, exp.as_bytes())
logger.trace( # type: ignore[attr-defined]
"ExpCache: cache put '%s'", rev[:7]
)

def get(self, rev: str) -> Optional[Union[SerializableExp, SerializableError]]:
obj = self.odb.get(rev)
try:
with obj.fs.open(obj.path, "rb") as fobj:
data = fobj.read()
except FileNotFoundError:
logger.trace( # type: ignore[attr-defined]
"ExpCache: cache miss '%s'", rev[:7]
)
return None
for typ in (SerializableExp, SerializableError):
try:
exp = typ.from_bytes(data) # type: ignore[attr-defined]
logger.trace( # type: ignore[attr-defined]
"ExpCache: cache load '%s'", rev[:7]
)
return exp
except DeserializeError:
continue
logger.debug("ExpCache: unknown object type for '%s'", rev)
return None
154 changes: 154 additions & 0 deletions dvc/repo/experiments/serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import json
from dataclasses import asdict, dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional

from dvc.exceptions import DvcException
from dvc.repo.metrics.show import _gather_metrics
from dvc.repo.params.show import _gather_params
from dvc.utils import onerror_collect, relpath

if TYPE_CHECKING:
from dvc.repo import Repo


class DeserializeError(DvcException):
pass


class _ISOEncoder(json.JSONEncoder):
def default(self, o: object) -> Any:
if isinstance(o, datetime):
return o.isoformat()
return super().default(o)


@dataclass(frozen=True)
class SerializableExp:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nature of our JSON format is arbitrary. It might be easier to just save json format directly and return that directly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I do like structure, but our format is arbitrary)

Copy link
Contributor Author

@pmrowla pmrowla Feb 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one of those places I'd like to move away from using json/yaml as our default serialization format (when it doesn't need to be human readable) in favor of something that's not text based and is faster. Keeping it structured makes it easier to do that, and also dealing with nested dicts of dicts everywhere in params/metrics/exp show is a lot harder to follow vs having proper dataclasses.

I am wondering if we want to start moving to attrs in core dvc though? We use attrs in some subprojects and dataclasses in some other subprojects and core dvc, but I wasn't sure if we ever made a final decision on using one vs the other.

"""Serializable experiment state."""

rev: str
timestamp: Optional[datetime] = None
params: Dict[str, Any] = field(default_factory=dict)
metrics: Dict[str, Any] = field(default_factory=dict)
deps: Dict[str, "_ExpDep"] = field(default_factory=dict)
outs: Dict[str, "_ExpOut"] = field(default_factory=dict)
status: Optional[str] = None
executor: Optional[str] = None
error: Optional["SerializableError"] = None

@classmethod
def from_repo(
cls,
repo: "Repo",
rev: Optional[str] = None,
onerror: Optional[Callable] = None,
**kwargs,
) -> "SerializableExp":
"""Returns a SerializableExp from the current repo state.
Params, metrics, deps, outs are filled via repo fs/index, all other fields
should be passed via kwargs.
"""
from dvc.dependency import ParamsDependency, RepoDependency

if not onerror:
onerror = onerror_collect

rev = rev or repo.get_rev()
assert rev
status: Optional[str] = kwargs.get("status")
# NOTE: _gather_params/_gather_metrics return defaultdict which is not
# supported in dataclasses.asdict() on all python releases
# see https://bugs.python.org/issue35540
params = dict(_gather_params(repo, onerror=onerror))
if status and status.lower() in ("queued", "failed"):
metrics: Dict[str, Any] = {}
else:
metrics = dict(
_gather_metrics(
repo,
targets=None,
rev=rev[:7],
recursive=False,
onerror=onerror_collect,
)
)
return cls(
rev=rev,
params=params,
metrics=metrics,
deps={
relpath(dep.fs_path, repo.root_dir): _ExpDep(
hash=dep.hash_info.value if dep.hash_info else None,
size=dep.meta.size if dep.meta else None,
nfiles=dep.meta.nfiles if dep.meta else None,
)
for dep in repo.index.deps
if not isinstance(dep, (ParamsDependency, RepoDependency))
},
outs={
relpath(out.fs_path, repo.root_dir): _ExpOut(
hash=out.hash_info.value if out.hash_info else None,
size=out.meta.size if out.meta else None,
nfiles=out.meta.nfiles if out.meta else None,
use_cache=out.use_cache,
is_data_source=out.stage.is_data_source,
)
for out in repo.index.outs
if not (out.is_metric or out.is_plot)
},
**kwargs,
)

def dumpd(self) -> Dict[str, Any]:
return asdict(self)

def as_bytes(self) -> bytes:
return _ISOEncoder().encode(self.dumpd()).encode("utf-8")

@classmethod
def from_bytes(cls, data: bytes):
try:
parsed = json.loads(data)
if "timestamp" in parsed:
parsed["timestamp"] = datetime.fromisoformat(parsed["timestamp"])
return cls(**parsed)
except (TypeError, json.JSONDecodeError) as exc:
raise DeserializeError("failed to load SerializableExp") from exc


@dataclass(frozen=True)
class _ExpDep:
hash: Optional[str] # noqa: A003
size: Optional[int]
nfiles: Optional[int]


@dataclass(frozen=True)
class _ExpOut:
hash: Optional[str] # noqa: A003
size: Optional[int]
nfiles: Optional[int]
use_cache: bool
is_data_source: bool


@dataclass(frozen=True)
class SerializableError:
msg: str
type: str = "" # noqa: A003

def dumpd(self) -> Dict[str, Any]:
return asdict(self)

def as_bytes(self) -> bytes:
return json.dumps(self.dumpd()).encode("utf-8")

@classmethod
def from_bytes(cls, data: bytes):
try:
parsed = json.loads(data)
return cls(**parsed)
except (TypeError, json.JSONDecodeError) as exc:
raise DeserializeError("failed to load SerializableError") from exc
Loading