diff --git a/test/test_collector.py b/test/test_collector.py index b6bfdaba7b1..51679248c9a 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -79,7 +79,10 @@ RandomPolicy, ) from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule -from torchrl.weight_update import SharedMemWeightSyncScheme +from torchrl.weight_update import ( + MultiProcessWeightSyncScheme, + SharedMemWeightSyncScheme, +) if os.getenv("PYTORCH_TEST_FBCODE"): IS_FB = True @@ -1485,12 +1488,12 @@ def env_fn(seed): @pytest.mark.parametrize("use_async", [False, True]) @pytest.mark.parametrize("cudagraph", [False, True]) + @pytest.mark.parametrize( + "weight_sync_scheme", + [None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme], + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") - def test_update_weights(self, use_async, cudagraph): - from torchrl.weight_update.weight_sync_schemes import ( - MultiProcessWeightSyncScheme, - ) - + def test_update_weights(self, use_async, cudagraph, weight_sync_scheme): def create_env(): return ContinuousActionVecMockEnv() @@ -1503,6 +1506,9 @@ def create_env(): collector_class = ( MultiSyncDataCollector if not use_async else MultiaSyncDataCollector ) + kwargs = {} + if weight_sync_scheme is not None: + kwargs["weight_sync_schemes"] = {"policy": weight_sync_scheme()} collector = collector_class( [create_env] * 3, policy=policy, @@ -1511,7 +1517,7 @@ def create_env(): frames_per_batch=20, cat_results="stack", cudagraph_policy=cudagraph, - weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()}, + **kwargs, ) assert "policy" in collector._weight_senders, collector._weight_senders.keys() try: @@ -2857,23 +2863,28 @@ def forward(self, td): # ["cuda:0", "cuda"], ], ) - def test_param_sync(self, give_weights, collector, policy_device, env_device): - from torchrl.weight_update.weight_sync_schemes import ( - MultiProcessWeightSyncScheme, - ) - + @pytest.mark.parametrize( + "weight_sync_scheme", + [None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme], + ) + def test_param_sync( + self, give_weights, collector, policy_device, env_device, weight_sync_scheme + ): policy = TestUpdateParams.Policy().to(policy_device) env = EnvCreator(lambda: TestUpdateParams.DummyEnv(device=env_device)) device = env().device env = [env] + kwargs = {} + if weight_sync_scheme is not None: + kwargs["weight_sync_schemes"] = {"policy": weight_sync_scheme()} col = collector( env, policy, device=device, total_frames=200, frames_per_batch=10, - weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()}, + **kwargs, ) try: for i, data in enumerate(col): @@ -2918,13 +2929,13 @@ def test_param_sync(self, give_weights, collector, policy_device, env_device): # ["cuda:0", "cuda"], ], ) + @pytest.mark.parametrize( + "weight_sync_scheme", + [None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme], + ) def test_param_sync_mixed_device( - self, give_weights, collector, policy_device, env_device + self, give_weights, collector, policy_device, env_device, weight_sync_scheme ): - from torchrl.weight_update.weight_sync_schemes import ( - MultiProcessWeightSyncScheme, - ) - with torch.device("cpu"): policy = TestUpdateParams.Policy() policy.param = nn.Parameter(policy.param.data.to(policy_device)) @@ -2933,13 +2944,16 @@ def test_param_sync_mixed_device( env = EnvCreator(lambda: TestUpdateParams.DummyEnv(device=env_device)) device = env().device env = [env] + kwargs = {} + if weight_sync_scheme is not None: + kwargs["weight_sync_schemes"] = {"policy": weight_sync_scheme()} col = collector( env, policy, device=device, total_frames=200, frames_per_batch=10, - weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()}, + **kwargs, ) try: for i, data in enumerate(col): @@ -3851,7 +3865,7 @@ def test_weight_update(self, weight_updater): if weight_updater == "scheme_shared": kwargs = {"weight_sync_schemes": {"policy": SharedMemWeightSyncScheme()}} elif weight_updater == "scheme_pipe": - kwargs = {"weight_sync_schemes": {"policy": SharedMemWeightSyncScheme()}} + kwargs = {"weight_sync_schemes": {"policy": MultiProcessWeightSyncScheme()}} elif weight_updater == "weight_updater": kwargs = {"weight_updater": self.MPSWeightUpdaterBase(policy_weights, 2)} else: @@ -3870,14 +3884,16 @@ def test_weight_update(self, weight_updater): **kwargs, ) - collector.update_policy_weights_() + # When using policy_factory, must pass weights explicitly + collector.update_policy_weights_(policy_weights) try: for i, data in enumerate(collector): if i == 2: assert (data["action"] != 0).any() # zero the policy policy_weights.data.zero_() - collector.update_policy_weights_() + # When using policy_factory, must pass weights explicitly + collector.update_policy_weights_(policy_weights) elif i == 3: assert (data["action"] == 0).all(), data["action"] break @@ -3973,11 +3989,11 @@ def test_start_multi(self, total_frames, cls): @pytest.mark.parametrize( "cls", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector] ) - def test_start_update_policy(self, total_frames, cls): - from torchrl.weight_update.weight_sync_schemes import ( - MultiProcessWeightSyncScheme, - ) - + @pytest.mark.parametrize( + "weight_sync_scheme", + [None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme], + ) + def test_start_update_policy(self, total_frames, cls, weight_sync_scheme): rb = ReplayBuffer(storage=LazyMemmapStorage(max_size=1000)) env = CountingEnv() m = nn.Linear(env.observation_spec["observation"].shape[-1], 1) @@ -3998,8 +4014,8 @@ def test_start_update_policy(self, total_frames, cls): # Add weight sync schemes for multi-process collectors kwargs = {} - if cls != SyncDataCollector: - kwargs["weight_sync_schemes"] = {"policy": MultiProcessWeightSyncScheme()} + if cls != SyncDataCollector and weight_sync_scheme is not None: + kwargs["weight_sync_schemes"] = {"policy": weight_sync_scheme()} collector = cls( env, diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 9affa08bb8a..5f407bedbab 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -318,8 +318,23 @@ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: ) strategy = WeightStrategy(extract_as=scheme.strategy) - model = _resolve_model(self, model_id) - return strategy.extract_weights(model) + try: + model = _resolve_model(self, model_id) + except AttributeError: + # Model not found, try fallback policy + model = None + + # If model is None, try to use the fallback policy + if model is None and model_id == "policy": + if ( + hasattr(self, "_fallback_policy") + and self._fallback_policy is not None + ): + model = self._fallback_policy + + if model is not None: + return strategy.extract_weights(model) + # If still None, fall through to legacy code below if weights is None: if model_id == "policy" and hasattr(self, "policy_weights"): @@ -351,6 +366,10 @@ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: else: return weights + @property + def _legacy_weight_updater(self) -> bool: + return self._weight_updater is not None + def update_policy_weights_( self, policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, @@ -393,6 +412,50 @@ def update_policy_weights_( :meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`. """ + if self._legacy_weight_updater: + return self._legacy_weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + else: + return self._weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + + def _legacy_weight_update_impl( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + if weights_dict is not None: + raise ValueError("weights_dict is not supported with legacy weight updater") + if model_id is not None: + raise ValueError("model_id is not supported with legacy weight updater") + # Fall back to old weight updater system + self.weight_updater( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def _weight_update_impl( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: if "policy_weights" in kwargs: warnings.warn( "`policy_weights` is deprecated. Use `policy_or_weights` instead.", @@ -408,56 +471,37 @@ def update_policy_weights_( "Cannot specify both 'weights_dict' and 'policy_or_weights'" ) + if policy_or_weights is not None: + weights_dict = {"policy": policy_or_weights} + # Priority: new weight sync schemes > old weight updater system if self._weight_senders: - if weights_dict is not None: - for target_model_id, weights in weights_dict.items(): - if target_model_id not in self._weight_senders: - raise KeyError( - f"Model '{target_model_id}' not found in registered weight senders. " - f"Available models: {list(self._weight_senders.keys())}" - ) - processed_weights = self._extract_weights_if_needed( - weights, target_model_id - ) - self._weight_senders[target_model_id].update_weights( - processed_weights + if model_id is not None: + # Compose weight_dict + weights_dict = {model_id: policy_or_weights} + if weights_dict is None: + if "policy" in self._weight_senders: + weights_dict = {"policy": policy_or_weights} + elif len(self._weight_senders) == 1: + single_model_id = next(iter(self._weight_senders.keys())) + weights_dict = {single_model_id: policy_or_weights} + else: + raise ValueError( + "Cannot determine the model to update. Please provide a weights_dict." ) - elif model_id is not None: - if model_id not in self._weight_senders: + for target_model_id, weights in weights_dict.items(): + if target_model_id not in self._weight_senders: raise KeyError( - f"Model '{model_id}' not found in registered weight senders. " + f"Model '{target_model_id}' not found in registered weight senders. " f"Available models: {list(self._weight_senders.keys())}" ) processed_weights = self._extract_weights_if_needed( - policy_or_weights, model_id + weights, target_model_id ) - self._weight_senders[model_id].update_weights(processed_weights) - else: - if "policy" in self._weight_senders: - processed_weights = self._extract_weights_if_needed( - policy_or_weights, "policy" - ) - self._weight_senders["policy"].update_weights(processed_weights) - elif len(self._weight_senders) == 1: - single_model_id = next(iter(self._weight_senders.keys())) - single_sender = self._weight_senders[single_model_id] - processed_weights = self._extract_weights_if_needed( - policy_or_weights, single_model_id - ) - single_sender.update_weights(processed_weights) - else: - for target_model_id, sender in self._weight_senders.items(): - processed_weights = self._extract_weights_if_needed( - policy_or_weights, target_model_id - ) - sender.update_weights(processed_weights) - + self._weight_senders[target_model_id].update_weights(processed_weights) elif self._weight_updater is not None: - # Fall back to old weight updater system - self.weight_updater( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) + # unreachable + raise RuntimeError else: # No weight updater configured # For single-process collectors, apply weights locally if explicitly provided @@ -838,11 +882,121 @@ def __init__( track_policy_version: bool = False, **kwargs, ): + self.closed = True + + # Initialize environment + env = self._init_env(create_env_fn, create_env_kwargs) + + # Initialize policy + policy = self._init_policy(policy, policy_factory, env, trust_policy) + self._read_compile_kwargs(compile_policy, cudagraph_policy) + + # Handle trajectory pool and validate kwargs + self._traj_pool_val = kwargs.pop("traj_pool", None) + if kwargs: + raise TypeError( + f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}." + ) + + # Set up devices and synchronization + self._setup_devices( + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + no_cuda_sync=no_cuda_sync, + ) + + self.env: EnvBase = env + del env + + # Set up policy version tracking + self._setup_policy_version_tracking(track_policy_version) + + # Set up replay buffer + self._setup_replay_buffer( + replay_buffer=replay_buffer, + extend_buffer=extend_buffer, + local_init_rb=local_init_rb, + postproc=postproc, + split_trajs=split_trajs, + return_same_td=return_same_td, + use_buffers=use_buffers, + ) + + self.closed = False + + # Validate reset_when_done + if not reset_when_done: + raise ValueError("reset_when_done is deprecated.") + self.reset_when_done = reset_when_done + self.n_env = self.env.batch_size.numel() + + # Register collector with policy and env + if hasattr(policy, "register_collector"): + policy.register_collector(self) + if hasattr(self.env, "register_collector"): + self.env.register_collector(self) + + # Set up policy and weights + self._setup_policy_and_weights(policy) + + # Apply environment device + self._apply_env_device() + + # Set up max frames per trajectory + self._setup_max_frames_per_traj(max_frames_per_traj) + + # Validate and set total frames + self.reset_at_each_iter = reset_at_each_iter + self._setup_total_frames(total_frames, frames_per_batch) + + # Set up init random frames + self._setup_init_random_frames(init_random_frames, frames_per_batch) + + # Set up postproc + self._setup_postproc(postproc) + + # Calculate frames per batch + self._setup_frames_per_batch(frames_per_batch) + + # Set exploration and other options + self.exploration_type = ( + exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE + ) + self.return_same_td = return_same_td + self.set_truncated = set_truncated + + # Create shuttle and rollout buffers + self._make_shuttle() + self._maybe_make_final_rollout(make_rollout=self._use_buffers) + self._set_truncated_keys() + + # Set split trajectories option + if split_trajs is None: + split_trajs = False + self.split_trajs = split_trajs + self._exclude_private_keys = True + + # Set up interruptor and frame tracking + self.interruptor = interruptor + self._frames = 0 + self._iter = -1 + + # Set up weight synchronization + self._setup_weight_sync(weight_updater, weight_sync_schemes) + + def _init_env( + self, + create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], + create_env_kwargs: dict[str, Any] | None, + ) -> EnvBase: + """Initialize and configure the environment.""" from torchrl.envs.batched_envs import BatchedEnvBase - self.closed = True if create_env_kwargs is None: create_env_kwargs = {} + if not isinstance(create_env_fn, EnvBase): env = create_env_fn(**create_env_kwargs) else: @@ -854,7 +1008,16 @@ def __init__( f"on environment of type {type(create_env_fn)}." ) env.update_kwargs(create_env_kwargs) + return env + def _init_policy( + self, + policy: TensorDictModule | Callable | None, + policy_factory: Callable[[], Callable] | None, + env: EnvBase, + trust_policy: bool | None, + ) -> TensorDictModule | Callable: + """Initialize and configure the policy.""" if policy is None: if policy_factory is not None: policy = policy_factory() @@ -862,33 +1025,26 @@ def __init__( policy = RandomPolicy(env.full_action_spec) elif policy_factory is not None: raise TypeError("policy_factory cannot be used with policy argument.") - # If the underlying policy has a state_dict, we keep a reference to the policy and - # do all policy weight saving/loading through it + + # If the underlying policy has a state_dict, keep a reference to it if hasattr(policy, "state_dict"): self._policy_w_state_dict = policy if trust_policy is None: trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule)) self.trust_policy = trust_policy - self._read_compile_kwargs(compile_policy, cudagraph_policy) - ########################## - # Trajectory pool - self._traj_pool_val = kwargs.pop("traj_pool", None) - if kwargs: - raise TypeError( - f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}." - ) + return policy - ########################## - # Setting devices: - # The rule is the following: - # - If no device is passed, all devices are assumed to work OOB. - # The tensordict used for output is not on any device (ie, actions and observations - # can be on a different device). - # - If the ``device`` is passed, it is used for all devices (storing, env and policy) - # unless overridden by another kwarg. - # - The rest of the kwargs control the respective device. + def _setup_devices( + self, + device: DEVICE_TYPING | None, + storing_device: DEVICE_TYPING | None, + policy_device: DEVICE_TYPING | None, + env_device: DEVICE_TYPING | None, + no_cuda_sync: bool, + ) -> None: + """Set up devices and synchronization functions.""" storing_device, policy_device, env_device = self._get_devices( storing_device=storing_device, policy_device=policy_device, @@ -897,65 +1053,39 @@ def __init__( ) self.storing_device = storing_device - if self.storing_device is not None and self.storing_device.type != "cuda": - # Cuda handles sync - if torch.cuda.is_available(): - self._sync_storage = torch.cuda.synchronize - elif torch.backends.mps.is_available() and hasattr(torch, "mps"): - # Will break for older PT versions which don't have torch.mps - self._sync_storage = torch.mps.synchronize - elif hasattr(torch, "npu") and torch.npu.is_available(): - self._sync_storage = torch.npu.synchronize - elif self.storing_device.type == "cpu": - self._sync_storage = _do_nothing - else: - raise RuntimeError("Non supported device") - else: - self._sync_storage = _do_nothing + self._sync_storage = self._get_sync_fn(storing_device) self.env_device = env_device - if self.env_device is not None and self.env_device.type != "cuda": - # Cuda handles sync - if torch.cuda.is_available(): - self._sync_env = torch.cuda.synchronize - elif torch.backends.mps.is_available() and hasattr(torch, "mps"): - self._sync_env = torch.mps.synchronize - elif hasattr(torch, "npu") and torch.npu.is_available(): - self._sync_env = torch.npu.synchronize - elif self.env_device.type == "cpu": - self._sync_env = _do_nothing - else: - raise RuntimeError("Non supported device") - else: - self._sync_env = _do_nothing + self._sync_env = self._get_sync_fn(env_device) + self.policy_device = policy_device - if self.policy_device is not None and self.policy_device.type != "cuda": + self._sync_policy = self._get_sync_fn(policy_device) + + self.device = device + self.no_cuda_sync = no_cuda_sync + self._cast_to_policy_device = self.policy_device != self.env_device + + def _get_sync_fn(self, device: torch.device | None) -> Callable: + """Get the appropriate synchronization function for a device.""" + if device is not None and device.type != "cuda": # Cuda handles sync if torch.cuda.is_available(): - self._sync_policy = torch.cuda.synchronize + return torch.cuda.synchronize elif torch.backends.mps.is_available() and hasattr(torch, "mps"): - self._sync_policy = torch.mps.synchronize + return torch.mps.synchronize elif hasattr(torch, "npu") and torch.npu.is_available(): - self._sync_policy = torch.npu.synchronize - elif self.policy_device.type == "cpu": - self._sync_policy = _do_nothing + return torch.npu.synchronize + elif device.type == "cpu": + return _do_nothing else: raise RuntimeError("Non supported device") else: - self._sync_policy = _do_nothing - self.device = device - self.no_cuda_sync = no_cuda_sync - # Check if we need to cast things from device to device - # If the policy has a None device and the env too, no need to cast (we don't know - # and assume the user knows what she's doing). - # If the devices match we're happy too. - # Only if the values differ we need to cast - self._cast_to_policy_device = self.policy_device != self.env_device - - self.env: EnvBase = env - del env + return _do_nothing - # Policy version tracking setup + def _setup_policy_version_tracking( + self, track_policy_version: bool | PolicyVersion + ) -> None: + """Set up policy version tracking if requested.""" self.policy_version_tracker = track_policy_version if isinstance(track_policy_version, bool) and track_policy_version: from torchrl.envs.batched_envs import BatchedEnvBase @@ -967,19 +1097,29 @@ def __init__( ) self.policy_version_tracker = PolicyVersion() self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore - elif hasattr( - track_policy_version, "increment_version" - ): # Check if it's a PolicyVersion instance + elif hasattr(track_policy_version, "increment_version"): self.policy_version_tracker = track_policy_version self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore else: self.policy_version_tracker = None + + def _setup_replay_buffer( + self, + replay_buffer: ReplayBuffer | None, + extend_buffer: bool, + local_init_rb: bool | None, + postproc: Callable | None, + split_trajs: bool | None, + return_same_td: bool, + use_buffers: bool | None, + ) -> None: + """Set up replay buffer configuration and validate compatibility.""" self.replay_buffer = replay_buffer self.extend_buffer = extend_buffer - # Handle local_init_rb deprecation for SyncDataCollector + # Handle local_init_rb deprecation if local_init_rb is None: - local_init_rb = False # Default for SyncDataCollector + local_init_rb = False if replay_buffer is not None and not local_init_rb: warnings.warn( "local_init_rb=False is deprecated and will be removed in v0.12. " @@ -988,6 +1128,7 @@ def __init__( ) self.local_init_rb = local_init_rb + # Validate replay buffer compatibility if self.replay_buffer is not None and not self._ignore_rb: if postproc is not None and not self.extend_buffer: raise TypeError( @@ -1003,27 +1144,15 @@ def __init__( ) if use_buffers: raise TypeError("replay_buffer is exclusive with use_buffers.") + if use_buffers is None: use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None self._use_buffers = use_buffers - self.replay_buffer = replay_buffer - - self.closed = False - - if not reset_when_done: - raise ValueError("reset_when_done is deprecated.") - self.reset_when_done = reset_when_done - self.n_env = self.env.batch_size.numel() - - if hasattr(policy, "register_collector"): - policy.register_collector(self) - if hasattr(self.env, "register_collector"): - self.env.register_collector(self) + def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None: + """Set up policy, wrapped policy, and extract weights.""" self._original_policy = policy - (policy, self.get_weights_fn,) = self._get_policy_and_device( - policy=policy, - ) + policy, self.get_weights_fn = self._get_policy_and_device(policy=policy) if not self.trust_policy: self.policy = policy @@ -1037,6 +1166,7 @@ def __init__( else: self.policy = self._wrapped_policy = policy + # Extract policy weights if isinstance(self._wrapped_policy, nn.Module): self.policy_weights = TensorDict.from_module( self._wrapped_policy, as_module=True @@ -1044,6 +1174,7 @@ def __init__( else: self.policy_weights = TensorDict() + # Apply compilation/cudagraph if self.compiled_policy: self._wrapped_policy = compile_with_warmup( self._wrapped_policy, **self.compiled_policy_kwargs @@ -1057,24 +1188,26 @@ def __init__( **self.cudagraphed_policy_kwargs, ) + def _apply_env_device(self) -> None: + """Apply device to environment if specified.""" if self.env_device: self.env: EnvBase = self.env.to(self.env_device) elif self.env.device is not None: - # we did not receive an env device, we use the device of the env + # Use the device of the env if none was provided self.env_device = self.env.device - # If the storing device is not the same as the policy device, we have - # no guarantee that the "next" entry from the policy will be on the - # same device as the collector metadata. + # Check if we need to cast to env device self._cast_to_env_device = self._cast_to_policy_device or ( self.env.device != self.storing_device ) + def _setup_max_frames_per_traj(self, max_frames_per_traj: int | None) -> None: + """Set up maximum frames per trajectory and add StepCounter if needed.""" self.max_frames_per_traj = ( int(max_frames_per_traj) if max_frames_per_traj is not None else 0 ) if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0: - # let's check that there is no StepCounter yet + # Check that there is no StepCounter yet for key in self.env.output_spec.keys(True, True): if isinstance(key, str): key = (key,) @@ -1090,6 +1223,8 @@ def __init__( self.env, StepCounter(max_steps=self.max_frames_per_traj) ) + def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None: + """Validate and set total frames.""" if total_frames is None or total_frames < 0: total_frames = float("inf") else: @@ -1103,7 +1238,11 @@ def __init__( self.total_frames = ( int(total_frames) if total_frames != float("inf") else total_frames ) - self.reset_at_each_iter = reset_at_each_iter + + def _setup_init_random_frames( + self, init_random_frames: int | None, frames_per_batch: int + ) -> None: + """Set up initial random frames.""" self.init_random_frames = ( int(init_random_frames) if init_random_frames not in (None, -1) else 0 ) @@ -1119,6 +1258,8 @@ def __init__( "To silence this message, set the environment variable RL_WARNINGS to False." ) + def _setup_postproc(self, postproc: Callable | None) -> None: + """Set up post-processing transform.""" self.postproc = postproc if ( self.postproc is not None @@ -1129,6 +1270,8 @@ def __init__( if postproc is not self.postproc and postproc is not None: self.postproc = postproc + def _setup_frames_per_batch(self, frames_per_batch: int) -> None: + """Calculate and validate frames per batch.""" if frames_per_batch % self.n_env != 0 and RL_WARNINGS: warnings.warn( f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), " @@ -1138,37 +1281,20 @@ def __init__( ) self.frames_per_batch = -(-frames_per_batch // self.n_env) self.requested_frames_per_batch = self.frames_per_batch * self.n_env - self.exploration_type = ( - exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE - ) - self.return_same_td = return_same_td - self.set_truncated = set_truncated - self._make_shuttle() - self._maybe_make_final_rollout(make_rollout=self._use_buffers) - self._set_truncated_keys() - - if split_trajs is None: - split_trajs = False - self.split_trajs = split_trajs - self._exclude_private_keys = True - - self.interruptor = interruptor - self._frames = 0 - self._iter = -1 - - # Set up weight synchronization - prefer new schemes over legacy updater - # For single-process SyncDataCollector, no weight sync is needed (policy is local) - # Weight sync schemes are only needed for multi-process/distributed collectors + def _setup_weight_sync( + self, + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up weight synchronization system.""" if weight_sync_schemes is not None: # Use new simplified weight synchronization system self._weight_sync_schemes = weight_sync_schemes self._weight_senders = {} - # For single-process collectors, we don't need senders/receivers # The policy is local and changes are immediately visible # Senders will be set up in multiprocess collectors during _run_processes - self.weight_updater = None # Don't use legacy system elif weight_updater is not None: # Use legacy weight updater system if explicitly provided @@ -1179,7 +1305,6 @@ def __init__( raise TypeError( f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead." ) - warnings.warn( "Using WeightUpdaterBase is deprecated. Please use weight_sync_schemes instead. " "This will be removed in a future version.", @@ -1191,7 +1316,6 @@ def __init__( self._weight_senders = {} else: # No weight sync needed for single-process collectors - # The policy is local and changes are immediately visible self.weight_updater = None self._weight_sync_schemes = None self._weight_senders = {} @@ -2305,6 +2429,110 @@ def __init__( track_policy_version: bool = False, ): self.closed = True + + # Set up workers and environment functions + create_env_fn, total_frames_per_batch = self._setup_workers_and_env_fns( + create_env_fn, num_workers, frames_per_batch + ) + + # Set up basic configuration + self.set_truncated = set_truncated + self.num_sub_threads = num_sub_threads + self.num_threads = num_threads + self.create_env_fn = create_env_fn + self._read_compile_kwargs(compile_policy, cudagraph_policy) + + # Set up environment kwargs + self.create_env_kwargs = self._setup_env_kwargs(create_env_kwargs) + + # Set up devices + storing_devices, policy_devices, env_devices = self._get_devices( + storing_device=storing_device, + env_device=env_device, + policy_device=policy_device, + device=device, + ) + self.storing_device = storing_devices + self.policy_device = policy_devices + self.env_device = env_devices + self.collector_class = collector_class + del storing_device, env_device, policy_device, device + self.no_cuda_sync = no_cuda_sync + + # Set up replay buffer + self._use_buffers = use_buffers + self.replay_buffer = replay_buffer + self._setup_multi_replay_buffer( + local_init_rb, replay_buffer, replay_buffer_chunk, extend_buffer + ) + + # Set up policy and weights + if trust_policy is None: + trust_policy = policy is not None and isinstance(policy, CudaGraphModule) + self.trust_policy = trust_policy + + policy_factory = self._setup_policy_factory(policy_factory) + self._setup_multi_policy_and_weights( + policy, policy_factory, weight_updater, weight_sync_schemes + ) + + # Set up weight synchronization + self._setup_multi_weight_sync(weight_updater, weight_sync_schemes) + + # Set up policy version tracking + self._setup_multi_policy_version_tracking(track_policy_version) + + # Store policy and policy_factory + self.policy = policy + self.policy_factory = policy_factory + + # Set up fallback policy for weight extraction + self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes) + + # Set up total frames and other parameters + self._setup_multi_total_frames( + total_frames, total_frames_per_batch, frames_per_batch + ) + self.reset_at_each_iter = reset_at_each_iter + self.postprocs = postproc + self.max_frames_per_traj = ( + int(max_frames_per_traj) if max_frames_per_traj is not None else 0 + ) + + # Set up split trajectories + self.requested_frames_per_batch = total_frames_per_batch + self.reset_when_done = reset_when_done + self._setup_split_trajs(split_trajs, reset_when_done) + + # Set up other parameters + self.init_random_frames = ( + int(init_random_frames) if init_random_frames is not None else 0 + ) + self.update_at_each_batch = update_at_each_batch + self.exploration_type = exploration_type + self.frames_per_worker = np.inf + + # Set up preemptive threshold + self._setup_preemptive_threshold(preemptive_threshold) + + # Run worker processes + self._run_processes() + + # Set up frame tracking and other options + self._exclude_private_keys = True + self._frames = 0 + self._iter = -1 + + # Validate cat_results + self._validate_cat_results(cat_results) + + def _setup_workers_and_env_fns( + self, + create_env_fn: Sequence[Callable] | Callable, + num_workers: int | None, + frames_per_batch: int | Sequence[int], + ) -> tuple[list[Callable], int]: + """Set up workers and environment functions.""" if isinstance(create_env_fn, Sequence): self.num_workers = len(create_env_fn) else: @@ -2327,11 +2555,12 @@ def __init__( else frames_per_batch ) - self.set_truncated = set_truncated - self.num_sub_threads = num_sub_threads - self.num_threads = num_threads - self.create_env_fn = create_env_fn - self._read_compile_kwargs(compile_policy, cudagraph_policy) + return create_env_fn, total_frames_per_batch + + def _setup_env_kwargs( + self, create_env_kwargs: Sequence[dict] | dict | None + ) -> list[dict]: + """Set up environment kwargs for each worker.""" if isinstance(create_env_kwargs, Mapping): create_env_kwargs = [create_env_kwargs] * self.num_workers elif create_env_kwargs is None: @@ -2342,53 +2571,29 @@ def __init__( raise ValueError( f"len(create_env_kwargs) must be equal to num_workers, got {len(create_env_kwargs)=} and {self.num_workers=}" ) - self.create_env_kwargs = create_env_kwargs - # Preparing devices: - # We want the user to be able to choose, for each worker, on which - # device will the policy live and which device will be used to store - # data. Those devices may or may not match. - # One caveat is that, if there is only one device for the policy, and - # if there are multiple workers, sending the same device and policy - # to be copied to each worker will result in multiple copies of the - # same policy on the same device. - # To go around this, we do the copies of the policy in the server - # (this object) to each possible device, and send to all the - # processes their copy of the policy. - - storing_devices, policy_devices, env_devices = self._get_devices( - storing_device=storing_device, - env_device=env_device, - policy_device=policy_device, - device=device, - ) - - # to avoid confusion - self.storing_device = storing_devices - self.policy_device = policy_devices - self.env_device = env_devices - self.collector_class = collector_class - - del storing_device, env_device, policy_device, device - self.no_cuda_sync = no_cuda_sync - - self._use_buffers = use_buffers - self.replay_buffer = replay_buffer + return create_env_kwargs + def _setup_multi_replay_buffer( + self, + local_init_rb: bool | None, + replay_buffer: ReplayBuffer | None, + replay_buffer_chunk: bool | None, + extend_buffer: bool, + ) -> None: + """Set up replay buffer for multi-process collector.""" # Handle local_init_rb deprecation if local_init_rb is None: - # v0.11: Default to False (current behavior), show deprecation warning - # v0.12: Default to True (new behavior) - local_init_rb = False # Will become True in 0.12 + local_init_rb = False if replay_buffer is not None and not local_init_rb: warnings.warn( "local_init_rb=False is deprecated and will be removed in v0.12. " "The new storage-level initialization provides better performance.", FutureWarning, ) - self.local_init_rb = local_init_rb self._check_replay_buffer_init() + if replay_buffer_chunk is not None: if extend_buffer is None: replay_buffer_chunk = extend_buffer @@ -2401,6 +2606,7 @@ def __init__( "conflicting values for replay_buffer_chunk and extend_buffer." ) self.extend_buffer = extend_buffer + if ( replay_buffer is not None and hasattr(replay_buffer, "shared") @@ -2409,21 +2615,32 @@ def __init__( torchrl_logger.warning("Replay buffer is not shared. Sharing it.") replay_buffer.share() - self._policy_weights_dict = {} - - if trust_policy is None: - trust_policy = policy is not None and isinstance(policy, CudaGraphModule) - self.trust_policy = trust_policy - + def _setup_policy_factory( + self, policy_factory: Callable | list[Callable] | None + ) -> list[Callable | None]: + """Set up policy factory for each worker.""" if not isinstance(policy_factory, Sequence): policy_factory = [policy_factory] * self.num_workers + return policy_factory + + def _setup_multi_policy_and_weights( + self, + policy: TensorDictModule | Callable | None, + policy_factory: list[Callable | None], + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up policy and extract weights for each device.""" + self._policy_weights_dict = {} + self._fallback_policy = None # Policy to use for weight extraction fallback + if any(policy_factory) and policy is not None: raise TypeError("policy_factory and policy are mutually exclusive") elif not any(policy_factory): for policy_device, env_maker, env_maker_kwargs in _zip_strict( self.policy_device, self.create_env_fn, self.create_env_kwargs ): - (policy_new_device, get_weights_fn,) = self._get_policy_and_device( + policy_new_device, get_weights_fn = self._get_policy_and_device( policy=policy, policy_device=policy_device, env_maker=env_maker, @@ -2437,15 +2654,14 @@ def __init__( else TensorDict() ) self._policy_weights_dict[policy_device] = weights + # Store the first policy instance for fallback weight extraction + if self._fallback_policy is None: + self._fallback_policy = policy_new_device self._get_weights_fn = get_weights_fn if weight_updater is None: # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default if weight_sync_schemes is None: weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()} - # Don't create legacy weight updater if we have schemes - else: - # Legacy weight updater was explicitly provided - pass elif weight_updater is None: warnings.warn( "weight_updater is None, but policy_factory is provided. This means that the server will " @@ -2456,7 +2672,12 @@ def __init__( "This will work whenever your inference and training policies are nn.Module instances with similar structures." ) - # Set up weight synchronization - prefer new schemes over legacy updater + def _setup_multi_weight_sync( + self, + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up weight synchronization for multi-process collector.""" if weight_sync_schemes is not None: # Use new simplified weight synchronization system self._weight_sync_schemes = weight_sync_schemes @@ -2469,14 +2690,15 @@ def __init__( self._weight_sync_schemes = None self._weight_senders = {} - # Policy version tracking setup + def _setup_multi_policy_version_tracking( + self, track_policy_version: bool | PolicyVersion + ) -> None: + """Set up policy version tracking for multi-process collector.""" self.policy_version_tracker = track_policy_version if PolicyVersion is not None: if isinstance(track_policy_version, bool) and track_policy_version: self.policy_version_tracker = PolicyVersion() - elif hasattr( - track_policy_version, "increment_version" - ): # Check if it's a PolicyVersion instance + elif hasattr(track_policy_version, "increment_version"): self.policy_version_tracker = track_policy_version else: self.policy_version_tracker = None @@ -2487,10 +2709,35 @@ def __init__( ) self.policy_version_tracker = None - self.policy = policy - self.policy_factory = policy_factory + def _setup_fallback_policy( + self, + policy: TensorDictModule | Callable | None, + policy_factory: list[Callable | None], + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up fallback policy for weight extraction when using policy_factory.""" + # _fallback_policy is already set in _setup_multi_policy_and_weights if a policy was provided + # If policy_factory was used, create a policy instance to use as fallback + if policy is None and any(policy_factory) and weight_sync_schemes is not None: + if not hasattr(self, "_fallback_policy") or self._fallback_policy is None: + first_factory = ( + policy_factory[0] + if isinstance(policy_factory, list) + else policy_factory + ) + if first_factory is not None: + # Create a policy instance for weight extraction + # This will be a reference to a policy with the same structure + # For shared memory, modifications to any policy will be visible here + self._fallback_policy = first_factory() - remainder = 0 + def _setup_multi_total_frames( + self, + total_frames: int, + total_frames_per_batch: int, + frames_per_batch: int | Sequence[int], + ) -> None: + """Validate and set total frames for multi-process collector.""" if total_frames is None or total_frames < 0: total_frames = float("inf") else: @@ -2504,27 +2751,21 @@ def __init__( self.total_frames = ( int(total_frames) if total_frames != float("inf") else total_frames ) - self.reset_at_each_iter = reset_at_each_iter - self.postprocs = postproc - self.max_frames_per_traj = ( - int(max_frames_per_traj) if max_frames_per_traj is not None else 0 - ) - self.requested_frames_per_batch = total_frames_per_batch - self.reset_when_done = reset_when_done + def _setup_split_trajs( + self, split_trajs: bool | None, reset_when_done: bool + ) -> None: + """Set up split trajectories option.""" if split_trajs is None: split_trajs = False - elif not self.reset_when_done and split_trajs: + elif not reset_when_done and split_trajs: raise RuntimeError( "Cannot split trajectories when reset_when_done is False." ) self.split_trajs = split_trajs - self.init_random_frames = ( - int(init_random_frames) if init_random_frames is not None else 0 - ) - self.update_at_each_batch = update_at_each_batch - self.exploration_type = exploration_type - self.frames_per_worker = np.inf + + def _setup_preemptive_threshold(self, preemptive_threshold: float | None) -> None: + """Set up preemptive threshold for early stopping.""" if preemptive_threshold is not None: if _is_osx: raise NotImplementedError( @@ -2537,10 +2778,9 @@ def __init__( else: self.preemptive_threshold = 1.0 self.interruptor = None - self._run_processes() - self._exclude_private_keys = True - self._frames = 0 - self._iter = -1 + + def _validate_cat_results(self, cat_results: str | int | None) -> None: + """Validate cat_results parameter.""" if cat_results is not None and ( not isinstance(cat_results, (int, str)) or (isinstance(cat_results, str) and cat_results != "stack")