Skip to content
2 changes: 1 addition & 1 deletion examples/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def make_env_transforms(
env.append_transform(Resize(cfg.image_size, cfg.image_size))
if cfg.grayscale:
env.append_transform(GrayScale())
env.append_transform(FlattenObservation())
env.append_transform(FlattenObservation(0))
env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"]))
if stats is None:
obs_stats = {"loc": 0.0, "scale": 1.0}
Expand Down
153 changes: 146 additions & 7 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# LICENSE file in the root directory of this source tree.

import argparse
import importlib
from functools import partial
from unittest import mock

import numpy as np
import pytest
Expand All @@ -28,14 +31,29 @@
ListStorage,
)
from torchrl.data.replay_buffers.writers import RoundRobinWriter
from torchrl.envs.transforms.transforms import (
CatTensors,
FlattenObservation,
SqueezeTransform,
ToTensorImage,
RewardClipping,
BinarizeReward,
Resize,
CenterCrop,
UnsqueezeTransform,
GrayScale,
ObservationNorm,
CatFrames,
RewardScaling,
DoubleToFloat,
VecNorm,
DiscreteActionProjection,
FiniteTensorDictCheck,
gSDENoise,
PinMemoryTransform,
)


# collate_fn_dict = {
# ListStorage: lambda x: torch.stack(x, 0),
# LazyTensorStorage: lambda x: x,
# LazyMemmapStorage: lambda x: x,
# None: lambda x: torch.stack(x, 0),
# }
_has_tv = importlib.util.find_spec("torchvision") is not None


@pytest.mark.parametrize(
Expand Down Expand Up @@ -594,6 +612,127 @@ def test_legacy_rb_does_not_attach():
assert rb not in storage._attached_entities


def test_append_transform():
rb = rb_prototype.ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0))
td = TensorDict(
{
"observation": torch.randn(2, 4, 3, 16),
"observation2": torch.randn(2, 4, 3, 16),
},
[],
)
rb.add(td)
flatten = CatTensors(
in_keys=["observation", "observation2"], out_key="observation_cat"
)

rb.append_transform(flatten)

sampled, _ = rb.sample(1)
assert sampled.get("observation_cat").shape[-1] == 32


def test_init_transform():
flatten = FlattenObservation(
-2, -1, in_keys=["observation"], out_keys=["flattened"]
)

rb = rb_prototype.ReplayBuffer(
collate_fn=lambda x: torch.stack(x, 0), transform=flatten
)

td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, [])
rb.add(td)
sampled, _ = rb.sample(1)
assert sampled.get("flattened").shape[-1] == 48


def test_insert_transform():
flatten = FlattenObservation(
-2, -1, in_keys=["observation"], out_keys=["flattened"]
)
rb = rb_prototype.ReplayBuffer(
collate_fn=lambda x: torch.stack(x, 0), transform=flatten
)
td = TensorDict({"observation": torch.randn(2, 4, 3, 16, 1)}, [])
rb.add(td)

rb.insert_transform(0, SqueezeTransform(-1, in_keys=["observation"]))

sampled, _ = rb.sample(1)
assert sampled.get("flattened").shape[-1] == 48

with pytest.raises(ValueError):
rb.insert_transform(10, SqueezeTransform(-1, in_keys=["observation"]))


transforms = [
ToTensorImage,
pytest.param(
partial(RewardClipping, clamp_min=0.1, clamp_max=0.9), id="RewardClipping"
),
BinarizeReward,
pytest.param(
partial(Resize, w=2, h=2),
id="Resize",
marks=pytest.mark.skipif(not _has_tv, reason="needs torchvision dependency"),
),
pytest.param(
partial(CenterCrop, w=1),
id="CenterCrop",
marks=pytest.mark.skipif(not _has_tv, reason="needs torchvision dependency"),
),
pytest.param(
partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform"
),
pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"),
GrayScale,
pytest.param(partial(ObservationNorm, loc=1, scale=2), id="ObservationNorm"),
CatFrames,
pytest.param(partial(RewardScaling, loc=1, scale=2), id="RewardScaling"),
DoubleToFloat,
VecNorm,
]


@pytest.mark.parametrize("transform", transforms)
def test_smoke_replay_buffer_transform(transform):
rb = rb_prototype.ReplayBuffer(
transform=transform(in_keys="observation"),
)

td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, [])
rb.add(td)
rb.sample(1)

rb._transform = mock.MagicMock()
rb.sample(1)
assert rb._transform.called


transforms = [
partial(DiscreteActionProjection, max_n=1, m=1),
FiniteTensorDictCheck,
gSDENoise,
PinMemoryTransform,
]


@pytest.mark.parametrize("transform", transforms)
def test_smoke_replay_buffer_transform_no_inkeys(transform):
rb = rb_prototype.ReplayBuffer(
collate_fn=lambda x: torch.stack(x, 0), transform=transform()
)

