Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7711,10 +7711,10 @@ def _test_vecnorm_subproc_auto(
queue_out.put(True)
msg = queue_in.get(timeout=TIMEOUT)
assert msg == "all_done"
t = env.transform
obs_sum = t._td.get("observation_sum").clone()
obs_ssq = t._td.get("observation_ssq").clone()
obs_count = t._td.get("observation_count").clone()
t = env.transform[1]
obs_sum = t._td.get(("some", "obs_sum")).clone()
obs_ssq = t._td.get(("some", "obs_ssq")).clone()
obs_count = t._td.get(("some", "obs_count")).clone()
reward_sum = t._td.get("reward_sum").clone()
reward_ssq = t._td.get("reward_ssq").clone()
reward_count = t._td.get("reward_count").clone()
Expand All @@ -7729,18 +7729,34 @@ def _test_vecnorm_subproc_auto(
queue_in.close()
del queue_in, queue_out

@property
def rename_t(self):
return RenameTransform(in_keys=["observation"], out_keys=[("some", "obs")])

@pytest.mark.parametrize("nprc", [2, 5])
def test_vecnorm_parallel_auto(self, nprc):
queues = []
prcs = []
if _has_gym:
make_env = EnvCreator(
lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED()), VecNorm(decay=1.0))
maker = lambda: TransformedEnv(
GymEnv(PENDULUM_VERSIONED()),
Compose(
self.rename_t,
VecNorm(decay=1.0, in_keys=[("some", "obs"), "reward"]),
),
)
check_env_specs(maker())
make_env = EnvCreator(maker)
else:
make_env = EnvCreator(
lambda: TransformedEnv(ContinuousActionVecMockEnv(), VecNorm(decay=1.0))
maker = lambda: TransformedEnv(
ContinuousActionVecMockEnv(),
Compose(
self.rename_t,
VecNorm(decay=1.0, in_keys=[("some", "obs"), "reward"]),
),
)
check_env_specs(maker())
make_env = EnvCreator(maker)

for idx in range(nprc):
prc_queue_in = mp.Queue(1)
Expand All @@ -7764,11 +7780,11 @@ def test_vecnorm_parallel_auto(self, nprc):
for idx in range(nprc):
queues[idx][1].put(msg)

td = make_env.state_dict()["_extra_state"]["td"]
td = make_env.state_dict()["transforms.1._extra_state"]["td"]

obs_sum = td.get("observation_sum").clone()
obs_ssq = td.get("observation_ssq").clone()
obs_count = td.get("observation_count").clone()
obs_sum = td.get(("some", "obs_sum")).clone()
obs_ssq = td.get(("some", "obs_ssq")).clone()
obs_count = td.get(("some", "obs_count")).clone()
reward_sum = td.get("reward_sum").clone()
reward_ssq = td.get("reward_ssq").clone()
reward_count = td.get("reward_count").clone()
Expand Down Expand Up @@ -7836,11 +7852,21 @@ def _run_parallelenv(parallel_env, queue_in, queue_out):
def test_parallelenv_vecnorm(self):
if _has_gym:
make_env = EnvCreator(
lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED()), VecNorm())
lambda: TransformedEnv(
GymEnv(PENDULUM_VERSIONED()),
Compose(
self.rename_t, VecNorm(in_keys=[("some", "obs"), "reward"])
),
)
)
else:
make_env = EnvCreator(
lambda: TransformedEnv(ContinuousActionVecMockEnv(), VecNorm())
lambda: TransformedEnv(
ContinuousActionVecMockEnv(),
Compose(
self.rename_t, VecNorm(in_keys=[("some", "obs"), "reward"])
),
)
)
parallel_env = ParallelEnv(
2,
Expand Down
9 changes: 9 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy as np
import torch
from packaging.version import parse
from tensordict import unravel_key

from tensordict.utils import NestedKey
from torch import multiprocessing as mp
Expand Down Expand Up @@ -719,6 +720,14 @@ def _replace_last(key: NestedKey, new_ending: str) -> NestedKey:
return key[:-1] + (new_ending,)


def _append_last(key: NestedKey, new_suffix: str) -> NestedKey:
key = unravel_key(key)
if isinstance(key, str):
return key + new_suffix
else:
return key[:-1] + (key[-1] + new_suffix,)


class _rng_decorator(_DecoratorContextManager):
"""Temporarily sets the seed and sets back the rng state when exiting."""

Expand Down
99 changes: 61 additions & 38 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from torch import nn, Tensor
from torch.utils._pytree import tree_map

from torchrl._utils import _ends_with, _replace_last
from torchrl._utils import _append_last, _ends_with, _replace_last

from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec,
Expand Down Expand Up @@ -4854,9 +4854,9 @@ def __init__(
if shared_td is not None:
for key in in_keys:
if (
(key + "_sum" not in shared_td.keys())
or (key + "_ssq" not in shared_td.keys())
or (key + "_count" not in shared_td.keys())
(_append_last(key, "_sum") not in shared_td.keys())
or (_append_last(key, "_ssq") not in shared_td.keys())
or (_append_last(key, "_count") not in shared_td.keys())
):
raise KeyError(
f"key {key} not present in the shared tensordict "
Expand All @@ -4868,16 +4868,12 @@ def __init__(
self.shapes = shapes
self.eps = eps

def _key_str(self, key):
if not isinstance(key, str):
key = "_".join(key)
return key

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
# TODO: remove this decorator when trackers are in data
with _set_missing_tolerance(self, True):
tensordict_reset = self._call(tensordict_reset)
return self._call(tensordict_reset)
return tensordict_reset

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
Expand All @@ -4886,6 +4882,9 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:

for key in self.in_keys:
if key not in tensordict.keys(include_nested=True):
# TODO: init missing rewards with this
# for key_suffix in [_append_last(key, suffix) for suffix in ("_sum", "_ssq", "_count")]:
# tensordict.set(key_suffix, self.container.observation_spec[key_suffix].zero())
continue
self._init(tensordict, key)
# update and standardize
Expand All @@ -4903,18 +4902,17 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
forward = _call

def _init(self, tensordict: TensorDictBase, key: str) -> None:
key_str = self._key_str(key)
if self._td is None or key_str + "_sum" not in self._td.keys():
if key is not key_str and key_str in tensordict.keys():
if self._td is None or _append_last(key, "_sum") not in self._td.keys(True):
if key is not key and key in tensordict.keys():
raise RuntimeError(
f"Conflicting key names: {key_str} from VecNorm and input tensordict keys."
f"Conflicting key names: {key} from VecNorm and input tensordict keys."
)
if self.shapes is None:
td_view = tensordict.view(-1)
td_select = td_view[0]
item = td_select.get(key)
d = {key_str + "_sum": torch.zeros_like(item)}
d.update({key_str + "_ssq": torch.zeros_like(item)})
d = {_append_last(key, "_sum"): torch.zeros_like(item)}
d.update({_append_last(key, "_ssq"): torch.zeros_like(item)})
else:
idx = 0
for in_key in self.in_keys:
Expand All @@ -4925,22 +4923,23 @@ def _init(self, tensordict: TensorDictBase, key: str) -> None:
shape = self.shapes[idx]
item = tensordict.get(key)
d = {
key_str
+ "_sum": torch.zeros(shape, device=item.device, dtype=item.dtype)
_append_last(key, "_sum"): torch.zeros(
shape, device=item.device, dtype=item.dtype
)
}
d.update(
{
key_str
+ "_ssq": torch.zeros(
_append_last(key, "_ssq"): torch.zeros(
shape, device=item.device, dtype=item.dtype
)
}
)

d.update(
{
key_str
+ "_count": torch.zeros(1, device=item.device, dtype=torch.float)
_append_last(key, "_count"): torch.zeros(
1, device=item.device, dtype=torch.float
)
}
)
if self._td is None:
Expand All @@ -4951,34 +4950,32 @@ def _init(self, tensordict: TensorDictBase, key: str) -> None:
pass

def _update(self, key, value, N) -> torch.Tensor:
key = self._key_str(key)
_sum = self._td.get(key + "_sum")
_ssq = self._td.get(key + "_ssq")
_count = self._td.get(key + "_count")
_sum = self._td.get(_append_last(key, "_sum"))
_ssq = self._td.get(_append_last(key, "_ssq"))
_count = self._td.get(_append_last(key, "_count"))

_sum = self._td.get(key + "_sum")
value_sum = _sum_left(value, _sum)
_sum *= self.decay
_sum += value_sum
self._td.set_(
key + "_sum",
_append_last(key, "_sum"),
_sum,
)

_ssq = self._td.get(key + "_ssq")
_ssq = self._td.get(_append_last(key, "_ssq"))
value_ssq = _sum_left(value.pow(2), _ssq)
_ssq *= self.decay
_ssq += value_ssq
self._td.set_(
key + "_ssq",
_append_last(key, "_ssq"),
_ssq,
)

_count = self._td.get(key + "_count")
_count = self._td.get(_append_last(key, "_count"))
_count *= self.decay
_count += N
self._td.set_(
key + "_count",
_append_last(key, "_count"),
_count,
)

Expand All @@ -4990,9 +4987,9 @@ def to_observation_norm(self) -> Union[Compose, ObservationNorm]:
"""Converts VecNorm into an ObservationNorm class that can be used at inference time."""
out = []
for key in self.in_keys:
_sum = self._td.get(key + "_sum")
_ssq = self._td.get(key + "_ssq")
_count = self._td.get(key + "_count")
_sum = self._td.get(_append_last(key, "_sum"))
_ssq = self._td.get(_append_last(key, "_ssq"))
_count = self._td.get(_append_last(key, "_count"))
mean = _sum / _count
std = (_ssq / _count - mean.pow(2)).clamp_min(self.eps).sqrt()

Expand Down Expand Up @@ -5056,17 +5053,17 @@ def build_td_for_shared_vecnorm(
)
keys = list(td_select.keys())
for key in keys:
td_select.set(key + "_ssq", td_select.get(key).clone())
td_select.set(_append_last(key, "_ssq"), td_select.get(key).clone())
td_select.set(
key + "_count",
_append_last(key, "_count"),
torch.zeros(
*td.batch_size,
1,
device=td_select.device,
dtype=torch.float,
),
)
td_select.rename_key_(key, key + "_sum")
td_select.rename_key_(key, _append_last(key, "_sum"))
td_select.exclude(*keys).zero_()
td_select = td_select.unflatten_keys(sep)
if memmap:
Expand Down Expand Up @@ -5112,6 +5109,32 @@ def __setstate__(self, state: Dict[str, Any]):
state["lock"] = _lock
self.__dict__.update(state)

@_apply_to_composite
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
if isinstance(observation_spec, BoundedTensorSpec):
return UnboundedContinuousTensorSpec(
shape=observation_spec.shape,
dtype=observation_spec.dtype,
device=observation_spec.device,
)
return observation_spec

# TODO: incorporate this when trackers are part of the data
# def transform_output_spec(self, output_spec: TensorSpec) -> TensorSpec:
# observation_spec = output_spec["full_observation_spec"]
# reward_spec = output_spec["full_reward_spec"]
# for key in list(observation_spec.keys(True, True)):
# if key in self.in_keys:
# observation_spec[_append_last(key, "_sum")] = observation_spec[key].clone()
# observation_spec[_append_last(key, "_ssq")] = observation_spec[key].clone()
# observation_spec[_append_last(key, "_count")] = observation_spec[key].clone()
# for key in list(reward_spec.keys(True, True)):
# if key in self.in_keys:
# observation_spec[_append_last(key, "_sum")] = reward_spec[key].clone()
# observation_spec[_append_last(key, "_ssq")] = reward_spec[key].clone()
# observation_spec[_append_last(key, "_count")] = reward_spec[key].clone()
# return output_spec


class RewardSum(Transform):
"""Tracks episode cumulative rewards.
Expand Down