Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix async gym env with non-sync resets #2170

Merged
merged 8 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,11 +1726,13 @@ def test_maxframes_error():
@pytest.mark.parametrize("env_device", [None, *get_available_devices()])
@pytest.mark.parametrize("storing_device", [None, *get_available_devices()])
@pytest.mark.parametrize("parallel", [False, True])
@pytest.mark.parametrize("share_individual_td", [False, True])
def test_reset_heterogeneous_envs(
policy_device: torch.device,
env_device: torch.device,
storing_device: torch.device,
parallel,
share_individual_td,
):
if (
policy_device is not None
Expand All @@ -1746,7 +1748,9 @@ def test_reset_heterogeneous_envs(
cls = ParallelEnv
else:
cls = SerialEnv
env = cls(2, [env1, env2], device=env_device, share_individual_td=True)
env = cls(
2, [env1, env2], device=env_device, share_individual_td=share_individual_td
)
collector = SyncDataCollector(
env,
RandomPolicy(env.action_spec),
Expand Down
14 changes: 7 additions & 7 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1894,13 +1894,13 @@ def test_auto_register(self, device, maybe_fork_ParallelEnv):
except ModuleNotFoundError:
import gym

env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED()), device=device)
check_env_specs(env)
env.set_info_dict_reader()
with pytest.raises(
AssertionError, match="The keys of the specs and data do not match"
):
check_env_specs(env)
# env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED()), device=device)
# check_env_specs(env)
# env.set_info_dict_reader()
# with pytest.raises(
# AssertionError, match="The keys of the specs and data do not match"
# ):
# check_env_specs(env)

env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED()), device=device)
env = env.auto_register_info_dict()
Expand Down
2 changes: 2 additions & 0 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2743,6 +2743,8 @@ def step_and_maybe_reset(
device=cpu,
is_shared=False)
"""
if tensordict.device != self.device:
tensordict = tensordict.to(self.device)
tensordict = self.step(tensordict)
# done and truncated are in done_keys
# We read if any key is done.
Expand Down
78 changes: 60 additions & 18 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.common import _EnvWrapper, EnvBase


class BaseInfoDictReader(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(
def __call__(
self, info_dict: Dict[str, Any], tensordict: TensorDictBase
) -> TensorDictBase:
if not isinstance(info_dict, dict) and len(self.keys):
if not isinstance(info_dict, (dict, TensorDictBase)) and len(self.keys):
warnings.warn(
f"Found an info_dict of type {type(info_dict)} "
f"but expected type or subtype `dict`."
Expand All @@ -124,19 +124,19 @@ def __call__(
info_spec = None if self.info_spec is not None else CompositeSpec()
for key in keys:
if key in info_dict:
if info_dict[key].dtype == np.dtype("O"):
val = np.stack(info_dict[key])
else:
val = info_dict[key]
val = info_dict[key]
if val.dtype == np.dtype("O"):
val = np.stack(val)
tensordict.set(key, val)
if info_spec is not None:
val = tensordict.get(key)
info_spec[key] = UnboundedContinuousTensorSpec(
val.shape, device=val.device, dtype=val.dtype
)
elif self.info_spec is not None:
# Fill missing with 0s
tensordict.set(key, self.info_spec[key].zero())
if key in self.info_spec:
# Fill missing with 0s
tensordict.set(key, self.info_spec[key].zero())
else:
raise KeyError(f"The key {key} could not be found or inferred.")
# set the info spec if there wasn't any - this should occur only once in this class
Expand Down Expand Up @@ -497,39 +497,81 @@ def set_info_dict_reader(
self.rand_step()
self.reset()

for info_key, spec in info_dict_reader.info_spec.items():
self.observation_spec[info_key] = spec.to(self.device)
self.observation_spec.update(info_dict_reader.info_spec)

return self

def auto_register_info_dict(self, ignore_private: bool = True):
"""Automatically registers the info dict.
def auto_register_info_dict(
self,
ignore_private: bool = True,
*,
info_dict_reader: BaseInfoDictReader = None,
) -> EnvBase:
"""Automatically registers the info dict and appends :class:`~torch.envs.transforms.TensorDictPrimer` instances if needed.

It is assumed that all the information contained in the info dict can be registered as numerical values
within the tensordict.
If no info_dict_reader is provided, it is assumed that all the information contained in the info dict can
be registered as numerical values within the tensordict.

This method returns a (possibly transformed) environment where we make sure that
the :func:`torchrl.envs.utils.check_env_specs` succeeds, whether
the info is filled at reset time.

This method requires running a few iterations in the environment to
manually check that the behaviour matches expectations.
.. note:: This method requires running a few iterations in the environment to
manually check that the behaviour matches expectations.

Args:
ignore_private (bool, optional): If ``True``, private infos (starting with
an underscore) will be ignored. Defaults to ``True``.

Keyword Args:
info_dict_reader (BaseInfoDictReader, optional): the info_dict_reader, if it is known in advance.
Unlike :meth:`~.set_info_dict_reader`, this method will create the primers necessary to get
:func:`~torchrl.envs.utils.check_env_specs` to run.

Examples:
>>> from torchrl.envs import GymEnv
>>> env = GymEnv("HalfCheetah-v4")
>>> env.register_info_dict()
>>> # registers the info dict reader
>>> env.auto_register_info_dict()
GymEnv(env=HalfCheetah-v4, batch_size=torch.Size([]), device=cpu)
>>> env.rollout(3)
TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 6]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([3, 17]), device=cpu, dtype=torch.float64, is_shared=False),
reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
reward_ctrl: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
reward_run: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
x_position: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
x_velocity: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([3]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([3, 17]), device=cpu, dtype=torch.float64, is_shared=False),
reward_ctrl: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
reward_run: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
x_position: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
x_velocity: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([3]),
device=cpu,
is_shared=False)

"""
from torchrl.envs import check_env_specs, TensorDictPrimer, TransformedEnv

