diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index f121a2fdb7a..57396e3c089 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -69,6 +69,8 @@ set_exploration_type, ) +from torchrl.envs.llm.transforms.policy_version import PolicyVersion + try: from torch.compiler import cudagraph_mark_step_begin except ImportError: @@ -571,6 +573,11 @@ class SyncDataCollector(DataCollectorBase): or its subclass, responsible for updating the policy weights on remote inference workers. This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment. Consider using a constructor if the updater needs to be serialized. + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track + the policy version. + Defaults to `False`. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -665,6 +672,7 @@ def __init__( weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, + track_policy_version: bool = False, **kwargs, ): from torchrl.envs.batched_envs import BatchedEnvBase @@ -783,6 +791,33 @@ def __init__( self.env: EnvBase = env del env + + # Policy version tracking setup + self.policy_version_tracker = track_policy_version + if PolicyVersion is not None: + if isinstance(track_policy_version, bool) and track_policy_version: + from torchrl.envs.batched_envs import BatchedEnvBase + + if isinstance(self.env, BatchedEnvBase): + raise RuntimeError( + "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " + "and pass that transform to the collector." + ) + self.policy_version_tracker = PolicyVersion() + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore + elif hasattr( + track_policy_version, "increment_version" + ): # Check if it's a PolicyVersion instance + self.policy_version_tracker = track_policy_version + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore + else: + self.policy_version_tracker = None + else: + if track_policy_version: + raise ImportError( + "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False." + ) + self.policy_version_tracker = None self.replay_buffer = replay_buffer self.extend_buffer = extend_buffer if self.replay_buffer is not None: @@ -1755,6 +1790,34 @@ def __repr__(self) -> str: except Exception: return f"{type(self).__name__}(not_init)" + def increment_version(self): + """Increment the policy version.""" + if self.policy_version_tracker is not None: + if not hasattr(self.policy_version_tracker, "increment_version"): + raise RuntimeError( + "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." + ) + self.policy_version_tracker.increment_version() + + @property + def policy_version(self) -> str | int | None: + """The current policy version.""" + if not hasattr(self.policy_version_tracker, "version"): + return None + return self.policy_version_tracker.version + + def get_policy_version(self) -> str | int | None: + """Get the current policy version. + + This method exists to support remote calls in Ray actors, since properties + cannot be accessed directly through Ray's RPC mechanism. + + Returns: + The current version number (int) or UUID (str), or None if version tracking is disabled. + """ + return self.policy_version + + class _MultiDataCollector(DataCollectorBase): """Runs a given number of DataCollectors on separate processes. @@ -1944,6 +2007,11 @@ class _MultiDataCollector(DataCollectorBase): If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default, which handles weight synchronization across multiple processes. Consider using a constructor if the updater needs to be serialized. + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track + the policy version. + Defaults to `False`. """ @@ -1989,6 +2057,7 @@ def __init__( weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, + track_policy_version: bool = False, ): self.closed = True if isinstance(create_env_fn, Sequence): @@ -2125,6 +2194,24 @@ def __init__( self.weight_updater = weight_updater + # Policy version tracking setup + self.policy_version_tracker = track_policy_version + if PolicyVersion is not None: + if isinstance(track_policy_version, bool) and track_policy_version: + self.policy_version_tracker = PolicyVersion() + elif hasattr( + track_policy_version, "increment_version" + ): # Check if it's a PolicyVersion instance + self.policy_version_tracker = track_policy_version + else: + self.policy_version_tracker = None + else: + if track_policy_version: + raise ImportError( + "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False." + ) + self.policy_version_tracker = None + self.policy = policy self.policy_factory = policy_factory @@ -2668,6 +2755,34 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: self._frames = state_dict["frames"] self._iter = state_dict["iter"] + def increment_version(self): + """Increment the policy version.""" + if self.policy_version_tracker is not None: + if not hasattr(self.policy_version_tracker, "increment_version"): + raise RuntimeError( + "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." + ) + self.policy_version_tracker.increment_version() + + @property + def policy_version(self) -> str | int | None: + """The current policy version.""" + if not hasattr(self.policy_version_tracker, "version"): + return None + return self.policy_version_tracker.version + + def get_policy_version(self) -> str | int | None: + """Get the current policy version. + + This method exists to support remote calls in Ray actors, since properties + cannot be accessed directly through Ray's RPC mechanism. + + Returns: + The current version number (int) or UUID (str), or None if version tracking is disabled. + """ + return self.policy_version + + @accept_remote_rref_udf_invocation class MultiSyncDataCollector(_MultiDataCollector): @@ -3473,6 +3588,11 @@ class aSyncDataCollector(MultiaSyncDataCollector): a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. Truncated keys can be set through ``env.add_truncated_keys``. Defaults to ``False``. + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track + the policy version. + Defaults to `False`. """ @@ -3502,6 +3622,7 @@ def __init__( num_threads: int | None = None, num_sub_threads: int = 1, set_truncated: bool = False, + track_policy_version: bool = False, **kwargs, ): super().__init__( @@ -3529,6 +3650,7 @@ def __init__( num_threads=num_threads, num_sub_threads=num_sub_threads, set_truncated=set_truncated, + track_policy_version=track_policy_version, **kwargs, ) @@ -3825,6 +3947,7 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): has_timed_out = False continue + elif msg == "close": del collected_tensordict, data, next_data, data_in inner_collector.shutdown()