Skip to content

Commit e97e406

Browse files
committed
Update
[ghstack-poisoned]
1 parent e913282 commit e97e406

File tree

2 files changed

+16
-37
lines changed

2 files changed

+16
-37
lines changed

test/test_collector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3884,14 +3884,16 @@ def test_weight_update(self, weight_updater):
38843884
**kwargs,
38853885
)
38863886

3887-
collector.update_policy_weights_()
3887+
# When using policy_factory, must pass weights explicitly
3888+
collector.update_policy_weights_(policy_weights)
38883889
try:
38893890
for i, data in enumerate(collector):
38903891
if i == 2:
38913892
assert (data["action"] != 0).any()
38923893
# zero the policy
38933894
policy_weights.data.zero_()
3894-
collector.update_policy_weights_()
3895+
# When using policy_factory, must pass weights explicitly
3896+
collector.update_policy_weights_(policy_weights)
38953897
elif i == 3:
38963898
assert (data["action"] == 0).all(), data["action"]
38973899
break

torchrl/collectors/collectors.py

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,6 @@ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any:
331331
and self._fallback_policy is not None
332332
):
333333
model = self._fallback_policy
334-
elif (
335-
hasattr(self, "_fallback_policy_ref")
336-
and self._fallback_policy_ref is not None
337-
):
338-
model = self._fallback_policy_ref()
339334

340335
if model is not None:
341336
return strategy.extract_weights(model)
@@ -2611,6 +2606,7 @@ def _setup_multi_policy_and_weights(
26112606
) -> None:
26122607
"""Set up policy and extract weights for each device."""
26132608
self._policy_weights_dict = {}
2609+
self._fallback_policy = None # Policy to use for weight extraction fallback
26142610

26152611
if any(policy_factory) and policy is not None:
26162612
raise TypeError("policy_factory and policy are mutually exclusive")
@@ -2632,6 +2628,9 @@ def _setup_multi_policy_and_weights(
26322628
else TensorDict()
26332629
)
26342630
self._policy_weights_dict[policy_device] = weights
2631+
# Store the first policy instance for fallback weight extraction
2632+
if self._fallback_policy is None:
2633+
self._fallback_policy = policy_new_device
26352634
self._get_weights_fn = get_weights_fn
26362635
if weight_updater is None:
26372636
# For multiprocessed collectors, use MultiProcessWeightSyncScheme by default
@@ -2690,36 +2689,14 @@ def _setup_fallback_policy(
26902689
policy_factory: list[Callable | None],
26912690
weight_sync_schemes: dict[str, WeightSyncScheme] | None,
26922691
) -> None:
2693-
"""Set up fallback policy for weight extraction when using policy_factory."""
2694-
if policy is None and any(policy_factory) and weight_sync_schemes is not None:
2695-
# Create a policy instance from the first factory for weight extraction
2696-
import weakref
2697-
2698-
first_factory = (
2699-
policy_factory[0]
2700-
if isinstance(policy_factory, list)
2701-
else policy_factory
2702-
)
2703-
if first_factory is not None:
2704-
fallback_policy = first_factory()
2705-
# For shared memory schemes (CPU or CUDA), store the actual policy
2706-
# For pipe-based schemes, store a weak reference
2707-
first_scheme = next(iter(weight_sync_schemes.values()))
2708-
from torchrl.weight_update.weight_sync_schemes import (
2709-
SharedMemWeightSyncScheme,
2710-
)
2711-
2712-
if isinstance(first_scheme, SharedMemWeightSyncScheme):
2713-
# Shared memory: store actual policy (weights are in shared mem)
2714-
self._fallback_policy = fallback_policy
2715-
else:
2716-
# Pipe-based: store weak reference
2717-
self._fallback_policy_ref = weakref.ref(fallback_policy)
2718-
# Keep the policy alive by storing it
2719-
self._fallback_policy = fallback_policy
2720-
else:
2721-
self._fallback_policy = None
2722-
self._fallback_policy_ref = None
2692+
"""Set up fallback policy for weight extraction.
2693+
2694+
Note: _fallback_policy is set in _setup_multi_policy_and_weights when a policy
2695+
is passed directly (not policy_factory). When using policy_factory, users MUST
2696+
pass weights explicitly to update_policy_weights_().
2697+
"""
2698+
# Nothing to do here - fallback is already set if a policy was provided
2699+
pass
27232700

27242701
def _setup_multi_total_frames(
27252702
self,

0 commit comments

Comments
 (0)