@@ -3575,6 +3575,138 @@ def test_collector_rb_multiasync(
3575
3575
).all (), steps_counts
3576
3576
assert (idsdiff >= 0 ).all ()
3577
3577
3578
+ @staticmethod
3579
+ def _zero_postproc (td ):
3580
+ # Apply zero to all tensor values in the tensordict
3581
+ return torch .zeros_like (td )
3582
+
3583
+ @pytest .mark .parametrize (
3584
+ "collector_class" ,
3585
+ [
3586
+ SyncDataCollector ,
3587
+ functools .partial (MultiSyncDataCollector , cat_results = "stack" ),
3588
+ MultiaSyncDataCollector ,
3589
+ ],
3590
+ )
3591
+ @pytest .mark .parametrize ("use_replay_buffer" , [True , False ])
3592
+ @pytest .mark .parametrize ("extend_buffer" , [True , False ])
3593
+ def test_collector_postproc_zeros (
3594
+ self , collector_class , use_replay_buffer , extend_buffer
3595
+ ):
3596
+ """Test that postproc functionality works correctly across all collector types.
3597
+
3598
+ This test verifies that:
3599
+ 1. Postproc is applied correctly when no replay buffer is used
3600
+ 2. Postproc is applied correctly when replay buffer is used with extend_buffer=True
3601
+ 3. Postproc is not applied when replay buffer is used with extend_buffer=False
3602
+ 4. The behavior is consistent across Sync, MultiaSync, and MultiSync collectors
3603
+ """
3604
+ # Create a simple dummy environment
3605
+ def make_env ():
3606
+ env = DiscreteActionVecMockEnv ()
3607
+ env .set_seed (0 )
3608
+ return env
3609
+
3610
+ # Create a simple dummy policy
3611
+ def make_policy (env ):
3612
+ return RandomPolicy (env .action_spec )
3613
+
3614
+ # Test parameters
3615
+ total_frames = 64
3616
+ frames_per_batch = 16
3617
+
3618
+ if use_replay_buffer :
3619
+ # Create replay buffer
3620
+ rb = ReplayBuffer (
3621
+ storage = LazyTensorStorage (256 ), batch_size = 5 , compilable = False
3622
+ )
3623
+
3624
+ # Test with replay buffer
3625
+ if collector_class == SyncDataCollector :
3626
+ collector = collector_class (
3627
+ make_env (),
3628
+ make_policy (make_env ()),
3629
+ replay_buffer = rb ,
3630
+ total_frames = total_frames ,
3631
+ frames_per_batch = frames_per_batch ,
3632
+ extend_buffer = extend_buffer ,
3633
+ postproc = self ._zero_postproc if extend_buffer else None ,
3634
+ )
3635
+ else :
3636
+ # MultiSync and MultiaSync collectors
3637
+ collector = collector_class (
3638
+ [make_env , make_env ],
3639
+ make_policy (make_env ()),
3640
+ replay_buffer = rb ,
3641
+ total_frames = total_frames ,
3642
+ frames_per_batch = frames_per_batch ,
3643
+ extend_buffer = extend_buffer ,
3644
+ postproc = self ._zero_postproc if extend_buffer else None ,
3645
+ )
3646
+ try :
3647
+ # Collect data
3648
+ collected_frames = 0
3649
+ for _ in collector :
3650
+ collected_frames += frames_per_batch
3651
+ if extend_buffer :
3652
+ # With extend_buffer=True, postproc should be applied
3653
+ # Check that the replay buffer contains zeros
3654
+ sample = rb .sample (5 )
3655
+ torch .testing .assert_close (
3656
+ sample ["observation" ],
3657
+ torch .zeros_like (sample ["observation" ]),
3658
+ )
3659
+ torch .testing .assert_close (
3660
+ sample ["action" ], torch .zeros_like (sample ["action" ])
3661
+ )
3662
+ # Check next.reward instead of reward
3663
+ torch .testing .assert_close (
3664
+ sample ["next" , "reward" ],
3665
+ torch .zeros_like (sample ["next" , "reward" ]),
3666
+ )
3667
+ else :
3668
+ # With extend_buffer=False, postproc should not be applied
3669
+ # Check that the replay buffer contains non-zero values
3670
+ sample = rb .sample (5 )
3671
+ assert torch .any (sample ["observation" ] != 0.0 )
3672
+ assert torch .any (sample ["action" ] != 0.0 )
3673
+
3674
+ if collected_frames >= total_frames :
3675
+ break
3676
+ finally :
3677
+ collector .shutdown ()
3678
+
3679
+ else :
3680
+ # Test without replay buffer
3681
+ if collector_class == SyncDataCollector :
3682
+ collector = collector_class (
3683
+ make_env (),
3684
+ make_policy (make_env ()),
3685
+ total_frames = total_frames ,
3686
+ frames_per_batch = frames_per_batch ,
3687
+ postproc = self ._zero_postproc ,
3688
+ )
3689
+ else :
3690
+ # MultiSync and MultiaSync collectors
3691
+ collector = collector_class (
3692
+ [make_env , make_env ],
3693
+ make_policy (make_env ()),
3694
+ total_frames = total_frames ,
3695
+ frames_per_batch = frames_per_batch ,
3696
+ postproc = self ._zero_postproc ,
3697
+ )
3698
+ try :
3699
+ # Collect data and verify postproc is applied
3700
+ for batch in collector :
3701
+ # All values should be zero due to postproc
3702
+ assert torch .all (batch ["observation" ] == 0.0 )
3703
+ assert torch .all (batch ["action" ] == 0.0 )
3704
+ # Check next.reward instead of reward
3705
+ assert torch .all (batch ["next" , "reward" ] == 0.0 )
3706
+ break # Just check first batch
3707
+ finally :
3708
+ collector .shutdown ()
3709
+
3578
3710
3579
3711
def __deepcopy_error__ (* args , ** kwargs ):
3580
3712
raise RuntimeError ("deepcopy not allowed" )
0 commit comments