Skip to content

Commit 87047ec

Browse files
committed
Update
[ghstack-poisoned]
1 parent e97e406 commit 87047ec

File tree

1 file changed

+82
-49
lines changed

1 file changed

+82
-49
lines changed

torchrl/collectors/collectors.py

Lines changed: 82 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,10 @@ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any:
366366
else:
367367
return weights
368368

369+
@property
370+
def _legacy_weight_updater(self) -> bool:
371+
return self._weight_updater is not None
372+
369373
def update_policy_weights_(
370374
self,
371375
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
@@ -408,6 +412,50 @@ def update_policy_weights_(
408412
:meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`.
409413
410414
"""
415+
if self._legacy_weight_updater:
416+
return self._legacy_weight_update_impl(
417+
policy_or_weights=policy_or_weights,
418+
worker_ids=worker_ids,
419+
model_id=model_id,
420+
weights_dict=weights_dict,
421+
**kwargs,
422+
)
423+
else:
424+
return self._weight_update_impl(
425+
policy_or_weights=policy_or_weights,
426+
worker_ids=worker_ids,
427+
model_id=model_id,
428+
weights_dict=weights_dict,
429+
**kwargs,
430+
)
431+
432+
def _legacy_weight_update_impl(
433+
self,
434+
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
435+
*,
436+
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
437+
model_id: str | None = None,
438+
weights_dict: dict[str, Any] | None = None,
439+
**kwargs,
440+
) -> None:
441+
if weights_dict is not None:
442+
raise ValueError("weights_dict is not supported with legacy weight updater")
443+
if model_id is not None:
444+
raise ValueError("model_id is not supported with legacy weight updater")
445+
# Fall back to old weight updater system
446+
self.weight_updater(
447+
policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
448+
)
449+
450+
def _weight_update_impl(
451+
self,
452+
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
453+
*,
454+
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
455+
model_id: str | None = None,
456+
weights_dict: dict[str, Any] | None = None,
457+
**kwargs,
458+
) -> None:
411459
if "policy_weights" in kwargs:
412460
warnings.warn(
413461
"`policy_weights` is deprecated. Use `policy_or_weights` instead.",
@@ -428,54 +476,32 @@ def update_policy_weights_(
428476

429477
# Priority: new weight sync schemes > old weight updater system
430478
if self._weight_senders:
431-
if weights_dict is not None:
432-
for target_model_id, weights in weights_dict.items():
433-
if target_model_id not in self._weight_senders:
434-
raise KeyError(
435-
f"Model '{target_model_id}' not found in registered weight senders. "
436-
f"Available models: {list(self._weight_senders.keys())}"
437-
)
438-
processed_weights = self._extract_weights_if_needed(
439-
weights, target_model_id
440-
)
441-
self._weight_senders[target_model_id].update_weights(
442-
processed_weights
479+
if model_id is not None:
480+
# Compose weight_dict
481+
weights_dict = {model_id: policy_or_weights}
482+
if weights_dict is None:
483+
if "policy" in self._weight_senders:
484+
weights_dict = {"policy": policy_or_weights}
485+
elif len(self._weight_senders) == 1:
486+
single_model_id = next(iter(self._weight_senders.keys()))
487+
weights_dict = {single_model_id: policy_or_weights}
488+
else:
489+
raise ValueError(
490+
"Cannot determine the model to update. Please provide a weights_dict."
443491
)
444-
elif model_id is not None:
445-
if model_id not in self._weight_senders:
492+
for target_model_id, weights in weights_dict.items():
493+
if target_model_id not in self._weight_senders:
446494
raise KeyError(
447-
f"Model '{model_id}' not found in registered weight senders. "
495+
f"Model '{target_model_id}' not found in registered weight senders. "
448496
f"Available models: {list(self._weight_senders.keys())}"
449497
)
450498
processed_weights = self._extract_weights_if_needed(
451-
policy_or_weights, model_id
499+
weights, target_model_id
452500
)
453-
self._weight_senders[model_id].update_weights(processed_weights)
454-
else:
455-
if "policy" in self._weight_senders:
456-
processed_weights = self._extract_weights_if_needed(
457-
policy_or_weights, "policy"
458-
)
459-
self._weight_senders["policy"].update_weights(processed_weights)
460-
elif len(self._weight_senders) == 1:
461-
single_model_id = next(iter(self._weight_senders.keys()))
462-
single_sender = self._weight_senders[single_model_id]
463-
processed_weights = self._extract_weights_if_needed(
464-
policy_or_weights, single_model_id
465-
)
466-
single_sender.update_weights(processed_weights)
467-
else:
468-
for target_model_id, sender in self._weight_senders.items():
469-
processed_weights = self._extract_weights_if_needed(
470-
policy_or_weights, target_model_id
471-
)
472-
sender.update_weights(processed_weights)
473-
501+
self._weight_senders[target_model_id].update_weights(processed_weights)
474502
elif self._weight_updater is not None:
475-
# Fall back to old weight updater system
476-
self.weight_updater(
477-
policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
478-
)
503+
# unreachable
504+
raise RuntimeError
479505
else:
480506
# No weight updater configured
481507
# For single-process collectors, apply weights locally if explicitly provided
@@ -2689,14 +2715,21 @@ def _setup_fallback_policy(
26892715
policy_factory: list[Callable | None],
26902716
weight_sync_schemes: dict[str, WeightSyncScheme] | None,
26912717
) -> None:
2692-
"""Set up fallback policy for weight extraction.
2693-
2694-
Note: _fallback_policy is set in _setup_multi_policy_and_weights when a policy
2695-
is passed directly (not policy_factory). When using policy_factory, users MUST
2696-
pass weights explicitly to update_policy_weights_().
2697-
"""
2698-
# Nothing to do here - fallback is already set if a policy was provided
2699-
pass
2718+
"""Set up fallback policy for weight extraction when using policy_factory."""
2719+
# _fallback_policy is already set in _setup_multi_policy_and_weights if a policy was provided
2720+
# If policy_factory was used, create a policy instance to use as fallback
2721+
if policy is None and any(policy_factory) and weight_sync_schemes is not None:
2722+
if not hasattr(self, "_fallback_policy") or self._fallback_policy is None:
2723+
first_factory = (
2724+
policy_factory[0]
2725+
if isinstance(policy_factory, list)
2726+
else policy_factory
2727+
)
2728+
if first_factory is not None:
2729+
# Create a policy instance for weight extraction
2730+
# This will be a reference to a policy with the same structure
2731+
# For shared memory, modifications to any policy will be visible here
2732+
self._fallback_policy = first_factory()
27002733

27012734
def _setup_multi_total_frames(
27022735
self,

0 commit comments

Comments
 (0)