@@ -331,11 +331,6 @@ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any:
331331 and self ._fallback_policy is not None
332332 ):
333333 model = self ._fallback_policy
334- elif (
335- hasattr (self , "_fallback_policy_ref" )
336- and self ._fallback_policy_ref is not None
337- ):
338- model = self ._fallback_policy_ref ()
339334
340335 if model is not None :
341336 return strategy .extract_weights (model )
@@ -2611,6 +2606,7 @@ def _setup_multi_policy_and_weights(
26112606 ) -> None :
26122607 """Set up policy and extract weights for each device."""
26132608 self ._policy_weights_dict = {}
2609+ self ._fallback_policy = None # Policy to use for weight extraction fallback
26142610
26152611 if any (policy_factory ) and policy is not None :
26162612 raise TypeError ("policy_factory and policy are mutually exclusive" )
@@ -2632,6 +2628,9 @@ def _setup_multi_policy_and_weights(
26322628 else TensorDict ()
26332629 )
26342630 self ._policy_weights_dict [policy_device ] = weights
2631+ # Store the first policy instance for fallback weight extraction
2632+ if self ._fallback_policy is None :
2633+ self ._fallback_policy = policy_new_device
26352634 self ._get_weights_fn = get_weights_fn
26362635 if weight_updater is None :
26372636 # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default
@@ -2690,36 +2689,14 @@ def _setup_fallback_policy(
26902689 policy_factory : list [Callable | None ],
26912690 weight_sync_schemes : dict [str , WeightSyncScheme ] | None ,
26922691 ) -> None :
2693- """Set up fallback policy for weight extraction when using policy_factory."""
2694- if policy is None and any (policy_factory ) and weight_sync_schemes is not None :
2695- # Create a policy instance from the first factory for weight extraction
2696- import weakref
2697-
2698- first_factory = (
2699- policy_factory [0 ]
2700- if isinstance (policy_factory , list )
2701- else policy_factory
2702- )
2703- if first_factory is not None :
2704- fallback_policy = first_factory ()
2705- # For shared memory schemes (CPU or CUDA), store the actual policy
2706- # For pipe-based schemes, store a weak reference
2707- first_scheme = next (iter (weight_sync_schemes .values ()))
2708- from torchrl .weight_update .weight_sync_schemes import (
2709- SharedMemWeightSyncScheme ,
2710- )
2711-
2712- if isinstance (first_scheme , SharedMemWeightSyncScheme ):
2713- # Shared memory: store actual policy (weights are in shared mem)
2714- self ._fallback_policy = fallback_policy
2715- else :
2716- # Pipe-based: store weak reference
2717- self ._fallback_policy_ref = weakref .ref (fallback_policy )
2718- # Keep the policy alive by storing it
2719- self ._fallback_policy = fallback_policy
2720- else :
2721- self ._fallback_policy = None
2722- self ._fallback_policy_ref = None
2692+ """Set up fallback policy for weight extraction.
2693+
2694+ Note: _fallback_policy is set in _setup_multi_policy_and_weights when a policy
2695+ is passed directly (not policy_factory). When using policy_factory, users MUST
2696+ pass weights explicitly to update_policy_weights_().
2697+ """
2698+ # Nothing to do here - fallback is already set if a policy was provided
2699+ pass
27232700
27242701 def _setup_multi_total_frames (
27252702 self ,
0 commit comments