diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 4eb55d3ac7f..0ee129fdcba 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -1123,6 +1123,7 @@ to be able to create this other composition: ExcludeTransform FiniteTensorDictCheck FlattenObservation + FlattenTensorDict FrameSkipTransform GrayScale Hash diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index 063a0f26735..c47436d11a8 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -184,6 +184,7 @@ Trainer and hooks TrainerHookBase UpdateWeights TargetNetUpdaterHook + UTDRHook Algorithm-specific trainers (Experimental) diff --git a/sota-implementations/ppo_trainer/config/config.yaml b/sota-implementations/ppo_trainer/config/config.yaml index bb811a4b9cc..3570d21625c 100644 --- a/sota-implementations/ppo_trainer/config/config.yaml +++ b/sota-implementations/ppo_trainer/config/config.yaml @@ -5,6 +5,7 @@ defaults: - transform@transform0: noop_reset - transform@transform1: step_counter + - transform@transform2: reward_sum - env@training_env: batched_env - env@training_env.create_env_fn: transformed_env @@ -64,6 +65,10 @@ transform1: max_steps: 200 step_count_key: "step_count" +transform2: + in_keys: ["reward"] + out_keys: ["reward_sum"] + training_env: num_workers: 1 create_env_fn: @@ -73,6 +78,7 @@ training_env: transforms: - ${transform0} - ${transform1} + - ${transform2} _partial_: true # Loss configuration @@ -92,6 +98,7 @@ collector: total_frames: 1_000_000 frames_per_batch: 1024 num_workers: 2 + _partial_: true # Replay buffer configuration replay_buffer: @@ -129,3 +136,4 @@ trainer: save_trainer_file: null optim_steps_per_batch: null num_epochs: 2 + async_collection: false diff --git a/sota-implementations/sac_trainer/config/config.yaml b/sota-implementations/sac_trainer/config/config.yaml index 2f794c9bfa2..edb146157ff 100644 --- a/sota-implementations/sac_trainer/config/config.yaml +++ b/sota-implementations/sac_trainer/config/config.yaml @@ -5,6 +5,7 @@ defaults: - transform@transform0: step_counter - transform@transform1: double_to_float + - transform@transform2: reward_sum - env@training_env: batched_env - env@training_env.create_env_fn: transformed_env @@ -72,6 +73,11 @@ transform1: in_keys: null out_keys: null +transform2: + # RewardSumTransform - sums up the rewards + in_keys: ["reward"] + out_keys: ["reward_sum"] + training_env: num_workers: 4 create_env_fn: @@ -81,6 +87,7 @@ training_env: transforms: - ${transform0} - ${transform1} + - ${transform2} _partial_: true # Loss configuration @@ -107,19 +114,21 @@ collector: total_frames: 1_000_000 frames_per_batch: 1000 num_workers: 4 - init_random_frames: 25000 + init_random_frames: 2500 track_policy_version: true + _partial_: true # Replay buffer configuration replay_buffer: storage: - max_size: 1_000_000 + max_size: 100_000 device: cpu ndim: 1 sampler: writer: compilable: false - batch_size: 256 + batch_size: 64 + shared: true logger: exp_name: sac_halfcheetah_v4 @@ -134,7 +143,7 @@ trainer: target_net_updater: ${target_net_updater} loss_module: ${loss} logger: ${logger} - total_frames: 1_000_000 + total_frames: ${collector.total_frames} frame_skip: 1 clip_grad_norm: false # SAC typically doesn't use gradient clipping clip_norm: null @@ -144,3 +153,4 @@ trainer: log_interval: 25000 save_trainer_file: null optim_steps_per_batch: 64 # Match SOTA utd_ratio + async_collection: false diff --git a/sota-implementations/sac_trainer/config/config_async.yaml b/sota-implementations/sac_trainer/config/config_async.yaml new file mode 100644 index 00000000000..d1509bd334f --- /dev/null +++ b/sota-implementations/sac_trainer/config/config_async.yaml @@ -0,0 +1,164 @@ +# SAC Trainer Configuration for HalfCheetah-v4 +# Run with `python sota-implementations/sac_trainer/train.py --config-name=config_async` +# This configuration uses the new configurable trainer system and matches SOTA SAC implementation + +defaults: + + - transform@transform0: step_counter + - transform@transform1: double_to_float + - transform@transform2: reward_sum + - transform@transform3: flatten_tensordict + + - env@training_env: batched_env + - env@training_env.create_env_fn: transformed_env + - env@training_env.create_env_fn.base_env: gym + - transform@training_env.create_env_fn.transform: compose + + - model@models.policy_model: tanh_normal + - model@models.value_model: value + - model@models.qvalue_model: value + + - network@networks.policy_network: mlp + - network@networks.value_network: mlp + - network@networks.qvalue_network: mlp + + - collector@collector: multi_async + + - replay_buffer@replay_buffer: base + - storage@replay_buffer.storage: lazy_tensor + - writer@replay_buffer.writer: round_robin + - sampler@replay_buffer.sampler: random + - trainer@trainer: sac + - optimizer@optimizer: adam + - loss@loss: sac + - target_net_updater@target_net_updater: soft + - logger@logger: wandb + - _self_ + +# Network configurations +networks: + policy_network: + out_features: 12 # HalfCheetah action space is 6-dimensional (loc + scale) = 2 * 6 + in_features: 17 # HalfCheetah observation space is 17-dimensional + num_cells: [256, 256] + + value_network: + out_features: 1 # Value output + in_features: 17 # HalfCheetah observation space + num_cells: [256, 256] + + qvalue_network: + out_features: 1 # Q-value output + in_features: 23 # HalfCheetah observation space (17) + action space (6) + num_cells: [256, 256] + +# Model configurations +models: + policy_model: + return_log_prob: true + in_keys: ["observation"] + param_keys: ["loc", "scale"] + out_keys: ["action"] + network: ${networks.policy_network} + # Configure NormalParamExtractor for higher exploration + scale_mapping: "biased_softplus_2.0" # Higher bias for more exploration (default: 1.0) + scale_lb: 1e-2 # Minimum scale value (default: 1e-4) + + qvalue_model: + in_keys: ["observation", "action"] + out_keys: ["state_action_value"] + network: ${networks.qvalue_network} + +transform0: + max_steps: 1000 + step_count_key: "step_count" + +transform1: + # DoubleToFloatTransform - converts double precision to float to fix dtype mismatch + in_keys: null + out_keys: null + +transform2: + # RewardSumTransform - sums up the rewards + in_keys: ["reward"] + out_keys: ["reward_sum"] + +training_env: + num_workers: 4 + create_env_fn: + base_env: + env_name: HalfCheetah-v4 + transform: + transforms: + - ${transform0} + - ${transform1} + - ${transform2} + _partial_: true + +# Loss configuration +loss: + actor_network: ${models.policy_model} + qvalue_network: ${models.qvalue_model} + target_entropy: "auto" + loss_function: l2 + alpha_init: 1.0 + delay_qvalue: true + num_qvalue_nets: 2 + +target_net_updater: + tau: 0.001 + +# Optimizer configuration +optimizer: + lr: 3.0e-4 + +# Collector configuration +collector: + create_env_fn: ${training_env} + policy: ${models.policy_model} + total_frames: 5_000_000 + frames_per_batch: 1000 + num_workers: 8 + # Incompatible with async collection + init_random_frames: 0 + track_policy_version: true + extend_buffer: true + _partial_: true + +# Replay buffer configuration +replay_buffer: + storage: + max_size: 10_000 + device: cpu + ndim: 1 + sampler: + writer: + compilable: false + batch_size: 256 + shared: true + transform: ${transform3} + +logger: + exp_name: sac_halfcheetah_v4 + offline: false + project: torchrl-sota-implementations + +# Trainer configuration +trainer: + collector: ${collector} + optimizer: ${optimizer} + replay_buffer: ${replay_buffer} + target_net_updater: ${target_net_updater} + loss_module: ${loss} + logger: ${logger} + total_frames: ${collector.total_frames} + frame_skip: 1 + clip_grad_norm: false # SAC typically doesn't use gradient clipping + clip_norm: null + progress_bar: true + seed: 42 + save_trainer_interval: 25000 # Match SOTA eval_iter + log_interval: 25000 + save_trainer_file: null + optim_steps_per_batch: 16 # Match SOTA utd_ratio + async_collection: true diff --git a/sota-implementations/sac_trainer/train.py b/sota-implementations/sac_trainer/train.py index 2df69106df9..2c4a6b0832f 100644 --- a/sota-implementations/sac_trainer/train.py +++ b/sota-implementations/sac_trainer/train.py @@ -3,17 +3,12 @@ # LICENSE file in the root directory of this source tree. import hydra -import torchrl from torchrl.trainers.algorithms.configs import * # noqa: F401, F403 @hydra.main(config_path="config", config_name="config", version_base="1.1") def main(cfg): - def print_reward(td): - torchrl.logger.info(f"reward: {td['next', 'reward'].mean(): 4.4f}") - trainer = hydra.utils.instantiate(cfg.trainer) - trainer.register_op(dest="batch_process", op=print_reward) trainer.train() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index cb7d66a42e7..34937c22d47 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -60,6 +60,8 @@ from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import _do_nothing, EnvBase from torchrl.envs.env_creator import EnvCreator + +from torchrl.envs.llm.transforms.policy_version import PolicyVersion from torchrl.envs.transforms import StepCounter, TransformedEnv from torchrl.envs.utils import ( _aggregate_end_of_traj, @@ -69,8 +71,6 @@ set_exploration_type, ) -from torchrl.envs.llm.transforms.policy_version import PolicyVersion - try: from torch.compiler import cudagraph_mark_step_begin except ImportError: @@ -1818,13 +1818,20 @@ def get_policy_version(self) -> str | int | None: return self.policy_version def getattr_policy(self, attr): + """Get an attribute from the policy.""" # send command to policy to return the attr return getattr(self.policy, attr) def getattr_env(self, attr): + """Get an attribute from the environment.""" # send command to env to return the attr return getattr(self.env, attr) + def getattr_rb(self, attr): + """Get an attribute from the replay buffer.""" + # send command to rb to return the attr + return getattr(self.replay_buffer, attr) + class _MultiDataCollector(DataCollectorBase): """Runs a given number of DataCollectors on separate processes. @@ -2153,6 +2160,7 @@ def __init__( and hasattr(replay_buffer, "shared") and not replay_buffer.shared ): + torchrl_logger.warning("Replay buffer is not shared. Sharing it.") replay_buffer.share() self._policy_weights_dict = {} @@ -2306,8 +2314,8 @@ def _check_replay_buffer_init(self): fake_td["collector", "traj_ids"] = torch.zeros( fake_td.shape, dtype=torch.long ) - - self.replay_buffer.add(fake_td) + # Use extend to avoid time-related transforms to fail + self.replay_buffer.extend(fake_td.unsqueeze(-1)) self.replay_buffer.empty() @classmethod @@ -2841,6 +2849,10 @@ def getattr_env(self, attr): return result + def getattr_rb(self, attr): + """Get an attribute from the replay buffer.""" + return getattr(self.replay_buffer, attr) + @accept_remote_rref_udf_invocation class MultiSyncDataCollector(_MultiDataCollector): diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index fe319b4497a..3fe6ce7b919 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1172,6 +1172,10 @@ def max_size_along_dim0(data_shape): self._storage = out self.initialized = True + if hasattr(self._storage, "shape"): + torchrl_logger.info( + f"Initialized LazyTensorStorage with {self._storage.shape} shape" + ) class LazyMemmapStorage(LazyTensorStorage): @@ -1391,6 +1395,10 @@ def max_size_along_dim0(data_shape): else: out = _init_pytree(self.scratch_dir, max_size_along_dim0, data) self._storage = out + if hasattr(self._storage, "shape"): + torchrl_logger.info( + f"Initialized LazyMemmapStorage with {self._storage.shape} shape" + ) self.initialized = True def get(self, index: int | Sequence[int] | slice) -> Any: diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 11799cc79dc..8f218c218e0 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -7261,7 +7261,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec class RewardSum(Transform): """Tracks episode cumulative rewards. - This transform accepts a list of tensordict reward keys (i.e. ´in_keys´) and tracks their cumulative + This transform accepts a list of tensordict reward keys (i.e. 'in_keys') and tracks their cumulative value along the time dimension for each episode. When called, the transform writes a new tensordict entry for each ``in_key`` named @@ -7269,7 +7269,7 @@ class RewardSum(Transform): Args: in_keys (list of NestedKeys, optional): Input reward keys. - All ´in_keys´ should be part of the environment reward_spec. + All 'in_keys' should be part of the environment reward_spec. If no ``in_keys`` are specified, this transform assumes ``"reward"`` to be the input key. However, multiple rewards (e.g. ``"reward1"`` and ``"reward2""``) can also be specified. out_keys (list of NestedKeys, optional): The output sum keys, should be one per each input key. @@ -11610,3 +11610,123 @@ def forward(self, tensordict: TensorDictBase) -> Any: raise RuntimeError( "ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional." ) + + +class FlattenTensorDict(Transform): + """Flattens TensorDict batch dimensions during inverse pass for replay buffer usage. + + This transform is specifically designed for replay buffers where data needs + to be flattened before being stored. It performs an identity operation during + the forward pass and flattens the batch dimensions during the inverse pass. + + This is useful when collecting batched data that needs to be stored as + individual experiences in a replay buffer. + + .. warning:: + This transform is NOT intended for use with environments. If you try to use + it as an environment transform, it will raise an exception. For reshaping + environment batch dimensions, use :class:`~torchrl.envs.BatchSizeTransform` + instead. + + .. note:: + This transform should be applied to replay buffers, not to environments. + It is designed to be used with :meth:`~torchrl.data.ReplayBuffer.append_transform`. + + Examples: + Using with a replay buffer: + + >>> import torch + >>> from tensordict import TensorDict + >>> from torchrl.envs.transforms import FlattenTensorDict + >>> from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage + >>> + >>> # Create a replay buffer with the transform + >>> transform = FlattenTensorDict() + >>> rb = TensorDictReplayBuffer( + ... storage=LazyTensorStorage(1000), + ... transform=transform, + ... batch_size=32 + ... ) + >>> + >>> # Create batched data (e.g., from multiple environments) + >>> td = TensorDict({ + ... "observation": torch.randn(4, 2, 3), + ... "action": torch.randn(4, 2, 1), + ... "reward": torch.randn(4, 2, 1), + ... }, batch_size=[4, 2]) + >>> + >>> # When extending the buffer, data gets flattened automatically + >>> rb.extend(td) # Data is flattened from [4, 2] to [8] before storage + >>> + >>> # When sampling, data comes out in the requested batch size + >>> sample = rb.sample(4) # Shape will be [4, ...] + + Direct usage (for testing): + + >>> # Forward pass (identity) + >>> td_forward = transform(td) + >>> print(td_forward.batch_size) # [4, 2] + >>> + >>> # Inverse pass (flatten) + >>> td_inverse = transform.inv(td) + >>> print(td_inverse.batch_size) # [8] + """ + + _ENV_ERROR_MSG = ( + "FlattenTensorDict is designed for replay buffers and should not be used " + "as an environment transform. For reshaping environment batch dimensions, " + "use BatchSizeTransform instead." + ) + + def __init__(self): + super().__init__(in_keys=[], out_keys=[]) + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + """Forward pass - identity operation.""" + return tensordict + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + """Inverse pass - flatten the tensordict.""" + return tensordict.reshape(-1) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Forward pass - identity operation.""" + return self._call(tensordict) + + def inv(self, tensordict: TensorDictBase) -> TensorDictBase: + """Inverse pass - flatten the tensordict.""" + return self._inv_call(tensordict) + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + """Reset pass - identity operation.""" + return self._call(tensordict_reset) + + def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + """Transform input spec - not supported for environments.""" + raise RuntimeError(self._ENV_ERROR_MSG) + + def transform_output_spec(self, output_spec: Composite) -> Composite: + """Transform output spec - not supported for environments.""" + raise RuntimeError(self._ENV_ERROR_MSG) + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + """Transform observation spec - not supported for environments.""" + raise RuntimeError(self._ENV_ERROR_MSG) + + def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec: + """Transform action spec - not supported for environments.""" + raise RuntimeError(self._ENV_ERROR_MSG) + + def transform_state_spec(self, state_spec: TensorSpec) -> TensorSpec: + """Transform state spec - not supported for environments.""" + raise RuntimeError(self._ENV_ERROR_MSG) + + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + """Transform reward spec - not supported for environments.""" + raise RuntimeError(self._ENV_ERROR_MSG) + + def transform_done_spec(self, done_spec: TensorSpec) -> TensorSpec: + """Transform done spec - not supported for environments.""" + raise RuntimeError(self._ENV_ERROR_MSG) diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 3eb3e7ca75d..93ea6134aca 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -20,6 +20,7 @@ Trainer, TrainerHookBase, UpdateWeights, + UTDRHook, ) __all__ = [ @@ -39,4 +40,5 @@ "TrainerHookBase", "UpdateWeights", "TargetNetUpdaterHook", + "UTDRHook", ] diff --git a/torchrl/trainers/algorithms/configs/__init__.py b/torchrl/trainers/algorithms/configs/__init__.py index 52c50ed8e11..f76b83b1387 100644 --- a/torchrl/trainers/algorithms/configs/__init__.py +++ b/torchrl/trainers/algorithms/configs/__init__.py @@ -119,6 +119,7 @@ ExcludeTransformConfig, FiniteTensorDictCheckConfig, FlattenObservationConfig, + FlattenTensorDictConfig, FrameSkipTransformConfig, GrayScaleConfig, HashConfig, @@ -265,6 +266,7 @@ "ExcludeTransformConfig", "FiniteTensorDictCheckConfig", "FlattenObservationConfig", + "FlattenTensorDictConfig", "FrameSkipTransformConfig", "GrayScaleConfig", "HashConfig", @@ -413,6 +415,7 @@ def _register_configs(): cs.store( group="transform", name="flatten_observation", node=FlattenObservationConfig ) + cs.store(group="transform", name="flatten_tensordict", node=FlattenTensorDictConfig) cs.store(group="transform", name="gray_scale", node=GrayScaleConfig) cs.store(group="transform", name="observation_norm", node=ObservationNormConfig) cs.store(group="transform", name="cat_frames", node=CatFramesConfig) diff --git a/torchrl/trainers/algorithms/configs/collectors.py b/torchrl/trainers/algorithms/configs/collectors.py index 34eb778b9b2..9e57b7c19fa 100644 --- a/torchrl/trainers/algorithms/configs/collectors.py +++ b/torchrl/trainers/algorithms/configs/collectors.py @@ -96,6 +96,7 @@ class AsyncDataCollectorConfig(DataCollectorConfig): weight_updater: Any = None track_policy_version: bool = False _target_: str = "torchrl.collectors.aSyncDataCollector" + _partial_: bool = False def __post_init__(self): self.create_env_fn._partial_ = True @@ -137,6 +138,7 @@ class MultiSyncDataCollectorConfig(DataCollectorConfig): weight_updater: Any = None track_policy_version: bool = False _target_: str = "torchrl.collectors.MultiSyncDataCollector" + _partial_: bool = False def __post_init__(self): for env_cfg in self.create_env_fn: @@ -179,6 +181,7 @@ class MultiaSyncDataCollectorConfig(DataCollectorConfig): weight_updater: Any = None track_policy_version: bool = False _target_: str = "torchrl.collectors.MultiaSyncDataCollector" + _partial_: bool = False def __post_init__(self): for env_cfg in self.create_env_fn: diff --git a/torchrl/trainers/algorithms/configs/data.py b/torchrl/trainers/algorithms/configs/data.py index 40c90c1a808..08a8eb44cc3 100644 --- a/torchrl/trainers/algorithms/configs/data.py +++ b/torchrl/trainers/algorithms/configs/data.py @@ -50,7 +50,6 @@ class RandomSamplerConfig(SamplerConfig): """Configuration for random sampling from replay buffer.""" _target_: str = "torchrl.data.replay_buffers.RandomSampler" - batch_size: int | None = None def __post_init__(self) -> None: """Post-initialization hook for random sampler configurations.""" @@ -304,3 +303,4 @@ class ReplayBufferConfig(ReplayBufferBaseConfig): writer: Any = None transform: Any = None batch_size: int | None = None + shared: bool = False diff --git a/torchrl/trainers/algorithms/configs/modules.py b/torchrl/trainers/algorithms/configs/modules.py index 7a582b11dba..8ec1a4df984 100644 --- a/torchrl/trainers/algorithms/configs/modules.py +++ b/torchrl/trainers/algorithms/configs/modules.py @@ -246,6 +246,8 @@ class TanhNormalModelConfig(ModelConfig): eval_mode: bool = False extract_normal_params: bool = True + scale_mapping: str = "biased_softplus_1.0" + scale_lb: float = 1e-4 param_keys: Any = None @@ -305,6 +307,8 @@ def _make_tanh_normal_model(*args, **kwargs): param_keys = list(kwargs.pop("param_keys", ["loc", "scale"])) out_keys = list(kwargs.pop("out_keys", ["action"])) extract_normal_params = kwargs.pop("extract_normal_params", True) + scale_mapping = kwargs.pop("scale_mapping", "biased_softplus_1.0") + scale_lb = kwargs.pop("scale_lb", 1e-4) return_log_prob = kwargs.pop("return_log_prob", False) eval_mode = kwargs.pop("eval_mode", False) exploration_type = kwargs.pop("exploration_type", "RANDOM") @@ -318,7 +322,10 @@ def _make_tanh_normal_model(*args, **kwargs): # Create the sequential if extract_normal_params: # Add NormalParamExtractor to split the output - network = torch.nn.Sequential(network, NormalParamExtractor()) + network = torch.nn.Sequential( + network, + NormalParamExtractor(scale_mapping=scale_mapping, scale_lb=scale_lb), + ) module = TensorDictModule(network, in_keys=in_keys, out_keys=param_keys) diff --git a/torchrl/trainers/algorithms/configs/trainers.py b/torchrl/trainers/algorithms/configs/trainers.py index cf2f02893bc..da467c61461 100644 --- a/torchrl/trainers/algorithms/configs/trainers.py +++ b/torchrl/trainers/algorithms/configs/trainers.py @@ -53,6 +53,7 @@ class SACTrainerConfig(TrainerConfig): actor_network: Any = None critic_network: Any = None target_net_updater: Any = None + async_collection: bool = False _target_: str = "torchrl.trainers.algorithms.configs.trainers._make_sac_trainer" @@ -83,8 +84,9 @@ def _make_sac_trainer(*args, **kwargs) -> SACTrainer: seed = kwargs.pop("seed") actor_network = kwargs.pop("actor_network") critic_network = kwargs.pop("critic_network") - create_env_fn = kwargs.pop("create_env_fn") + kwargs.pop("create_env_fn") target_net_updater = kwargs.pop("target_net_updater") + async_collection = kwargs.pop("async_collection", False) # Instantiate networks first if actor_network is not None: @@ -94,7 +96,16 @@ def _make_sac_trainer(*args, **kwargs) -> SACTrainer: if not isinstance(collector, DataCollectorBase): # then it's a partial config - collector = collector(create_env_fn=create_env_fn, policy=actor_network) + if not async_collection: + collector = collector() + elif replay_buffer is not None: + collector = collector(replay_buffer=replay_buffer) + elif getattr(collector, "replay_buffer", None) is None: + if collector.replay_buffer is None or replay_buffer is None: + raise ValueError( + "replay_buffer must be provided when async_collection is True" + ) + if not isinstance(loss_module, LossModule): # then it's a partial config loss_module = loss_module( @@ -138,6 +149,7 @@ def _make_sac_trainer(*args, **kwargs) -> SACTrainer: save_trainer_file=save_trainer_file, replay_buffer=replay_buffer, target_net_updater=target_net_updater, + async_collection=async_collection, ) @@ -168,6 +180,7 @@ class PPOTrainerConfig(TrainerConfig): actor_network: Any = None critic_network: Any = None num_epochs: int = 4 + async_collection: bool = False _target_: str = "torchrl.trainers.algorithms.configs.trainers._make_ppo_trainer" @@ -199,7 +212,11 @@ def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer: actor_network = kwargs.pop("actor_network") critic_network = kwargs.pop("critic_network") create_env_fn = kwargs.pop("create_env_fn") + if create_env_fn is not None: + # could be referenced somewhere else, no need to raise an error + pass num_epochs = kwargs.pop("num_epochs", 4) + async_collection = kwargs.pop("async_collection", False) # Instantiate networks first if actor_network is not None: @@ -209,7 +226,14 @@ def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer: if not isinstance(collector, DataCollectorBase): # then it's a partial config - collector = collector(create_env_fn=create_env_fn, policy=actor_network) + if not async_collection: + collector = collector() + else: + collector = collector(replay_buffer=replay_buffer) + elif getattr(collector, "replay_buffer", None) is None: + raise RuntimeError( + "replay_buffer must be provided when async_collection is True" + ) if not isinstance(loss_module, LossModule): # then it's a partial config loss_module = loss_module( @@ -250,4 +274,5 @@ def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer: save_trainer_file=save_trainer_file, replay_buffer=replay_buffer, num_epochs=num_epochs, + async_collection=async_collection, ) diff --git a/torchrl/trainers/algorithms/configs/transforms.py b/torchrl/trainers/algorithms/configs/transforms.py index 52646551d65..4a60e2da9b0 100644 --- a/torchrl/trainers/algorithms/configs/transforms.py +++ b/torchrl/trainers/algorithms/configs/transforms.py @@ -922,3 +922,19 @@ class CropConfig(TransformConfig): def __post_init__(self) -> None: """Post-initialization hook for Crop configuration.""" super().__post_init__() + + +@dataclass +class FlattenTensorDictConfig(TransformConfig): + """Configuration for flattening TensorDict during inverse pass. + + This transform reshapes the tensordict to have a flat batch dimension + during the inverse pass, which is useful for replay buffers that need + to store data with a flat batch structure. + """ + + _target_: str = "torchrl.envs.transforms.transforms.FlattenTensorDict" + + def __post_init__(self) -> None: + """Post-initialization hook for FlattenTensorDict configuration.""" + super().__post_init__() diff --git a/torchrl/trainers/algorithms/ppo.py b/torchrl/trainers/algorithms/ppo.py index d08dd41a061..4b09788894a 100644 --- a/torchrl/trainers/algorithms/ppo.py +++ b/torchrl/trainers/algorithms/ppo.py @@ -110,6 +110,7 @@ def __init__( log_rewards: bool = True, log_actions: bool = True, log_observations: bool = False, + async_collection: bool = False, ) -> None: warnings.warn( "PPOTrainer is an experimental/prototype feature. The API may change in future versions. " @@ -133,8 +134,10 @@ def __init__( log_interval=log_interval, save_trainer_file=save_trainer_file, num_epochs=num_epochs, + async_collection=async_collection, ) self.replay_buffer = replay_buffer + self.async_collection = async_collection gae = GAE( gamma=gamma, @@ -144,8 +147,10 @@ def __init__( ) self.register_op("pre_epoch", gae) - if replay_buffer is not None and not isinstance( - replay_buffer.sampler, SamplerWithoutReplacement + if ( + not self.async_collection + and replay_buffer is not None + and not isinstance(replay_buffer.sampler, SamplerWithoutReplacement) ): warnings.warn( "Sampler is not a SamplerWithoutReplacement, which is required for PPO." @@ -161,7 +166,8 @@ def __init__( iterate=True, ) - self.register_op("pre_epoch", rb_trainer.extend) + if not self.async_collection: + self.register_op("pre_epoch", rb_trainer.extend) self.register_op("process_optim_batch", rb_trainer.sample) self.register_op("post_loss", rb_trainer.update_priority) @@ -201,7 +207,10 @@ def _setup_ppo_logging(self): include_std=False, # No std for binary values reduction="mean", ) - self.register_op("pre_steps_log", log_done_percentage) + if not self.async_collection: + self.register_op("pre_steps_log", log_done_percentage) + else: + self.register_op("post_optim_log", log_done_percentage) # Log rewards if enabled if self.log_rewards: @@ -213,7 +222,10 @@ def _setup_ppo_logging(self): include_std=True, reduction="mean", ) - self.register_op("pre_steps_log", log_rewards) + if not self.async_collection: + self.register_op("pre_steps_log", log_rewards) + else: + self.register_op("post_optim_log", log_rewards) # 2. Log maximum reward in batch (for monitoring best performance) log_max_reward = LogScalar( @@ -223,7 +235,10 @@ def _setup_ppo_logging(self): include_std=False, reduction="max", ) - self.register_op("pre_steps_log", log_max_reward) + if not self.async_collection: + self.register_op("pre_steps_log", log_max_reward) + else: + self.register_op("post_optim_log", log_max_reward) # 3. Log total reward in batch (for monitoring cumulative performance) log_total_reward = LogScalar( @@ -233,7 +248,10 @@ def _setup_ppo_logging(self): include_std=False, reduction="sum", ) - self.register_op("pre_steps_log", log_total_reward) + if not self.async_collection: + self.register_op("pre_steps_log", log_total_reward) + else: + self.register_op("post_optim_log", log_total_reward) # Log actions if enabled if self.log_actions: @@ -245,7 +263,10 @@ def _setup_ppo_logging(self): include_std=True, reduction="mean", ) - self.register_op("pre_steps_log", log_action_norm) + if not self.async_collection: + self.register_op("pre_steps_log", log_action_norm) + else: + self.register_op("post_optim_log", log_action_norm) # Log observations if enabled if self.log_observations: @@ -257,4 +278,7 @@ def _setup_ppo_logging(self): include_std=True, reduction="mean", ) - self.register_op("pre_steps_log", log_obs_norm) + if not self.async_collection: + self.register_op("pre_steps_log", log_obs_norm) + else: + self.register_op("post_optim_log", log_obs_norm) diff --git a/torchrl/trainers/algorithms/sac.py b/torchrl/trainers/algorithms/sac.py index caf4180925a..6bf4956bcd5 100644 --- a/torchrl/trainers/algorithms/sac.py +++ b/torchrl/trainers/algorithms/sac.py @@ -27,6 +27,7 @@ TargetNetUpdaterHook, Trainer, UpdateWeights, + UTDRHook, ) @@ -123,6 +124,7 @@ def __init__( log_actions: bool = True, log_observations: bool = False, target_net_updater: TargetNetUpdater | None = None, + async_collection: bool = False, ) -> None: warnings.warn( "SACTrainer is an experimental/prototype feature. The API may change in future versions. " @@ -148,8 +150,10 @@ def __init__( save_trainer_interval=save_trainer_interval, log_interval=log_interval, save_trainer_file=save_trainer_file, + async_collection=async_collection, ) self.replay_buffer = replay_buffer + self.async_collection = async_collection # Note: SAC can use any sampler type, unlike PPO which requires SamplerWithoutReplacement @@ -162,8 +166,8 @@ def __init__( device=getattr(replay_buffer.storage, "device", "cpu"), iterate=True, ) - - self.register_op("pre_epoch", rb_trainer.extend) + if not self.async_collection: + self.register_op("pre_epoch", rb_trainer.extend) self.register_op("process_optim_batch", rb_trainer.sample) self.register_op("post_loss", rb_trainer.update_priority) self.register_op("post_optim", TargetNetUpdaterHook(target_net_updater)) @@ -222,7 +226,10 @@ def _setup_sac_logging(self): include_std=False, # No std for binary values reduction="mean", ) - self.register_op("pre_steps_log", log_done_percentage) + if not self.async_collection: + self.register_op("pre_steps_log", log_done_percentage) + else: + self.register_op("post_optim_log", log_done_percentage) # Log rewards if enabled if self.log_rewards: @@ -234,7 +241,11 @@ def _setup_sac_logging(self): include_std=True, reduction="mean", ) - self.register_op("pre_steps_log", log_rewards) + if not self.async_collection: + self.register_op("pre_steps_log", log_rewards) + else: + # In the async case, use the batch passed to the optimizer + self.register_op("post_optim_log", log_rewards) # 2. Log maximum reward in batch (for monitoring best performance) log_max_reward = LogScalar( @@ -244,17 +255,23 @@ def _setup_sac_logging(self): include_std=False, reduction="max", ) - self.register_op("pre_steps_log", log_max_reward) + if not self.async_collection: + self.register_op("pre_steps_log", log_max_reward) + else: + self.register_op("post_optim_log", log_max_reward) # 3. Log total reward in batch (for monitoring cumulative performance) log_total_reward = LogScalar( - key=("next", "reward"), + key=("next", "reward_sum"), logname="r_total", log_pbar=False, include_std=False, - reduction="sum", + reduction="max", ) - self.register_op("pre_steps_log", log_total_reward) + if not self.async_collection: + self.register_op("pre_steps_log", log_total_reward) + else: + self.register_op("post_optim_log", log_total_reward) # Log actions if enabled if self.log_actions: @@ -266,7 +283,10 @@ def _setup_sac_logging(self): include_std=True, reduction="mean", ) - self.register_op("pre_steps_log", log_action_norm) + if not self.async_collection: + self.register_op("pre_steps_log", log_action_norm) + else: + self.register_op("post_optim_log", log_action_norm) # Log observations if enabled if self.log_observations: @@ -278,4 +298,9 @@ def _setup_sac_logging(self): include_std=True, reduction="mean", ) - self.register_op("pre_steps_log", log_obs_norm) + if not self.async_collection: + self.register_op("pre_steps_log", log_obs_norm) + else: + self.register_op("post_optim_log", log_obs_norm) + + self.register_op("pre_steps_log", UTDRHook(self)) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 4f7006b97ff..d93cef3375d 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -8,6 +8,7 @@ import abc import itertools import pathlib +import time import warnings from collections import defaultdict, OrderedDict from collections.abc import Callable, Sequence @@ -27,6 +28,7 @@ _CKPT_BACKEND, KeyDependentDefaultDict, logger as torchrl_logger, + RL_WARNINGS, VERBOSE, ) from torchrl.collectors import DataCollectorBase @@ -144,6 +146,10 @@ class Trainer: in frame count. Default is 10000. save_trainer_file (path, optional): path where to save the trainer. Default is None (no saving) + async_collection (bool, optional): Whether to collect data asynchronously. + This will only work if the replay buffer is registed within the data collector. + If using this, the UTD ratio (Update to Data) will be logged under the key "utd_ratio". + Default is False. """ @classmethod @@ -179,6 +185,7 @@ def __init__( log_interval: int = 10000, save_trainer_file: str | pathlib.Path | None = None, num_epochs: int = 1, + async_collection: bool = False, ) -> None: # objects @@ -187,6 +194,7 @@ def __init__( self.loss_module = loss_module self.optimizer = optimizer self.logger = logger + self.async_collection = async_collection # Logging frequency control - how often to log each metric (in frames) self._log_interval = log_interval @@ -596,20 +604,37 @@ def train(self): self._pbar = tqdm(total=self.total_frames) self._pbar_str = {} - for batch in self.collector: - batch = self._process_batch_hook(batch) - current_frames = ( - batch.get(("collector", "mask"), torch.tensor(batch.numel())) - .sum() - .item() - * self.frame_skip - ) - self.collected_frames += current_frames + if self.async_collection: + self.collector.start() + while not self.collector.getattr_rb("write_count"): + time.sleep(0.1) + + # Create async iterator that monitors write_count progress + iterator = self._async_iterator() + else: + iterator = self.collector + + for batch in iterator: + if not self.async_collection: + batch = self._process_batch_hook(batch) + current_frames = ( + batch.get(("collector", "mask"), torch.tensor(batch.numel())) + .sum() + .item() + * self.frame_skip + ) + self.collected_frames += current_frames + else: + # In async mode, batch is None and we track frames via write_count + batch = None + cf = self.collected_frames + self.collected_frames = self.collector.getattr_rb("write_count") + current_frames = self.collected_frames - cf # LOGGING POINT 1: Pre-optimization logging (e.g., rewards, frame counts) self._pre_steps_log_hook(batch) - if self.collected_frames > self.collector.init_random_frames: + if self.collected_frames >= self.collector.init_random_frames: self.optim_steps(batch) self._post_steps_hook() @@ -627,6 +652,21 @@ def train(self): self.collector.shutdown() + def _async_iterator(self): + """Create an iterator for async collection that monitors replay buffer write_count. + + This iterator yields None batches and terminates when total_frames is reached + based on the replay buffer's write_count rather than using a fixed range. + This ensures the training loop properly consumes the entire collector output. + """ + while True: + current_write_count = self.collector.getattr_rb("write_count") + # Check if we've reached the target frames + if current_write_count >= self.total_frames: + break + else: + yield None + def __del__(self): try: self.collector.shutdown() @@ -1723,7 +1763,7 @@ def flatten_dict(d): return out -class TargetNetUpdaterHook: +class TargetNetUpdaterHook(TrainerHookBase): """A hook for target parameters update. Examples: @@ -1744,6 +1784,77 @@ def __init__(self, target_params_updater: TargetNetUpdater): ) self.target_params_updater = target_params_updater - def __call__(self, tensordict: TensorCollection): + def __call__(self, tensordict: TensorCollection | None = None): self.target_params_updater.step() return tensordict + + def register(self, trainer: Trainer, name: str): + trainer.register_op("post_steps", self) + + +class UTDRHook(TrainerHookBase): + """Hook for logging Update-to-Data (UTD) ratio during async collection. + + The UTD ratio measures how many optimization steps are performed per + collected data sample, providing insight into training efficiency during + asynchronous data collection. This metric is particularly useful for + off-policy algorithms where data collection and training happen concurrently. + + The UTD ratio is calculated as: (batch_size * update_count) / write_count + where: + - batch_size: Size of batches sampled from replay buffer + - update_count: Total number of optimization steps performed + - write_count: Total number of samples written to replay buffer + + Args: + trainer (Trainer): The trainer instance to monitor for UTD calculation. + Must have async_collection=True for meaningful results. + + Note: + This hook is only meaningful when async_collection is enabled, as it + relies on the replay buffer's write_count to track data collection progress. + """ + + def __init__(self, trainer: Trainer): + self.trainer = trainer + + def __call__(self, batch: TensorDictBase | None = None) -> dict: + if ( + hasattr(self.trainer, "replay_buffer") + and self.trainer.replay_buffer is not None + ): + write_count = self.trainer.replay_buffer.write_count + batch_size = self.trainer.replay_buffer.batch_size + else: + write_count = self.trainer.collector.getattr_rb("write_count") + batch_size = self.trainer.collector.getattr_rb("batch_size") + if not write_count: + return {} + if batch_size is None and RL_WARNINGS: + warnings.warn("Batch size is not set. Using 1.") + batch_size = 1 + update_count = self.trainer._optim_count + utd_ratio = batch_size * update_count / write_count + return { + "utd_ratio": utd_ratio, + "write_count": write_count, + "update_count": update_count, + "log_pbar": False, + } + + def register(self, trainer: Trainer, name: str = "utdr_hook"): + """Register the UTD ratio hook with the trainer. + + Args: + trainer (Trainer): The trainer to register with. + name (str): Name to use when registering the hook module. + """ + trainer.register_op("pre_steps_log", self) + trainer.register_module(name, self) + + def state_dict(self) -> dict[str, Any]: + """Return state dictionary for checkpointing.""" + return {} + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load state from dictionary."""