Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
set_exploration_type,
)

from torchrl.envs.llm.transforms.policy_version import PolicyVersion

try:
from torch.compiler import cudagraph_mark_step_begin
except ImportError:
Expand Down Expand Up @@ -571,6 +573,11 @@ class SyncDataCollector(DataCollectorBase):
or its subclass, responsible for updating the policy weights on remote inference workers.
This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment.
Consider using a constructor if the updater needs to be serialized.
track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
the policy version.
Defaults to `False`.

Examples:
>>> from torchrl.envs.libs.gym import GymEnv
Expand Down Expand Up @@ -665,6 +672,7 @@ def __init__(
weight_updater: WeightUpdaterBase
| Callable[[], WeightUpdaterBase]
| None = None,
track_policy_version: bool = False,
**kwargs,
):
from torchrl.envs.batched_envs import BatchedEnvBase
Expand Down Expand Up @@ -783,6 +791,33 @@ def __init__(

self.env: EnvBase = env
del env

# Policy version tracking setup
self.policy_version_tracker = track_policy_version
if PolicyVersion is not None:
if isinstance(track_policy_version, bool) and track_policy_version:
from torchrl.envs.batched_envs import BatchedEnvBase

if isinstance(self.env, BatchedEnvBase):
raise RuntimeError(
"BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, "
"and pass that transform to the collector."
)
self.policy_version_tracker = PolicyVersion()
self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore
elif hasattr(
track_policy_version, "increment_version"
): # Check if it's a PolicyVersion instance
self.policy_version_tracker = track_policy_version
self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore
else:
self.policy_version_tracker = None
else:
if track_policy_version:
raise ImportError(
"PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False."
)
self.policy_version_tracker = None
self.replay_buffer = replay_buffer
self.extend_buffer = extend_buffer
if self.replay_buffer is not None:
Expand Down Expand Up @@ -1755,6 +1790,34 @@ def __repr__(self) -> str:
except Exception:
return f"{type(self).__name__}(not_init)"

def increment_version(self):
"""Increment the policy version."""
if self.policy_version_tracker is not None:
if not hasattr(self.policy_version_tracker, "increment_version"):
raise RuntimeError(
"Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector."
)
self.policy_version_tracker.increment_version()

@property
def policy_version(self) -> str | int | None:
"""The current policy version."""
if not hasattr(self.policy_version_tracker, "version"):
return None
return self.policy_version_tracker.version

def get_policy_version(self) -> str | int | None:
"""Get the current policy version.

This method exists to support remote calls in Ray actors, since properties
cannot be accessed directly through Ray's RPC mechanism.

Returns:
The current version number (int) or UUID (str), or None if version tracking is disabled.
"""
return self.policy_version



class _MultiDataCollector(DataCollectorBase):
"""Runs a given number of DataCollectors on separate processes.
Expand Down Expand Up @@ -1944,6 +2007,11 @@ class _MultiDataCollector(DataCollectorBase):
If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default,
which handles weight synchronization across multiple processes.
Consider using a constructor if the updater needs to be serialized.
track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
the policy version.
Defaults to `False`.

"""

Expand Down Expand Up @@ -1989,6 +2057,7 @@ def __init__(
weight_updater: WeightUpdaterBase
| Callable[[], WeightUpdaterBase]
| None = None,
track_policy_version: bool = False,
):
self.closed = True
if isinstance(create_env_fn, Sequence):
Expand Down Expand Up @@ -2125,6 +2194,24 @@ def __init__(

self.weight_updater = weight_updater

# Policy version tracking setup
self.policy_version_tracker = track_policy_version
if PolicyVersion is not None:
if isinstance(track_policy_version, bool) and track_policy_version:
self.policy_version_tracker = PolicyVersion()
elif hasattr(
track_policy_version, "increment_version"
): # Check if it's a PolicyVersion instance
self.policy_version_tracker = track_policy_version
else:
self.policy_version_tracker = None
else:
if track_policy_version:
raise ImportError(
"PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False."
)
self.policy_version_tracker = None

self.policy = policy
self.policy_factory = policy_factory

Expand Down Expand Up @@ -2668,6 +2755,34 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:
self._frames = state_dict["frames"]
self._iter = state_dict["iter"]

def increment_version(self):
"""Increment the policy version."""
if self.policy_version_tracker is not None:
if not hasattr(self.policy_version_tracker, "increment_version"):
raise RuntimeError(
"Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector."
)
self.policy_version_tracker.increment_version()

@property
def policy_version(self) -> str | int | None:
"""The current policy version."""
if not hasattr(self.policy_version_tracker, "version"):
return None
return self.policy_version_tracker.version

def get_policy_version(self) -> str | int | None:
"""Get the current policy version.

This method exists to support remote calls in Ray actors, since properties
cannot be accessed directly through Ray's RPC mechanism.

Returns:
The current version number (int) or UUID (str), or None if version tracking is disabled.
"""
return self.policy_version



@accept_remote_rref_udf_invocation
class MultiSyncDataCollector(_MultiDataCollector):
Expand Down Expand Up @@ -3473,6 +3588,11 @@ class aSyncDataCollector(MultiaSyncDataCollector):
a rollout is reached. If no ``"truncated"`` key is found, an exception is raised.
Truncated keys can be set through ``env.add_truncated_keys``.
Defaults to ``False``.
track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
the policy version.
Defaults to `False`.

"""

Expand Down Expand Up @@ -3502,6 +3622,7 @@ def __init__(
num_threads: int | None = None,
num_sub_threads: int = 1,
set_truncated: bool = False,
track_policy_version: bool = False,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -3529,6 +3650,7 @@ def __init__(
num_threads=num_threads,
num_sub_threads=num_sub_threads,
set_truncated=set_truncated,
track_policy_version=track_policy_version,
**kwargs,
)

Expand Down Expand Up @@ -3825,6 +3947,7 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR):
has_timed_out = False
continue


elif msg == "close":
del collected_tensordict, data, next_data, data_in
inner_collector.shutdown()
Expand Down
Loading