diff --git a/src/dvclive/report.py b/src/dvclive/report.py index 90984010..b47b266b 100644 --- a/src/dvclive/report.py +++ b/src/dvclive/report.py @@ -1,3 +1,4 @@ +import base64 import json from pathlib import Path from typing import TYPE_CHECKING @@ -64,13 +65,17 @@ def get_scalar_renderers(metrics_path): return renderers -def get_image_renderers(images_folder): +def get_image_renderers(images_folder, report_mode): plots_path = images_folder.parent.parent renderers = [] for suffix in Image.suffixes: all_images = Path(images_folder).rglob(f"*{suffix}") for file in sorted(all_images): - src = str(file.relative_to(plots_path)) + if report_mode in {"html", "notebook"}: + base64_str = base64.b64encode(file.read_bytes()).decode() + src = f"data:image;base64,{base64_str}" + else: + src = str(file.relative_to(plots_path)) name = str(file.relative_to(images_folder)) data = [ { @@ -144,7 +149,9 @@ def make_report(live: "Live"): renderers.extend(get_params_renderers(live.params_file)) renderers.extend(get_metrics_renderers(live.metrics_file)) renderers.extend(get_scalar_renderers(plots_path / Metric.subfolder)) - renderers.extend(get_image_renderers(plots_path / Image.subfolder)) + renderers.extend( + get_image_renderers(plots_path / Image.subfolder, report_mode=live._report_mode) + ) renderers.extend(get_plot_renderers(plots_path / SKLearnPlot.subfolder, live)) if live._report_mode == "html": diff --git a/tests/test_report.py b/tests/test_report.py index 616ef29f..81897421 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -19,6 +19,24 @@ ) +@pytest.mark.parametrize("mode", ["html", "md", "notebook"]) +def test_get_image_renderers(tmp_dir, mode): + with Live() as live: + img = Image.new("RGB", (10, 10), (255, 0, 0)) + live.log_image("image.png", img) + + image_renderers = get_image_renderers( + tmp_dir / live.plots_dir / LiveImage.subfolder, report_mode=mode + ) + assert len(image_renderers) == 1 + img = image_renderers[0].datapoints[0] + if mode == "md": + assert img["src"] == os.path.join("plots", LiveImage.subfolder, "image.png") + else: + assert img["src"].startswith("data:image;base64,") + assert img["rev"] == "image.png" + + def test_get_renderers(tmp_dir, mocker): live = Live() @@ -27,21 +45,8 @@ def test_get_renderers(tmp_dir, mocker): for i in range(2): live.log_metric("foo/bar", i) - img = Image.new("RGB", (10, 10), (i, i, i)) - live.log_image("image.png", img) live.next_step() - image_renderers = get_image_renderers( - tmp_dir / live.plots_dir / LiveImage.subfolder - ) - assert len(image_renderers) == 1 - assert image_renderers[0].datapoints == [ - { - "src": os.path.join("plots", LiveImage.subfolder, "image.png"), - "rev": "image.png", - } - ] - scalar_renderers = get_scalar_renderers(tmp_dir / live.plots_dir / Metric.subfolder) assert len(scalar_renderers) == 1 assert scalar_renderers[0].datapoints == [