Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
# LICENSE file in the root directory of this source tree.

import argparse
from functools import partial

import numpy as np
import pytest
import torch
from _utils_internal import get_available_devices
from tensordict.tensordict import assert_allclose_td, TensorDictBase, TensorDict
from torchrl._utils import prod
from torchrl.data import (
PrioritizedReplayBuffer,
ReplayBuffer,
Expand All @@ -28,6 +30,7 @@
ListStorage,
)
from torchrl.data.replay_buffers.writers import RoundRobinWriter
from torchrl.envs.transforms.transforms import *


collate_fn_dict = {
Expand Down Expand Up @@ -183,6 +186,101 @@ def test_index(self, rb_type, sampler, writer, storage, size):
b = b.all()
assert b

def test_transformed_replay_buffer(self, rb_type, sampler, writer, storage, size):
torch.manual_seed(0)
rb = self._get_rb(
rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size
)
td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, [])
rb.add(td)
sampled, _ = rb.sample(1)
# assert sampled.get("observation").shape[0] == 4
flatten = FlattenObservation(
-2, -1, in_keys=["observation"], out_keys=["flattened"]
)
transformed_rb = TransformedReplayBuffer(rb, flatten)

sampled, _ = transformed_rb.sample(1)
assert sampled.get("flattened").shape[-1] == 48

def test_transformed_replay_buffer_append_transform(
self, rb_type, sampler, writer, storage, size
):
rb = self._get_rb(
rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size
)
td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, [])
rb.add(td)
transformed_rb = TransformedReplayBuffer(
rb,
FlattenObservation(-2, -1, in_keys=["observation"], out_keys=["flattened"]),
)
transformed_rb.append_transform(
FlattenObservation(-2, -1, in_keys=["flattened"])
)
sampled, _ = transformed_rb.sample(1)
assert sampled.get("flattened").shape[-1] == 192

def test_transformed_replay_buffer_insert_transform(
self, rb_type, sampler, writer, storage, size
):
rb = self._get_rb(
rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size
)
td = TensorDict({"observation": torch.randn(2, 4, 1, 3, 16)}, [])
rb.add(td)
transformed_rb = TransformedReplayBuffer(
rb,
FlattenObservation(
-2, -1, in_keys=["observation"], out_keys=["transformed"]
),
)
transformed_rb.append_transform(
FlattenObservation(-2, -1, in_keys=["transformed"])
)
sampled, _ = transformed_rb.sample(1)
assert sampled.get("transformed").shape[-1] == 48
transformed_rb.insert_transform(
0, SqueezeTransform(-3, in_keys=["observation"])
)
sampled, _ = transformed_rb.sample(1)
assert sampled.get("transformed").shape[-1] == 192


transforms = [
ToTensorImage,
partial(RewardClipping, clamp_min=0.1, clamp_max=0.9),
BinarizeReward,
partial(Resize, w=2, h=2),
partial(CenterCrop, w=1),
FlattenObservation,
partial(UnsqueezeTransform, unsqueeze_dim=0), # FIXME: Assumes existence of parent, cannot call Compose
partial(SqueezeTransform, squeeze_dim=0), # FIXME: Assumes existence of parent, cannot call Compose
GrayScale,
ObservationNorm,
CatFrames,
partial(RewardScaling, loc=1, scale=2),
FiniteTensorDictCheck,
DoubleToFloat,
CatTensors,
partial(DiscreteActionProjection, max_n=1, m=1),
NoopResetEnv,
TensorDictPrimer,
PinMemoryTransform,
gSDENoise,
VecNorm,
]
@pytest.mark.parametrize(
"transform", transforms
)
def test_transformed_replay_buffer_smoke(transform):
rb = ReplayBuffer()
td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, [])
rb.add(td)
transformed_rb = TransformedReplayBuffer(rb, transform())
transformed_rb.append_transform(transform())
transformed_rb.insert_transform(0, transform())


@pytest.mark.parametrize("priority_key", ["pk", "td_error"])
@pytest.mark.parametrize("contiguous", [True, False])
Expand Down
58 changes: 45 additions & 13 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from tensordict.tensordict import TensorDictBase, TensorDict
from torch import nn, Tensor
from torchrl.data.replay_buffers.rb_prototype import ReplayBuffer

