diff --git a/src/dvclive/live.py b/src/dvclive/live.py index f34671fa..1822a3c5 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -55,6 +55,7 @@ def __init__( self._images: Dict[str, Any] = {} self._params: Dict[str, Any] = {} self._plots: Dict[str, Any] = {} + self._inside_with = False os.makedirs(self.dir, exist_ok=True) @@ -344,6 +345,9 @@ def make_report(self): open_file_in_browser(self.report_file) def end(self): + if self._inside_with: + # Prevent `live.end` calls inside context manager + return self.make_summary(update_step=False) if "done" not in self._studio_events_to_skip: response = False @@ -392,7 +396,9 @@ def read_latest(self): return json.load(fobj) def __enter__(self): + self._inside_with = True return self def __exit__(self, exc_type, exc_val, exc_tb): + self._inside_with = False self.end() diff --git a/tests/test_main.py b/tests/test_main.py index 8aa96894..ed362019 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -407,6 +407,14 @@ def test_context_manager(tmp_dir): assert report_file.exists() +def test_context_manager_skips_end_calls(tmp_dir): + with Live() as live: + live.summary["foo"] = 1.0 + live.end() + assert not (tmp_dir / live.metrics_file).exists() + assert (tmp_dir / live.metrics_file).exists() + + @pytest.mark.parametrize("dvc_root", [True, False]) @pytest.mark.parametrize("set_env", [True, False]) def test_create_checkpoint_file(tmp_dir, monkeypatch, dvc_root, set_env, mocker):