Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import argparse
import functools
import gc

import sys
Expand Down Expand Up @@ -591,7 +592,15 @@ def env_fn(seed):

@pytest.mark.skipif(not _has_gym, reason="gym library is not installed")
@pytest.mark.parametrize("parallel", [False, True])
def test_collector_env_reset(parallel):
@pytest.mark.parametrize(
"constr",
[
functools.partial(split_trajectories, prefix="collector"),
functools.partial(split_trajectories),
functools.partial(split_trajectories, trajectory_key=("collector", "traj_ids")),
],
)
def test_collector_env_reset(constr, parallel):
torch.manual_seed(0)

def make_env():
Expand Down Expand Up @@ -627,7 +636,7 @@ def make_env():
# check that if step is 1, then the env was done before
assert (steps == 1)[done].all()
# check that split traj has a minimum total reward of -21 (for pong only)
_data = split_trajectories(_data, prefix="collector")
_data = constr(_data)
assert _data["next", "reward"].sum(-2).min() == -21
finally:
env.close()
Expand Down
15 changes: 13 additions & 2 deletions test/test_postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import functools

import pytest
import torch
Expand Down Expand Up @@ -276,12 +277,22 @@ def create_fake_trajs(

@pytest.mark.parametrize("num_workers", range(3, 34, 3))
@pytest.mark.parametrize("traj_len", [10, 17, 50, 97])
def test_splits(self, num_workers, traj_len):
@pytest.mark.parametrize(
"constr",
[
functools.partial(split_trajectories, prefix="collector"),
functools.partial(split_trajectories),
functools.partial(
split_trajectories, trajectory_key=("collector", "traj_ids")
),
],
)
def test_splits(self, num_workers, traj_len, constr):

trajs = TestSplits.create_fake_trajs(num_workers, traj_len)
assert trajs.shape[0] == num_workers
assert trajs.shape[1] == traj_len
split_trajs = split_trajectories(trajs, prefix="collector")
split_trajs = constr(trajs)
assert (
split_trajs.shape[0] == split_trajs.get(("collector", "traj_ids")).max() + 1
)
Expand Down
120 changes: 99 additions & 21 deletions torchrl/collectors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from typing import Callable

import torch

from tensordict import pad, set_lazy_legacy, TensorDictBase
from tensordict import NestedKey, pad, set_lazy_legacy, TensorDictBase


def _stack_output(fun) -> Callable:
Expand All @@ -28,7 +29,11 @@ def stacked_output_fun(*args, **kwargs):

@set_lazy_legacy(False)
def split_trajectories(
rollout_tensordict: TensorDictBase, prefix=None
rollout_tensordict: TensorDictBase,
*,
prefix=None,
trajectory_key: NestedKey | None = None,
done_key: NestedKey | None = None,
) -> TensorDictBase:
"""A util function for trajectory separation.

Expand All @@ -39,28 +44,101 @@ def split_trajectories(
Args:
rollout_tensordict (TensorDictBase): a rollout with adjacent trajectories
along the last dimension.
prefix (str or tuple of str, optional): the prefix used to read and write meta-data,
prefix (NestedKey, optional): the prefix used to read and write meta-data,
such as ``"traj_ids"`` (the optional integer id of each trajectory)
and the ``"mask"`` entry indicating which data are valid and which
aren't. Defaults to ``None`` (no prefix).
aren't. Defaults to ``"collector"`` if the input has a ``"collector"``
entry, ``()`` (no prefix) otherwise.
``prefix`` is kept as a legacy feature and will be deprecated eventually.
Prefer ``trajectory_key`` or ``done_key`` whenever possible.
trajectory_key (NestedKey, optional): the key pointing to the trajectory
ids. Supersedes ``done_key`` and ``prefix``. If not provided, defaults
to ``(prefix, "traj_ids")``.
done_key (NestedKey, optional): the key pointing to the ``"done""`` signal,
if the trajectory could not be directly recovered. Defaults to ``"done"``.

Returns:
A new tensordict with a leading dimension corresponding to the trajectory.
A ``"mask"`` boolean entry sharing the ``trajectory_key`` prefix
and the tensordict shape is also added. It indicated the valid elements of the tensordict,
as well as a ``"traj_ids"`` entry if ``trajectory_key`` could not be found.

Examples:
>>> from tensordict import TensorDict
>>> import torch
>>> from torchrl.collectors.utils import split_trajectories
>>> obs = torch.cat([torch.arange(10), torch.arange(5)])
>>> obs_ = torch.cat([torch.arange(1, 11), torch.arange(1, 6)])
>>> done = torch.zeros(15, dtype=torch.bool)
>>> done[9] = True
>>> trajectory_id = torch.cat([torch.zeros(10, dtype=torch.int32),
... torch.ones(5, dtype=torch.int32)])
>>> data = TensorDict({"obs": obs, ("next", "obs"): obs_, ("next", "done"): done, "trajectory": trajectory_id}, batch_size=[15])
>>> data_split = split_trajectories(data, done_key="done")
>>> print(data_split)
TensorDict(
fields={
mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([2, 10]),
device=None,
is_shared=False),
obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
traj_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)},
batch_size=torch.Size([2, 10]),
device=None,
is_shared=False)
>>> # check that split_trajectory got the trajectories right with the done signal
>>> assert (data_split["traj_ids"] == data_split["trajectory"]).all()
>>> print(data_split["mask"])
tensor([[ True, True, True, True, True, True, True, True, True, True],
[ True, True, True, True, True, False, False, False, False, False]])
>>> data_split = split_trajectories(data, trajectory_key="trajectory")
>>> print(data_split)
TensorDict(
fields={
mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([2, 10]),
device=None,
is_shared=False),
obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)},
batch_size=torch.Size([2, 10]),
device=None,
is_shared=False)

"""
sep = ".-|-."

if isinstance(prefix, str):
traj_ids_key = (prefix, "traj_ids")
mask_key = (prefix, "mask")
elif isinstance(prefix, tuple):
traj_ids_key = (*prefix, "traj_ids")
mask_key = (*prefix, "mask")
elif prefix is None:
traj_ids_key = "traj_ids"
mask_key = "mask"
else:
raise NotImplementedError(f"Unknown key type {type(prefix)}.")
mask_key = None
if trajectory_key is not None:
from torchrl.envs.utils import _replace_last

traj_ids_key = trajectory_key
mask_key = _replace_last(trajectory_key, "mask")
else:
if prefix is None and "collector" in rollout_tensordict.keys():
prefix = "collector"
if prefix is None:
traj_ids_key = "traj_ids"
mask_key = "mask"
else:
traj_ids_key = (prefix, "traj_ids")
mask_key = (prefix, "mask")

rollout_tensordict = rollout_tensordict.copy()
traj_ids = rollout_tensordict.get(traj_ids_key, None)
done = rollout_tensordict.get(("next", "done"))
if traj_ids is None:
if done_key is None:
done_key = "done"
done_key = ("next", done_key)
done = rollout_tensordict.get(done_key)
idx = (slice(None),) * (rollout_tensordict.ndim - 1) + (slice(None, -1),)
done_sel = done[idx]
pads = [1, 0]
Expand Down Expand Up @@ -91,7 +169,8 @@ def split_trajectories(
)
if rollout_tensordict.ndimension() == 1:
rollout_tensordict = rollout_tensordict.unsqueeze(0)
return rollout_tensordict.unflatten_keys(sep)
return rollout_tensordict

out_splits = rollout_tensordict.view(-1).split(splits, 0)

for out_split in out_splits:
Expand All @@ -109,6 +188,5 @@ def split_trajectories(
MAX = out_splits[0].shape[0]
td = torch.stack(
[pad(out_split, [0, MAX - out_split.shape[0]]) for out_split in out_splits], 0
).contiguous()
# td = td.unflatten_keys(sep)
)
return td