Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/dvclive/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def make_dvcyaml(live):
if plots:
dvcyaml["plots"] = plots

if live._artifacts:
dvcyaml["artifacts"] = live._artifacts
for artifact in dvcyaml["artifacts"].values():
abs_path = os.path.realpath(artifact["path"])
abs_dir = os.path.realpath(live.dir)
relative_path = os.path.relpath(abs_path, abs_dir)
artifact["path"] = Path(relative_path).as_posix()

dump_yaml(dvcyaml, live.dvc_file)


Expand Down Expand Up @@ -164,9 +172,10 @@ def get_dvc_stage_template(live):
"cmd": "<python my_code_file.py my_args>",
"deps": ["<my_code_file.py>"],
}
if live._outs:
if live._artifacts:
stage["outs"] = []
for o in live._outs:
for artifact in live._artifacts.values():
o = artifact["path"]
artifact_path = Path(os.getcwd()) / o
artifact_path = artifact_path.relative_to(live._dvc_repo.root_dir).as_posix()
stage["outs"].append(artifact_path)
Expand Down
29 changes: 26 additions & 3 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self._images: Dict[str, Any] = {}
self._params: Dict[str, Any] = {}
self._plots: Dict[str, Any] = {}
self._outs: Set[StrPath] = set()
self._artifacts: Dict[str, Dict] = {}
self._inside_with = False
self._dvcyaml = dvcyaml

Expand Down Expand Up @@ -321,19 +321,42 @@ def log_param(self, name: str, val: ParamLike):
"""Saves the given parameter value to yaml"""
self.log_params({name: val})

def log_artifact(self, path: StrPath):
def log_artifact(
self,
path: StrPath,
type: Optional[str] = None, # noqa: A002
name: Optional[str] = None,
desc: Optional[str] = None, # noqa: ARG002
labels: Optional[List[str]] = None, # noqa: ARG002
meta: Optional[Dict[str, Any]] = None, # noqa: ARG002
):
"""Tracks a local file or directory with DVC"""
if not isinstance(path, (str, Path)):
raise InvalidDataTypeError(path, type(path))

if self._dvc_repo is not None:
from dvc.repo.artifacts import name_is_compatible

try:
stage = self._dvc_repo.add(path)
except Exception as e: # noqa: BLE001
logger.warning(f"Failed to dvc add {path}: {e}")
return

self._outs.add(path)
name = name or Path(path).stem
if name_is_compatible(name):
self._artifacts[name] = {
k: v
for k, v in locals().items()
if k in ("path", "type", "desc", "labels", "meta") and v is not None
}
else:
logger.warning(
"Can't use '%s' as artifact name (ID)."
" It will not be included in the `artifacts` section.",
name,
)

dvc_file = stage[0].addressing

if self._save_dvc_exp:
Expand Down
59 changes: 59 additions & 0 deletions tests/test_log_artifact.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dvclive import Live
from dvclive.serialize import load_yaml


def test_log_artifact(tmp_dir, dvc_repo):
Expand Down Expand Up @@ -44,3 +45,61 @@ def test_log_artifact_with_save_dvc_exp(tmp_dir, mocker, mocked_dvc_repo):
include_untracked=[live.dir, "data", ".gitignore"],
force=True,
)


def test_log_artifact_type_model(tmp_dir, mocked_dvc_repo):
(tmp_dir / "model.pth").touch()

with Live() as live:
live.log_artifact("model.pth", type="model")

assert load_yaml(live.dvc_file) == {
"artifacts": {"model": {"path": "../model.pth", "type": "model"}}
}


def test_log_artifact_type_model_provided_name(tmp_dir, mocked_dvc_repo):
(tmp_dir / "model.pth").touch()

with Live() as live:
live.log_artifact("model.pth", type="model", name="custom")

assert load_yaml(live.dvc_file) == {
"artifacts": {"custom": {"path": "../model.pth", "type": "model"}}
}


def test_log_artifact_type_model_on_step(tmp_dir, mocked_dvc_repo):
(tmp_dir / "model.pth").touch()

with Live() as live:
for _ in range(3):
live.log_artifact("model.pth", type="model")
live.next_step()
live.log_artifact("model.pth", type="model", labels=["final"])
assert load_yaml(live.dvc_file) == {
"artifacts": {
"model": {"path": "../model.pth", "type": "model", "labels": ["final"]},
},
"metrics": ["metrics.json"],
}


def test_log_artifact_attrs(tmp_dir, mocked_dvc_repo):
(tmp_dir / "model.pth").touch()

attrs = {
"type": "model",
"name": "foo",
"desc": "bar",
"labels": ["foo"],
"meta": {"foo": "bar"},
}
with Live() as live:
live.log_artifact("model.pth", **attrs)
attrs.pop("name")
assert load_yaml(live.dvc_file) == {
"artifacts": {
"foo": {"path": "../model.pth", **attrs},
}
}