Skip to content

Commit 7874e74

Browse files
authored
[BugFix] Consistent postproc when using a RB in a collector (#3144)
1 parent f596e53 commit 7874e74

File tree

3 files changed

+149
-5
lines changed

3 files changed

+149
-5
lines changed

test/test_collector.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3575,6 +3575,138 @@ def test_collector_rb_multiasync(
35753575
).all(), steps_counts
35763576
assert (idsdiff >= 0).all()
35773577

3578+
@staticmethod
3579+
def _zero_postproc(td):
3580+
# Apply zero to all tensor values in the tensordict
3581+
return torch.zeros_like(td)
3582+
3583+
@pytest.mark.parametrize(
3584+
"collector_class",
3585+
[
3586+
SyncDataCollector,
3587+
functools.partial(MultiSyncDataCollector, cat_results="stack"),
3588+
MultiaSyncDataCollector,
3589+
],
3590+
)
3591+
@pytest.mark.parametrize("use_replay_buffer", [True, False])
3592+
@pytest.mark.parametrize("extend_buffer", [True, False])
3593+
def test_collector_postproc_zeros(
3594+
self, collector_class, use_replay_buffer, extend_buffer
3595+
):
3596+
"""Test that postproc functionality works correctly across all collector types.
3597+
3598+
This test verifies that:
3599+
1. Postproc is applied correctly when no replay buffer is used
3600+
2. Postproc is applied correctly when replay buffer is used with extend_buffer=True
3601+
3. Postproc is not applied when replay buffer is used with extend_buffer=False
3602+
4. The behavior is consistent across Sync, MultiaSync, and MultiSync collectors
3603+
"""
3604+
# Create a simple dummy environment
3605+
def make_env():
3606+
env = DiscreteActionVecMockEnv()
3607+
env.set_seed(0)
3608+
return env
3609+
3610+
# Create a simple dummy policy
3611+
def make_policy(env):
3612+
return RandomPolicy(env.action_spec)
3613+
3614+
# Test parameters
3615+
total_frames = 64
3616+
frames_per_batch = 16
3617+
3618+
if use_replay_buffer:
3619+
# Create replay buffer
3620+
rb = ReplayBuffer(
3621+
storage=LazyTensorStorage(256), batch_size=5, compilable=False
3622+
)
3623+
3624+
# Test with replay buffer
3625+
if collector_class == SyncDataCollector:
3626+
collector = collector_class(
3627+
make_env(),
3628+
make_policy(make_env()),
3629+
replay_buffer=rb,
3630+
total_frames=total_frames,
3631+
frames_per_batch=frames_per_batch,
3632+
extend_buffer=extend_buffer,
3633+
postproc=self._zero_postproc if extend_buffer else None,
3634+
)
3635+
else:
3636+
# MultiSync and MultiaSync collectors
3637+
collector = collector_class(
3638+
[make_env, make_env],
3639+
make_policy(make_env()),
3640+
replay_buffer=rb,
3641+
total_frames=total_frames,
3642+
frames_per_batch=frames_per_batch,
3643+
extend_buffer=extend_buffer,
3644+
postproc=self._zero_postproc if extend_buffer else None,
3645+
)
3646+
try:
3647+
# Collect data
3648+
collected_frames = 0
3649+
for _ in collector:
3650+
collected_frames += frames_per_batch
3651+
if extend_buffer:
3652+
# With extend_buffer=True, postproc should be applied
3653+
# Check that the replay buffer contains zeros
3654+
sample = rb.sample(5)
3655+
torch.testing.assert_close(
3656+
sample["observation"],
3657+
torch.zeros_like(sample["observation"]),
3658+
)
3659+
torch.testing.assert_close(
3660+
sample["action"], torch.zeros_like(sample["action"])
3661+
)
3662+
# Check next.reward instead of reward
3663+
torch.testing.assert_close(
3664+
sample["next", "reward"],
3665+
torch.zeros_like(sample["next", "reward"]),
3666+
)
3667+
else:
3668+
# With extend_buffer=False, postproc should not be applied
3669+
# Check that the replay buffer contains non-zero values
3670+
sample = rb.sample(5)
3671+
assert torch.any(sample["observation"] != 0.0)
3672+
assert torch.any(sample["action"] != 0.0)
3673+
3674+
if collected_frames >= total_frames:
3675+
break
3676+
finally:
3677+
collector.shutdown()
3678+
3679+
else:
3680+
# Test without replay buffer
3681+
if collector_class == SyncDataCollector:
3682+
collector = collector_class(
3683+
make_env(),
3684+
make_policy(make_env()),
3685+
total_frames=total_frames,
3686+
frames_per_batch=frames_per_batch,
3687+
postproc=self._zero_postproc,
3688+
)
3689+
else:
3690+
# MultiSync and MultiaSync collectors
3691+
collector = collector_class(
3692+
[make_env, make_env],
3693+
make_policy(make_env()),
3694+
total_frames=total_frames,
3695+
frames_per_batch=frames_per_batch,
3696+
postproc=self._zero_postproc,
3697+
)
3698+
try:
3699+
# Collect data and verify postproc is applied
3700+
for batch in collector:
3701+
# All values should be zero due to postproc
3702+
assert torch.all(batch["observation"] == 0.0)
3703+
assert torch.all(batch["action"] == 0.0)
3704+
# Check next.reward instead of reward
3705+
assert torch.all(batch["next", "reward"] == 0.0)
3706+
break # Just check first batch
3707+
finally:
3708+
collector.shutdown()
3709+
35783710

35793711
def __deepcopy_error__(*args, **kwargs):
35803712
raise RuntimeError("deepcopy not allowed")

torchrl/collectors/collectors.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,10 @@ class SyncDataCollector(DataCollectorBase):
508508
postproc (Callable, optional): A post-processing transform, such as
509509
a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
510510
instance.
511+
512+
.. warning:: Postproc is not applied when a replay buffer is used and items are added to the buffer
513+
as they are produced (`extend_buffer=False`). The recommended usage is to use `extend_buffer=True`.
514+
511515
Defaults to ``None``.
512516
split_trajs (bool, optional): Boolean indicating whether the resulting
513517
TensorDict should be split according to the trajectories.
@@ -3021,7 +3025,11 @@ def iterator(self) -> Iterator[TensorDictBase]:
30213025
self._frames += n_collected
30223026

30233027
if self.postprocs:
3024-
self.postprocs = self.postprocs.to(out.device)
3028+
self.postprocs = (
3029+
self.postprocs.to(out.device)
3030+
if hasattr(self.postprocs, "to")
3031+
else self.postprocs
3032+
)
30253033
out = self.postprocs(out)
30263034
if self._exclude_private_keys:
30273035
excluded_keys = [key for key in out.keys() if key.startswith("_")]
@@ -3144,7 +3152,7 @@ def __init__(self, *args, **kwargs):
31443152
self.out_tensordicts = defaultdict(lambda: None)
31453153
self.running = False
31463154

3147-
if self.postprocs is not None:
3155+
if self.postprocs is not None and self.replay_buffer is None:
31483156
postproc = self.postprocs
31493157
self.postprocs = {}
31503158
for _device in self.storing_device:
@@ -3265,7 +3273,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
32653273
worker_frames = self.frames_per_batch_worker()
32663274
self._frames += worker_frames
32673275
workers_frames[idx] = workers_frames[idx] + worker_frames
3268-
if self.postprocs:
3276+
if out is not None and self.postprocs:
32693277
out = self.postprocs[out.device](out)
32703278

32713279
# the function blocks here until the next item is asked, hence we send the message to the

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,9 @@ def update_priority(
801801

802802
@pin_memory_output
803803
def _sample(self, batch_size: int) -> tuple[Any, dict]:
804-
with self._replay_lock if not is_compiling() else contextlib.nullcontext():
804+
is_comp = is_compiling()
805+
nc = contextlib.nullcontext()
806+
with self._replay_lock if not is_comp else nc, self._write_lock if not is_comp else nc:
805807
index, info = self._sampler.sample(self._storage, batch_size)
806808
info["index"] = index
807809
data = self._storage.get(index)
@@ -1539,7 +1541,9 @@ def sample(
15391541

15401542
@pin_memory_output
15411543
def _sample(self, batch_size: int) -> tuple[Any, dict]:
1542-
with self._replay_lock if not is_compiling() else contextlib.nullcontext():
1544+
is_comp = is_compiling()
1545+
nc = contextlib.nullcontext()
1546+
with self._replay_lock if not is_comp else nc, self._write_lock if not is_comp else nc:
15431547
index, info = self._sampler.sample(self._storage, batch_size)
15441548
info["index"] = index
15451549
data = self._storage.get(index)

0 commit comments

Comments
 (0)