Skip to content

Commit

Permalink
[Feature] Exclude and select transforms (#832)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 13, 2023
1 parent 4ebc764 commit 9003a56
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 20 deletions.
7 changes: 5 additions & 2 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ in the environment. The keys to be included in this inverse transform are passed
CatTensors
CenterCrop
Compose
DiscreteActionProjection
DoubleToFloat
ExcludeTransform
FiniteTensorDictCheck
FlattenObservation
FrameSkipTransform
Expand All @@ -218,19 +220,20 @@ in the environment. The keys to be included in this inverse transform are passed
ObservationNorm
ObservationTransform
PinMemoryTransform
R3MTransform
Resize
RewardClipping
RewardScaling
RewardSum
SelectTransform
SqueezeTransform
StepCounter
TensorDictPrimer
ToTensorImage
UnsqueezeTransform
VecNorm
R3MTransform
VIPTransform
VIPRewardTransform
VIPTransform

Recorders
---------
Expand Down
78 changes: 64 additions & 14 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,47 +22,49 @@
MockBatchedLockedEnv,
MockBatchedUnLockedEnv,
)
from tensordict import TensorDict
from tensordict.tensordict import TensorDict, TensorDictBase
from torch import multiprocessing as mp, Tensor
from torchrl._utils import prod
from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs import (
BinarizeReward,
CatFrames,
CatTensors,
CenterCrop,
Compose,
DiscreteActionProjection,
DoubleToFloat,
EnvBase,
EnvCreator,
ExcludeTransform,
FiniteTensorDictCheck,
FlattenObservation,
FrameSkipTransform,
GrayScale,
gSDENoise,
NoopResetEnv,
ObservationNorm,
ParallelEnv,
PinMemoryTransform,
R3MTransform,
Resize,
RewardClipping,
RewardScaling,
RewardSum,
SelectTransform,
SerialEnv,
SqueezeTransform,
StepCounter,
TensorDictPrimer,
ToTensorImage,
TransformedEnv,
UnsqueezeTransform,
VIPTransform,
)
from torchrl.envs.libs.gym import _has_gym, GymEnv
from torchrl.envs.transforms import TransformedEnv, VecNorm
from torchrl.envs.transforms import VecNorm
from torchrl.envs.transforms.r3m import _R3MNet
from torchrl.envs.transforms.transforms import (
_has_tv,
CenterCrop,
DiscreteActionProjection,
FrameSkipTransform,
gSDENoise,
NoopResetEnv,
PinMemoryTransform,
SqueezeTransform,
TensorDictPrimer,
UnsqueezeTransform,
)
from torchrl.envs.transforms.transforms import _has_tv
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
from torchrl.envs.utils import check_env_specs

Expand Down Expand Up @@ -2268,6 +2270,54 @@ def test_batch_unlocked_with_batch_size_transformed(device):
env.step(td_expanded)


class TestExcludeSelect:
class EnvWithManyKeys(EnvBase):
def __init__(self):
super().__init__()
self.observation_spec = CompositeSpec(
a=UnboundedContinuousTensorSpec(3),
b=UnboundedContinuousTensorSpec(3),
c=UnboundedContinuousTensorSpec(3),
)
self.reward_spec = UnboundedContinuousTensorSpec(1)
self.input_spec = CompositeSpec(action=UnboundedContinuousTensorSpec(2))

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
return self.observation_spec.rand().update(
{
"reward": self.reward_spec.rand(),
"done": torch.zeros(1, dtype=torch.bool),
}
)

def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
return self.observation_spec.rand().update(
{"done": torch.zeros(1, dtype=torch.bool)}
)

def _set_seed(self, seed):
return seed + 1

def test_exclude(self):
base_env = TestExcludeSelect.EnvWithManyKeys()
env = TransformedEnv(base_env, ExcludeTransform("a"))
check_env_specs(env)
assert "a" not in env.reset().keys()
assert "b" in env.reset().keys()
assert "c" in env.reset().keys()

def test_select(self):
base_env = TestExcludeSelect.EnvWithManyKeys()
env = TransformedEnv(base_env, SelectTransform("b", "c"))
check_env_specs(env)
assert "a" not in env.reset().keys()
assert "b" in env.reset().keys()
assert "c" in env.reset().keys()


