-
Notifications
You must be signed in to change notification settings - Fork 422
Description
Describe the bug
If replay_buffer is given to the collector (for possible .start() use) with a parallelEnv instance, then it doesn't work (flattening is needed ?) Freeze at the first collected batch during the .extend call. But work with extend if called outside of the collector, without giving the replay buffer.
To Reproduce
Steps to reproduce the behavior.
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the markdown code blocks for both code and stack traces.
import torch, time
import torch.nn as nn
from gymnasium.envs.classic_control.pendulum import PendulumEnv
from torchrl.envs import EnvCreator, ParallelEnv, GymWrapper, Transform, TransformedEnv, Compose, DTypeCastTransform, RewardScaling, RewardSum, StepCounter
from torchrl.collectors import aSyncDataCollector, MultiSyncDataCollector
from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage, RandomSampler
from tensordict.nn import TensorDictModule
# Environment factory
def create_env(render = None):
env = PendulumEnv("human")
env = GymWrapper(env)
env = TransformedEnv(
env,
Compose(DTypeCastTransform(torch.float64, torch.float32),
RewardScaling(-8., 8., "reward", "reward", True),
RewardSum(in_keys=[("reward",)], out_keys=[("episode_reward",)]),
StepCounter(256)
)
)
return env
create_env_new = EnvCreator(create_env)
def parallel_env():
return ParallelEnv(4, create_env_new)
if __name__ == '__main__':
policy_net = nn.Linear(3, 1)
policy = TensorDictModule(policy_net, in_keys=["observation"], out_keys=["action"]).to("cuda")
replay_buffer = TensorDictReplayBuffer(
storage=LazyTensorStorage(2e6, ndim=1), #tried with ndim = 1, ndim = 2, ndim = 3
sampler=RandomSampler(),
batch_size=128,
)
# Create async data collector
collector = aSyncDataCollector( #Try aSyncDataCollector or MultiSyncDataCollector
parallel_env, # work with create_env function but not this one
policy,
num_workers=1,
frames_per_batch=64,
total_frames=-1,
extend_buffer=True,
replay_buffer=replay_buffer,
device=torch.device("cpu"),
storing_device=torch.device("cpu"),
env_device=torch.device("cpu"),
policy_device=torch.device("cuda"),
)
# Doesn't work
collector.start()
while True:
print(len(replay_buffer))
time.sleep(2.)
# Doesn't work if replay_buffer is given in the collector
# for batch in collector:
# print(batch.shape)
# replay_buffer.extend(batch) This work if no replay buffer is given to the collector and used here instead
# print(len(replay_buffer))
# time.sleep(2.)RuntimeError: indexed destination TensorDict batch size is torch.Size([4, 4]) (batch_size = torch.Size([2000000, 4]), index=tensor([0, 1, 2, 3])), which differs from the source batch size torch.Size([4, 16]or
RuntimeError: expand_as_right requires the destination tensor to have less dimensions than the input tensor, got tensor.ndimension()=2 and dest.ndimension()=1Expected behavior
Same behavior as when the replay buffer isn't given to the collector and extend manually.
Additional context
Found this in the doc in single node data collectors:
Using replay buffers that sample trajectories with MultiSyncDataCollector isn’t currently fully supported as the data batches can come from any worker and in most cases consecutive batches written in the buffer won’t come from the same source (thereby interrupting the trajectories).
But I guess this apply only with Multi'a'SyncDataCollector and there is a typo ? It doesn't explained why Sync and aSync - without Multi- wouldn't work ? Can be wrong and miss understood
Checklist
- [*] I have checked that there is no similar issue in the repo (required)
- [*] I have read the documentation (required)
- [*] I have provided a minimal working example to reproduce the bug (required)