if self.info_dict_reader:
raise RuntimeError("The environment already has an info-dict reader.")
self.set_info_dict_reader(ignore_private=ignore_private)
self.set_info_dict_reader(
ignore_private=ignore_private, info_dict_reader=info_dict_reader
)
try:
check_env_specs(self)
return self
Expand Down
86 changes: 55 additions & 31 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import torch
from packaging import version

from tensordict import TensorDictBase
from tensordict import TensorDict, TensorDictBase
from torch.utils._pytree import tree_map

from torchrl._utils import implement_for
from torchrl.data.tensor_specs import (
Expand All @@ -37,7 +38,7 @@
from torchrl.envs.batched_envs import CloudpickleWrapper
from torchrl.envs.common import _EnvPostInit

from torchrl.envs.gym_like import BaseInfoDictReader, GymLikeEnv
from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv

from torchrl.envs.utils import _classproperty

Expand Down Expand Up @@ -576,11 +577,11 @@ def __call__(cls, *args, **kwargs):
)
add_info_dict = False
if add_info_dict:
# First register the basic info dict reader
instance.auto_register_info_dict()
# Make it so that infos are properly cast where they should at done time
instance.set_info_dict_reader(
terminal_obs_reader(instance.observation_spec, backend=backend)
# register terminal_obs_reader
instance.auto_register_info_dict(
info_dict_reader=terminal_obs_reader(
instance.observation_spec, backend=backend
)
)
return TransformedEnv(instance, VecGymEnvTransform())
return instance
Expand Down Expand Up @@ -1470,7 +1471,7 @@ def lib(self) -> ModuleType:
_make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs)


