From a983306898960784247b881d7d52bb6bcd5d0295 Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Tue, 14 Oct 2025 15:20:27 +0200 Subject: [PATCH 1/9] TrackioLogger --- torchrl/record/loggers/trackio.py | 146 ++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 torchrl/record/loggers/trackio.py diff --git a/torchrl/record/loggers/trackio.py b/torchrl/record/loggers/trackio.py new file mode 100644 index 00000000000..4f899a66cf3 --- /dev/null +++ b/torchrl/record/loggers/trackio.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import importlib.util + +from collections.abc import Sequence + +from torch import Tensor + +from .common import Logger + +_has_trackio = importlib.util.find_spec("trackio") is not None +_has_omegaconf = importlib.util.find_spec("omegaconf") is not None + + +class TrackioLogger(Logger): + """Wrapper for the trackio logger. + + Args: + exp_name (str): The name of the experiment. + project (str): The name of the project. + + Keyword Args: + fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``. + **kwargs: Extra keyword arguments for ``trackio.init``. + + """ + + @classmethod + def __new__(cls, *args, **kwargs): + return super().__new__(cls) + + def __init__( + self, + exp_name: str, + project: str, + *, + video_fps: int = 32, + **kwargs, + ) -> None: + if not _has_trackio: + raise ImportError("trackio could not be imported") + + self.video_fps = video_fps + self._trackio_kwargs = { + "name": exp_name, + "project": project, + "resume": "allow", + **kwargs, + } + + super().__init__(exp_name=exp_name, log_dir=project) + + def _create_experiment(self): + """Creates a trackio experiment. + + Args: + exp_name (str): The name of the experiment. + + Returns: + A trackio.Experiment object. + """ + if not _has_trackio: + raise ImportError("Trackio is not installed") + import trackio + + return trackio.init(**self._trackio_kwargs) + + def log_scalar( + self, name: str, value: float, step: int | None = None + ) -> None: + """Logs a scalar value to trackio. + + Args: + name (str): The name of the scalar. + value (float): The value of the scalar. + step (int, optional): The step at which the scalar is logged. + Defaults to None. + """ + self.experiment.log({name: value}, step=step) + + def log_video(self, name: str, video: Tensor, **kwargs) -> None: + """Log videos inputs to trackio. + + Args: + name (str): The name of the video. + video (Tensor): The video to be logged. + **kwargs: Other keyword arguments. By construction, log_video + supports 'step' (integer indicating the step index), 'format' + (default is 'mp4') and 'fps' (defaults to ``self.video_fps``). Other kwargs are + passed as-is to the :obj:`experiment.log` method. + """ + import trackio + + fps = kwargs.pop("fps", self.video_fps) + format = kwargs.pop("format", "mp4") + self.experiment.log( + {name: trackio.Video(video, fps=fps, format=format)}, + **kwargs, + ) + + def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 + """Logs the hyperparameters of the experiment. + + Args: + cfg (DictConfig or dict): The configuration of the experiment. + + """ + if type(cfg) is not dict and _has_omegaconf: + if not _has_omegaconf: + raise ImportError( + "OmegaConf could not be imported. " + "Cannot log hydra configs without OmegaConf." + ) + from omegaconf import OmegaConf + + cfg = OmegaConf.to_container(cfg, resolve=True) + self.experiment.config.update(cfg) + + def __repr__(self) -> str: + return f"TrackioLogger(experiment={self.experiment.__repr__()})" + + def log_histogram(self, name: str, data: Sequence, **kwargs): + raise NotImplementedError("Logging histograms in trackio is not permitted.") + + def log_str(self, name: str, value: str, step: int | None = None) -> None: + """Logs a string value to trackio using a table format for better visualization. + + Args: + name (str): The name of the string data. + value (str): The string value to log. + step (int, optional): The step at which the string is logged. + Defaults to None. + """ + import trackio + + # Create a table with a single row + table = trackio.Table(columns=["text"], data=[[value]]) + + if step is not None: + self.experiment.log({name: value}, step=step) + else: + self.experiment.log({name: table}) From 9c1af687a00af6611a30b9d0edc0ae4bacfc5273 Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Tue, 14 Oct 2025 15:20:39 +0200 Subject: [PATCH 2/9] test TrackioLogger --- test/test_loggers.py | 60 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/test/test_loggers.py b/test/test_loggers.py index 3ddcd6b5a5e..f263b0a6173 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -21,6 +21,7 @@ from torchrl.record.loggers.csv import CSVLogger from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger +from torchrl.record.loggers.trackio import _has_trackio, TrackioLogger from torchrl.record.loggers.wandb import _has_wandb, WandbLogger from torchrl.record.recorder import PixelRenderTransform, VideoRecorder @@ -455,6 +456,65 @@ def make_env(): env.close() +@pytest.fixture(scope="function") +def trackio_logger(): + exp_name = "ramala" + logger = TrackioLogger(project="test", exp_name=exp_name) + yield logger + logger.experiment.finish() + del logger + + +@pytest.mark.skipif(not _has_trackio, reason="trackio not installed") +class TestTrackioLogger: + @pytest.mark.parametrize("steps", [None, [1, 10, 11]]) + def test_log_scalar(self, steps, trackio_logger): + torch.manual_seed(0) + + values = torch.rand(3) + for i in range(3): + scalar_name = "foo" + scalar_value = values[i].item() + trackio_logger.log_scalar( + value=scalar_value, + name=scalar_name, + step=steps[i] if steps else None, + ) + + def test_log_video(self, trackio_logger): + torch.manual_seed(0) + + # creating a sample video (T, C, H, W), where T - number of frames, + # C - number of image channels (e.g. 3 for RGB), H, W - image dimensions. + # the first 64 frames are black and the next 64 are white + video = torch.cat( + (torch.zeros(128, 1, 32, 32), torch.full((128, 1, 32, 32), 255)) + ) + video = video[None, :] + trackio_logger.log_video( + name="foo", + video=video, + fps=4, + format="mp4", + ) + trackio_logger.log_video( + name="foo_16fps", + video=video, + fps=16, + format="mp4", + ) + + def test_log_hparams(self, trackio_logger, config): + trackio_logger.log_hparams(config) + for key, value in config.items(): + assert trackio_logger.experiment.config[key] == value + + def test_log_histogram(self, trackio_logger): + with pytest.raises(NotImplementedError): + data = torch.randn(10) + trackio_logger.log_histogram("hist", data, step=0, bins=2) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From f169669c24b0655a05175e8222678e234afe6bee Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Tue, 14 Oct 2025 17:31:46 +0200 Subject: [PATCH 3/9] format --- torchrl/record/loggers/trackio.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchrl/record/loggers/trackio.py b/torchrl/record/loggers/trackio.py index 4f899a66cf3..0d4758a50c6 100644 --- a/torchrl/record/loggers/trackio.py +++ b/torchrl/record/loggers/trackio.py @@ -69,9 +69,7 @@ def _create_experiment(self): return trackio.init(**self._trackio_kwargs) - def log_scalar( - self, name: str, value: float, step: int | None = None - ) -> None: + def log_scalar(self, name: str, value: float, step: int | None = None) -> None: """Logs a scalar value to trackio. Args: From 0aead87fcd4d4eca60e8ed910e9eb749f6ad76ac Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Tue, 14 Oct 2025 18:05:57 +0200 Subject: [PATCH 4/9] trackio option in get_logger --- torchrl/record/loggers/utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torchrl/record/loggers/utils.py b/torchrl/record/loggers/utils.py index 5fe443db301..08d65e3c675 100644 --- a/torchrl/record/loggers/utils.py +++ b/torchrl/record/loggers/utils.py @@ -35,7 +35,7 @@ def get_logger( If empty, ``None`` is returned. logger_name (str): Name to be used as a log_dir experiment_name (str): Name of the experiment - kwargs (dict[str]): might contain either `wandb_kwargs` or `mlflow_kwargs` + kwargs (dict[str]): might contain either `wandb_kwargs`, `mlflow_kwargs` or `trackio_kwargs` """ if logger_type == "tensorboard": from torchrl.record.loggers.tensorboard import TensorboardLogger @@ -63,6 +63,14 @@ def get_logger( exp_name=experiment_name, **mlflow_kwargs, ) + elif logger_type == "trackio": + from torchrl.record.loggers.trackio import TrackioLogger + + trackio_kwargs = kwargs.get("trackio_kwargs", {}) + project = trackio_kwargs.pop("project", "torchrl") + logger = TrackioLogger( + project=project, exp_name=experiment_name, **trackio_kwargs + ) elif logger_type in ("", None): return None else: From 52329a301597cdeb369784201c8f6adcfccc315b Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Wed, 15 Oct 2025 10:53:44 +0200 Subject: [PATCH 5/9] ensure numpy --- test/test_loggers.py | 2 +- torchrl/record/loggers/trackio.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/test/test_loggers.py b/test/test_loggers.py index f263b0a6173..45ec5bcc9fa 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -488,7 +488,7 @@ def test_log_video(self, trackio_logger): # C - number of image channels (e.g. 3 for RGB), H, W - image dimensions. # the first 64 frames are black and the next 64 are white video = torch.cat( - (torch.zeros(128, 1, 32, 32), torch.full((128, 1, 32, 32), 255)) + (torch.zeros(128, 3, 32, 32), torch.full((128, 3, 32, 32), 255)) ) video = video[None, :] trackio_logger.log_video( diff --git a/torchrl/record/loggers/trackio.py b/torchrl/record/loggers/trackio.py index 0d4758a50c6..0d1b18eff6b 100644 --- a/torchrl/record/loggers/trackio.py +++ b/torchrl/record/loggers/trackio.py @@ -8,6 +8,8 @@ from collections.abc import Sequence +import numpy as np + from torch import Tensor from .common import Logger @@ -96,7 +98,11 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: fps = kwargs.pop("fps", self.video_fps) format = kwargs.pop("format", "mp4") self.experiment.log( - {name: trackio.Video(video, fps=fps, format=format)}, + { + name: trackio.Video( + video.numpy().astype(np.uint8), fps=fps, format=format + ) + }, **kwargs, ) From 654a2d2f69fe7f792df8e34da686743e73926b2d Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Wed, 15 Oct 2025 11:34:43 +0200 Subject: [PATCH 6/9] no need the function scpe --- test/test_loggers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_loggers.py b/test/test_loggers.py index 45ec5bcc9fa..e06d765b8a2 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -456,7 +456,7 @@ def make_env(): env.close() -@pytest.fixture(scope="function") +@pytest.fixture() def trackio_logger(): exp_name = "ramala" logger = TrackioLogger(project="test", exp_name=exp_name) From 8791d9841ce2646e85c6183711b132eb5d0426f4 Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Wed, 15 Oct 2025 11:36:06 +0200 Subject: [PATCH 7/9] add TrackioLogger to docs --- docs/source/reference/trainers.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index c47436d11a8..b086bde3c07 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -402,6 +402,7 @@ Loggers csv.CSVLogger mlflow.MLFlowLogger tensorboard.TensorboardLogger + trackio.TrackioLogger wandb.WandbLogger get_logger generate_exp_name From 51969d2185da1b68a37adfee1691e0446e7f4f14 Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Wed, 15 Oct 2025 11:38:25 +0200 Subject: [PATCH 8/9] add trackio in test deps --- .github/unittest/linux/scripts/environment.yml | 1 + .github/unittest/linux_distributed/scripts/environment.yml | 1 + .github/unittest/linux_sota/scripts/environment.yml | 1 + 3 files changed, 3 insertions(+) diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index 3283867e9bc..5b82885f967 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -29,6 +29,7 @@ dependencies: - dm_control - mujoco<3.3.6 - mlflow + - trackio - av - coverage - ray diff --git a/.github/unittest/linux_distributed/scripts/environment.yml b/.github/unittest/linux_distributed/scripts/environment.yml index 2eac1112692..432eb99020c 100644 --- a/.github/unittest/linux_distributed/scripts/environment.yml +++ b/.github/unittest/linux_distributed/scripts/environment.yml @@ -28,6 +28,7 @@ dependencies: - dm_control - mujoco<3.3.6 - mlflow + - trackio - av - coverage - ray diff --git a/.github/unittest/linux_sota/scripts/environment.yml b/.github/unittest/linux_sota/scripts/environment.yml index 848720a7bbb..a3ad87752f7 100644 --- a/.github/unittest/linux_sota/scripts/environment.yml +++ b/.github/unittest/linux_sota/scripts/environment.yml @@ -25,6 +25,7 @@ dependencies: - dm_control - mujoco<3.3.6 - mlflow + - trackio - av - coverage - vmas From d892fe38478b99c3c61676b24090ed888d77c899 Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Thu, 16 Oct 2025 10:49:51 +0200 Subject: [PATCH 9/9] trackio histograms and str --- test/test_loggers.py | 21 +++++++++++++++++---- torchrl/record/loggers/trackio.py | 25 +++++++++++++++++++------ 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/test/test_loggers.py b/test/test_loggers.py index e06d765b8a2..4991f3619db 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -481,6 +481,15 @@ def test_log_scalar(self, steps, trackio_logger): step=steps[i] if steps else None, ) + @pytest.mark.parametrize("steps", [None, [1, 10, 11]]) + def test_log_str(self, steps, trackio_logger): + for i in range(3): + trackio_logger.log_str( + name="foo", + value="bar", + step=steps[i] if steps else None, + ) + def test_log_video(self, trackio_logger): torch.manual_seed(0) @@ -509,10 +518,14 @@ def test_log_hparams(self, trackio_logger, config): for key, value in config.items(): assert trackio_logger.experiment.config[key] == value - def test_log_histogram(self, trackio_logger): - with pytest.raises(NotImplementedError): - data = torch.randn(10) - trackio_logger.log_histogram("hist", data, step=0, bins=2) + @pytest.mark.parametrize("steps", [None, [1, 10, 11]]) + def test_log_histogram(self, steps, trackio_logger): + torch.manual_seed(0) + for i in range(3): + data = torch.randn(100) + trackio_logger.log_histogram( + "hist", data, step=steps[i] if steps else None, bins=10 + ) if __name__ == "__main__": diff --git a/torchrl/record/loggers/trackio.py b/torchrl/record/loggers/trackio.py index 0d1b18eff6b..67c094d9609 100644 --- a/torchrl/record/loggers/trackio.py +++ b/torchrl/record/loggers/trackio.py @@ -128,7 +128,24 @@ def __repr__(self) -> str: return f"TrackioLogger(experiment={self.experiment.__repr__()})" def log_histogram(self, name: str, data: Sequence, **kwargs): - raise NotImplementedError("Logging histograms in trackio is not permitted.") + """Add histogram to log. + + Args: + name (str): Data identifier + data (torch.Tensor, numpy.ndarray): Values to build histogram + + Keyword Args: + step (int): Global step value to record + bins (int): Number of bins to use for the histogram + + """ + import trackio + + num_bins = kwargs.pop("bins", None) + step = kwargs.pop("step", None) + self.experiment.log( + {name: trackio.Histogram(data, num_bins=num_bins)}, step=step + ) def log_str(self, name: str, value: str, step: int | None = None) -> None: """Logs a string value to trackio using a table format for better visualization. @@ -143,8 +160,4 @@ def log_str(self, name: str, value: str, step: int | None = None) -> None: # Create a table with a single row table = trackio.Table(columns=["text"], data=[[value]]) - - if step is not None: - self.experiment.log({name: value}, step=step) - else: - self.experiment.log({name: table}) + self.experiment.log({name: table}, step=step)