Skip to content

[BUG] Resetting environments with non-empty batch_size #790

@matteobettini

Description

@matteobettini

Describe the bug

When resetting an environment with non-empty batch size, there is currently no unified way to specify which dimensions to reset.

ParallelEnv has a reset key called "reset_workers" which is used to choose which workers to reset. The use of this unidimensional key makes ParallelEnv crash when env.batch_size is not empty.

This is what happens:

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
    cmd_out = "reset"
    if tensordict is not None and "reset_workers" in tensordict.keys():
        self._assert_tensordict_shape(tensordict) # First assert that the key has the same batch size as the env, let's say [10,4,5,2]
        reset_workers = tensordict.get("reset_workers")
    else:
        reset_workers = torch.ones(self.num_workers, dtype=torch.bool) # If not create one (without respecting the batch size)

    for i, channel in enumerate(self.parent_channels):
        if not reset_workers[i]: # If run on dimension 0 !! This only works if the else branch is taken or env.batch_size is empty
            continue
        channel.send((cmd_out, kwargs)) # Do not even pass the tensordict to the env

The problems just in this snippet are:

  1. This only works if the else branch is taken or env.batch_size is empty
  2. it does not pass the reset tensordict to the env

This is paired with a series of problems in the various reset functions (ParallelEnv and EnvBase) where after reset done.any() is called. This is highly problematic as any spans over all dimensions.

Proposed changes to the API

  1. Remove "reset_workers" (which already doesn't work)
  2. Introduce the possibility of having a "reset" key in the tensordict given as parameter to the reset functions. This reset key has shape (in the more general case of ParallelEnv) (n_parallel_envs, *env.batch_size) and is a boolean telling precisely which dimensions to reset. ParallelEnv then can call reset[worker_id].any() to know if to pass the reset command and key to the worker
  3. In the reset function do not check done.any() but chacke that at least the requested dimensions to reset have done=False. The absence of the "reset" key means resetting all dimensions

To Reproduce

env = MockBatchedLockedEnv(device="cpu", batch_size=torch.Size(env_batch_size))
env.set_seed(1)
parallel_env = ParallelEnv(num_parallel_env, lambda: env)
parallel_env.start()
reset_td = TensorDict(
    {"reset_workers": torch.full(parallel_env.batch_size, True, device=parallel_env.device)},
    batch_size=parallel_env.batch_size,
    device=parallel_env.device,
)
parallel_env.reset(reset_td)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions