From 595ef55f4674fe5cbef8c9d1692827557574fbb4 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Feb 2024 17:45:54 -0800 Subject: [PATCH 1/2] init --- test/test_collector.py | 13 ++++++-- test/test_postprocs.py | 15 ++++++++-- torchrl/collectors/utils.py | 60 ++++++++++++++++++++++++------------- 3 files changed, 64 insertions(+), 24 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index bebbd103bc7..f63c6fd7dd4 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -5,6 +5,7 @@ from __future__ import annotations import argparse +import functools import gc import sys @@ -591,7 +592,15 @@ def env_fn(seed): @pytest.mark.skipif(not _has_gym, reason="gym library is not installed") @pytest.mark.parametrize("parallel", [False, True]) -def test_collector_env_reset(parallel): +@pytest.mark.parametrize( + "constr", + [ + functools.partial(split_trajectories, prefix="collector"), + functools.partial(split_trajectories), + functools.partial(split_trajectories, trajectory_key=("collector", "traj_ids")), + ], +) +def test_collector_env_reset(constr, parallel): torch.manual_seed(0) def make_env(): @@ -627,7 +636,7 @@ def make_env(): # check that if step is 1, then the env was done before assert (steps == 1)[done].all() # check that split traj has a minimum total reward of -21 (for pong only) - _data = split_trajectories(_data, prefix="collector") + _data = constr(_data) assert _data["next", "reward"].sum(-2).min() == -21 finally: env.close() diff --git a/test/test_postprocs.py b/test/test_postprocs.py index 10a559d3cac..c3cba371167 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse +import functools import pytest import torch @@ -276,12 +277,22 @@ def create_fake_trajs( @pytest.mark.parametrize("num_workers", range(3, 34, 3)) @pytest.mark.parametrize("traj_len", [10, 17, 50, 97]) - def test_splits(self, num_workers, traj_len): + @pytest.mark.parametrize( + "constr", + [ + functools.partial(split_trajectories, prefix="collector"), + functools.partial(split_trajectories), + functools.partial( + split_trajectories, trajectory_key=("collector", "traj_ids") + ), + ], + ) + def test_splits(self, num_workers, traj_len, constr): trajs = TestSplits.create_fake_trajs(num_workers, traj_len) assert trajs.shape[0] == num_workers assert trajs.shape[1] == traj_len - split_trajs = split_trajectories(trajs, prefix="collector") + split_trajs = constr(trajs) assert ( split_trajs.shape[0] == split_trajs.get(("collector", "traj_ids")).max() + 1 ) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index b8db47f412d..1ba98e379c6 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -2,12 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from typing import Callable import torch -from tensordict import pad, set_lazy_legacy, TensorDictBase +from tensordict import NestedKey, pad, set_lazy_legacy, TensorDictBase def _stack_output(fun) -> Callable: @@ -28,7 +29,11 @@ def stacked_output_fun(*args, **kwargs): @set_lazy_legacy(False) def split_trajectories( - rollout_tensordict: TensorDictBase, prefix=None + rollout_tensordict: TensorDictBase, + *, + prefix=None, + trajectory_key: NestedKey | None = None, + done_key: NestedKey | None = None, ) -> TensorDictBase: """A util function for trajectory separation. @@ -39,28 +44,43 @@ def split_trajectories( Args: rollout_tensordict (TensorDictBase): a rollout with adjacent trajectories along the last dimension. - prefix (str or tuple of str, optional): the prefix used to read and write meta-data, + prefix (NestedKey, optional): the prefix used to read and write meta-data, such as ``"traj_ids"`` (the optional integer id of each trajectory) and the ``"mask"`` entry indicating which data are valid and which - aren't. Defaults to ``None`` (no prefix). + aren't. Defaults to ``"collector"`` if the input has a ``"collector"`` + entry, ``()`` (no prefix) otherwise. + ``prefix`` is kept as a legacy feature and will be deprecated eventually. + Prefer ``trajectory_key`` or ``done_key`` whenever possible. + trajectory_key (NestedKey, optional): the key pointing to the trajectory + ids. Supersedes ``done_key`` and ``prefix``. If not provided, defaults + to ``(prefix, "traj_ids")``. + done_key (NestedKey, optional): the key pointing to the ``"done""`` signal, + if the trajectory could not be directly recovered. Defaults to ``"done"``. + """ - sep = ".-|-." - - if isinstance(prefix, str): - traj_ids_key = (prefix, "traj_ids") - mask_key = (prefix, "mask") - elif isinstance(prefix, tuple): - traj_ids_key = (*prefix, "traj_ids") - mask_key = (*prefix, "mask") - elif prefix is None: - traj_ids_key = "traj_ids" - mask_key = "mask" - else: - raise NotImplementedError(f"Unknown key type {type(prefix)}.") + mask_key = None + if trajectory_key is not None: + from torchrl.envs.utils import _replace_last + traj_ids_key = trajectory_key + mask_key = _replace_last(trajectory_key, "mask") + else: + if prefix is None and "collector" in rollout_tensordict.keys(): + prefix = "collector" + if prefix is None: + traj_ids_key = "traj_ids" + mask_key = "mask" + else: + traj_ids_key = (prefix, "traj_ids") + mask_key = (prefix, "mask") + + rollout_tensordict = rollout_tensordict.copy() traj_ids = rollout_tensordict.get(traj_ids_key, None) - done = rollout_tensordict.get(("next", "done")) if traj_ids is None: + if done_key is None: + done_key = "done" + done_key = ("next", done_key) + done = rollout_tensordict.get(done_key) idx = (slice(None),) * (rollout_tensordict.ndim - 1) + (slice(None, -1),) done_sel = done[idx] pads = [1, 0] @@ -91,7 +111,8 @@ def split_trajectories( ) if rollout_tensordict.ndimension() == 1: rollout_tensordict = rollout_tensordict.unsqueeze(0) - return rollout_tensordict.unflatten_keys(sep) + return rollout_tensordict + out_splits = rollout_tensordict.view(-1).split(splits, 0) for out_split in out_splits: @@ -110,5 +131,4 @@ def split_trajectories( td = torch.stack( [pad(out_split, [0, MAX - out_split.shape[0]]) for out_split in out_splits], 0 ).contiguous() - # td = td.unflatten_keys(sep) return td From 2322b5f04fd9b31bf9cc05e185bc0225083d1d00 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Feb 2024 17:57:14 -0800 Subject: [PATCH 2/2] amend --- torchrl/collectors/utils.py | 60 ++++++++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 1ba98e379c6..91460fa6df3 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -57,6 +57,64 @@ def split_trajectories( done_key (NestedKey, optional): the key pointing to the ``"done""`` signal, if the trajectory could not be directly recovered. Defaults to ``"done"``. + Returns: + A new tensordict with a leading dimension corresponding to the trajectory. + A ``"mask"`` boolean entry sharing the ``trajectory_key`` prefix + and the tensordict shape is also added. It indicated the valid elements of the tensordict, + as well as a ``"traj_ids"`` entry if ``trajectory_key`` could not be found. + + Examples: + >>> from tensordict import TensorDict + >>> import torch + >>> from torchrl.collectors.utils import split_trajectories + >>> obs = torch.cat([torch.arange(10), torch.arange(5)]) + >>> obs_ = torch.cat([torch.arange(1, 11), torch.arange(1, 6)]) + >>> done = torch.zeros(15, dtype=torch.bool) + >>> done[9] = True + >>> trajectory_id = torch.cat([torch.zeros(10, dtype=torch.int32), + ... torch.ones(5, dtype=torch.int32)]) + >>> data = TensorDict({"obs": obs, ("next", "obs"): obs_, ("next", "done"): done, "trajectory": trajectory_id}, batch_size=[15]) + >>> data_split = split_trajectories(data, done_key="done") + >>> print(data_split) + TensorDict( + fields={ + mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), + obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([2, 10]), + device=None, + is_shared=False), + obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), + traj_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), + trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)}, + batch_size=torch.Size([2, 10]), + device=None, + is_shared=False) + >>> # check that split_trajectory got the trajectories right with the done signal + >>> assert (data_split["traj_ids"] == data_split["trajectory"]).all() + >>> print(data_split["mask"]) + tensor([[ True, True, True, True, True, True, True, True, True, True], + [ True, True, True, True, True, False, False, False, False, False]]) + >>> data_split = split_trajectories(data, trajectory_key="trajectory") + >>> print(data_split) + TensorDict( + fields={ + mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), + obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([2, 10]), + device=None, + is_shared=False), + obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), + trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)}, + batch_size=torch.Size([2, 10]), + device=None, + is_shared=False) + """ mask_key = None if trajectory_key is not None: @@ -130,5 +188,5 @@ def split_trajectories( MAX = out_splits[0].shape[0] td = torch.stack( [pad(out_split, [0, MAX - out_split.shape[0]]) for out_split in out_splits], 0 - ).contiguous() + ) return td