class terminal_obs_reader(BaseInfoDictReader):
class terminal_obs_reader(default_info_dict_reader):
"""Terminal observation reader for 'vectorized' gym environments.

When running envs in parallel, Gym(nasium) writes the result of the true call
Expand Down Expand Up @@ -1507,11 +1508,11 @@ class terminal_obs_reader(BaseInfoDictReader):
}

def __init__(self, observation_spec: CompositeSpec, backend, name="final"):
super().__init__()
self.name = name
self._info_spec = CompositeSpec(
{name: observation_spec.clone()}, shape=observation_spec.shape
)
self._obs_spec = observation_spec.clone()
self.backend = backend
self._final_validated = False

@property
def info_spec(self):
Expand All @@ -1525,6 +1526,11 @@ def _read_obs(self, obs, key, tensor, index):
# presented as a np.ndarray. The key should be pixels or observation.
# We just write that value at its location in the tensor
tensor[index] = torch.as_tensor(obs, device=tensor.device)
if isinstance(obs, torch.Tensor):
# Simplest case: there is one observation,
# presented as a np.ndarray. The key should be pixels or observation.
# We just write that value at its location in the tensor
tensor[index] = obs.to(device=tensor.device)
elif isinstance(obs, dict):
if key not in obs:
raise KeyError(
Expand All @@ -1548,37 +1554,55 @@ def _read_obs(self, obs, key, tensor, index):
)

def __call__(self, info_dict, tensordict):
terminal_obs = info_dict.get(self.backend_key[self.backend], None)
terminal_info = info_dict.get(self.backend_info_key[self.backend], None)
if terminal_info is not None:
# terminal_info is a list of items that can be None or not
# If they're not None, they are a dict of values that we want to put in a root dict
keys = set()
for info in terminal_info:
if info is None:
continue
keys = keys.union(info.keys())
terminal_info = {
key: [info[key] if info is not None else info for info in terminal_info]
for key in keys
}
else:
def replace_none(nparray):
if not isinstance(nparray, np.ndarray) or nparray.dtype != np.dtype("O"):
return nparray
is_none = np.array([info is None for info in nparray])
if is_none.any():
# Then it is a final observation and we delegate the registration to the appropriate reader
nz = (~is_none).nonzero()[0][0]
zero_like = tree_map(lambda x: np.zeros_like(x), nparray[nz])
for idx in is_none.nonzero()[0]:
nparray[idx] = zero_like
return tree_map(lambda *x: np.stack(x), *nparray)

info_dict = tree_map(replace_none, info_dict)
# convert info_dict to a tensordict
info_dict = TensorDict(info_dict)
# get the terminal observation
terminal_obs = info_dict.pop(self.backend_key[self.backend], None)
# get the terminal info dict
terminal_info = info_dict.pop(self.backend_info_key[self.backend], None)

if terminal_info is None:
terminal_info = {}
obs_dict = terminal_info.copy()

super().__call__(info_dict, tensordict)
if not self._final_validated:
self.info_spec[self.name] = self._obs_spec.update(self.info_spec)
self._final_validated = True

final_info = terminal_info.copy()
if terminal_obs is not None:
obs_dict["observation"] = terminal_obs
for key, terminal_obs in obs_dict.items():
final_info["observation"] = terminal_obs

for key in self.info_spec[self.name].keys():

spec = self.info_spec[self.name, key]
# for key, item in self.info_spec.items(True, True):
# key = (key,) if isinstance(key, str) else key

final_obs_buffer = spec.zero()
terminal_obs = final_info.get(key, None)
if terminal_obs is not None:
for i, obs in enumerate(terminal_obs):
# writes final_obs inplace with terminal_obs content
self._read_obs(obs, key, final_obs_buffer, index=i)
tensordict.set((self.name, key), final_obs_buffer)
return tensordict

def reset(self):
super().reset()
self._final_validated = False


def _flip_info_tuple(info: Tuple[Dict]) -> Dict[str, tuple]:
# In Gym < 0.24, batched envs returned tuples of dict, and not dict of tuples.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


def _updater_check_forward_prehook(module, *args, **kwargs):
if not all(v for v in module._has_update_associated.values()) and RL_WARNINGS:
if not all(module._has_update_associated.values()) and RL_WARNINGS:
warnings.warn(
module.TARGET_NET_WARNING,
category=UserWarning,
Expand Down
Loading