@@ -366,6 +366,10 @@ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any:
366366 else :
367367 return weights
368368
369+ @property
370+ def _legacy_weight_updater (self ) -> bool :
371+ return self ._weight_updater is not None
372+
369373 def update_policy_weights_ (
370374 self ,
371375 policy_or_weights : TensorDictBase | TensorDictModuleBase | dict | None = None ,
@@ -408,6 +412,50 @@ def update_policy_weights_(
408412 :meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`.
409413
410414 """
415+ if self ._legacy_weight_updater :
416+ return self ._legacy_weight_update_impl (
417+ policy_or_weights = policy_or_weights ,
418+ worker_ids = worker_ids ,
419+ model_id = model_id ,
420+ weights_dict = weights_dict ,
421+ ** kwargs ,
422+ )
423+ else :
424+ return self ._weight_update_impl (
425+ policy_or_weights = policy_or_weights ,
426+ worker_ids = worker_ids ,
427+ model_id = model_id ,
428+ weights_dict = weights_dict ,
429+ ** kwargs ,
430+ )
431+
432+ def _legacy_weight_update_impl (
433+ self ,
434+ policy_or_weights : TensorDictBase | TensorDictModuleBase | dict | None = None ,
435+ * ,
436+ worker_ids : int | list [int ] | torch .device | list [torch .device ] | None = None ,
437+ model_id : str | None = None ,
438+ weights_dict : dict [str , Any ] | None = None ,
439+ ** kwargs ,
440+ ) -> None :
441+ if weights_dict is not None :
442+ raise ValueError ("weights_dict is not supported with legacy weight updater" )
443+ if model_id is not None :
444+ raise ValueError ("model_id is not supported with legacy weight updater" )
445+ # Fall back to old weight updater system
446+ self .weight_updater (
447+ policy_or_weights = policy_or_weights , worker_ids = worker_ids , ** kwargs
448+ )
449+
450+ def _weight_update_impl (
451+ self ,
452+ policy_or_weights : TensorDictBase | TensorDictModuleBase | dict | None = None ,
453+ * ,
454+ worker_ids : int | list [int ] | torch .device | list [torch .device ] | None = None ,
455+ model_id : str | None = None ,
456+ weights_dict : dict [str , Any ] | None = None ,
457+ ** kwargs ,
458+ ) -> None :
411459 if "policy_weights" in kwargs :
412460 warnings .warn (
413461 "`policy_weights` is deprecated. Use `policy_or_weights` instead." ,
@@ -428,54 +476,32 @@ def update_policy_weights_(
428476
429477 # Priority: new weight sync schemes > old weight updater system
430478 if self ._weight_senders :
431- if weights_dict is not None :
432- for target_model_id , weights in weights_dict . items ():
433- if target_model_id not in self . _weight_senders :
434- raise KeyError (
435- f"Model ' { target_model_id } ' not found in registered weight senders. "
436- f"Available models: { list ( self . _weight_senders . keys ()) } "
437- )
438- processed_weights = self ._extract_weights_if_needed (
439- weights , target_model_id
440- )
441- self . _weight_senders [ target_model_id ]. update_weights (
442- processed_weights
479+ if model_id is not None :
480+ # Compose weight_dict
481+ weights_dict = { model_id : policy_or_weights }
482+ if weights_dict is None :
483+ if "policy" in self . _weight_senders :
484+ weights_dict = { "policy" : policy_or_weights }
485+ elif len ( self . _weight_senders ) == 1 :
486+ single_model_id = next ( iter ( self ._weight_senders . keys ()))
487+ weights_dict = { single_model_id : policy_or_weights }
488+ else :
489+ raise ValueError (
490+ "Cannot determine the model to update. Please provide a weights_dict."
443491 )
444- elif model_id is not None :
445- if model_id not in self ._weight_senders :
492+ for target_model_id , weights in weights_dict . items () :
493+ if target_model_id not in self ._weight_senders :
446494 raise KeyError (
447- f"Model '{ model_id } ' not found in registered weight senders. "
495+ f"Model '{ target_model_id } ' not found in registered weight senders. "
448496 f"Available models: { list (self ._weight_senders .keys ())} "
449497 )
450498 processed_weights = self ._extract_weights_if_needed (
451- policy_or_weights , model_id
499+ weights , target_model_id
452500 )
453- self ._weight_senders [model_id ].update_weights (processed_weights )
454- else :
455- if "policy" in self ._weight_senders :
456- processed_weights = self ._extract_weights_if_needed (
457- policy_or_weights , "policy"
458- )
459- self ._weight_senders ["policy" ].update_weights (processed_weights )
460- elif len (self ._weight_senders ) == 1 :
461- single_model_id = next (iter (self ._weight_senders .keys ()))
462- single_sender = self ._weight_senders [single_model_id ]
463- processed_weights = self ._extract_weights_if_needed (
464- policy_or_weights , single_model_id
465- )
466- single_sender .update_weights (processed_weights )
467- else :
468- for target_model_id , sender in self ._weight_senders .items ():
469- processed_weights = self ._extract_weights_if_needed (
470- policy_or_weights , target_model_id
471- )
472- sender .update_weights (processed_weights )
473-
501+ self ._weight_senders [target_model_id ].update_weights (processed_weights )
474502 elif self ._weight_updater is not None :
475- # Fall back to old weight updater system
476- self .weight_updater (
477- policy_or_weights = policy_or_weights , worker_ids = worker_ids , ** kwargs
478- )
503+ # unreachable
504+ raise RuntimeError
479505 else :
480506 # No weight updater configured
481507 # For single-process collectors, apply weights locally if explicitly provided
@@ -2689,14 +2715,21 @@ def _setup_fallback_policy(
26892715 policy_factory : list [Callable | None ],
26902716 weight_sync_schemes : dict [str , WeightSyncScheme ] | None ,
26912717 ) -> None :
2692- """Set up fallback policy for weight extraction.
2693-
2694- Note: _fallback_policy is set in _setup_multi_policy_and_weights when a policy
2695- is passed directly (not policy_factory). When using policy_factory, users MUST
2696- pass weights explicitly to update_policy_weights_().
2697- """
2698- # Nothing to do here - fallback is already set if a policy was provided
2699- pass
2718+ """Set up fallback policy for weight extraction when using policy_factory."""
2719+ # _fallback_policy is already set in _setup_multi_policy_and_weights if a policy was provided
2720+ # If policy_factory was used, create a policy instance to use as fallback
2721+ if policy is None and any (policy_factory ) and weight_sync_schemes is not None :
2722+ if not hasattr (self , "_fallback_policy" ) or self ._fallback_policy is None :
2723+ first_factory = (
2724+ policy_factory [0 ]
2725+ if isinstance (policy_factory , list )
2726+ else policy_factory
2727+ )
2728+ if first_factory is not None :
2729+ # Create a policy instance for weight extraction
2730+ # This will be a reference to a policy with the same structure
2731+ # For shared memory, modifications to any policy will be visible here
2732+ self ._fallback_policy = first_factory ()
27002733
27012734 def _setup_multi_total_frames (
27022735 self ,
0 commit comments