Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions torchrl/trainers/helpers/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
from torchrl.collectors.collectors import (
_DataCollector,
_MultiDataCollector,
SyncDataCollector,
MultiaSyncDataCollector,
MultiSyncDataCollector,
)
from torchrl.data import MultiStep
from torchrl.data.tensordict.tensordict import TensorDictBase
from torchrl.envs import ParallelEnv
from torchrl.envs.common import EnvBase
from torchrl.modules import TensorDictModuleWrapper, ProbabilisticTensorDictModule

__all__ = [
"sync_sync_collector",
Expand All @@ -23,9 +26,6 @@
"make_collector_onpolicy",
]

from torchrl.envs.common import EnvBase
from torchrl.modules import TensorDictModuleWrapper, ProbabilisticTensorDictModule


def sync_async_collector(
env_fns: Union[Callable, List[Callable]],
Expand Down Expand Up @@ -95,7 +95,7 @@ def sync_sync_collector(
num_env_per_collector: Optional[int] = None,
num_collectors: Optional[int] = None,
**kwargs,
) -> MultiSyncDataCollector:
) -> Union[SyncDataCollector, MultiSyncDataCollector]:
"""
Runs synchronous collectors, each running synchronous environments.

Expand Down Expand Up @@ -149,6 +149,19 @@ def sync_sync_collector(
**kwargs: Other kwargs passed to the data collectors

"""
if num_collectors == 1:
if "devices" in kwargs:
kwargs["device"] = kwargs.pop("devices")
if "passing_devices" in kwargs:
kwargs["passing_device"] = kwargs.pop("passing_devices")
return _make_collector(
SyncDataCollector,
env_fns=env_fns,
env_kwargs=env_kwargs,
num_env_per_collector=num_env_per_collector,
num_collectors=num_collectors,
**kwargs,
)
return _make_collector(
MultiSyncDataCollector,
env_fns=env_fns,
Expand Down Expand Up @@ -224,6 +237,13 @@ def _make_collector(
for _env_fn, _env_kwargs in zip(env_fns_split, env_kwargs_split)
]
env_kwargs = None
if collector_class is SyncDataCollector:
if len(env_fns) > 1:
raise RuntimeError(
f"Something went wrong: expected a single env constructor but got {len(env_fns)}"
)
env_fns = env_fns[0]
env_kwargs = env_kwargs[0]
return collector_class(
create_env_fn=env_fns,
create_env_kwargs=env_kwargs,
Expand Down