diff --git a/src/dvclive/plots/image.py b/src/dvclive/plots/image.py index f914d921..49204cab 100644 --- a/src/dvclive/plots/image.py +++ b/src/dvclive/plots/image.py @@ -15,10 +15,13 @@ def output_path(self) -> Path: @staticmethod def could_log(val: object) -> bool: - if val.__class__.__module__ == "PIL.Image": - return True - if val.__class__.__module__ == "numpy": - return True + acceptable = { + ("numpy", "ndarray"), + ("PIL.Image", "Image"), + } + for cls in type(val).mro(): + if (cls.__module__, cls.__name__) in acceptable: + return True if isinstance(val, (PurePath, str)): return True return False diff --git a/tests/plots/test_image.py b/tests/plots/test_image.py index 22eb97ef..8a277dc3 100644 --- a/tests/plots/test_image.py +++ b/tests/plots/test_image.py @@ -6,6 +6,15 @@ from dvclive.plots import Image as LiveImage +# From https://stackoverflow.com/questions/5165317/how-can-i-extend-image-class +class ExtendedImage(Image.Image): + def __init__(self, img): + self._img = img + + def __getattr__(self, key): + return getattr(self._img, key) + + def test_pil(tmp_dir): live = Live() img = Image.new("RGB", (10, 10), (250, 250, 250)) @@ -82,3 +91,12 @@ def test_cleanup(tmp_dir): Live() assert not (tmp_dir / live.plots_dir / LiveImage.subfolder).exists() + + +def test_custom_class(tmp_dir): + live = Live() + img = Image.new("RGB", (10, 10), (250, 250, 250)) + extended_img = ExtendedImage(img) + live.log_image("image.png", extended_img) + + assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists()