Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies:
- dm_control
- mujoco<3.3.6
- mlflow
- trackio
- av
- coverage
- ray
Expand Down
1 change: 1 addition & 0 deletions .github/unittest/linux_distributed/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies:
- dm_control
- mujoco<3.3.6
- mlflow
- trackio
- av
- coverage
- ray
Expand Down
1 change: 1 addition & 0 deletions .github/unittest/linux_sota/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies:
- dm_control
- mujoco<3.3.6
- mlflow
- trackio
- av
- coverage
- vmas
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ Loggers
csv.CSVLogger
mlflow.MLFlowLogger
tensorboard.TensorboardLogger
trackio.TrackioLogger
wandb.WandbLogger
get_logger
generate_exp_name
Expand Down
73 changes: 73 additions & 0 deletions test/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchrl.record.loggers.csv import CSVLogger
from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger
from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger
from torchrl.record.loggers.trackio import _has_trackio, TrackioLogger
from torchrl.record.loggers.wandb import _has_wandb, WandbLogger
from torchrl.record.recorder import PixelRenderTransform, VideoRecorder

Expand Down Expand Up @@ -455,6 +456,78 @@ def make_env():
env.close()


@pytest.fixture()
def trackio_logger():
exp_name = "ramala"
logger = TrackioLogger(project="test", exp_name=exp_name)
yield logger
logger.experiment.finish()
del logger


@pytest.mark.skipif(not _has_trackio, reason="trackio not installed")
class TestTrackioLogger:
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
def test_log_scalar(self, steps, trackio_logger):
torch.manual_seed(0)

values = torch.rand(3)
for i in range(3):
scalar_name = "foo"
scalar_value = values[i].item()
trackio_logger.log_scalar(
value=scalar_value,
name=scalar_name,
step=steps[i] if steps else None,
)

@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
def test_log_str(self, steps, trackio_logger):
for i in range(3):
trackio_logger.log_str(
name="foo",
value="bar",
step=steps[i] if steps else None,
)

def test_log_video(self, trackio_logger):
torch.manual_seed(0)

# creating a sample video (T, C, H, W), where T - number of frames,
# C - number of image channels (e.g. 3 for RGB), H, W - image dimensions.
# the first 64 frames are black and the next 64 are white
video = torch.cat(
(torch.zeros(128, 3, 32, 32), torch.full((128, 3, 32, 32), 255))
)
video = video[None, :]
trackio_logger.log_video(
name="foo",
video=video,
fps=4,
format="mp4",
)
trackio_logger.log_video(
name="foo_16fps",
video=video,
fps=16,
format="mp4",
)

def test_log_hparams(self, trackio_logger, config):
trackio_logger.log_hparams(config)
for key, value in config.items():
assert trackio_logger.experiment.config[key] == value

@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
def test_log_histogram(self, steps, trackio_logger):
torch.manual_seed(0)
for i in range(3):
data = torch.randn(100)
trackio_logger.log_histogram(
"hist", data, step=steps[i] if steps else None, bins=10
)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
163 changes: 163 additions & 0 deletions torchrl/record/loggers/trackio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import importlib.util

from collections.abc import Sequence

import numpy as np

from torch import Tensor

from .common import Logger

_has_trackio = importlib.util.find_spec("trackio") is not None
_has_omegaconf = importlib.util.find_spec("omegaconf") is not None


class TrackioLogger(Logger):
"""Wrapper for the trackio logger.

Args:
exp_name (str): The name of the experiment.
project (str): The name of the project.

Keyword Args:
fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``.
**kwargs: Extra keyword arguments for ``trackio.init``.

"""

@classmethod
def __new__(cls, *args, **kwargs):
return super().__new__(cls)

def __init__(
self,
exp_name: str,
project: str,
*,
video_fps: int = 32,
**kwargs,
) -> None:
if not _has_trackio:
raise ImportError("trackio could not be imported")

self.video_fps = video_fps
self._trackio_kwargs = {
"name": exp_name,
"project": project,
"resume": "allow",
**kwargs,
}

super().__init__(exp_name=exp_name, log_dir=project)

def _create_experiment(self):
"""Creates a trackio experiment.

Args:
exp_name (str): The name of the experiment.

Returns:
A trackio.Experiment object.
"""
if not _has_trackio:
raise ImportError("Trackio is not installed")
import trackio

return trackio.init(**self._trackio_kwargs)

def log_scalar(self, name: str, value: float, step: int | None = None) -> None:
"""Logs a scalar value to trackio.

Args:
name (str): The name of the scalar.
value (float): The value of the scalar.
step (int, optional): The step at which the scalar is logged.
Defaults to None.
"""
self.experiment.log({name: value}, step=step)

def log_video(self, name: str, video: Tensor, **kwargs) -> None:
"""Log videos inputs to trackio.

Args:
name (str): The name of the video.
video (Tensor): The video to be logged.
**kwargs: Other keyword arguments. By construction, log_video
supports 'step' (integer indicating the step index), 'format'
(default is 'mp4') and 'fps' (defaults to ``self.video_fps``). Other kwargs are
passed as-is to the :obj:`experiment.log` method.
"""
import trackio

fps = kwargs.pop("fps", self.video_fps)
format = kwargs.pop("format", "mp4")
self.experiment.log(
{
name: trackio.Video(
video.numpy().astype(np.uint8), fps=fps, format=format
)
},
**kwargs,
)

def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
"""Logs the hyperparameters of the experiment.

Args:
cfg (DictConfig or dict): The configuration of the experiment.

"""
if type(cfg) is not dict and _has_omegaconf:
if not _has_omegaconf:
raise ImportError(
"OmegaConf could not be imported. "
"Cannot log hydra configs without OmegaConf."
)
from omegaconf import OmegaConf

cfg = OmegaConf.to_container(cfg, resolve=True)
self.experiment.config.update(cfg)

def __repr__(self) -> str:
return f"TrackioLogger(experiment={self.experiment.__repr__()})"

def log_histogram(self, name: str, data: Sequence, **kwargs):
"""Add histogram to log.

Args:
name (str): Data identifier
data (torch.Tensor, numpy.ndarray): Values to build histogram

Keyword Args:
step (int): Global step value to record
bins (int): Number of bins to use for the histogram

"""
import trackio

num_bins = kwargs.pop("bins", None)
step = kwargs.pop("step", None)
self.experiment.log(
{name: trackio.Histogram(data, num_bins=num_bins)}, step=step
)

def log_str(self, name: str, value: str, step: int | None = None) -> None:
"""Logs a string value to trackio using a table format for better visualization.

Args:
name (str): The name of the string data.
value (str): The string value to log.
step (int, optional): The step at which the string is logged.
Defaults to None.
"""
import trackio

# Create a table with a single row
table = trackio.Table(columns=["text"], data=[[value]])
self.experiment.log({name: table}, step=step)
10 changes: 9 additions & 1 deletion torchrl/record/loggers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_logger(
If empty, ``None`` is returned.
logger_name (str): Name to be used as a log_dir
experiment_name (str): Name of the experiment
kwargs (dict[str]): might contain either `wandb_kwargs` or `mlflow_kwargs`
kwargs (dict[str]): might contain either `wandb_kwargs`, `mlflow_kwargs` or `trackio_kwargs`
"""
if logger_type == "tensorboard":
from torchrl.record.loggers.tensorboard import TensorboardLogger
Expand Down Expand Up @@ -63,6 +63,14 @@ def get_logger(
exp_name=experiment_name,
**mlflow_kwargs,
)
elif logger_type == "trackio":
from torchrl.record.loggers.trackio import TrackioLogger

trackio_kwargs = kwargs.get("trackio_kwargs", {})
project = trackio_kwargs.pop("project", "torchrl")
logger = TrackioLogger(
project=project, exp_name=experiment_name, **trackio_kwargs
)
elif logger_type in ("", None):
return None
else:
Expand Down
Loading