transforms = [
ToTensorImage,
pytest.param(
Expand Down
6 changes: 6 additions & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
CatTensors,
CenterCrop,
Compose,
DiscreteActionProjection,
DoubleToFloat,
ExcludeTransform,
FiniteTensorDictCheck,
FlattenObservation,
FrameSkipTransform,
GrayScale,
gSDENoise,
NoopResetEnv,
Expand All @@ -27,13 +30,16 @@
RewardClipping,
RewardScaling,
RewardSum,
SelectTransform,
SqueezeTransform,
StepCounter,
TensorDictPrimer,
ToTensorImage,
Transform,
TransformedEnv,
UnsqueezeTransform,
VecNorm,
VIPRewardTransform,
VIPTransform,
)
from .vec_env import ParallelEnv, SerialEnv
3 changes: 3 additions & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
CatTensors,
CenterCrop,
Compose,
DiscreteActionProjection,
DoubleToFloat,
ExcludeTransform,
FiniteTensorDictCheck,
FlattenObservation,
FrameSkipTransform,
Expand All @@ -24,6 +26,7 @@
RewardClipping,
RewardScaling,
RewardSum,
SelectTransform,
SqueezeTransform,
StepCounter,
TensorDictPrimer,
Expand Down
98 changes: 94 additions & 4 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,13 +448,17 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = tensordict.clone(False)
tensordict_in = self.transform.inv(tensordict)
tensordict_out = self.base_env._step(tensordict_in)
tensordict_out = tensordict_out.update(
tensordict.exclude(*tensordict_out.keys())
tensordict_out = (
tensordict_out.update( # update the output with the original tensordict
tensordict.exclude(
*tensordict_out.keys()
) # exclude the newly written keys
)
)
next_tensordict = self.transform._step(tensordict_out)
tensordict_out.update(next_tensordict, inplace=False)
# tensordict_out.update(next_tensordict, inplace=False)

return tensordict_out
return next_tensordict

def set_seed(
self, seed: Optional[int] = None, static_seed: bool = False
Expand Down Expand Up @@ -2671,3 +2675,89 @@ def transform_observation_spec(
)
observation_spec["step_count"].space.minimum = 0
return observation_spec


class ExcludeTransform(Transform):
"""Excludes keys from the input tensordict.
Args:
*excluded_keys (iterable of str): The name of the keys to exclude. If the key is
not present, it is simply ignored.
"""

inplace = False

def __init__(self, *excluded_keys):
super().__init__(in_keys=[], in_keys_inv=[], out_keys=[], out_keys_inv=[])
if not all(isinstance(item, str) for item in excluded_keys):
raise ValueError("excluded_keys must be a list or tuple of strings.")
self.excluded_keys = excluded_keys
if "reward" in excluded_keys:
raise RuntimeError("'reward' cannot be excluded from the keys.")

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
return tensordict.exclude(*self.excluded_keys)

def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
return tensordict.exclude(*self.excluded_keys)

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
if any(key in observation_spec.keys() for key in self.excluded_keys):
return CompositeSpec(
**{
key: value
for key, value in observation_spec.items()
if key not in self.excluded_keys
}
)
return observation_spec


class SelectTransform(Transform):
"""Select keys from the input tensordict.
In general, the :obj:`ExcludeTransform` should be preferred: this transforms also
selects the "action" (or other keys from input_spec), "done" and "reward"
keys but other may be necessary.
Args:
*selected_keys (iterable of str): The name of the keys to select. If the key is
not present, it is simply ignored.
"""

inplace = False

def __init__(self, *selected_keys):
super().__init__(in_keys=[], in_keys_inv=[], out_keys=[], out_keys_inv=[])
if not all(isinstance(item, str) for item in selected_keys):
raise ValueError("excluded_keys must be a list or tuple of strings.")
self.selected_keys = selected_keys

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.parent:
input_keys = self.parent.input_spec.keys()
else:
input_keys = []
return tensordict.select(
*self.selected_keys, "reward", "done", *input_keys, strict=False
)

def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.parent:
input_keys = self.parent.input_spec.keys()
else:
input_keys = []
return tensordict.select(
*self.selected_keys, "reward", "done", *input_keys, strict=False
)

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
return CompositeSpec(
**{
key: value
for key, value in observation_spec.items()
if key in self.selected_keys
}
)

0 comments on commit 9003a56

Please sign in to comment.