From 9231ec8078d30bda65e6f4ea6e32b7034c7e55ff Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Fri, 21 Apr 2023 12:44:56 +0200 Subject: [PATCH] Support `artifacts` section. - Update `make_dvcyaml` to write `artifacts` section. - Extend `log_artifact` to accept `type`, `name`, `desc`, `labels`, `meta`. --- src/dvclive/dvc.py | 13 +++++++-- src/dvclive/live.py | 29 +++++++++++++++++-- tests/test_log_artifact.py | 59 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 5 deletions(-) diff --git a/src/dvclive/dvc.py b/src/dvclive/dvc.py index 27dd0f28..b9c65e78 100644 --- a/src/dvclive/dvc.py +++ b/src/dvclive/dvc.py @@ -107,6 +107,14 @@ def make_dvcyaml(live): if plots: dvcyaml["plots"] = plots + if live._artifacts: + dvcyaml["artifacts"] = live._artifacts + for artifact in dvcyaml["artifacts"].values(): + abs_path = os.path.realpath(artifact["path"]) + abs_dir = os.path.realpath(live.dir) + relative_path = os.path.relpath(abs_path, abs_dir) + artifact["path"] = Path(relative_path).as_posix() + dump_yaml(dvcyaml, live.dvc_file) @@ -164,9 +172,10 @@ def get_dvc_stage_template(live): "cmd": "", "deps": [""], } - if live._outs: + if live._artifacts: stage["outs"] = [] - for o in live._outs: + for artifact in live._artifacts.values(): + o = artifact["path"] artifact_path = Path(os.getcwd()) / o artifact_path = artifact_path.relative_to(live._dvc_repo.root_dir).as_posix() stage["outs"].append(artifact_path) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index e627dd0a..0e9008c7 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -60,7 +60,7 @@ def __init__( self._images: Dict[str, Any] = {} self._params: Dict[str, Any] = {} self._plots: Dict[str, Any] = {} - self._outs: Set[StrPath] = set() + self._artifacts: Dict[str, Dict] = {} self._inside_with = False self._dvcyaml = dvcyaml @@ -321,19 +321,42 @@ def log_param(self, name: str, val: ParamLike): """Saves the given parameter value to yaml""" self.log_params({name: val}) - def log_artifact(self, path: StrPath): + def log_artifact( + self, + path: StrPath, + type: Optional[str] = None, # noqa: A002 + name: Optional[str] = None, + desc: Optional[str] = None, # noqa: ARG002 + labels: Optional[List[str]] = None, # noqa: ARG002 + meta: Optional[Dict[str, Any]] = None, # noqa: ARG002 + ): """Tracks a local file or directory with DVC""" if not isinstance(path, (str, Path)): raise InvalidDataTypeError(path, type(path)) if self._dvc_repo is not None: + from dvc.repo.artifacts import name_is_compatible + try: stage = self._dvc_repo.add(path) except Exception as e: # noqa: BLE001 logger.warning(f"Failed to dvc add {path}: {e}") return - self._outs.add(path) + name = name or Path(path).stem + if name_is_compatible(name): + self._artifacts[name] = { + k: v + for k, v in locals().items() + if k in ("path", "type", "desc", "labels", "meta") and v is not None + } + else: + logger.warning( + "Can't use '%s' as artifact name (ID)." + " It will not be included in the `artifacts` section.", + name, + ) + dvc_file = stage[0].addressing if self._save_dvc_exp: diff --git a/tests/test_log_artifact.py b/tests/test_log_artifact.py index d3b35a05..db24271b 100644 --- a/tests/test_log_artifact.py +++ b/tests/test_log_artifact.py @@ -1,4 +1,5 @@ from dvclive import Live +from dvclive.serialize import load_yaml def test_log_artifact(tmp_dir, dvc_repo): @@ -44,3 +45,61 @@ def test_log_artifact_with_save_dvc_exp(tmp_dir, mocker, mocked_dvc_repo): include_untracked=[live.dir, "data", ".gitignore"], force=True, ) + + +def test_log_artifact_type_model(tmp_dir, mocked_dvc_repo): + (tmp_dir / "model.pth").touch() + + with Live() as live: + live.log_artifact("model.pth", type="model") + + assert load_yaml(live.dvc_file) == { + "artifacts": {"model": {"path": "../model.pth", "type": "model"}} + } + + +def test_log_artifact_type_model_provided_name(tmp_dir, mocked_dvc_repo): + (tmp_dir / "model.pth").touch() + + with Live() as live: + live.log_artifact("model.pth", type="model", name="custom") + + assert load_yaml(live.dvc_file) == { + "artifacts": {"custom": {"path": "../model.pth", "type": "model"}} + } + + +def test_log_artifact_type_model_on_step(tmp_dir, mocked_dvc_repo): + (tmp_dir / "model.pth").touch() + + with Live() as live: + for _ in range(3): + live.log_artifact("model.pth", type="model") + live.next_step() + live.log_artifact("model.pth", type="model", labels=["final"]) + assert load_yaml(live.dvc_file) == { + "artifacts": { + "model": {"path": "../model.pth", "type": "model", "labels": ["final"]}, + }, + "metrics": ["metrics.json"], + } + + +def test_log_artifact_attrs(tmp_dir, mocked_dvc_repo): + (tmp_dir / "model.pth").touch() + + attrs = { + "type": "model", + "name": "foo", + "desc": "bar", + "labels": ["foo"], + "meta": {"foo": "bar"}, + } + with Live() as live: + live.log_artifact("model.pth", **attrs) + attrs.pop("name") + assert load_yaml(live.dvc_file) == { + "artifacts": { + "foo": {"path": "../model.pth", **attrs}, + } + }