From 07668882ceffc2dc631e61d37195f0ed63ac5e1e Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Mon, 30 Jan 2023 18:50:24 +0100 Subject: [PATCH] live: Skip `live.end` calls inside context manager. Add `live._inside_with` boolean to handle it. The idea is to include additional data in a experiment when using frameworks: ```python from dvclive import Live from dvclive.keras import DVCLiveCallback with Live(save_dvc_exp=True) as live: live.log_param("foo", 2) model.fit( x, y, # Don't call live.end so additional stuff can be logged callbacks=[DVCLiveCallback(live=live)]) # So it is possible to include additional data in the experiment live.summary["out-of-loop-metric"] = 1 ``` --- src/dvclive/live.py | 6 ++++++ tests/test_main.py | 8 ++++++++ 2 files changed, 14 insertions(+) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index f34671fa..1822a3c5 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -55,6 +55,7 @@ def __init__( self._images: Dict[str, Any] = {} self._params: Dict[str, Any] = {} self._plots: Dict[str, Any] = {} + self._inside_with = False os.makedirs(self.dir, exist_ok=True) @@ -344,6 +345,9 @@ def make_report(self): open_file_in_browser(self.report_file) def end(self): + if self._inside_with: + # Prevent `live.end` calls inside context manager + return self.make_summary(update_step=False) if "done" not in self._studio_events_to_skip: response = False @@ -392,7 +396,9 @@ def read_latest(self): return json.load(fobj) def __enter__(self): + self._inside_with = True return self def __exit__(self, exc_type, exc_val, exc_tb): + self._inside_with = False self.end() diff --git a/tests/test_main.py b/tests/test_main.py index 8aa96894..ed362019 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -407,6 +407,14 @@ def test_context_manager(tmp_dir): assert report_file.exists() +def test_context_manager_skips_end_calls(tmp_dir): + with Live() as live: + live.summary["foo"] = 1.0 + live.end() + assert not (tmp_dir / live.metrics_file).exists() + assert (tmp_dir / live.metrics_file).exists() + + @pytest.mark.parametrize("dvc_root", [True, False]) @pytest.mark.parametrize("set_env", [True, False]) def test_create_checkpoint_file(tmp_dir, monkeypatch, dvc_root, set_env, mocker):