-
Notifications
You must be signed in to change notification settings - Fork 418
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When resetting an environment with non-empty batch size, there is currently no unified way to specify which dimensions to reset.
ParallelEnv has a reset key called "reset_workers" which is used to choose which workers to reset. The use of this unidimensional key makes ParallelEnv crash when env.batch_size is not empty.
This is what happens:
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
cmd_out = "reset"
if tensordict is not None and "reset_workers" in tensordict.keys():
self._assert_tensordict_shape(tensordict) # First assert that the key has the same batch size as the env, let's say [10,4,5,2]
reset_workers = tensordict.get("reset_workers")
else:
reset_workers = torch.ones(self.num_workers, dtype=torch.bool) # If not create one (without respecting the batch size)
for i, channel in enumerate(self.parent_channels):
if not reset_workers[i]: # If run on dimension 0 !! This only works if the else branch is taken or env.batch_size is empty
continue
channel.send((cmd_out, kwargs)) # Do not even pass the tensordict to the envThe problems just in this snippet are:
- This only works if the else branch is taken or env.batch_size is empty
- it does not pass the reset tensordict to the env
This is paired with a series of problems in the various reset functions (ParallelEnv and EnvBase) where after reset done.any() is called. This is highly problematic as any spans over all dimensions.
Proposed changes to the API
- Remove "reset_workers" (which already doesn't work)
- Introduce the possibility of having a "reset" key in the tensordict given as parameter to the reset functions. This reset key has shape (in the more general case of
ParallelEnv)(n_parallel_envs, *env.batch_size)and is a boolean telling precisely which dimensions to reset. ParallelEnv then can callreset[worker_id].any()to know if to pass the reset command and key to the worker - In the reset function do not check
done.any()but chacke that at least the requested dimensions to reset have done=False. The absence of the "reset" key means resetting all dimensions
To Reproduce
env = MockBatchedLockedEnv(device="cpu", batch_size=torch.Size(env_batch_size))
env.set_seed(1)
parallel_env = ParallelEnv(num_parallel_env, lambda: env)
parallel_env.start()
reset_td = TensorDict(
{"reset_workers": torch.full(parallel_env.batch_size, True, device=parallel_env.device)},
batch_size=parallel_env.batch_size,
device=parallel_env.device,
)
parallel_env.reset(reset_td)Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working