from torchrl.data.tensor_specs import (
BoundedTensorSpec,
Expand Down Expand Up @@ -213,7 +214,7 @@ def dump(self, **kwargs) -> None:
def __repr__(self) -> str:
return f"{self.__class__.__name__}(keys={self.in_keys})"

def set_parent(self, parent: Union[Transform, EnvBase]) -> None:
def set_parent(self, parent) -> None:
if self.__dict__["_parent"] is not None:
raise AttributeError("parent of transform already set")
self.__dict__["_parent"] = parent
Expand All @@ -228,7 +229,7 @@ def parent(self) -> EnvBase:
parent = self._parent
if parent is None:
return parent
if not isinstance(parent, EnvBase):
if not isinstance(parent, EnvBase) and not isinstance(parent, TransformedReplayBuffer):
# if it's not an env, it should be a Compose transform
if not isinstance(parent, Compose):
raise ValueError(
Expand Down Expand Up @@ -256,6 +257,8 @@ def parent(self) -> EnvBase:
transform = copy(orig_trans)
transform.reset_parent()
out.append_transform(transform)
elif isinstance(parent, TransformedReplayBuffer):
return parent
elif isinstance(parent, TransformedEnv):
out = TransformedEnv(parent.base_env)
else:
Expand All @@ -267,6 +270,32 @@ def empty_cache(self):
self.parent.empty_cache()


class TransformedReplayBuffer:

def __init__(
self,
rb: ReplayBuffer,
transform: Optional[Transform]
):
self.rb = rb
if isinstance(transform, Compose):
self.transform = transform
else:
self.transform = Compose(transform)
transform._parent = self

def sample(self, batch_size: int) -> Tuple[Any, dict]:
data, info = self.rb.sample(batch_size)
self.transform(data)
return data, info

def append_transform(self, transform: Transform):
self.transform.append(transform)

def insert_transform(self, index: int, transform: Transform):
self.transform.insert(index, transform)


class TransformedEnv(EnvBase):
"""A transformed_in environment.

Expand Down Expand Up @@ -1029,10 +1058,11 @@ def __init__(
first_dim: int = 0,
last_dim: int = -3,
in_keys: Optional[Sequence[str]] = None,
out_keys: Optional[Sequence[str]] = None,
):
if in_keys is None:
in_keys = IMAGE_KEYS # default
super().__init__(in_keys=in_keys)
super().__init__(in_keys=in_keys, out_keys=out_keys)
self.first_dim = first_dim
self.last_dim = last_dim

Expand All @@ -1042,15 +1072,17 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:

def set_parent(self, parent: Union[Transform, EnvBase]) -> None:
out = super().set_parent(parent)
observation_spec = self.parent.observation_spec
for key in self.in_keys:
if key in observation_spec:
observation_spec = observation_spec[key]
if self.first_dim >= 0:
self.first_dim = self.first_dim - len(observation_spec.shape)
if self.last_dim >= 0:
self.last_dim = self.last_dim - len(observation_spec.shape)
break
# FIXME: Cannot run Compose(FlattenObservation)
if isinstance(parent, EnvBase):
observation_spec = self.parent.observation_spec
for key in self.in_keys:
if key in observation_spec:
observation_spec = observation_spec[key]
if self.first_dim >= 0:
self.first_dim = self.first_dim - len(observation_spec.shape)
if self.last_dim >= 0:
self.last_dim = self.last_dim - len(observation_spec.shape)
break
return out

@_apply_to_composite
Expand Down Expand Up @@ -1111,7 +1143,7 @@ def set_parent(self, parent: Union[Transform, EnvBase]) -> None:
if self._unsqueeze_dim_orig < 0:
self._unsqueeze_dim = self._unsqueeze_dim_orig
else:
parent = self.parent
parent = self.parent # FIXME: Assumes parent exists and has batch_size
batch_size = parent.batch_size
self._unsqueeze_dim = self._unsqueeze_dim_orig + len(batch_size)
return super().set_parent(parent)
Expand Down