td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, [])
rb.add(td)
rb.sample(1)

rb._transform = mock.MagicMock()
rb.sample(1)
assert rb._transform.called


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
47 changes: 47 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import argparse
from copy import copy, deepcopy
from functools import partial

import numpy as np
import pytest
Expand Down Expand Up @@ -54,13 +55,15 @@
from torchrl.envs.transforms import TransformedEnv, VecNorm
from torchrl.envs.transforms.r3m import _R3MNet
from torchrl.envs.transforms.transforms import (
DiscreteActionProjection,
_has_tv,
CenterCrop,
NoopResetEnv,
PinMemoryTransform,
SqueezeTransform,
TensorDictPrimer,
UnsqueezeTransform,
gSDENoise,
)
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform

Expand Down Expand Up @@ -1975,6 +1978,50 @@ def test_batch_unlocked_with_batch_size_transformed(device):
env.step(td_expanded)


transforms = [
ToTensorImage,
pytest.param(
partial(RewardClipping, clamp_min=0.1, clamp_max=0.9), id="RewardClipping"
),
BinarizeReward,
pytest.param(
partial(Resize, w=2, h=2),
id="Resize",
marks=pytest.mark.skipif(not _has_tv, reason="needs torchvision dependency"),
),
pytest.param(
partial(CenterCrop, w=1),
id="CenterCrop",
marks=pytest.mark.skipif(not _has_tv, reason="needs torchvision dependency"),
),
pytest.param(partial(FlattenObservation, first_dim=-3), id="FlattenObservation"),
pytest.param(
partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform"
),
pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"),
GrayScale,
ObservationNorm,
CatFrames,
pytest.param(partial(RewardScaling, loc=1, scale=2), id="RewardScaling"),
FiniteTensorDictCheck,
DoubleToFloat,
CatTensors,
pytest.param(
partial(DiscreteActionProjection, max_n=1, m=1), id="DiscreteActionProjection"
),
NoopResetEnv,
TensorDictPrimer,
PinMemoryTransform,
gSDENoise,
VecNorm,
]


@pytest.mark.parametrize("transform", transforms)
def test_smoke_compose_transform(transform):
Compose(transform())


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
34 changes: 34 additions & 0 deletions torchrl/data/replay_buffers/rb_prototype.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from tensordict.tensordict import TensorDictBase, LazyStackedTensorDict

from torchrl.envs.transforms.transforms import Compose, Transform
from .replay_buffers import pin_memory_output
from .samplers import Sampler, RandomSampler
from .storages import Storage, ListStorage, _get_default_collate
Expand All @@ -30,6 +31,8 @@ class ReplayBuffer:
samples.
prefetch (int, optional): number of next batches to be prefetched
using multithreading.
transform (Transform, optional): Transform to be executed when sample() is called.
To chain transforms use the :obj:`Compose` class.
"""

def __init__(
Expand All @@ -40,6 +43,7 @@ def __init__(
collate_fn: Optional[Callable] = None,
pin_memory: bool = False,
prefetch: Optional[int] = None,
transform: Optional[Transform] = None,
) -> None:
self._storage = storage if storage is not None else ListStorage(max_size=1_000)
self._storage.attach(self)
Expand All @@ -62,6 +66,12 @@ def __init__(

self._replay_lock = threading.RLock()
self._futures_lock = threading.RLock()
if transform is None:
transform = Compose()
elif not isinstance(transform, Compose):
transform = Compose(transform)
transform.eval()
self._transform = transform

def __len__(self) -> int:
with self._replay_lock:
Expand Down Expand Up @@ -131,6 +141,7 @@ def _sample(self, batch_size: int) -> Tuple[Any, dict]:
data = self._storage[index]
if not isinstance(index, INT_CLASSES):
data = self._collate_fn(data)
data = self._transform(data)
return data, info

def sample(self, batch_size: int) -> Tuple[Any, dict]:
Expand Down Expand Up @@ -163,6 +174,29 @@ def sample(self, batch_size: int) -> Tuple[Any, dict]:
def mark_update(self, index: Union[int, torch.Tensor]) -> None:
self._sampler.mark_update(index)

def append_transform(self, transform: Transform) -> None:
"""Appends transform at the end.

Transforms are applied in order when `sample` is called.

Args:
transform (Transform): The transform to be appended
"""
transform.eval()
self._transform.append(transform)

def insert_transform(self, index: int, transform: Transform) -> None:
"""Inserts transform.

Transforms are executed in order when `sample` is called.

Args:
index (int): Position to insert the transform.
transform (Transform): The transform to be appended
"""
transform.eval()
self._transform.insert(index, transform)


class TensorDictReplayBuffer(ReplayBuffer):
"""TensorDict-specific wrapper around the ReplayBuffer class.
Expand Down
Loading