From 8889aaf78b8531b10a06de333ebd3ec3e810983e Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 6 Nov 2025 12:23:35 +0000 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- .pre-commit-config.yaml | 7 + docs/source/reference/collectors.rst | 615 +------ docs/source/reference/collectors_basics.rst | 118 ++ .../reference/collectors_distributed.rst | 49 + docs/source/reference/collectors_replay.rst | 81 + docs/source/reference/collectors_single.rst | 50 + .../reference/collectors_weightsync.rst | 299 ++++ docs/source/reference/config.rst | 12 +- .../source/reference/cudnn_persistent_rnn.rst | 4 + .../reference/cudnn_rnn_determinism.rst | 8 + docs/source/reference/data.rst | 1282 +------------- docs/source/reference/data_datasets.rst | 19 + docs/source/reference/data_replaybuffers.rst | 51 + docs/source/reference/data_samplers.rst | 32 + docs/source/reference/data_specs.rst | 26 + docs/source/reference/data_storage.rst | 37 + docs/source/reference/envs.rst | 1482 +---------------- docs/source/reference/envs_api.rst | 208 +++ docs/source/reference/envs_libraries.rst | 277 +++ docs/source/reference/envs_multiagent.rst | 138 ++ docs/source/reference/envs_recorders.rst | 83 + docs/source/reference/envs_transforms.rst | 359 ++++ docs/source/reference/envs_vectorized.rst | 351 ++++ docs/source/reference/llms.rst | 1199 +------------ docs/source/reference/llms_collectors.rst | 34 + docs/source/reference/llms_data.rst | 36 + docs/source/reference/llms_envs.rst | 24 + docs/source/reference/llms_modules.rst | 45 + docs/source/reference/llms_objectives.rst | 27 + docs/source/reference/llms_transforms.rst | 26 + docs/source/reference/modules.rst | 504 +----- docs/source/reference/modules_actors.rst | 47 + docs/source/reference/modules_critics.rst | 25 + .../reference/modules_distributions.rst | 20 + docs/source/reference/modules_exploration.rst | 15 + docs/source/reference/modules_models.rst | 18 + docs/source/reference/modules_utils.rst | 16 + docs/source/reference/objectives.rst | 450 +---- .../reference/objectives_actorcritic.rst | 17 + docs/source/reference/objectives_common.rst | 27 + docs/source/reference/objectives_offline.rst | 16 + docs/source/reference/objectives_other.rst | 17 + docs/source/reference/objectives_policy.rst | 16 + docs/source/reference/objectives_value.rst | 17 + docs/source/reference/services.rst | 609 +++++++ docs/source/reference/trainers.rst | 442 +---- docs/source/reference/trainers_basics.rst | 58 + docs/source/reference/trainers_hooks.rst | 23 + docs/source/reference/trainers_loggers.rst | 33 + 49 files changed, 3668 insertions(+), 5681 deletions(-) create mode 100644 docs/source/reference/collectors_basics.rst create mode 100644 docs/source/reference/collectors_distributed.rst create mode 100644 docs/source/reference/collectors_replay.rst create mode 100644 docs/source/reference/collectors_single.rst create mode 100644 docs/source/reference/collectors_weightsync.rst create mode 100644 docs/source/reference/cudnn_persistent_rnn.rst create mode 100644 docs/source/reference/cudnn_rnn_determinism.rst create mode 100644 docs/source/reference/data_datasets.rst create mode 100644 docs/source/reference/data_replaybuffers.rst create mode 100644 docs/source/reference/data_samplers.rst create mode 100644 docs/source/reference/data_specs.rst create mode 100644 docs/source/reference/data_storage.rst create mode 100644 docs/source/reference/envs_api.rst create mode 100644 docs/source/reference/envs_libraries.rst create mode 100644 docs/source/reference/envs_multiagent.rst create mode 100644 docs/source/reference/envs_recorders.rst create mode 100644 docs/source/reference/envs_transforms.rst create mode 100644 docs/source/reference/envs_vectorized.rst create mode 100644 docs/source/reference/llms_collectors.rst create mode 100644 docs/source/reference/llms_data.rst create mode 100644 docs/source/reference/llms_envs.rst create mode 100644 docs/source/reference/llms_modules.rst create mode 100644 docs/source/reference/llms_objectives.rst create mode 100644 docs/source/reference/llms_transforms.rst create mode 100644 docs/source/reference/modules_actors.rst create mode 100644 docs/source/reference/modules_critics.rst create mode 100644 docs/source/reference/modules_distributions.rst create mode 100644 docs/source/reference/modules_exploration.rst create mode 100644 docs/source/reference/modules_models.rst create mode 100644 docs/source/reference/modules_utils.rst create mode 100644 docs/source/reference/objectives_actorcritic.rst create mode 100644 docs/source/reference/objectives_common.rst create mode 100644 docs/source/reference/objectives_offline.rst create mode 100644 docs/source/reference/objectives_other.rst create mode 100644 docs/source/reference/objectives_policy.rst create mode 100644 docs/source/reference/objectives_value.rst create mode 100644 docs/source/reference/services.rst create mode 100644 docs/source/reference/trainers_basics.rst create mode 100644 docs/source/reference/trainers_hooks.rst create mode 100644 docs/source/reference/trainers_loggers.rst diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 04c0f40c2aa..b83882b0f54 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,3 +50,10 @@ repos: entry: autoflake --in-place --remove-unused-variables --remove-all-unused-imports language: system types: [python] + - id: check-sphinx-section-underline + name: Check Sphinx section underline lengths + entry: ./check-sphinx-section-underline --fix + language: script + files: ^docs/.*\.rst$ + pass_filenames: true + description: Ensure Sphinx section underline lengths match section titles. diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index 315e17e082f..6cbfde29b30 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -5,592 +5,67 @@ torchrl.collectors package .. _ref_collectors: -Data collectors are somewhat equivalent to pytorch dataloaders, except that (1) they -collect data over non-static data sources and (2) the data is collected using a model -(likely a version of the model that is being trained). +Data collectors are the bridge between your environments and training loop, managing the process of gathering +experience data using your policy. They handle environment resets, policy execution, and data aggregation, +making it easy to collect high-quality training data efficiently. -TorchRL's data collectors accept two main arguments: an environment (or a list of -environment constructors) and a policy. They will iteratively execute an environment -step and a policy query over a defined number of steps before delivering a stack of -the data collected to the user. Environments will be reset whenever they reach a done -state, and/or after a predefined number of steps. +TorchRL provides several collector implementations optimized for different scenarios: -Because data collection is a potentially compute heavy process, it is crucial to -configure the execution hyperparameters appropriately. -The first parameter to take into consideration is whether the data collection should -occur serially with the optimization step or in parallel. The :obj:`SyncDataCollector` -class will execute the data collection on the training worker. The :obj:`MultiSyncDataCollector` -will split the workload across an number of workers and aggregate the results that -will be delivered to the training worker. Finally, the :obj:`MultiaSyncDataCollector` will -execute the data collection on several workers and deliver the first batch of results -that it can gather. This execution will occur continuously and concomitantly with -the training of the networks: this implies that the weights of the policy that -is used for the data collection may slightly lag the configuration of the policy -on the training worker. Therefore, although this class may be the fastest to collect -data, it comes at the price of being suitable only in settings where it is acceptable -to gather data asynchronously (e.g. off-policy RL or curriculum RL). -For remotely executed rollouts (:obj:`MultiSyncDataCollector` or :obj:`MultiaSyncDataCollector`) -it is necessary to synchronise the weights of the remote policy with the weights -from the training worker using either the `collector.update_policy_weights_()` or -by setting `update_at_each_batch=True` in the constructor. +- **SyncDataCollector**: Single-process collection on the training worker +- **MultiSyncDataCollector**: Parallel collection across multiple workers with synchronized delivery +- **MultiaSyncDataCollector**: Asynchronous collection with first-come-first-serve delivery +- **Distributed collectors**: For multi-node setups using Ray, RPC, or distributed backends -The second parameter to consider (in the remote settings) is the device where the -data will be collected and the device where the environment and policy operations -will be executed. For instance, a policy executed on CPU may be slower than one -executed on CUDA. When multiple inference workers run concomitantly, dispatching -the compute workload across the available devices may speed up the collection or -avoid OOM errors. Finally, the choice of the batch size and passing device (ie the -device where the data will be stored while waiting to be passed to the collection -worker) may also impact the memory management. The key parameters to control are -:obj:`devices` which controls the execution devices (ie the device of the policy) -and :obj:`storing_device` which will control the device where the environment and -data are stored during a rollout. A good heuristic is usually to use the same device -for storage and compute, which is the default behavior when only the `devices` argument -is being passed. +Key Features +------------ -Besides those compute parameters, users may choose to configure the following parameters: +- **Flexible execution**: Choose between sync, async, and distributed collection +- **Device management**: Control where environments and policies execute +- **Weight synchronization**: Keep inference policies up-to-date with training weights +- **Replay buffer integration**: Seamless compatibility with TorchRL's replay buffers +- **Batching strategies**: Multiple ways to organize collected data -- max_frames_per_traj: the number of frames after which a :obj:`env.reset()` is called -- frames_per_batch: the number of frames delivered at each iteration over the collector -- init_random_frames: the number of random steps (steps where :obj:`env.rand_step()` is being called) -- reset_at_each_iter: if :obj:`True`, the environment(s) will be reset after each batch collection -- split_trajs: if :obj:`True`, the trajectories will be split and delivered in a padded tensordict - along with a :obj:`"mask"` key that will point to a boolean mask representing the valid values. -- exploration_type: the exploration strategy to be used with the policy. -- reset_when_done: whether environments should be reset when reaching a done state. - -Collectors and batch size -------------------------- - -Because each collector has its own way of organizing the environments that are -run within, the data will come with different batch-size depending on how -the specificities of the collector. The following table summarizes what is to -be expected when collecting data: - - -+--------------------+---------------------+--------------------------------------------+------------------------------+ -| | SyncDataCollector | MultiSyncDataCollector (n=B) |MultiaSyncDataCollector (n=B) | -+====================+=====================+=============+==============+===============+==============================+ -| `cat_results` | NA | `"stack"` | `0` | `-1` | NA | -+--------------------+---------------------+-------------+--------------+---------------+------------------------------+ -| Single env | [T] | `[B, T]` | `[B*(T//B)` | `[B*(T//B)]` | [T] | -+--------------------+---------------------+-------------+--------------+---------------+------------------------------+ -| Batched env (n=P) | [P, T] | `[B, P, T]` | `[B * P, T]`| `[P, T * B]` | [P, T] | -+--------------------+---------------------+-------------+--------------+---------------+------------------------------+ - -In each of these cases, the last dimension (``T`` for ``time``) is adapted such -that the batch size equals the ``frames_per_batch`` argument passed to the -collector. - -.. warning:: :class:`~torchrl.collectors.MultiSyncDataCollector` should not be - used with ``cat_results=0``, as the data will be stacked along the batch - dimension with batched environment, or the time dimension for single environments, - which can introduce some confusion when swapping one with the other. - ``cat_results="stack"`` is a better and more consistent way of interacting - with the environments as it will keep each dimension separate, and provide - better interchangeability between configurations, collector classes and other - components. - -Whereas :class:`~torchrl.collectors.MultiSyncDataCollector` -has a dimension corresponding to the number of sub-collectors being run (``B``), -:class:`~torchrl.collectors.MultiaSyncDataCollector` doesn't. This -is easily understood when considering that :class:`~torchrl.collectors.MultiaSyncDataCollector` -delivers batches of data on a first-come, first-serve basis, whereas -:class:`~torchrl.collectors.MultiSyncDataCollector` gathers data from -each sub-collector before delivering it. - -Collectors and policy copies ----------------------------- - -When passing a policy to a collector, we can choose the device on which this policy will be run. This can be used to -keep the training version of the policy on a device and the inference version on another. For example, if you have two -CUDA devices, it may be wise to train on one device and execute the policy for inference on the other. If that is the -case, a :meth:`~torchrl.collectors.DataCollector.update_policy_weights_` can be used to copy the parameters from one -device to the other (if no copy is required, this method is a no-op). - -Since the goal is to avoid calling `policy.to(policy_device)` explicitly, the collector will do a deepcopy of the -policy structure and copy the parameters placed on the new device during instantiation if necessary. -Since not all policies support deepcopies (e.g., policies using CUDA graphs or relying on third-party libraries), we -try to limit the cases where a deepcopy will be executed. The following chart shows when this will occur. - -.. figure:: /_static/img/collector-copy.png - - Policy copy decision tree in Collectors. - -Weight Synchronization using Weight Update Schemes --------------------------------------------------- - -RL pipelines are typically split in two big computational buckets: training, and inference. -While the inference pipeline sends data to the training one, the training pipeline needs to occasionally -synchronize its weights with the inference one. -In the most basic setting (fully synchronized data collection with traditional neural networks), the same weights are -used in both instances. From there, anything can happen: - -- In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named - `DataCollectors` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights - for his instance of the policy. -- In some cases, the environment or the postprocessing hooks can rely on the usage of a model which itself needs - synchronization. This means that there can be multiple ends in the data transfer API and one needs to think beyond - policy-to-policy weight synchronization strategies. -- In the LLM world, the inference engine and the training one are very different: they will use different libraries, - kernels and calling APIs (e.g., `generate` vs. `forward`). The weight format can also be drastically different (quantized - vs non-quantized). - This makes the weight synchronization much more complex, as one cannot simply dump and load a state dict on both ends. -- One typically also has to choose who instantiates a transfer: should this come from the inference engine who actively - asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach - is to store the weights on some intermediary server and let the workers fetch them when necessary. - -TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight -transfer: - -- A `Sender` class that somehow gets the weights (or a reference to them) and initializes the transfer; -- A `Receiver` class that casts the weights to the destination module (policy or other utility module); -- A `Transport` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else). -- A Scheme that defines what sender, receiver and transport have to be used and how to initialize them. - -Each of these classes is detailed below. - -Usage Examples -~~~~~~~~~~~~~~ - -.. note:: - **Runnable versions** of these examples are available in the repository: - - - `examples/collectors/weight_sync_standalone.py `_: Standalone weight synchronization - - `examples/collectors/weight_sync_collectors.py `_: Collector integration - -Using Weight Update Schemes Independently -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Weight update schemes can be used outside of collectors for custom synchronization scenarios. -The new simplified API provides four core methods for weight synchronization: - -- ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side -- ``init_on_worker(model_id, **kwargs)`` - Initialize on worker process side -- ``get_sender()`` - Get the configured sender instance -- ``get_receiver()`` - Get the configured receiver instance - -Here's a basic example: +Quick Example +------------- .. code-block:: python - import torch - import torch.nn as nn - from torch import multiprocessing as mp - from tensordict import TensorDict - from torchrl.weight_update import ( - MultiProcessWeightSyncScheme, - SharedMemWeightSyncScheme, - ) - - # Create a simple policy - policy = nn.Linear(4, 2) - - # Example 1: Multiprocess weight synchronization with state_dict - # -------------------------------------------------------------- - # On the main process side (trainer): - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + from torchrl.collectors import SyncDataCollector + from torchrl.envs import GymEnv, ParallelEnv - # Initialize scheme with pipes - parent_pipe, child_pipe = mp.Pipe() - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) + # Create a batched environment + def make_env(): + return GymEnv("Pendulum-v1") - # Get the sender and send weights - sender = scheme.get_sender() - weights = policy.state_dict() - sender.send(weights) # Synchronous send - # or sender.send_async(weights); sender.wait_async() # Asynchronous send - - # On the worker process side: - # scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy) - # receiver = scheme.get_receiver() - # # Non-blocking check for new weights - # if receiver.receive(timeout=0.001): - # # Weights were received and applied - - # Example 2: Shared memory weight synchronization - # ------------------------------------------------ - # Create shared memory scheme with auto-registration - shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - - # Initialize with pipes for lazy registration - parent_pipe2, child_pipe2 = mp.Pipe() - shared_scheme.init_on_sender(model_id="policy", pipes=[parent_pipe2]) + env = ParallelEnv(4, make_env) - # Get sender and send weights (automatically creates shared buffer on first send) - shared_sender = shared_scheme.get_sender() - weights_td = TensorDict.from_module(policy) - shared_sender.send(weights_td) - - # Workers automatically see updates via shared memory! - -Using Weight Update Schemes with Collectors -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Weight update schemes integrate seamlessly with TorchRL collectors, enabling efficient weight synchronization -across multiple inference workers: - -.. code-block:: python - - import torch.nn as nn - from tensordict.nn import TensorDictModule - from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector - from torchrl.envs import GymEnv - from torchrl.weight_update import ( - MultiProcessWeightSyncScheme, - SharedMemWeightSyncScheme, - ) - - # Create environment and policy - env = GymEnv("CartPole-v1") - policy = TensorDictModule( - nn.Linear(env.observation_spec["observation"].shape[-1], - env.action_spec.shape[-1]), - in_keys=["observation"], - out_keys=["action"], - ) - - # Example 1: Single collector with multiprocess scheme - # ----------------------------------------------------- - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - + # Create collector collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), - policy=policy, - frames_per_batch=64, - total_frames=1000, - weight_sync_schemes={"policy": scheme}, - ) - - # Collect data and update weights periodically - for i, data in enumerate(collector): - # ... training step with data ... - - # Update policy weights every N iterations - if i % 10 == 0: - new_weights = policy.state_dict() - collector.update_policy_weights_(new_weights) - - collector.shutdown() - - # Example 2: Multiple collectors with shared memory - # -------------------------------------------------- - # Shared memory is more efficient for frequent updates - shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - - collector = MultiSyncDataCollector( - create_env_fn=[ - lambda: GymEnv("CartPole-v1"), - lambda: GymEnv("CartPole-v1"), - lambda: GymEnv("CartPole-v1"), - ], - policy=policy, - frames_per_batch=192, + env, + policy=my_policy, + frames_per_batch=200, total_frames=10000, - weight_sync_schemes={"policy": shared_scheme}, ) - - # Workers automatically see weight updates via shared memory + + # Collect data for data in collector: - # ... training ... - collector.update_policy_weights_(TensorDict.from_module(policy)) - + # data is a TensorDict with shape [4, 50] (4 envs, 50 steps each) + # Use data for training... + + # Update policy weights periodically + if should_update: + collector.update_policy_weights_() + collector.shutdown() -.. note:: - When using ``SharedMemWeightSyncScheme``, weight updates are zero-copy and extremely fast since all - processes share the same memory buffers. This is ideal for frequent weight updates but requires all - processes to be on the same machine. - -.. note:: - The ``strategy`` parameter determines the weight format: ``"state_dict"`` uses PyTorch's native state - dictionaries, while ``"tensordict"`` uses TensorDict format which can be more efficient for structured - models and supports advanced features like lazy initialization. - -Weight Senders -~~~~~~~~~~~~~~ - -.. currentmodule:: torchrl.weight_update - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - WeightSender - RayModuleTransformSender - -Weight Receivers -~~~~~~~~~~~~~~~~ - -.. currentmodule:: torchrl.weight_update - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - WeightReceiver - RayModuleTransformReceiver - -Transports -~~~~~~~~~~ - -.. currentmodule:: torchrl.weight_update - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - TransportBackend - MPTransport - SharedMemTransport - RayTransport - RayActorTransport - RPCTransport - DistributedTransport - -Schemes -~~~~~~~ - -.. currentmodule:: torchrl.weight_update - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - WeightSyncScheme - MultiProcessWeightSyncScheme - SharedMemWeightSyncScheme - NoWeightSyncScheme - RayWeightSyncScheme - RayModuleTransformScheme - RPCWeightSyncScheme - DistributedWeightSyncScheme - -Legacy: Weight Synchronization in Distributed Environments ----------------------------------------------------------- - -.. warning:: - The `WeightUpdater` is considered legacy as per the 0.11 release and will be deprecated soon. - The Weight update schemes, which provides more flexibility and a better compatibility with heavy - weight transfers (e.g., LLMs) is to be preferred. - -In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the -latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible -mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios. - -Sending and receiving model weights with WeightUpdaters -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The weight synchronization process is facilitated by one dedicated extension point: -:class:`~torchrl.collectors.WeightUpdaterBase`. These base class provides a structured interface for -implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs. - -:class:`~torchrl.collectors.WeightUpdaterBase` handles the distribution of policy weights to -the policy or to remote inference workers, as well as formatting / gathering the weights from a server if necessary. -Every collector -- server or worker -- should have a `WeightUpdaterBase` instance to handle the -weight synchronization with the policy. -Even the simplest collectors use a :class:`~torchrl.collectors.VanillaWeightUpdater` instance to update the policy -state-dict (assuming it is a :class:`~torch.nn.Module` instance). - -Extending the Updater Class -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To accommodate diverse use cases, the API allows users to extend the updater classes with custom implementations. -The goal is to be able to customize the weight sync strategy while leaving the collector and policy implementation -untouched. -This flexibility is particularly beneficial in scenarios involving complex network architectures or specialized hardware -setups. -By implementing the abstract methods in these base classes, users can define how weights are retrieved, -transformed, and applied, ensuring seamless integration with their existing infrastructure. - -.. currentmodule:: torchrl.collectors - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - WeightUpdaterBase - VanillaWeightUpdater - MultiProcessedWeightUpdater - RayWeightUpdater - -.. currentmodule:: torchrl.collectors.distributed - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - RPCWeightUpdater - DistributedWeightUpdater - -Collectors and replay buffers interoperability ----------------------------------------------- - -In the simplest scenario where single transitions have to be sampled -from the replay buffer, little attention has to be given to the way -the collector is built. Flattening the data after collection will -be a sufficient preprocessing step before populating the storage: - - >>> memory = ReplayBuffer( - ... storage=LazyTensorStorage(N), - ... transform=lambda data: data.reshape(-1)) - >>> for data in collector: - ... memory.extend(data) - -If trajectory slices have to be collected, the recommended way to achieve this is to create -a multidimensional buffer and sample using the :class:`~torchrl.data.replay_buffers.SliceSampler` -sampler class. One must ensure that the data passed to the buffer is properly shaped, with the -``time`` and ``batch`` dimensions clearly separated. In practice, the following configurations -will work: - - >>> # Single environment: no need for a multi-dimensional buffer - >>> memory = ReplayBuffer( - ... storage=LazyTensorStorage(N), - ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) - ... ) - >>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1) - >>> for data in collector: - ... memory.extend(data) - >>> # Batched environments: a multi-dim buffer is required - >>> memory = ReplayBuffer( - ... storage=LazyTensorStorage(N, ndim=2), - ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) - ... ) - >>> env = ParallelEnv(4, make_env) - >>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1) - >>> for data in collector: - ... memory.extend(data) - >>> # MultiSyncDataCollector + regular env: behaves like a ParallelEnv if cat_results="stack" - >>> memory = ReplayBuffer( - ... storage=LazyTensorStorage(N, ndim=2), - ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) - ... ) - >>> collector = MultiSyncDataCollector([make_env] * 4, - ... policy, - ... frames_per_batch=N, - ... total_frames=-1, - ... cat_results="stack") - >>> for data in collector: - ... memory.extend(data) - >>> # MultiSyncDataCollector + parallel env: the ndim must be adapted accordingly - >>> memory = ReplayBuffer( - ... storage=LazyTensorStorage(N, ndim=3), - ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) - ... ) - >>> collector = MultiSyncDataCollector([ParallelEnv(2, make_env)] * 4, - ... policy, - ... frames_per_batch=N, - ... total_frames=-1, - ... cat_results="stack") - >>> for data in collector: - ... memory.extend(data) - -Using replay buffers that sample trajectories with :class:`~torchrl.collectors.MultiSyncDataCollector` -isn't currently fully supported as the data batches can come from any worker and in most cases consecutive -batches written in the buffer won't come from the same source (thereby interrupting the trajectories). - -Running the Collector Asynchronously ------------------------------------- - -Passing replay buffers to a collector allows us to start the collection and get rid of the iterative nature of the -collector. -If you want to run a data collector in the background, simply run :meth:`~torchrl.DataCollectorBase.start`: - - >>> collector = SyncDataCollector(..., replay_buffer=rb) # pass your replay buffer - >>> collector.start() - >>> # little pause - >>> time.sleep(10) - >>> # Start training - >>> for i in range(optim_steps): - ... data = rb.sample() # Sampling from the replay buffer - ... # rest of the training loop - -Single-process collectors (:class:`~torchrl.collectors.SyncDataCollector`) will run the process using multithreading, -so be mindful of Python's GIL and related multithreading restrictions. - -Multiprocessed collectors will on the other hand let the child processes handle the filling of the buffer on their own, -which truly decouples the data collection and training. - -Data collectors that have been started with `start()` should be shut down using -:meth:`~torchrl.DataCollectorBase.async_shutdown`. - -.. warning:: Running a collector asynchronously decouples the collection from training, which means that the training - performance may be drastically different depending on the hardware, load and other factors (although it is generally - expected to provide significant speed-ups). Make sure you understand how this may affect your algorithm and if it - is a legitimate thing to do! (For example, on-policy algorithms such as PPO should not be run asynchronously - unless properly benchmarked). - -Single node data collectors ---------------------------- -.. currentmodule:: torchrl.collectors - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - DataCollectorBase - SyncDataCollector - MultiSyncDataCollector - MultiaSyncDataCollector - aSyncDataCollector - - -Distributed data collectors ---------------------------- -TorchRL provides a set of distributed data collectors. These tools support -multiple backends (``'gloo'``, ``'nccl'``, ``'mpi'`` with the :class:`~.DistributedDataCollector` -or PyTorch RPC with :class:`~.RPCDataCollector`) and launchers (``'ray'``, -``submitit`` or ``torch.multiprocessing``). -They can be efficiently used in synchronous or asynchronous mode, on a single -node or across multiple nodes. - -*Resources*: Find examples for these collectors in the -`dedicated folder `_. - -.. note:: - *Choosing the sub-collector*: All distributed collectors support the various single machine collectors. - One may wonder why using a :class:`MultiSyncDataCollector` or a :class:`~torchrl.envs.ParallelEnv` - instead. In general, multiprocessed collectors have a lower IO footprint than - parallel environments which need to communicate at each step. Yet, the model specs - play a role in the opposite direction, since using parallel environments will - result in a faster execution of the policy (and/or transforms) since these - operations will be vectorized. - -.. note:: - *Choosing the device of a collector (or a parallel environment)*: Sharing data - among processes is achieved via shared-memory buffers with parallel environment - and multiprocessed environments executed on CPU. Depending on the capabilities - of the machine being used, this may be prohibitively slow compared to sharing - data on GPU which is natively supported by cuda drivers. - In practice, this means that using the ``device="cpu"`` keyword argument when - building a parallel environment or collector can result in a slower collection - than using ``device="cuda"`` when available. - -.. note:: - Given the library's many optional dependencies (eg, Gym, Gymnasium, and many others) - warnings can quickly become quite annoying in multiprocessed / distributed settings. - By default, TorchRL filters out these warnings in sub-processes. If one still wishes to - see these warnings, they can be displayed by setting ``torchrl.filter_warnings_subprocess=False``. - -.. currentmodule:: torchrl.collectors.distributed - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - DistributedDataCollector - RPCDataCollector - DistributedSyncDataCollector - submitit_delayed_launcher - RayCollector - -Helper functions ----------------- - -.. currentmodule:: torchrl.collectors.utils +Documentation Sections +---------------------- -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst +.. toctree:: + :maxdepth: 2 - split_trajectories + collectors_basics + collectors_single + collectors_distributed + collectors_weightsync + collectors_replay diff --git a/docs/source/reference/collectors_basics.rst b/docs/source/reference/collectors_basics.rst new file mode 100644 index 00000000000..dd78cd61eb6 --- /dev/null +++ b/docs/source/reference/collectors_basics.rst @@ -0,0 +1,118 @@ +.. currentmodule:: torchrl.collectors + +.. _ref_collectors: + +Collector Basics +================ + +Data collectors are somewhat equivalent to pytorch dataloaders, except that (1) they +collect data over non-static data sources and (2) the data is collected using a model +(likely a version of the model that is being trained). + +TorchRL's data collectors accept two main arguments: an environment (or a list of +environment constructors) and a policy. They will iteratively execute an environment +step and a policy query over a defined number of steps before delivering a stack of +the data collected to the user. Environments will be reset whenever they reach a done +state, and/or after a predefined number of steps. + +Because data collection is a potentially compute heavy process, it is crucial to +configure the execution hyperparameters appropriately. +The first parameter to take into consideration is whether the data collection should +occur serially with the optimization step or in parallel. The :class:`SyncDataCollector` +class will execute the data collection on the training worker. The :class:`MultiSyncDataCollector` +will split the workload across an number of workers and aggregate the results that +will be delivered to the training worker. Finally, the :class:`MultiaSyncDataCollector` will +execute the data collection on several workers and deliver the first batch of results +that it can gather. This execution will occur continuously and concomitantly with +the training of the networks: this implies that the weights of the policy that +is used for the data collection may slightly lag the configuration of the policy +on the training worker. Therefore, although this class may be the fastest to collect +data, it comes at the price of being suitable only in settings where it is acceptable +to gather data asynchronously (e.g. off-policy RL or curriculum RL). +For remotely executed rollouts (:class:`MultiSyncDataCollector` or :class:`MultiaSyncDataCollector`) +it is necessary to synchronise the weights of the remote policy with the weights +from the training worker using either the :meth:`collector.update_policy_weights_` or +by setting ``update_at_each_batch=True`` in the constructor. + +The second parameter to consider (in the remote settings) is the device where the +data will be collected and the device where the environment and policy operations +will be executed. For instance, a policy executed on CPU may be slower than one +executed on CUDA. When multiple inference workers run concomitantly, dispatching +the compute workload across the available devices may speed up the collection or +avoid OOM errors. Finally, the choice of the batch size and passing device (ie the +device where the data will be stored while waiting to be passed to the collection +worker) may also impact the memory management. The key parameters to control are +``devices`` which controls the execution devices (ie the device of the policy) +and ``storing_device`` which will control the device where the environment and +data are stored during a rollout. A good heuristic is usually to use the same device +for storage and compute, which is the default behavior when only the ``devices`` argument +is being passed. + +Besides those compute parameters, users may choose to configure the following parameters: + +- max_frames_per_traj: the number of frames after which a :meth:`env.reset` is called +- frames_per_batch: the number of frames delivered at each iteration over the collector +- init_random_frames: the number of random steps (steps where :meth:`env.rand_step` is being called) +- reset_at_each_iter: if ``True``, the environment(s) will be reset after each batch collection +- split_trajs: if ``True``, the trajectories will be split and delivered in a padded tensordict + along with a ``"mask"`` key that will point to a boolean mask representing the valid values. +- exploration_type: the exploration strategy to be used with the policy. +- reset_when_done: whether environments should be reset when reaching a done state. + +Collectors and batch size +------------------------- + +Because each collector has its own way of organizing the environments that are +run within, the data will come with different batch-size depending on how +the specificities of the collector. The following table summarizes what is to +be expected when collecting data: + + ++--------------------+---------------------+--------------------------------------------+------------------------------+ +| | SyncDataCollector | MultiSyncDataCollector (n=B) |MultiaSyncDataCollector (n=B) | ++====================+=====================+=============+==============+===============+==============================+ +| `cat_results` | NA | `"stack"` | `0` | `-1` | NA | ++--------------------+---------------------+-------------+--------------+---------------+------------------------------+ +| Single env | [T] | `[B, T]` | `[B*(T//B)` | `[B*(T//B)]` | [T] | ++--------------------+---------------------+-------------+--------------+---------------+------------------------------+ +| Batched env (n=P) | [P, T] | `[B, P, T]` | `[B * P, T]`| `[P, T * B]` | [P, T] | ++--------------------+---------------------+-------------+--------------+---------------+------------------------------+ + +In each of these cases, the last dimension (``T`` for ``time``) is adapted such +that the batch size equals the ``frames_per_batch`` argument passed to the +collector. + +.. warning:: :class:`~torchrl.collectors.MultiSyncDataCollector` should not be + used with ``cat_results=0``, as the data will be stacked along the batch + dimension with batched environment, or the time dimension for single environments, + which can introduce some confusion when swapping one with the other. + ``cat_results="stack"`` is a better and more consistent way of interacting + with the environments as it will keep each dimension separate, and provide + better interchangeability between configurations, collector classes and other + components. + +Whereas :class:`~torchrl.collectors.MultiSyncDataCollector` +has a dimension corresponding to the number of sub-collectors being run (``B``), +:class:`~torchrl.collectors.MultiaSyncDataCollector` doesn't. This +is easily understood when considering that :class:`~torchrl.collectors.MultiaSyncDataCollector` +delivers batches of data on a first-come, first-serve basis, whereas +:class:`~torchrl.collectors.MultiSyncDataCollector` gathers data from +each sub-collector before delivering it. + +Collectors and policy copies +---------------------------- + +When passing a policy to a collector, we can choose the device on which this policy will be run. This can be used to +keep the training version of the policy on a device and the inference version on another. For example, if you have two +CUDA devices, it may be wise to train on one device and execute the policy for inference on the other. If that is the +case, a :meth:`~torchrl.collectors.DataCollector.update_policy_weights_` can be used to copy the parameters from one +device to the other (if no copy is required, this method is a no-op). + +Since the goal is to avoid calling `policy.to(policy_device)` explicitly, the collector will do a deepcopy of the +policy structure and copy the parameters placed on the new device during instantiation if necessary. +Since not all policies support deepcopies (e.g., policies using CUDA graphs or relying on third-party libraries), we +try to limit the cases where a deepcopy will be executed. The following chart shows when this will occur. + +.. figure:: /_static/img/collector-copy.png + + Policy copy decision tree in Collectors. diff --git a/docs/source/reference/collectors_distributed.rst b/docs/source/reference/collectors_distributed.rst new file mode 100644 index 00000000000..acc4c7af35b --- /dev/null +++ b/docs/source/reference/collectors_distributed.rst @@ -0,0 +1,49 @@ +.. currentmodule:: torchrl.collectors.distributed + +Distributed Collectors +====================== + +TorchRL provides a set of distributed data collectors. These tools support +multiple backends (``'gloo'``, ``'nccl'``, ``'mpi'`` with the :class:`~.DistributedDataCollector` +or PyTorch RPC with :class:`~.RPCDataCollector`) and launchers (``'ray'``, +``submitit`` or ``torch.multiprocessing``). +They can be efficiently used in synchronous or asynchronous mode, on a single +node or across multiple nodes. + +*Resources*: Find examples for these collectors in the +`dedicated folder `_. + +.. note:: + *Choosing the sub-collector*: All distributed collectors support the various single machine collectors. + One may wonder why using a :class:`MultiSyncDataCollector` or a :class:`~torchrl.envs.ParallelEnv` + instead. In general, multiprocessed collectors have a lower IO footprint than + parallel environments which need to communicate at each step. Yet, the model specs + play a role in the opposite direction, since using parallel environments will + result in a faster execution of the policy (and/or transforms) since these + operations will be vectorized. + +.. note:: + *Choosing the device of a collector (or a parallel environment)*: Sharing data + among processes is achieved via shared-memory buffers with parallel environment + and multiprocessed environments executed on CPU. Depending on the capabilities + of the machine being used, this may be prohibitively slow compared to sharing + data on GPU which is natively supported by cuda drivers. + In practice, this means that using the ``device="cpu"`` keyword argument when + building a parallel environment or collector can result in a slower collection + than using ``device="cuda"`` when available. + +.. note:: + Given the library's many optional dependencies (eg, Gym, Gymnasium, and many others) + warnings can quickly become quite annoying in multiprocessed / distributed settings. + By default, TorchRL filters out these warnings in sub-processes. If one still wishes to + see these warnings, they can be displayed by setting ``torchrl.filter_warnings_subprocess=False``. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + DistributedDataCollector + RPCDataCollector + DistributedSyncDataCollector + submitit_delayed_launcher + RayCollector diff --git a/docs/source/reference/collectors_replay.rst b/docs/source/reference/collectors_replay.rst new file mode 100644 index 00000000000..fb0a776fdbf --- /dev/null +++ b/docs/source/reference/collectors_replay.rst @@ -0,0 +1,81 @@ +.. currentmodule:: torchrl.collectors + +Collectors and Replay Buffers +============================= + +Collectors and replay buffers interoperability +---------------------------------------------- + +In the simplest scenario where single transitions have to be sampled +from the replay buffer, little attention has to be given to the way +the collector is built. Flattening the data after collection will +be a sufficient preprocessing step before populating the storage: + + >>> memory = ReplayBuffer( + ... storage=LazyTensorStorage(N), + ... transform=lambda data: data.reshape(-1)) + >>> for data in collector: + ... memory.extend(data) + +If trajectory slices have to be collected, the recommended way to achieve this is to create +a multidimensional buffer and sample using the :class:`~torchrl.data.replay_buffers.SliceSampler` +sampler class. One must ensure that the data passed to the buffer is properly shaped, with the +``time`` and ``batch`` dimensions clearly separated. In practice, the following configurations +will work: + + >>> # Single environment: no need for a multi-dimensional buffer + >>> memory = ReplayBuffer( + ... storage=LazyTensorStorage(N), + ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) + ... ) + >>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1) + >>> for data in collector: + ... memory.extend(data) + >>> # Batched environments: a multi-dim buffer is required + >>> memory = ReplayBuffer( + ... storage=LazyTensorStorage(N, ndim=2), + ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) + ... ) + >>> env = ParallelEnv(4, make_env) + >>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1) + >>> for data in collector: + ... memory.extend(data) + >>> # MultiSyncDataCollector + regular env: behaves like a ParallelEnv if cat_results="stack" + >>> memory = ReplayBuffer( + ... storage=LazyTensorStorage(N, ndim=2), + ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) + ... ) + >>> collector = MultiSyncDataCollector([make_env] * 4, + ... policy, + ... frames_per_batch=N, + ... total_frames=-1, + ... cat_results="stack") + >>> for data in collector: + ... memory.extend(data) + >>> # MultiSyncDataCollector + parallel env: the ndim must be adapted accordingly + >>> memory = ReplayBuffer( + ... storage=LazyTensorStorage(N, ndim=3), + ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) + ... ) + >>> collector = MultiSyncDataCollector([ParallelEnv(2, make_env)] * 4, + ... policy, + ... frames_per_batch=N, + ... total_frames=-1, + ... cat_results="stack") + >>> for data in collector: + ... memory.extend(data) + +Using replay buffers that sample trajectories with :class:`~torchrl.collectors.MultiSyncDataCollector` +isn't currently fully supported as the data batches can come from any worker and in most cases consecutive +batches written in the buffer won't come from the same source (thereby interrupting the trajectories). + +Helper functions +---------------- + +.. currentmodule:: torchrl.collectors.utils + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + split_trajectories diff --git a/docs/source/reference/collectors_single.rst b/docs/source/reference/collectors_single.rst new file mode 100644 index 00000000000..529b03ee4e2 --- /dev/null +++ b/docs/source/reference/collectors_single.rst @@ -0,0 +1,50 @@ +.. currentmodule:: torchrl.collectors + +Single Node Collectors +====================== + +TorchRL provides several collector classes for single-node data collection, each with different execution strategies. + +Single node data collectors +--------------------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + DataCollectorBase + SyncDataCollector + MultiSyncDataCollector + MultiaSyncDataCollector + aSyncDataCollector + +Running the Collector Asynchronously +------------------------------------ + +Passing replay buffers to a collector allows us to start the collection and get rid of the iterative nature of the +collector. +If you want to run a data collector in the background, simply run :meth:`~torchrl.DataCollectorBase.start`: + + >>> collector = SyncDataCollector(..., replay_buffer=rb) # pass your replay buffer + >>> collector.start() + >>> # little pause + >>> time.sleep(10) + >>> # Start training + >>> for i in range(optim_steps): + ... data = rb.sample() # Sampling from the replay buffer + ... # rest of the training loop + +Single-process collectors (:class:`~torchrl.collectors.SyncDataCollector`) will run the process using multithreading, +so be mindful of Python's GIL and related multithreading restrictions. + +Multiprocessed collectors will on the other hand let the child processes handle the filling of the buffer on their own, +which truly decouples the data collection and training. + +Data collectors that have been started with `start()` should be shut down using +:meth:`~torchrl.DataCollectorBase.async_shutdown`. + +.. warning:: Running a collector asynchronously decouples the collection from training, which means that the training + performance may be drastically different depending on the hardware, load and other factors (although it is generally + expected to provide significant speed-ups). Make sure you understand how this may affect your algorithm and if it + is a legitimate thing to do! (For example, on-policy algorithms such as PPO should not be run asynchronously + unless properly benchmarked). diff --git a/docs/source/reference/collectors_weightsync.rst b/docs/source/reference/collectors_weightsync.rst new file mode 100644 index 00000000000..0fcf174f3c1 --- /dev/null +++ b/docs/source/reference/collectors_weightsync.rst @@ -0,0 +1,299 @@ +.. currentmodule:: torchrl.weight_update + +Weight Synchronization +====================== + +RL pipelines are typically split in two big computational buckets: training, and inference. +While the inference pipeline sends data to the training one, the training pipeline needs to occasionally +synchronize its weights with the inference one. +In the most basic setting (fully synchronized data collection with traditional neural networks), the same weights are +used in both instances. From there, anything can happen: + +- In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named + `DataCollectors` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights + for his instance of the policy. +- In some cases, the environment or the postprocessing hooks can rely on the usage of a model which itself needs + synchronization. This means that there can be multiple ends in the data transfer API and one needs to think beyond + policy-to-policy weight synchronization strategies. +- In the LLM world, the inference engine and the training one are very different: they will use different libraries, + kernels and calling APIs (e.g., `generate` vs. `forward`). The weight format can also be drastically different (quantized + vs non-quantized). + This makes the weight synchronization much more complex, as one cannot simply dump and load a state dict on both ends. +- One typically also has to choose who instantiates a transfer: should this come from the inference engine who actively + asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach + is to store the weights on some intermediary server and let the workers fetch them when necessary. + +TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight +transfer: + +- A `Sender` class that somehow gets the weights (or a reference to them) and initializes the transfer; +- A `Receiver` class that casts the weights to the destination module (policy or other utility module); +- A `Transport` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else). +- A Scheme that defines what sender, receiver and transport have to be used and how to initialize them. + +Each of these classes is detailed below. + +Usage Examples +-------------- + +.. note:: + **Runnable versions** of these examples are available in the repository: + + - `examples/collectors/weight_sync_standalone.py `_: Standalone weight synchronization + - `examples/collectors/weight_sync_collectors.py `_: Collector integration + +Using Weight Update Schemes Independently +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Weight update schemes can be used outside of collectors for custom synchronization scenarios. +The new simplified API provides four core methods for weight synchronization: + +- ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side +- ``init_on_worker(model_id, **kwargs)`` - Initialize on worker process side +- ``get_sender()`` - Get the configured sender instance +- ``get_receiver()`` - Get the configured receiver instance + +Here's a basic example: + +.. code-block:: python + + import torch + import torch.nn as nn + from torch import multiprocessing as mp + from tensordict import TensorDict + from torchrl.weight_update import ( + MultiProcessWeightSyncScheme, + SharedMemWeightSyncScheme, + ) + + # Create a simple policy + policy = nn.Linear(4, 2) + + # Example 1: Multiprocess weight synchronization with state_dict + # -------------------------------------------------------------- + # On the main process side (trainer): + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + + # Initialize scheme with pipes + parent_pipe, child_pipe = mp.Pipe() + scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) + + # Get the sender and send weights + sender = scheme.get_sender() + weights = policy.state_dict() + sender.send(weights) # Synchronous send + # or sender.send_async(weights); sender.wait_async() # Asynchronous send + + # On the worker process side: + # scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy) + # receiver = scheme.get_receiver() + # # Non-blocking check for new weights + # if receiver.receive(timeout=0.001): + # # Weights were received and applied + + # Example 2: Shared memory weight synchronization + # ------------------------------------------------ + # Create shared memory scheme with auto-registration + shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + + # Initialize with pipes for lazy registration + parent_pipe2, child_pipe2 = mp.Pipe() + shared_scheme.init_on_sender(model_id="policy", pipes=[parent_pipe2]) + + # Get sender and send weights (automatically creates shared buffer on first send) + shared_sender = shared_scheme.get_sender() + weights_td = TensorDict.from_module(policy) + shared_sender.send(weights_td) + + # Workers automatically see updates via shared memory! + +Using Weight Update Schemes with Collectors +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Weight update schemes integrate seamlessly with TorchRL collectors, enabling efficient weight synchronization +across multiple inference workers: + +.. code-block:: python + + import torch.nn as nn + from tensordict.nn import TensorDictModule + from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector + from torchrl.envs import GymEnv + from torchrl.weight_update import ( + MultiProcessWeightSyncScheme, + SharedMemWeightSyncScheme, + ) + + # Create environment and policy + env = GymEnv("CartPole-v1") + policy = TensorDictModule( + nn.Linear(env.observation_spec["observation"].shape[-1], + env.action_spec.shape[-1]), + in_keys=["observation"], + out_keys=["action"], + ) + + # Example 1: Single collector with multiprocess scheme + # ----------------------------------------------------- + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=policy, + frames_per_batch=64, + total_frames=1000, + weight_sync_schemes={"policy": scheme}, + ) + + # Collect data and update weights periodically + for i, data in enumerate(collector): + # ... training step with data ... + + # Update policy weights every N iterations + if i % 10 == 0: + new_weights = policy.state_dict() + collector.update_policy_weights_(new_weights) + + collector.shutdown() + + # Example 2: Multiple collectors with shared memory + # -------------------------------------------------- + # Shared memory is more efficient for frequent updates + shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + + collector = MultiSyncDataCollector( + create_env_fn=[ + lambda: GymEnv("CartPole-v1"), + lambda: GymEnv("CartPole-v1"), + lambda: GymEnv("CartPole-v1"), + ], + policy=policy, + frames_per_batch=192, + total_frames=10000, + weight_sync_schemes={"policy": shared_scheme}, + ) + + # Workers automatically see weight updates via shared memory + for data in collector: + # ... training ... + collector.update_policy_weights_(TensorDict.from_module(policy)) + + collector.shutdown() + +.. note:: + When using ``SharedMemWeightSyncScheme``, weight updates are zero-copy and extremely fast since all + processes share the same memory buffers. This is ideal for frequent weight updates but requires all + processes to be on the same machine. + +.. note:: + The ``strategy`` parameter determines the weight format: ``"state_dict"`` uses PyTorch's native state + dictionaries, while ``"tensordict"`` uses TensorDict format which can be more efficient for structured + models and supports advanced features like lazy initialization. + +Weight Senders +-------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + WeightSender + RayModuleTransformSender + +Weight Receivers +---------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + WeightReceiver + RayModuleTransformReceiver + +Transports +---------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + TransportBackend + MPTransport + SharedMemTransport + RayTransport + RayActorTransport + RPCTransport + DistributedTransport + +Schemes +------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + WeightSyncScheme + MultiProcessWeightSyncScheme + SharedMemWeightSyncScheme + NoWeightSyncScheme + RayWeightSyncScheme + RayModuleTransformScheme + RPCWeightSyncScheme + DistributedWeightSyncScheme + +Legacy: Weight Updaters +----------------------- + +.. warning:: + The `WeightUpdater` is considered legacy as per the 0.11 release and will be deprecated soon. + The Weight update schemes, which provides more flexibility and a better compatibility with heavy + weight transfers (e.g., LLMs) is to be preferred. + +In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the +latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible +mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios. + +Sending and receiving model weights with WeightUpdaters +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The weight synchronization process is facilitated by one dedicated extension point: +:class:`~torchrl.collectors.WeightUpdaterBase`. These base class provides a structured interface for +implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs. + +:class:`~torchrl.collectors.WeightUpdaterBase` handles the distribution of policy weights to +the policy or to remote inference workers, as well as formatting / gathering the weights from a server if necessary. +Every collector -- server or worker -- should have a `WeightUpdaterBase` instance to handle the +weight synchronization with the policy. +Even the simplest collectors use a :class:`~torchrl.collectors.VanillaWeightUpdater` instance to update the policy +state-dict (assuming it is a :class:`~torch.nn.Module` instance). + +Extending the Updater Class +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To accommodate diverse use cases, the API allows users to extend the updater classes with custom implementations. +The goal is to be able to customize the weight sync strategy while leaving the collector and policy implementation +untouched. +This flexibility is particularly beneficial in scenarios involving complex network architectures or specialized hardware +setups. +By implementing the abstract methods in these base classes, users can define how weights are retrieved, +transformed, and applied, ensuring seamless integration with their existing infrastructure. + +.. currentmodule:: torchrl.collectors + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + WeightUpdaterBase + VanillaWeightUpdater + MultiProcessedWeightUpdater + RayWeightUpdater + +.. currentmodule:: torchrl.collectors.distributed + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + RPCWeightUpdater + DistributedWeightUpdater diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index f7372a38f6f..78035570a10 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -14,7 +14,7 @@ The advantages of using a configuration system are: - Easy to version control: you can easily version control your configuration file Quick Start with a Simple Example ----------------------------------- +--------------------------------- Let's start with a simple example that creates a Gym environment. Here's a minimal configuration file: @@ -65,7 +65,7 @@ TorchRL organizes configurations into several categories using the ``@`` syntax The ``@`` syntax allows you to assign configurations to specific locations in your config structure. More Complex Example: Parallel Environment with Transforms ------------------------------------------------------------ +---------------------------------------------------------- Here's a more complex example that creates a parallel environment with multiple transforms applied to each worker: @@ -127,7 +127,7 @@ This configuration builds a **parallel environment with 4 workers**, where each 4. **Variable interpolation**: ``${transform0}`` and ``${transform1}`` reference the separately defined transform configurations Getting Available Options --------------------------- +------------------------- To explore all available configurations and their parameters, one can use the ``--help`` flag with any TorchRL script: @@ -141,7 +141,7 @@ This shows all configuration groups and their options, making it easy to discove Complete Training Example --------------------------- +------------------------- Here's a complete configuration for PPO training: @@ -219,7 +219,7 @@ Here's a complete configuration for PPO training: exp_name: my_experiment Running Experiments --------------------- +------------------- Basic Usage ~~~~~~~~~~~ @@ -252,7 +252,7 @@ Hyperparameter Sweeps training_env.num_workers=2,4,8 Custom Configuration Files -~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: bash diff --git a/docs/source/reference/cudnn_persistent_rnn.rst b/docs/source/reference/cudnn_persistent_rnn.rst new file mode 100644 index 00000000000..d6a1e44380c --- /dev/null +++ b/docs/source/reference/cudnn_persistent_rnn.rst @@ -0,0 +1,4 @@ +.. note:: + In some circumstances when using the CUDNN backend with CuDNN 7.2.1, the backward + can be up to 5x slower when called with a batch_first input. This is expected to be fixed + in CuDNN 7.2.5. diff --git a/docs/source/reference/cudnn_rnn_determinism.rst b/docs/source/reference/cudnn_rnn_determinism.rst new file mode 100644 index 00000000000..92eff7c0596 --- /dev/null +++ b/docs/source/reference/cudnn_rnn_determinism.rst @@ -0,0 +1,8 @@ +.. note:: + If the following conditions are not met, the backward pass will use a slower but more + memory efficient implementation: + + * The input is a :class:`~torch.nn.utils.rnn.PackedSequence` + * The input is not batch first + * ``dropout != 0`` + * ``training == True`` diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 319eb19d0dc..46bfd25e2bd 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -5,1234 +5,54 @@ torchrl.data package .. _ref_data: -Replay Buffers --------------- - -Replay buffers are a central part of off-policy RL algorithms. TorchRL provides an efficient implementation of a few, -widely used replay buffers: - - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - ReplayBuffer - PrioritizedReplayBuffer - TensorDictReplayBuffer - TensorDictPrioritizedReplayBuffer - RayReplayBuffer - RemoteTensorDictReplayBuffer - -Composable Replay Buffers -------------------------- - -.. _ref_buffers: - -We also give users the ability to compose a replay buffer. -We provide a wide panel of solutions for replay buffer usage, including support for -almost any data type; storage in memory, on device or on physical memory; -several sampling strategies; usage of transforms etc. - -Supported data types and choosing a storage -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -In theory, replay buffers support any data type but we can't guarantee that each -component will support any data type. The most crude replay buffer implementation -is made of a :class:`~torchrl.data.replay_buffers.ReplayBuffer` base with a -:class:`~torchrl.data.replay_buffers.ListStorage` storage. This is very inefficient -but it will allow you to store complex data structures with non-tensor data. -Storages in contiguous memory include :class:`~torchrl.data.replay_buffers.TensorStorage`, -:class:`~torchrl.data.replay_buffers.LazyTensorStorage` and -:class:`~torchrl.data.replay_buffers.LazyMemmapStorage`. -These classes support :class:`~tensordict.TensorDict` data as first-class citizens, but also -any PyTree data structure (eg, tuples, lists, dictionaries and nested versions -of these). The :class:`~torchrl.data.replay_buffers.TensorStorage` storage requires -you to provide the storage at construction time, whereas :class:`~torchrl.data.replay_buffers.TensorStorage` -(RAM, CUDA) and :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` (physical memory) -will preallocate the storage for you after they've been extended the first time. - -Here are a few examples, starting with the generic :class:`~torchrl.data.replay_buffers.ListStorage`: - - >>> from torchrl.data.replay_buffers import ReplayBuffer, ListStorage - >>> rb = ReplayBuffer(storage=ListStorage(10)) - >>> rb.add("a string!") # first element will be a string - >>> rb.extend([30, None]) # element [1] is an int, [2] is None - -The main entry points to write onto a buffer are :meth:`~torchrl.data.ReplayBuffer.add` and -:meth:`~torchrl.data.ReplayBuffer.extend`. -One can also use :meth:`~torchrl.data.ReplayBuffer.__setitem__`, in which case the data is written -where indicated without updating the length or cursor of the buffer. This can be useful when sampling -items from the buffer and them updating their values in-place afterwards. - -Using a :class:`~torchrl.data.replay_buffers.TensorStorage` we tell our RB that -we want the storage to be contiguous, which is by far more efficient but also -more restrictive: - - >>> import torch - >>> from torchrl.data.replay_buffers import ReplayBuffer, TensorStorage - >>> container = torch.empty(10, 3, 64, 64, dtype=torch.unit8) - >>> rb = ReplayBuffer(storage=TensorStorage(container)) - >>> img = torch.randint(255, (3, 64, 64), dtype=torch.uint8) - >>> rb.add(img) - -Next we can avoid creating the container and ask the storage to do it automatically. -This is very useful when using PyTrees and tensordicts! For PyTrees as other data -structures, :meth:`~torchrl.data.replay_buffers.ReplayBuffer.add` considers the sampled -passed to it as a single instance of the type. :meth:`~torchrl.data.replay_buffers.ReplayBuffer.extend` -on the other hand will consider that the data is an iterable. For tensors, tensordicts -and lists (see below), the iterable is looked for at the root level. For PyTrees, -we assume that the leading dimension of all the leaves (tensors) in the tree -match. If they don't, ``extend`` will throw an exception. - - >>> import torch - >>> from tensordict import TensorDict - >>> from torchrl.data.replay_buffers import ReplayBuffer, LazyMemmapStorage - >>> rb_td = ReplayBuffer(storage=LazyMemmapStorage(10), batch_size=1) # max 10 elements stored - >>> rb_td.add(TensorDict({"img": torch.randint(255, (3, 64, 64), dtype=torch.unit8), - ... "labels": torch.randint(100, ())}, batch_size=[])) - >>> rb_pytree = ReplayBuffer(storage=LazyMemmapStorage(10)) # max 10 elements stored - >>> # extend with a PyTree where all tensors have the same leading dim (3) - >>> rb_pytree.extend({"a": {"b": torch.randn(3), "c": [torch.zeros(3, 2), (torch.ones(3, 10),)]}}) - >>> assert len(rb_pytree) == 3 # the replay buffer has 3 elements! - -.. note:: :meth:`~torchrl.data.replay_buffers.ReplayBuffer.extend` can have an - ambiguous signature when dealing with lists of values, which should be interpreted - either as PyTree (in which case all elements in the list will be put in a slice - in the stored PyTree in the storage) or a list of values to add one at a time. - To solve this, TorchRL makes the clear-cut distinction between list and tuple: - a tuple will be viewed as a PyTree, a list (at the root level) will be interpreted - as a stack of values to add one at a time to the buffer. - -Sampling and indexing -~~~~~~~~~~~~~~~~~~~~~ - -Replay buffers can be indexed and sampled. -Indexing and sampling collect data at given indices in the storage and then process them -through a series of transforms and ``collate_fn`` that can be passed to the `__init__` -function of the replay buffer. ``collate_fn`` comes with default values that should -match user expectations in the majority of cases, such that you should not have -to worry about it most of the time. Transforms are usually instances of :class:`~torchrl.envs.transforms.Transform` -even though regular functions will work too (in the latter case, the :meth:`~torchrl.envs.transforms.Transform.inv` -method will obviously be ignored, whereas in the first case it can be used to -preprocess the data before it is passed to the buffer). -Finally, sampling can be achieved using multithreading by passing the number of threads -to the constructor through the ``prefetch`` keyword argument. We advise users to -benchmark this technique in real life settings before adopting it, as there is -no guarantee that it will lead to a faster throughput in practice depending on -the machine and setting where it is used. - -When sampling, the ``batch_size`` can be either passed during construction -(e.g., if it's constant throughout training) or -to the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method. - -To further refine the sampling strategy, we advise you to look into our samplers! - -Here are a couple of examples of how to get data out of a replay buffer: - - >>> first_elt = rb_td[0] - >>> storage = rb_td[:] # returns all valid elements from the buffer - >>> sample = rb_td.sample(128) - >>> for data in rb_td: # iterate over the buffer using the sampler -- batch-size was set in the constructor to 1 - ... print(data) - -using the following components: - -.. currentmodule:: torchrl.data.replay_buffers - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - - CompressedListStorage - CompressedListStorageCheckpointer - FlatStorageCheckpointer - H5StorageCheckpointer - ImmutableDatasetWriter - LazyMemmapStorage - LazyTensorStorage - ListStorage - LazyStackStorage - ListStorageCheckpointer - NestedStorageCheckpointer - PrioritizedSampler - PrioritizedSliceSampler - RandomSampler - RoundRobinWriter - Sampler - SamplerWithoutReplacement - SliceSampler - SliceSamplerWithoutReplacement - Storage - StorageCheckpointerBase - StorageEnsembleCheckpointer - TensorDictMaxValueWriter - TensorDictRoundRobinWriter - TensorStorage - TensorStorageCheckpointer - Writer - - -Storage choice is very influential on replay buffer sampling latency, especially -in distributed reinforcement learning settings with larger data volumes. -:class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage` is highly -advised in distributed settings with shared storage due to the lower serialization -cost of MemoryMappedTensors as well as the ability to specify file storage locations -for improved node failure recovery. -The following mean sampling latency improvements over using :class:`~torchrl.data.replay_buffers.ListStorage` -were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/benchmarks/storage. - -+-------------------------------+-----------+ -| Storage Type | Speed up | -| | | -+===============================+===========+ -| :class:`ListStorage` | 1x | -+-------------------------------+-----------+ -| :class:`LazyTensorStorage` | 1.83x | -+-------------------------------+-----------+ -| :class:`LazyMemmapStorage` | 3.44x | -+-------------------------------+-----------+ - -Compressed Storage for Memory Efficiency -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -For applications where memory usage or memory bandwidth is a primary concern—especially when storing or transferring large sensory observations such as images, audio, or text—the :class:`~torchrl.data.replay_buffers.storages.CompressedListStorage` provides significant memory savings through compression. - -**Key features:** - -- **Memory Efficiency:** Achieves substantial memory savings via compression. -- **Data Integrity:** Maintains full data fidelity through lossless compression. -- **Flexible Compression:** Uses zstd compression by default, with support for custom compression algorithms. -- **TensorDict Support:** Seamlessly integrates with TensorDict structures. -- **Checkpointing:** Fully supports saving and loading compressed data. -- **Batched GPU Compression/Decompression:** Enables efficient replay buffer sampling directly from VRAM. - -The `CompressedListStorage` compresses data when storing and decompresses when retrieving, achieving compression ratios of 95x–122x for Atari images while maintaining full data fidelity. -We see these results in the Atari Learning Environment (ALE) from a rollout in Pong with a random policy for an episode at each compression level: - -+-------------------------------+--------+--------+--------+--------+--------+ -| Compression level of zstd | 1 | 3 | 8 | 12 | 22 | -+===============================+========+========+========+========+========+ -| Compression ratio in ALE Pong | 95x | 99x | 106x | 111x | 122x | -+-------------------------------+--------+--------+--------+--------+--------+ - -Example usage: - - >>> import torch - >>> from torchrl.data import ReplayBuffer, CompressedListStorage - >>> from tensordict import TensorDict - >>> - >>> # Create a compressed storage for image data - >>> storage = CompressedListStorage(max_size=1000, compression_level=3) - >>> rb = ReplayBuffer(storage=storage, batch_size=32) - >>> - >>> # Add image data - >>> images = torch.randn(100, 3, 84, 84) # Atari-like frames - >>> data = TensorDict({"obs": images}, batch_size=[100]) - >>> rb.extend(data) - >>> - >>> # Sample data (automatically decompressed) - >>> sample = rb.sample(32) - >>> print(sample["obs"].shape) # torch.Size([32, 3, 84, 84]) - -The compression level can be adjusted from 1 (fast, less compression) to 22 (slow, more compression), -with level 3 being a good default for most use cases. - -For custom compression algorithms: - - >>> def my_compress(tensor): - ... return tensor.to(torch.uint8) # Simple example - >>> - >>> def my_decompress(compressed_tensor, metadata): - ... return compressed_tensor.to(metadata["dtype"]) - >>> - >>> storage = CompressedListStorage( - ... max_size=1000, - ... compression_fn=my_compress, - ... decompression_fn=my_decompress - ... ) - -.. note:: The CompressedListStorage uses `zstd` for python versions of at least 3.14 and defaults to zlib otherwise. - -.. note:: Batched GPU compression relies on `nvidia.nvcomp`, please see example code - `examples/replay-buffers/compressed_replay_buffer.py `_. - -Sharing replay buffers across processes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Replay buffers can be shared between processes as long as their components are -sharable. This feature allows for multiple processes to collect data and populate a shared -replay buffer collaboratively, rather than centralizing the data on the main process -which can incur some data transmission overhead. - -Sharable storages include :class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage` -or any subclass of :class:`~torchrl.data.replay_buffers.storages.TensorStorage` -as long as they are instantiated and their content is stored as memory-mapped -tensors. Stateful writers such as :class:`~torchrl.data.replay_buffers.writers.TensorDictRoundRobinWriter` -are currently not sharable, and the same goes for stateful samplers such as -:class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`. - -A shared replay buffer can be read and extended on any process that has access -to it, as the following example shows: - - >>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage - >>> import torch - >>> from torch import multiprocessing as mp - >>> from tensordict import TensorDict - >>> - >>> def worker(rb): - ... # Updates the replay buffer with new data - ... td = TensorDict({"a": torch.ones(10)}, [10]) - ... rb.extend(td) - ... - >>> if __name__ == "__main__": - ... rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(21)) - ... td = TensorDict({"a": torch.zeros(10)}, [10]) - ... rb.extend(td) - ... - ... proc = mp.Process(target=worker, args=(rb,)) - ... proc.start() - ... proc.join() - ... # the replay buffer now has a length of 20, since the worker updated it - ... assert len(rb) == 20 - ... assert (rb["_data", "a"][:10] == 0).all() # data from main process - ... assert (rb["_data", "a"][10:20] == 1).all() # data from remote process - - -Storing trajectories -~~~~~~~~~~~~~~~~~~~~ - -It is not too difficult to store trajectories in the replay buffer. -One element to pay attention to is that the size of the replay buffer is by default -the size of the leading dimension of the storage: in other words, creating a -replay buffer with a storage of size 1M when storing multidimensional data -does not mean storing 1M frames but 1M trajectories. However, if trajectories -(or episodes/rollouts) are flattened before being stored, the capacity will still -be 1M steps. - -There is a way to circumvent this by telling the storage how many dimensions -it should take into account when saving data. This can be done through the ``ndim`` -keyword argument which is accepted by all contiguous storages such as -:class:`~torchrl.data.replay_buffers.TensorStorage` and the likes. When a -multidimensional storage is passed to a buffer, the buffer will automatically -consider the last dimension as the "time" dimension, as it is conventional in -TorchRL. This can be overridden through the ``dim_extend`` keyword argument -in :class:`~torchrl.data.ReplayBuffer`. -This is the recommended way to save trajectories that are obtained through -:class:`~torchrl.envs.ParallelEnv` or its serial counterpart, as we will see -below. - -When sampling trajectories, it may be desirable to sample sub-trajectories -to diversify learning or make the sampling more efficient. -TorchRL offers two distinctive ways of accomplishing this: - -- The :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` allows to - sample a given number of slices of trajectories stored one after another - along the leading dimension of the :class:`~torchrl.data.replay_buffers.samplers.TensorStorage`. - This is the recommended way of sampling sub-trajectories in TorchRL __especially__ - when using offline datasets (which are stored using that convention). - This strategy requires to flatten the trajectories before extending the replay - buffer and reshaping them after sampling. - The :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` class docstrings - gives extensive details about this storage and sampling strategy. - Note that :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` - is compatible with multidimensional storages. The following examples show - how to use this feature with and without flattening of the tensordict. - In the first scenario, we are collecting data from a single environment. In - that case, we are happy with a storage that concatenates the data coming in - along the first dimension, since there will be no interruption introduced - by the collection schedule: - - >>> from torchrl.envs import TransformedEnv, StepCounter, GymEnv - >>> from torchrl.collectors import SyncDataCollector, RandomPolicy - >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler - >>> env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter()) - >>> collector = SyncDataCollector(env, - ... RandomPolicy(env.action_spec), - ... frames_per_batch=10, total_frames=-1) - >>> rb = ReplayBuffer( - ... storage=LazyTensorStorage(100), - ... sampler=SliceSampler(num_slices=8, traj_key=("collector", "traj_ids"), - ... truncated_key=None, strict_length=False), - ... batch_size=64) - >>> for i, data in enumerate(collector): - ... rb.extend(data) - ... if i == 10: - ... break - >>> assert len(rb) == 100, len(rb) - >>> print(rb[:]["next", "step_count"]) - tensor([[32], - [33], - [34], - [35], - [36], - [37], - [38], - [39], - [40], - [41], - [11], - [12], - [13], - [14], - [15], - [16], - [17], - [... - - If there are more than one environment run in a batch, we could still store - the data in the same buffer as before by calling ``data.reshape(-1)`` which - will flatten the ``[B, T]`` size into ``[B * T]`` but that means that the - trajectories of, say, the first environment of the batch will be interleaved - by trajectories of the other environments, a scenario that ``SliceSampler`` - cannot handle. To solve this, we suggest to use the ``ndim`` argument in the - storage constructor: - - >>> env = TransformedEnv(SerialEnv(2, - ... lambda: GymEnv("CartPole-v1")), StepCounter()) - >>> collector = SyncDataCollector(env, - ... RandomPolicy(env.action_spec), - ... frames_per_batch=1, total_frames=-1) - >>> rb = ReplayBuffer( - ... storage=LazyTensorStorage(100, ndim=2), - ... sampler=SliceSampler(num_slices=8, traj_key=("collector", "traj_ids"), - ... truncated_key=None, strict_length=False), - ... batch_size=64) - >>> for i, data in enumerate(collector): - ... rb.extend(data) - ... if i == 100: - ... break - >>> assert len(rb) == 100, len(rb) - >>> print(rb[:]["next", "step_count"].squeeze()) - tensor([[ 6, 5], - [ 2, 2], - [ 3, 3], - [ 4, 4], - [ 5, 5], - [ 6, 6], - [ 7, 7], - [ 8, 8], - [ 9, 9], - [10, 10], - [11, 11], - [12, 12], - [13, 13], - [14, 14], - [15, 15], - [16, 16], - [17, 17], - [18, 1], - [19, 2], - [... - - -- Trajectories can also be stored independently, with the each element of the - leading dimension pointing to a different trajectory. This requires - for the trajectories to have a congruent shape (or to be padded). - We provide a custom :class:`~torchrl.envs.Transform` class named - :class:`~torchrl.envs.RandomCropTensorDict` that allows to sample - sub-trajectories in the buffer. Note that, unlike the :class:`~torchrl.data.replay_buffers.samplers.SliceSampler`-based - strategy, here having an ``"episode"`` or ``"done"`` key pointing at the - start and stop signals isn't required. - Here is an example of how this class can be used: - - .. code-block::Python - - >>> import torch - >>> from tensordict import TensorDict - >>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer - >>> from torchrl.envs import RandomCropTensorDict - >>> - >>> obs = torch.randn(100, 50, 1) - >>> data = TensorDict({"obs": obs[:-1], "next": {"obs": obs[1:]}}, [99]) - >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000)) - >>> rb.extend(data) - >>> # subsample trajectories of length 10 - >>> rb.append_transform(RandomCropTensorDict(sub_seq_len=10)) - >>> print(rb.sample(128)) - TensorDict( - fields={ - index: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int32, is_shared=False), - next: TensorDict( - fields={ - obs: Tensor(shape=torch.Size([10, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([10]), - device=None, - is_shared=False), - obs: Tensor(shape=torch.Size([10, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([10]), - device=None, - is_shared=False) - -Checkpointing Replay Buffers -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. _checkpoint-rb: - -Each component of the replay buffer can potentially be stateful and, as such, -require a dedicated way of being serialized. -Our replay buffer enjoys two separate APIs for saving their state on disk: -:meth:`~torchrl.data.ReplayBuffer.dumps` and :meth:`~torchrl.data.ReplayBuffer.loads` will save the -data of each component except transforms (storage, writer, sampler) using memory-mapped -tensors and json files for the metadata. - -This will work across all classes except -:class:`~torchrl.data.replay_buffers.storages.ListStorage`, which content -cannot be anticipated (and as such does not comply with memory-mapped data -structures such as those that can be found in the tensordict library). - -This API guarantees that a buffer that is saved and then loaded back will be in -the exact same state, whether we look at the status of its sampler (eg, priority trees) -its writer (eg, max writer heaps) or its storage. - -Under the hood, a naive call to :meth:`~torchrl.data.ReplayBuffer.dumps` will just call the public -`dumps` method in a specific folder for each of its components (except transforms -which we don't assume to be serializable using memory-mapped tensors in general). - -Saving data in :ref:`TED-format ` may however consume much more memory than required. If continuous -trajectories are stored in a buffer, we can avoid saving duplicated observations by saving all the -observations at the root plus only the last element of the `"next"` sub-tensordict's observations, which -can reduce the storage consumption up to two times. To enable this, three checkpointer classes are available: -:class:`~torchrl.data.FlatStorageCheckpointer` will discard duplicated observations to compress the TED format. At -load time, this class will re-write the observations in the correct format. If the buffer is saved on disk, -the operations executed by this checkpointer will not require any additional RAM. -The :class:`~torchrl.data.NestedStorageCheckpointer` will save the trajectories using nested tensors to make the data -representation more apparent (each item along the first dimension representing a distinct trajectory). -Finally, the :class:`~torchrl.data.H5StorageCheckpointer` will save the buffer in an H5DB format, enabling users to -compress the data and save some more space. - -.. warning:: The checkpointers make some restrictive assumption about the replay buffers. First, it is assumed that - the ``done`` state accurately represents the end of a trajectory (except for the last trajectory which was written - for which the writer cursor indicates where to place the truncated signal). For MARL usage, one should note that - only done states that have as many elements as the root tensordict are allowed: - if the done state has extra elements that are not represented in - the batch-size of the storage, these checkpointers will fail. For example, a done state with shape ``torch.Size([3, 4, 5])`` - within a storage of shape ``torch.Size([3, 4])`` is not allowed. - -Here is a concrete example of how an H5DB checkpointer could be used in practice: - - >>> from torchrl.data import ReplayBuffer, H5StorageCheckpointer, LazyMemmapStorage - >>> from torchrl.collectors import SyncDataCollector - >>> from torchrl.envs import GymEnv, SerialEnv - >>> import torch - >>> env = SerialEnv(3, lambda: GymEnv("CartPole-v1", device=None)) - >>> env.set_seed(0) - >>> torch.manual_seed(0) - >>> collector = SyncDataCollector( - >>> env, policy=env.rand_step, total_frames=200, frames_per_batch=22 - >>> ) - >>> rb = ReplayBuffer(storage=LazyMemmapStorage(100, ndim=2)) - >>> rb_test = ReplayBuffer(storage=LazyMemmapStorage(100, ndim=2)) - >>> rb.storage.checkpointer = H5StorageCheckpointer() - >>> rb_test.storage.checkpointer = H5StorageCheckpointer() - >>> for i, data in enumerate(collector): - ... rb.extend(data) - ... assert rb._storage.max_size == 102 - ... rb.dumps(path_to_save_dir) - ... rb_test.loads(path_to_save_dir) - ... assert_allclose_td(rb_test[:], rb[:]) - - -Whenever saving data using :meth:`~torchrl.data.ReplayBuffer.dumps` is not possible, an -alternative way is to use :meth:`~torchrl.data.ReplayBuffer.state_dict`, which returns a data -structure that can be saved using :func:`torch.save` and loaded using :func:`torch.load` -before calling :meth:`~torchrl.data.ReplayBuffer.load_state_dict`. The drawback -of this method is that it will struggle to save big data structures, which is a -common setting when using replay buffers. - -TorchRL Episode Data Format (TED) ---------------------------------- - -.. _TED-format: - -In TorchRL, sequential data is consistently presented in a specific format, known -as the TorchRL Episode Data Format (TED). This format is crucial for the seamless -integration and functioning of various components within TorchRL. - -Some components, such as replay buffers, are somewhat indifferent to the data -format. However, others, particularly environments, heavily depend on it for smooth operation. - -Therefore, it's essential to understand the TED, its purpose, and how to interact -with it. This guide will provide a clear explanation of the TED, why it's used, -and how to effectively work with it. - -The Rationale Behind TED -~~~~~~~~~~~~~~~~~~~~~~~~ - -Formatting sequential data can be a complex task, especially in the realm of -Reinforcement Learning (RL). As practitioners, we often encounter situations -where data is delivered at the reset time (though not always), and sometimes data -is provided or discarded at the final step of the trajectory. - -This variability means that we can observe data of different lengths in a dataset, -and it's not always immediately clear how to match each time step across the -various elements of this dataset. Consider the following ambiguous dataset structure: - - >>> observation.shape - [200, 3] - >>> action.shape - [199, 4] - >>> info.shape - [200, 3] - -At first glance, it seems that the info and observation were delivered -together (one of each at reset + one of each at each step call), as suggested by -the action having one less element. However, if info has one less element, we -must assume that it was either omitted at reset time or not delivered or recorded -for the last step of the trajectory. Without proper documentation of the data -structure, it's impossible to determine which info corresponds to which time step. - -Complicating matters further, some datasets provide inconsistent data formats, -where ``observations`` or ``infos`` are missing at the start or end of the -rollout, and this behavior is often not documented. -The primary aim of TED is to eliminate these ambiguities by providing a clear -and consistent data representation. - -The structure of TED -~~~~~~~~~~~~~~~~~~~~ - -TED is built upon the canonical definition of a Markov Decision Process (MDP) in RL contexts. -At each step, an observation conditions an action that results in (1) a new -observation, (2) an indicator of task completion (terminated, truncated, done), -and (3) a reward signal. - -Some elements may be missing (for example, the reward is optional in imitation -learning contexts), or additional information may be passed through a state or -info container. In some cases, additional information is required to get the -observation during a call to ``step`` (for instance, in stateless environment simulators). Furthermore, -in certain scenarios, an "action" (or any other data) cannot be represented as a -single tensor and needs to be organized differently. For example, in Multi-Agent RL -settings, actions, observations, rewards, and completion signals may be composite. - -TED accommodates all these scenarios with a single, uniform, and unambiguous -format. We distinguish what happens at time step ``t`` and ``t+1`` by setting a -limit at the time the action is executed. In other words, everything that was -present before ``env.step`` was called belongs to ``t``, and everything that -comes after belongs to ``t+1``. - -The general rule is that everything that belongs to time step ``t`` is stored -at the root of the tensordict, while everything that belongs to ``t+1`` is stored -in the ``"next"`` entry of the tensordict. Here's an example: - - >>> data = env.reset() - >>> data = policy(data) - >>> print(env.step(data)) - TensorDict( - fields={ - action: Tensor(...), # The action taken at time t - done: Tensor(...), # The done state when the action was taken (at reset) - next: TensorDict( # all of this content comes from the call to `step` - fields={ - done: Tensor(...), # The done state after the action has been taken - observation: Tensor(...), # The observation resulting from the action - reward: Tensor(...), # The reward resulting from the action - terminated: Tensor(...), # The terminated state after the action has been taken - truncated: Tensor(...), # The truncated state after the action has been taken - batch_size=torch.Size([]), - device=cpu, - is_shared=False), - observation: Tensor(...), # the observation at reset - terminated: Tensor(...), # the terminated at reset - truncated: Tensor(...), # the truncated at reset - batch_size=torch.Size([]), - device=cpu, - is_shared=False) - -During a rollout (either using :class:`~torchrl.envs.EnvBase` or -:class:`~torchrl.collectors.SyncDataCollector`), the content of the ``"next"`` -tensordict is brought to the root through the :func:`~torchrl.envs.utils.step_mdp` -function when the agent resets its step count: ``t <- t+1``. You can read more -about the environment API :ref:`here `. - -In most cases, there is no `True`-valued ``"done"`` state at the root since any -done state will trigger a (partial) reset which will turn the ``"done"`` to ``False``. -However, this is only true as long as resets are automatically performed. In some -cases, partial resets will not trigger a reset, so we retain these data, which -should have a considerably lower memory footprint than observations, for instance. - -This format eliminates any ambiguity regarding the matching of an observation with -its action, info, or done state. - -A note on singleton dimensions in TED -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. _reward_done_singleton: - -In TorchRL, the standard practice is that `done` states (including terminated and truncated) and rewards should have a -dimension that can be expanded to match the shape of observations, states, and actions without recurring to anything -else than repetition (i.e., the reward must have as many dimensions as the observation and/or action, or their -embeddings). - -Essentially, this format is acceptable (though not strictly enforced): - - >>> print(rollout[t]) - ... TensorDict( - ... fields={ - ... action: Tensor(n_action), - ... done: Tensor(1), # The done state has a rightmost singleton dimension - ... next: TensorDict( - ... fields={ - ... done: Tensor(1), - ... observation: Tensor(n_obs), - ... reward: Tensor(1), # The reward has a rightmost singleton dimension - ... terminated: Tensor(1), - ... truncated: Tensor(1), - ... batch_size=torch.Size([]), - ... device=cpu, - ... is_shared=False), - ... observation: Tensor(n_obs), # the observation at reset - ... terminated: Tensor(1), # the terminated at reset - ... truncated: Tensor(1), # the truncated at reset - ... batch_size=torch.Size([]), - ... device=cpu, - ... is_shared=False) - -The rationale behind this is to ensure that the results of operations (such as value estimation) on observations and/or -actions have the same number of dimensions as the reward and `done` state. This consistency allows subsequent operations -to proceed without issues: - - >>> state_value = f(observation) - >>> next_state_value = state_value + reward - -Without this singleton dimension at the end of the reward, broadcasting rules (which only work when tensors can be -expanded from the left) would try to expand the reward on the left. This could lead to failures (at best) or introduce -bugs (at worst). - -Flattening TED to reduce memory consumption -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -TED copies the observations twice in the memory, which can impact the feasibility of using this format -in practice. Since it is being used mostly for ease of representation, one can store the data -in a flat manner but represent it as TED during training. - -This is particularly useful when serializing replay buffers: -For instance, the :class:`~torchrl.data.TED2Flat` class ensures that a TED-formatted data -structure is flattened before being written to disk, whereas the :class:`~torchrl.data.Flat2TED` -load hook will unflatten this structure during deserialization. - - -Dimensionality of the Tensordict -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -During a rollout, all collected tensordicts will be stacked along a new dimension -positioned at the end. Both collectors and environments will label this dimension -with the ``"time"`` name. Here's an example: - - >>> rollout = env.rollout(10, policy) - >>> assert rollout.shape[-1] == 10 - >>> assert rollout.names[-1] == "time" - -This ensures that the time dimension is clearly marked and easily identifiable -in the data structure. - -Special cases and footnotes -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Multi-Agent data presentation -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The multi-agent data formatting documentation can be accessed in the :ref:`MARL environment API ` section. - -Memory-based policies (RNNs and Transformers) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In the examples provided above, only ``env.step(data)`` generates data that -needs to be read in the next step. However, in some cases, the policy also -outputs information that will be required in the next step. This is typically -the case for RNN-based policies, which output an action as well as a recurrent -state that needs to be used in the next step. -To accommodate this, we recommend users to adjust their RNN policy to write this -data under the ``"next"`` entry of the tensordict. This ensures that this content -will be brought to the root in the next step. More information can be found in -:class:`~torchrl.modules.GRUModule` and :class:`~torchrl.modules.LSTMModule`. - -Multi-step -^^^^^^^^^^ - -Collectors allow users to skip steps when reading the data, accumulating reward -for the upcoming n steps. This technique is popular in DQN-like algorithms like Rainbow. -The :class:`~torchrl.data.postprocs.MultiStep` class performs this data transformation -on batches coming out of collectors. In these cases, a check like the following -will fail since the next observation is shifted by n steps: - - >>> assert (data[..., 1:]["observation"] == data[..., :-1]["next", "observation"]).all() - -What about memory requirements? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Implemented naively, this data format consumes approximately twice the memory -that a flat representation would. In some memory-intensive settings -(for example, in the :class:`~torchrl.data.datasets.AtariDQNExperienceReplay` dataset), -we store only the ``T+1`` observation on disk and perform the formatting online at get time. -In other cases, we assume that the 2x memory cost is a small price to pay for a -clearer representation. However, generalizing the lazy representation for offline -datasets would certainly be a beneficial feature to have, and we welcome -contributions in this direction! - -Datasets --------- - -TorchRL provides wrappers around offline RL datasets. -These data are presented as :class:`~torchrl.data.ReplayBuffer` instances, which -means that they can be customized at will with transforms, samplers and storages. -For instance, entries can be filtered in or out of a dataset with :class:`~torchrl.envs.SelectTransform` -or :class:`~torchrl.envs.ExcludeTransform`. - -By default, datasets are stored as memory mapped tensors, allowing them to be -promptly sampled with virtually no memory footprint. - -Here's an example: - -.. code::Python - - >>> from torchrl.data.datasets import D4RLExperienceReplay - >>> from torchrl.data.replay_buffers import SamplerWithoutReplacement - >>> from torchrl.envs.transforms import RenameTransform - >>> dataset = D4RLExperienceReplay('kitchen-complete-v0', split_trajs=True, batch_size=10) - >>> print(dataset.sample()) # will sample 10 trajectories since split_trajs is set to True - TensorDict( - fields={ - action: Tensor(shape=torch.Size([10, 207, 9]), device=cpu, dtype=torch.float32, is_shared=False), - done: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - index: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.int32, is_shared=False), - infos: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.int64, is_shared=False), - mask: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([10, 207, 60]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([10, 207]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([10, 207, 60]), device=cpu, dtype=torch.float32, is_shared=False), - timeouts: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - traj_ids: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([10, 207]), - device=cpu, - is_shared=False) - >>> dataset.append_transform(RenameTransform(["done", ("next", "done")], ["terminal", ("next", "terminal")])) - >>> print(dataset.sample()) # The "done" has been renamed to "terminal" - TensorDict( - fields={ - action: Tensor(shape=torch.Size([10, 207, 9]), device=cpu, dtype=torch.float32, is_shared=False), - terminal: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - index: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.int32, is_shared=False), - infos: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.int64, is_shared=False), - mask: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - terminal: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([10, 207, 60]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([10, 207]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([10, 207, 60]), device=cpu, dtype=torch.float32, is_shared=False), - timeouts: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - traj_ids: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([10, 207]), - device=cpu, - is_shared=False) - >>> # we can also use a `SamplerWithoutReplacement` to iterate over the dataset with random samples: - >>> dataset = D4RLExperienceReplay( - ... 'kitchen-complete-v0', - ... sampler=SamplerWithoutReplacement(drop_last=True), - ... split_trajs=True, - ... batch_size=3) - >>> for data in dataset: - ... print(data) - ... - -.. note:: - - Installing dependencies is the responsibility of the user. For D4RL, a clone of - `the repository `_ is needed as - the latest wheels are not published on PyPI. For OpenML, `scikit-learn `_ and - `pandas `_ are required. - -Transforming datasets -~~~~~~~~~~~~~~~~~~~~~ - -In many instances, the raw data isn't going to be used as-is. -The natural solution could be to pass a :class:`~torchrl.envs.transforms.Transform` -instance to the dataset constructor and modify the sample on-the-fly. This will -work but it will incur an extra runtime for the transform. -If the transformations can be (at least a part) pre-applied to the dataset, -a conisderable disk space and some incurred overhead at sampling time can be -saved. To do this, the -:meth:`~torchrl.data.datasets.BaseDatasetExperienceReplay.preprocess` can be -used. This method will run a per-sample preprocessing pipeline on each element -of the dataset, and replace the existing dataset by its transformed version. - -Once transformed, re-creating the same dataset will produce another object with -the same transformed storage (unless ``download="force"`` is being used): - - >>> dataset = RobosetExperienceReplay( - ... "FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4", batch_size=32, download="force" - ... ) - >>> - >>> def func(data): - ... return data.set("obs_norm", data.get("observation").norm(dim=-1)) - ... - >>> dataset.preprocess( - ... func, - ... num_workers=max(1, os.cpu_count() - 2), - ... num_chunks=1000, - ... mp_start_method="fork", - ... ) - >>> sample = dataset.sample() - >>> assert "obs_norm" in sample.keys() - >>> # re-recreating the dataset gives us the transformed version back. - >>> dataset = RobosetExperienceReplay( - ... "FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4", batch_size=32 - ... ) - >>> sample = dataset.sample() - >>> assert "obs_norm" in sample.keys() - - -.. currentmodule:: torchrl.data.datasets - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - BaseDatasetExperienceReplay - AtariDQNExperienceReplay - D4RLExperienceReplay - GenDGRLExperienceReplay - MinariExperienceReplay - OpenMLExperienceReplay - OpenXExperienceReplay - RobosetExperienceReplay - VD4RLExperienceReplay - -Composing datasets -~~~~~~~~~~~~~~~~~~ - -In offline RL, it is customary to work with more than one dataset at the same time. -Moreover, TorchRL usually has a fine-grained dataset nomenclature, where -each task is represented separately when other libraries will represent these -datasets in a more compact way. To allow users to compose multiple datasets -together, we propose a :class:`~torchrl.data.replay_buffers.ReplayBufferEnsemble` -primitive that allows users to sample from multiple datasets at once. - -If the individual dataset formats differ, :class:`~torchrl.envs.Transform` instances -can be used. In the following example, we create two dummy datasets with semantically -identical entries that differ in names (``("some", "key")`` and ``"another_key"``) -and show how they can be renamed to have a matching name. We also resize images -such that they can be stacked together during sampling. - - >>> from torchrl.envs import Comopse, ToTensorImage, Resize, RenameTransform - >>> from torchrl.data import TensorDictReplayBuffer, ReplayBufferEnsemble, LazyMemmapStorage - >>> from tensordict import TensorDict - >>> import torch - >>> rb0 = TensorDictReplayBuffer( - ... storage=LazyMemmapStorage(10), - ... transform=Compose( - ... ToTensorImage(in_keys=["pixels", ("next", "pixels")]), - ... Resize(32, in_keys=["pixels", ("next", "pixels")]), - ... RenameTransform([("some", "key")], ["renamed"]), - ... ), - ... ) - >>> rb1 = TensorDictReplayBuffer( - ... storage=LazyMemmapStorage(10), - ... transform=Compose( - ... ToTensorImage(in_keys=["pixels", ("next", "pixels")]), - ... Resize(32, in_keys=["pixels", ("next", "pixels")]), - ... RenameTransform(["another_key"], ["renamed"]), - ... ), - ... ) - >>> rb = ReplayBufferEnsemble( - ... rb0, - ... rb1, - ... p=[0.5, 0.5], - ... transform=Resize(33, in_keys=["pixels"], out_keys=["pixels33"]), - ... ) - >>> data0 = TensorDict( - ... { - ... "pixels": torch.randint(255, (10, 244, 244, 3)), - ... ("next", "pixels"): torch.randint(255, (10, 244, 244, 3)), - ... ("some", "key"): torch.randn(10), - ... }, - ... batch_size=[10], - ... ) - >>> data1 = TensorDict( - ... { - ... "pixels": torch.randint(255, (10, 64, 64, 3)), - ... ("next", "pixels"): torch.randint(255, (10, 64, 64, 3)), - ... "another_key": torch.randn(10), - ... }, - ... batch_size=[10], - ... ) - >>> rb[0].extend(data0) - >>> rb[1].extend(data1) - >>> for _ in range(2): - ... sample = rb.sample(10) - ... assert sample["next", "pixels"].shape == torch.Size([2, 5, 3, 32, 32]) - ... assert sample["pixels"].shape == torch.Size([2, 5, 3, 32, 32]) - ... assert sample["pixels33"].shape == torch.Size([2, 5, 3, 33, 33]) - ... assert sample["renamed"].shape == torch.Size([2, 5]) - -.. currentmodule:: torchrl.data.replay_buffers - - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - ReplayBufferEnsemble - SamplerEnsemble - StorageEnsemble - WriterEnsemble - -TensorSpec ----------- - -.. _ref_specs: - -The :class:`~torchrl.data.TensorSpec` parent class and subclasses define the basic properties of state, observations -actions, rewards and done status in TorchRL, such as their shape, device, dtype and domain. - -It is important that your environment specs match the input and output that it sends and receives, as -:class:`~torchrl.envs.ParallelEnv` will create buffers from these specs to communicate with the spawn processes. -Check the :func:`torchrl.envs.utils.check_env_specs` method for a sanity check. - -If needed, specs can be automatically generated from data using the :func:`~torchrl.envs.utils.make_composite_from_td` -function. - -Specs fall in two main categories, numerical and categorical. - -.. table:: Numerical TensorSpec subclasses. - - +-------------------------------------------------------------------------------+ - | Numerical | - +=====================================+=========================================+ - | Bounded | Unbounded | - +-----------------+-------------------+-------------------+---------------------+ - | BoundedDiscrete | BoundedContinuous | UnboundedDiscrete | UnboundedContinuous | - +-----------------+-------------------+-------------------+---------------------+ - -Whenever a :class:`~torchrl.data.Bounded` instance is created, its domain (defined either implicitly by its dtype or -explicitly by the `"domain"` keyword argument) will determine if the instantiated class will be of :class:`~torchrl.data.BoundedContinuous` -or :class:`~torchrl.data.BoundedDiscrete` type. The same applies to the :class:`~torchrl.data.Unbounded` class. -See these classes for further information. - -.. table:: Categorical TensorSpec subclasses. - - +------------------------------------------------------------------+ - | Categorical | - +========+=============+=============+==================+==========+ - | OneHot | MultiOneHot | Categorical | MultiCategorical | Binary | - +--------+-------------+-------------+------------------+----------+ - -Unlike ``gymnasium``, TorchRL does not have the concept of an arbitrary list of specs. If multiple specs have to be -combined together, TorchRL assumes that the data will be presented as dictionaries (more specifically, as -:class:`~tensordict.TensorDict` or related formats). The corresponding :class:`~torchrl.data.TensorSpec` class in these -cases is the :class:`~torchrl.data.Composite` spec. - -Nevertheless, specs can be stacked together using :func:`~torch.stack`: if they are identical, their shape will be -expanded accordingly. -Otherwise, a lazy stack will be created through the :class:`~torchrl.data.Stacked` class. - -Similarly, ``TensorSpecs`` possess some common behavior with :class:`~torch.Tensor` and -:class:`~tensordict.TensorDict`: they can be reshaped, indexed, squeezed, unsqueezed, moved to another device (``to``) -or unbound (``unbind``) as regular :class:`~torch.Tensor` instances would be. - -Specs where some dimensions are ``-1`` are said to be "dynamic" and the negative dimensions indicate that the corresponding -data has an inconsistent shape. When seen by an optimizer or an environment (e.g., batched environment such as -:class:`~torchrl.envs.ParallelEnv`), these negative shapes tell TorchRL to avoid using buffers as the tensor shapes are -not predictable. - -.. currentmodule:: torchrl.data - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - TensorSpec - Binary - Bounded - Categorical - Composite - MultiCategorical - MultiOneHot - NonTensor - OneHot - Stacked - StackedComposite - Unbounded - UnboundedContinuous - UnboundedDiscrete - -The following classes are deprecated and just point to the classes above: - -.. currentmodule:: torchrl.data - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - BinaryDiscreteTensorSpec - BoundedTensorSpec - CompositeSpec - DiscreteTensorSpec - LazyStackedCompositeSpec - LazyStackedTensorSpec - MultiDiscreteTensorSpec - MultiOneHotDiscreteTensorSpec - NonTensorSpec - OneHotDiscreteTensorSpec - UnboundedContinuousTensorSpec - UnboundedDiscreteTensorSpec - -Trees and Forests ------------------ - -TorchRL offers a set of classes and functions that can be used to represent trees and forests efficiently, -which is particularly useful for Monte Carlo Tree Search (MCTS) algorithms. - -TensorDictMap -~~~~~~~~~~~~~ - -At its core, the MCTS API relies on the :class:`~torchrl.data.TensorDictMap` which acts like a storage where indices can -be any numerical object. In traditional storages (e.g., :class:`~torchrl.data.TensorStorage`), only integer indices -are allowed: - - >>> storage = TensorStorage(...) - >>> data = storage[3] - -:class:`~torchrl.data.TensorDictMap` allows us to make more advanced queries in the storage. The typical example is -when we have a storage containing a set of MDPs and we want to rebuild a trajectory given its initial observation, action -pair. In tensor terms, this could be written with the following pseudocode: - - >>> next_state = storage[observation, action] - -(if there is more than one next state associated with this pair one could return a stack of ``next_states`` instead). -This API would make sense but it would be restrictive: allowing observations or actions that are composed of -multiple tensors may be hard to implement. Instead, we provide a tensordict containing these values and let the storage -know what ``in_keys`` to look at to query the next state: - - >>> td = TensorDict(observation=observation, action=action) - >>> next_td = storage[td] - -Of course, this class also allows us to extend the storage with new data: - - >>> storage[td] = next_state - -This comes in handy because it allows us to represent complex rollout structures where different actions are undertaken -at a given node (ie, for a given observation). All `(observation, action)` pairs that have been observed may lead us to -a (set of) rollout that we can use further. - -MCTSForest -~~~~~~~~~~ - -Building a tree from an initial observation then becomes just a matter of organizing data efficiently. -The :class:`~torchrl.data.MCTSForest` has at its core two storages: a first storage links observations to hashes and -indices of actions encountered in the past in the dataset: - - >>> data = TensorDict(observation=observation) - >>> metadata = forest.node_map[data] - >>> index = metadata["_index"] - -where ``forest`` is a :class:`~torchrl.data.MCTSForest` instance. -Then, a second storage keeps track of the actions and results associated with the observation: - - >>> next_data = forest.data_map[index] - -The ``next_data`` entry can have any shape, but it will usually match the shape of ``index`` (since at each index -corresponds one action). Once ``next_data`` is obtained, it can be put together with ``data`` to form a set of nodes, -and the tree can be expanded for each of these. The following figure shows how this is done. - -.. figure:: /_static/img/collector-copy.png - - Building a :class:`~torchrl.data.Tree` from a :class:`~torchrl.data.MCTSForest` object. - The flowchart represents a tree being built from an initial observation `o`. The :class:`~torchrl.data.MCTSForest.get_tree` - method passed the input data structure (the root node) to the ``node_map`` :class:`~torchrl.data.TensorDictMap` instance - that returns a set of hashes and indices. These indices are then used to query the corresponding tuples of - actions, next observations, rewards etc. that are associated with the root node. - A vertex is created from each of them (possibly with a longer rollout when a compact representation is asked). - The stack of vertices is then used to build up the tree further, and these vertices are stacked together and make - up the branches of the tree at the root. This process is repeated for a given depth or until the tree cannot be - expanded anymore. - -.. currentmodule:: torchrl.data - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - BinaryToDecimal - HashToInt - MCTSForest - QueryModule - RandomProjectionHash - SipHash - TensorDictMap - TensorMap - Tree - - -Large language models and Reinforcement Learning From Human Feedback (RLHF) ---------------------------------------------------------------------------- - -.. warning:: - These APIs are deprecated and will be removed in the future. - Use the :mod:`torchrl.data.llm` module instead. - See the full :ref:`LLM documentation ` for more information. - -Data is of utmost importance in LLM post-training (e.g., GRPO or Reinforcement Learning from Human Feedback (RLHF)). -Given that these techniques are commonly employed in the realm of language, -which is scarcely addressed in other subdomains of RL within the library, -we offer specific utilities to facilitate interaction with external libraries -like datasets. These utilities consist of tools for tokenizing data, formatting -it in a manner suitable for TorchRL modules, and optimizing storage for -efficient sampling. - -.. currentmodule:: torchrl.data - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - PairwiseDataset - PromptData - PromptTensorDictTokenizer - RewardData - RolloutFromModel - TensorDictTokenizer - TokenizedDatasetLoader - create_infinite_iterator - get_dataloader - ConstantKLController - AdaptiveKLController - - -Utils ------ - -.. currentmodule:: torchrl.data - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - DensifyReward - Flat2TED - H5Combine - H5Split - MultiStep - Nested2TED - TED2Flat - TED2Nested - check_no_exclusive_keys - consolidate_spec - contains_lazy_spec - -.. currentmodule:: torchrl.envs.transforms.rb_transforms - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - MultiStepTransform +TorchRL provides a comprehensive data management system built around replay buffers, which are central to +off-policy RL algorithms. The library offers efficient implementations of various replay buffers with +composable components for storage, sampling, and data transformation. + +Key Features +------------ + +- **Flexible storage backends**: Memory, memmap, and compressed storage options +- **Advanced sampling strategies**: Prioritized, slice-based, and custom samplers +- **Composable design**: Mix and match storage, samplers, and writers +- **Type flexibility**: Support for tensors, tensordicts, and arbitrary data types +- **Efficient transforms**: Apply preprocessing during sampling +- **Distributed support**: Ray-based and remote replay buffers + +Quick Example +------------- + +.. code-block:: python + + from torchrl.data import ReplayBuffer, LazyMemmapStorage, PrioritizedSampler + from tensordict import TensorDict + + # Create a replay buffer with memmap storage and prioritized sampling + buffer = ReplayBuffer( + storage=LazyMemmapStorage(max_size=1000000), + sampler=PrioritizedSampler(max_capacity=1000000, alpha=0.7, beta=0.5), + batch_size=256, + ) + + # Add data + data = TensorDict({ + "observation": torch.randn(32, 4), + "action": torch.randn(32, 2), + "reward": torch.randn(32, 1), + }, batch_size=[32]) + buffer.extend(data) + + # Sample + sample = buffer.sample() # Returns batch_size=256 + +Documentation Sections +---------------------- + +.. toctree:: + :maxdepth: 2 + + data_replaybuffers + data_storage + data_samplers + data_datasets + data_specs diff --git a/docs/source/reference/data_datasets.rst b/docs/source/reference/data_datasets.rst new file mode 100644 index 00000000000..5946c0ac4df --- /dev/null +++ b/docs/source/reference/data_datasets.rst @@ -0,0 +1,19 @@ +.. currentmodule:: torchrl.data + +Datasets +======== + +TorchRL provides dataset utilities for offline RL and data management. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + datasets.AtariDQNExperienceReplay + datasets.D4RLExperienceReplay + datasets.Gen_DGRLExperienceReplay + datasets.MinariExperienceReplay + datasets.OpenMLExperienceReplay + datasets.OpenXExperienceReplay + datasets.RobosetExperienceReplay + datasets.VD4RLExperienceReplay diff --git a/docs/source/reference/data_replaybuffers.rst b/docs/source/reference/data_replaybuffers.rst new file mode 100644 index 00000000000..2fec1b2a94a --- /dev/null +++ b/docs/source/reference/data_replaybuffers.rst @@ -0,0 +1,51 @@ +.. currentmodule:: torchrl.data + +Replay Buffers +============== + +Replay buffers are a central part of off-policy RL algorithms. TorchRL provides an efficient implementation of a few, +widely used replay buffers: + +Core Replay Buffer Classes +-------------------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + ReplayBuffer + PrioritizedReplayBuffer + TensorDictReplayBuffer + TensorDictPrioritizedReplayBuffer + RayReplayBuffer + RemoteTensorDictReplayBuffer + +Composable Replay Buffers +------------------------- + +.. _ref_buffers: + +We also give users the ability to compose a replay buffer. +We provide a wide panel of solutions for replay buffer usage, including support for +almost any data type; storage in memory, on device or on physical memory; +several sampling strategies; usage of transforms etc. + +Supported data types and choosing a storage +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In theory, replay buffers support any data type but we can't guarantee that each +component will support any data type. The most crude replay buffer implementation +is made of a :class:`~torchrl.data.replay_buffers.ReplayBuffer` base with a +:class:`~torchrl.data.replay_buffers.ListStorage` storage. This is very inefficient +but it will allow you to store complex data structures with non-tensor data. +Storages in contiguous memory include :class:`~torchrl.data.replay_buffers.TensorStorage`, +:class:`~torchrl.data.replay_buffers.LazyTensorStorage` and +:class:`~torchrl.data.replay_buffers.LazyMemmapStorage`. + +Sampling and indexing +~~~~~~~~~~~~~~~~~~~~~ + +Replay buffers can be indexed and sampled. +Indexing and sampling collect data at given indices in the storage and then process them +through a series of transforms and ``collate_fn`` that can be passed to the `__init__` +function of the replay buffer. diff --git a/docs/source/reference/data_samplers.rst b/docs/source/reference/data_samplers.rst new file mode 100644 index 00000000000..0f8859408a7 --- /dev/null +++ b/docs/source/reference/data_samplers.rst @@ -0,0 +1,32 @@ +.. currentmodule:: torchrl.data.replay_buffers + +Sampling Strategies +=================== + +Samplers control how data is retrieved from the replay buffer storage. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + PrioritizedSampler + PrioritizedSliceSampler + RandomSampler + Sampler + SamplerWithoutReplacement + SliceSampler + SliceSamplerWithoutReplacement + +Writers +------- + +Writers control how data is written to the storage. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + RoundRobinWriter + TensorDictMaxValueWriter + TensorDictRoundRobinWriter + Writer diff --git a/docs/source/reference/data_specs.rst b/docs/source/reference/data_specs.rst new file mode 100644 index 00000000000..f1ca8a38c1f --- /dev/null +++ b/docs/source/reference/data_specs.rst @@ -0,0 +1,26 @@ +.. currentmodule:: torchrl.data + +TensorSpec System +================= + +TensorSpec classes define the shape, dtype, and domain of tensors in TorchRL. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + Binary + Bounded + Categorical + Composite + DiscreteTensorSpec + LazyStackedCompositeSpec + MultiCategorical + MultiDiscreteTensorSpec + MultiOneHot + NonTensor + OneHot + TensorSpec + Unbounded + UnboundedContinuous + UnboundedDiscrete diff --git a/docs/source/reference/data_storage.rst b/docs/source/reference/data_storage.rst new file mode 100644 index 00000000000..4f691c1ca05 --- /dev/null +++ b/docs/source/reference/data_storage.rst @@ -0,0 +1,37 @@ +.. currentmodule:: torchrl.data.replay_buffers + +Storage Backends +================ + +TorchRL provides various storage backends for replay buffers, each optimized for different use cases. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + CompressedListStorage + CompressedListStorageCheckpointer + FlatStorageCheckpointer + H5StorageCheckpointer + ImmutableDatasetWriter + LazyMemmapStorage + LazyTensorStorage + ListStorage + LazyStackStorage + ListStorageCheckpointer + NestedStorageCheckpointer + Storage + StorageCheckpointerBase + StorageEnsembleCheckpointer + TensorStorage + TensorStorageCheckpointer + +Storage Performance +------------------- + +Storage choice is very influential on replay buffer sampling latency, especially +in distributed reinforcement learning settings with larger data volumes. +:class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage` is highly +advised in distributed settings with shared storage due to the lower serialization +cost of MemoryMappedTensors as well as the ability to specify file storage locations +for improved node failure recovery. diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 3e807cc2f93..a5cd92e1a9a 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -3,1442 +3,60 @@ torchrl.envs package ==================== -.. _Environment-API: +.. _ref_envs: -TorchRL offers an API to handle environments of different backends, such as gym, -dm-control, dm-lab, model-based environments as well as custom environments. -The goal is to be able to swap environments in an experiment with little or no effort, -even if these environments are simulated using different libraries. -TorchRL offers some out-of-the-box environment wrappers under :obj:`torchrl.envs.libs`, -which we hope can be easily imitated for other libraries. -The parent class :class:`~torchrl.envs.EnvBase` is a :class:`torch.nn.Module` subclass that implements -some typical environment methods using :class:`tensordict.TensorDict` as a data organiser. This allows this -class to be generic and to handle an arbitrary number of input and outputs, as well as -nested or batched data structures. +TorchRL offers a comprehensive API to handle environments of different backends, making it easy to swap +environments in an experiment with minimal effort. The library provides wrappers for popular RL frameworks +including Gym, DMControl, Brax, Jumanji, and many others. -Each env will have the following attributes: +The :class:`~torchrl.envs.EnvBase` class serves as the foundation, providing a unified interface that uses +:class:`tensordict.TensorDict` for data organization. This design allows the framework to be generic and +handle an arbitrary number of inputs and outputs, as well as nested or batched data structures. -- :obj:`env.batch_size`: a :obj:`torch.Size` representing the number of envs - batched together. -- :obj:`env.device`: the device where the input and output tensordict are expected to live. - The environment device does not mean that the actual step operations will be computed on device - (this is the responsibility of the backend, with which TorchRL can do little). The device of - an environment just represents the device where the data is to be expected when input to the - environment or retrieved from it. TorchRL takes care of mapping the data to the desired device. - This is especially useful for transforms (see below). For parametric environments (e.g. - model-based environments), the device does represent the hardware that will be used to - compute the operations. -- :obj:`env.observation_spec`: a :class:`~torchrl.data.Composite` object - containing all the observation key-spec pairs. -- :obj:`env.state_spec`: a :class:`~torchrl.data.Composite` object - containing all the input key-spec pairs (except action). For most stateful - environments, this container will be empty. -- :obj:`env.action_spec`: a :class:`~torchrl.data.TensorSpec` object - representing the action spec. -- :obj:`env.reward_spec`: a :class:`~torchrl.data.TensorSpec` object representing - the reward spec. -- :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing - the done-flag spec. See the section on trajectory termination below. -- :obj:`env.input_spec`: a :class:`~torchrl.data.Composite` object containing - all the input keys (:obj:`"full_action_spec"` and :obj:`"full_state_spec"`). -- :obj:`env.output_spec`: a :class:`~torchrl.data.Composite` object containing - all the output keys (:obj:`"full_observation_spec"`, :obj:`"full_reward_spec"` and :obj:`"full_done_spec"`). +Key Features +------------ -If the environment carries non-tensor data, a :class:`~torchrl.data.NonTensor` -instance can be used. +- **Unified API**: Consistent interface across different environment backends +- **Vectorization**: Built-in support for parallel and batched environments +- **Transforms**: Powerful transform system for preprocessing observations and actions +- **Multi-agent**: Native support for multi-agent RL with no additional infrastructure +- **Flexible backends**: Easy integration with Gym, DMControl, Brax, and custom environments -Env specs: locks and batch size -------------------------------- - -.. _Environment-lock: - -Environment specs are locked by default (through a ``spec_locked`` arg passed to the env constructor). -Locking specs means that any modification of the spec (or its children if it is a :class:`~torchrl.data.Composite` -instance) will require to unlock it. This can be done via the :meth:`~torchrl.envs.EnvBase.set_spec_lock_`. -The reason specs are locked by default is that it makes it easy to cache values such as action or reset keys and the -likes. -Unlocking an env should only be done if it expected that the specs will be modified often (which, in principle, should -be avoided). -Modifications of the specs such as `env.observation_spec = new_spec` are allowed: under the hood, TorchRL will erase -the cache, unlock the specs, make the modification and relock the specs if the env was previously locked. - -Importantly, the environment spec shapes should contain the batch size, e.g. -an environment with :obj:`env.batch_size == torch.Size([4])` should have -an :obj:`env.action_spec` with shape :obj:`torch.Size([4, action_size])`. -This is helpful when preallocation tensors, checking shape consistency etc. - -Env methods ------------ - -With these, the following methods are implemented: - -- :meth:`env.reset`: a reset method that may (but not necessarily requires to) take - a :class:`tensordict.TensorDict` input. It return the first tensordict of a rollout, usually - containing a :obj:`"done"` state and a set of observations. If not present, - a `"reward"` key will be instantiated with 0s and the appropriate shape. -- :meth:`env.step`: a step method that takes a :class:`tensordict.TensorDict` input - containing an input action as well as other inputs (for model-based or stateless - environments, for instance). -- :meth:`env.step_and_maybe_reset`: executes a step, and (partially) resets the - environments if it needs to. It returns the updated input with a ``"next"`` - key containing the data of the next step, as well as a tensordict containing - the input data for the next step (ie, reset or result or - :func:`~torchrl.envs.utils.step_mdp`) - This is done by reading the ``done_keys`` and - assigning a ``"_reset"`` signal to each done state. This method allows - to code non-stopping rollout functions with little effort: - - >>> data_ = env.reset() - >>> result = [] - >>> for i in range(N): - ... data, data_ = env.step_and_maybe_reset(data_) - ... result.append(data) - ... - >>> result = torch.stack(result) - -- :meth:`env.set_seed`: a seeding method that will return the next seed - to be used in a multi-env setting. This next seed is deterministically computed - from the preceding one, such that one can seed multiple environments with a different - seed without risking to overlap seeds in consecutive experiments, while still - having reproducible results. -- :meth:`env.rollout`: executes a rollout in the environment for - a maximum number of steps (``max_steps=N``) and using a policy (``policy=model``). - The policy should be coded using a :class:`tensordict.nn.TensorDictModule` - (or any other :class:`tensordict.TensorDict`-compatible module). - The resulting :class:`tensordict.TensorDict` instance will be marked with - a trailing ``"time"`` named dimension that can be used by other modules - to treat this batched dimension as it should. - -The following figure summarizes how a rollout is executed in torchrl. - -.. figure:: /_static/img/rollout.gif - - TorchRL rollouts using TensorDict. - -In brief, a TensorDict is created by the :meth:`~.EnvBase.reset` method, -then populated with an action by the policy before being passed to the -:meth:`~.EnvBase.step` method which writes the observations, done flag(s) and -reward under the ``"next"`` entry. The result of this call is stored for -delivery and the ``"next"`` entry is gathered by the :func:`~.utils.step_mdp` -function. - -.. note:: - In general, all TorchRL environment have a ``"done"`` and ``"terminated"`` - entry in their output tensordict. If they are not present by design, - the :class:`~.EnvBase` metaclass will ensure that every done or terminated - is flanked with its dual. - In TorchRL, ``"done"`` strictly refers to the union of all the end-of-trajectory - signals and should be interpreted as "the last step of a trajectory" or - equivalently "a signal indicating the need to reset". - If the environment provides it (eg, Gymnasium), the truncation entry is also - written in the :meth:`EnvBase.step` output under a ``"truncated"`` entry. - If the environment carries a single value, it will interpreted as a ``"terminated"`` - signal by default. - By default, TorchRL's collectors and rollout methods will be looking for the ``"done"`` - entry to assess if the environment should be reset. - -.. note:: - - The `torchrl.collectors.utils.split_trajectories` function can be used to - slice adjacent trajectories. It relies on a ``"traj_ids"`` entry in the - input tensordict, or to the junction of ``"done"`` and ``"truncated"`` key - if the ``"traj_ids"`` is missing. - - -.. note:: - - In some contexts, it can be useful to mark the first step of a trajectory. - TorchRL provides such functionality through the :class:`~torchrl.envs.InitTracker` - transform. - - -Our environment :ref:`tutorial ` -provides more information on how to design a custom environment from scratch. - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - EnvBase - GymLikeEnv - EnvMetaData - -Partial steps and partial resets --------------------------------- - -TorchRL allows environments to reset some but not all the environments, or run a step in one but not all environments. -If there is only one environment in the batch, then a partial reset / step is also allowed with the behavior detailed -below. - -Batching environments and locking the batch -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. _ref_batch_locked: - -Before detailing what partial resets and partial steps do, we must distinguish cases where an environment has -a batch size of its own (mostly stateful environments) or when the environment is just a mere module that, given an -input of arbitrary size, batches the operations over all elements (mostly stateless environments). - -This is controlled via the :attr:`~torchrl.envs.batch_locked` attribute: a batch-locked environment requires all input -tensordicts to have the same batch-size as the env's. Typical examples of these environments are -:class:`~torchrl.envs.GymEnv` and related. Batch-unlocked envs are by contrast allowed to work with any input size. -Notable examples are :class:`~torchrl.envs.BraxEnv` or :class:`~torchrl.envs.JumanjiEnv`. - -Executing partial steps in a batch-unlocked environment is straightforward: one just needs to mask the part of the -tensordict that does not need to be executed, pass the other part to `step` and merge the results with the previous -input. - -Batched environments (:class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv`) can also deal with -partial steps easily, they just pass the actions to the sub-environments that are required to be executed. - -In all other cases, TorchRL assumes that the environment handles the partial steps correctly. - -.. warning:: This means that custom environments may silently run the non-required steps as there is no way for torchrl - to control what happens within the `_step` method! - -Partial Steps -~~~~~~~~~~~~~ - -.. _ref_partial_steps: - -Partial steps are controlled via the temporary key `"_step"` which points to a boolean mask of the -size of the tensordict that holds it. The classes armed to deal with this are: - -- Batched environments: :class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv` will dispatch the - action to and only to the environments where `"_step"` is `True`; -- Batch-unlocked environments; -- Unbatched environments (i.e., environments without batch size). In these environments, the :meth:`~torchrl.envs.EnvBase.step` - method will first look for a `"_step"` entry and, if present, act accordingly. - If a :class:`~torchrl.envs.Transform` instance passes a `"_step"` entry to the tensordict, it is also captured by - :class:`~torchrl.envs.TransformedEnv`'s own `_step` method which will skip the `base_env.step` as well as any further - transformation. - -When dealing with partial steps, the strategy is always to use the step output and mask missing values with the previous -content of the input tensordict, if present, or a `0`-valued tensor if the tensor cannot be found. This means that -if the input tensordict does not contain all the previous observations, then the output tensordict will be 0-valued for -all the non-stepped elements. Within batched environments, data collectors and rollouts utils, this is an issue that -is not observed because these classes handle the passing of data properly. - -Partial steps are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_all_done` is `True`, -as the environments with a `True` done state will need to be skipped during calls to `_step`. - -The :class:`~torchrl.envs.ConditionalSkip` transform allows you to programmatically ask for (partial) step skips. - -Partial Resets -~~~~~~~~~~~~~~ - -.. _ref_partial_resets: - -Partial resets work pretty much like partial steps, but with the `"_reset"` entry. - -The same restrictions of partial steps apply to partial resets. - -Likewise, partial resets are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_any_done` is `True`, -as the environments with a `True` done state will need to be reset, but not others. - -See te following paragraph for a deep dive in partial resets within batched and vectorized environments. - -Vectorized envs ---------------- - -Vectorized (or better: parallel) environments is a common feature in Reinforcement Learning -where executing the environment step can be cpu-intensive. -Some libraries such as `gym3 `_ or `EnvPool `_ -offer interfaces to execute batches of environments simultaneously. -While they often offer a very competitive computational advantage, they do not -necessarily scale to the wide variety of environment libraries supported by TorchRL. -Therefore, TorchRL offers its own, generic :class:`ParallelEnv` class to run multiple -environments in parallel. -As this class inherits from :class:`SerialEnv`, it enjoys the exact same API as other environment. -Of course, a :class:`ParallelEnv` will have a batch size that corresponds to its environment count: - -.. note:: - Given the library's many optional dependencies (eg, Gym, Gymnasium, and many others) - warnings can quickly become quite annoying in multiprocessed / distributed settings. - By default, TorchRL filters out these warnings in sub-processes. If one still wishes to - see these warnings, they can be displayed by setting ``torchrl.filter_warnings_subprocess=False``. - -It is important that your environment specs match the input and output that it sends and receives, as -:class:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes. -Check the :func:`~torchrl.envs.utils.check_env_specs` method for a sanity check. - -.. code-block:: - :caption: Parallel environment - - >>> def make_env(): - ... return GymEnv("Pendulum-v1", from_pixels=True, g=9.81, device="cuda:0") - >>> check_env_specs(env) # this must pass for ParallelEnv to work - >>> env = ParallelEnv(4, make_env) - >>> print(env.batch_size) - torch.Size([4]) - -:class:`ParallelEnv` allows to retrieve the attributes from its contained environments: -one can simply call: - -.. code-block:: - :caption: Parallel environment attributes - - >>> a, b, c, d = env.g # gets the g-force of the various envs, which we set to 9.81 before - >>> print(a) - 9.81 - -TorchRL uses a private ``"_reset"`` key to indicate to the environment which -component (sub-environments or agents) should be reset. -This allows to reset some but not all of the components. - -The ``"_reset"`` key has two distinct functionalities: - -1. During a call to :meth:`~.EnvBase._reset`, the ``"_reset"`` key may or may - not be present in the input tensordict. TorchRL's convention is that the - absence of the ``"_reset"`` key at a given ``"done"`` level indicates - a total reset of that level (unless a ``"_reset"`` key was found at a level - above, see details below). - If it is present, it is expected that those entries and only those components - where the ``"_reset"`` entry is ``True`` (along key and shape dimension) will be reset. - - The way an environment deals with the ``"_reset"`` keys in its :meth:`~.EnvBase._reset` - method is proper to its class. - Designing an environment that behaves according to ``"_reset"`` inputs is the - developer's responsibility, as TorchRL has no control over the inner logic - of :meth:`~.EnvBase._reset`. Nevertheless, the following point should be - kept in mind when designing that method. - -2. After a call to :meth:`~.EnvBase._reset`, the output will be masked with the - ``"_reset"`` entries and the output of the previous :meth:`~.EnvBase.step` - will be written wherever the ``"_reset"`` was ``False``. In practice, this - means that if a ``"_reset"`` modifies data that isn't exposed by it, this - modification will be lost. After this masking operation, the ``"_reset"`` - entries will be erased from the :meth:`~.EnvBase.reset` outputs. - -It must be pointed out that ``"_reset"`` is a private key, and it should only be -used when coding specific environment features that are internal facing. -In other words, this should NOT be used outside of the library, and developers -will keep the right to modify the logic of partial resets through ``"_reset"`` -setting without preliminary warranty, as long as they don't affect TorchRL -internal tests. - -Finally, the following assumptions are made and should be kept in mind when -designing reset functionalities: - -- Each ``"_reset"`` is paired with a ``"done"`` entry (+ ``"terminated"`` and, - possibly, ``"truncated"``). This means that the following structure is not - allowed: ``TensorDict({"done": done, "nested": {"_reset": reset}}, [])``, as - the ``"_reset"`` lives at a different nesting level than the ``"done"``. -- A reset at one level does not preclude the presence of a ``"_reset"`` at lower - levels, but it annihilates its effects. The reason is simply that - whether the ``"_reset"`` at the root level corresponds to an ``all()``, ``any()`` - or custom call to the nested ``"done"`` entries cannot be known in advance, - and it is explicitly assumed that the ``"_reset"`` at the root was placed - there to supersede the nested values (for an example, have a look at - :class:`~.PettingZooWrapper` implementation where each group has one or more - ``"done"`` entries associated which is aggregated at the root level with a - ``any`` or ``all`` logic depending on the task). -- When calling :meth:`env.reset(tensordict)` with a partial ``"_reset"`` entry - that will reset some but not all the done sub-environments, the input data - should contain the data of the sub-environments that are __not__ being reset. - The reason for this constrain lies in the fact that the output of the - ``env._reset(data)`` can only be predicted for the entries that are reset. - For the others, TorchRL cannot know in advance if they will be meaningful or - not. For instance, one could perfectly just pad the values of the non-reset - components, in which case the non-reset data will be meaningless and should - be discarded. - -Below, we give some examples of the expected effect that ``"_reset"`` keys will -have on an environment returning zeros after reset: - - >>> # single reset at the root - >>> data = TensorDict({"val": [1, 1], "_reset": [False, True]}, []) - >>> env.reset(data) - >>> print(data.get("val")) # only the second value is 0 - tensor([1, 0]) - >>> # nested resets - >>> data = TensorDict({ - ... ("agent0", "val"): [1, 1], ("agent0", "_reset"): [False, True], - ... ("agent1", "val"): [2, 2], ("agent1", "_reset"): [True, False], - ... }, []) - >>> env.reset(data) - >>> print(data.get(("agent0", "val"))) # only the second value is 0 - tensor([1, 0]) - >>> print(data.get(("agent1", "val"))) # only the first value is 0 - tensor([0, 2]) - >>> # nested resets are overridden by a "_reset" at the root - >>> data = TensorDict({ - ... "_reset": [True, True], - ... ("agent0", "val"): [1, 1], ("agent0", "_reset"): [False, True], - ... ("agent1", "val"): [2, 2], ("agent1", "_reset"): [True, False], - ... }, []) - >>> env.reset(data) - >>> print(data.get(("agent0", "val"))) # reset at the root overrides nested - tensor([0, 0]) - >>> print(data.get(("agent1", "val"))) # reset at the root overrides nested - tensor([0, 0]) - -.. code-block:: - :caption: Parallel environment reset - - >>> tensordict = TensorDict({"_reset": [[True], [False], [True], [True]]}, [4]) - >>> env.reset(tensordict) # eliminates the "_reset" entry - TensorDict( - fields={ - terminated: Tensor(torch.Size([4, 1]), dtype=torch.bool), - done: Tensor(torch.Size([4, 1]), dtype=torch.bool), - pixels: Tensor(torch.Size([4, 500, 500, 3]), dtype=torch.uint8), - truncated: Tensor(torch.Size([4, 1]), dtype=torch.bool), - batch_size=torch.Size([4]), - device=None, - is_shared=True) - - -.. note:: - - *A note on performance*: launching a :class:`~.ParallelEnv` can take quite some time - as it requires to launch as many python instances as there are processes. Due to - the time that it takes to run ``import torch`` (and other imports), starting the - parallel env can be a bottleneck. This is why, for instance, TorchRL tests are so slow. - Once the environment is launched, a great speedup should be observed. - -.. note:: - - *TorchRL requires precise specs*: Another thing to take in consideration is - that :class:`ParallelEnv` (as well as data collectors) - will create data buffers based on the environment specs to pass data from one process - to another. This means that a misspecified spec (input, observation or reward) will - cause a breakage at runtime as the data can't be written on the preallocated buffer. - In general, an environment should be tested using the :func:`~.utils.check_env_specs` - test function before being used in a :class:`ParallelEnv`. This function will raise - an assertion error whenever the preallocated buffer and the collected data mismatch. - -We also offer the :class:`~.SerialEnv` class that enjoys the exact same API but is executed -serially. This is mostly useful for testing purposes, when one wants to assess the -behavior of a :class:`~.ParallelEnv` without launching the subprocesses. - -In addition to :class:`~.ParallelEnv`, which offers process-based parallelism, we also provide a way to create -multithreaded environments with :obj:`~.MultiThreadedEnv`. This class uses `EnvPool `_ -library underneath, which allows for higher performance, but at the same time restricts flexibility - one can only -create environments implemented in ``EnvPool``. This covers many popular RL environments types (Atari, Classic Control, -etc.), but one can not use an arbitrary TorchRL environment, as it is possible with :class:`~.ParallelEnv`. Run -`benchmarks/benchmark_batched_envs.py` to compare performance of different ways to parallelize batched environments. - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - SerialEnv - ParallelEnv - EnvCreator - -Async environments ------------------- - -Asynchronous environments allow for parallel execution of multiple environments, which can significantly speed up the -data collection process in reinforcement learning. - -The `AsyncEnvPool` class and its subclasses provide a flexible interface for managing these environments using different -backends, such as threading and multiprocessing. - -The `AsyncEnvPool` class serves as a base class for asynchronous environment pools, providing a common interface for -managing multiple environments concurrently. It supports different backends for parallel execution, such as threading -and multiprocessing, and provides methods for asynchronous stepping and resetting of environments. - -Contrary to :class:`~torchrl.envs.ParallelEnv`, :class:`~torchrl.envs.AsyncEnvPool` and its subclasses permit the -execution of a given set of sub-environments while another task performed, allowing for complex asynchronous jobs to be -run at the same time. For instance, it is possible to execute some environments while the policy is running based on -the output of others. - -This family of classes is particularly interesting when dealing with environments that have a high (and/or variable) -latency. - -.. note:: This class and its subclasses should work when nested in with :class:`~torchrl.envs.TransformedEnv` and - batched environments, but users won't currently be able to use the async features of the base environment when - it's nested in these classes. One should prefer nested transformed envs within an `AsyncEnvPool` instead. - If this is not possible, please raise an issue. - -Classes -~~~~~~~ - -- :class:`~torchrl.envs.AsyncEnvPool`: A base class for asynchronous environment pools. It determines the backend - implementation to use based on the provided arguments and manages the lifecycle of the environments. -- :class:`~torchrl.envs.ProcessorAsyncEnvPool`: An implementation of :class:`~torchrl.envs.AsyncEnvPool` using - multiprocessing for parallel execution of environments. This class manages a pool of environments, each running in - its own process, and provides methods for asynchronous stepping and resetting of environments using inter-process - communication. It is automatically instantiated when `"multiprocessing"` is passed as a backend during the - :class:`~torchrl.envs.AsyncEnvPool` instantiation. -- :class:`~torchrl.envs.ThreadingAsyncEnvPool`: An implementation of :class:`~torchrl.envs.AsyncEnvPool` using - threading for parallel execution of environments. This class manages a pool of environments, each running in its own - thread, and provides methods for asynchronous stepping and resetting of environments using a thread pool executor. - It is automatically instantiated when `"threading"` is passed as a backend during the - :class:`~torchrl.envs.AsyncEnvPool` instantiation. - -Example -~~~~~~~ - - >>> from functools import partial - >>> from torchrl.envs import AsyncEnvPool, GymEnv - >>> import torch - >>> # Choose backend - >>> backend = "threading" - >>> env = AsyncEnvPool( - >>> [partial(GymEnv, "Pendulum-v1"), partial(GymEnv, "CartPole-v1")], - >>> stack="lazy", - >>> backend=backend - >>> ) - >>> # Execute a synchronous reset - >>> reset = env.reset() - >>> print(reset) - >>> # Execute a synchronous step - >>> s = env.rand_step(reset) - >>> print(s) - >>> # Execute an asynchronous step in env 0 - >>> s0 = s[0] - >>> s0["action"] = torch.randn(1).clamp(-1, 1) - >>> s0["env_index"] = 0 - >>> env.async_step_send(s0) - >>> # Receive data - >>> s0_result = env.async_step_recv() - >>> print('result', s0_result) - >>> # Close env - >>> env.close() - - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - AsyncEnvPool - ProcessorAsyncEnvPool - ThreadingAsyncEnvPool - - -Custom native TorchRL environments ----------------------------------- - -TorchRL offers a series of custom built-in environments. - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - ChessEnv - PendulumEnv - TicTacToeEnv - LLMHashingEnv - - -Multi-agent environments ------------------------- - -.. _MARL-environment-API: - -.. currentmodule:: torchrl.envs - -TorchRL supports multi-agent learning out-of-the-box. -*The same classes used in a single-agent learning pipeline can be seamlessly used in multi-agent contexts, -without any modification or dedicated multi-agent infrastructure.* - -In this view, environments play a core role for multi-agent. In multi-agent environments, -many decision-making agents act in a shared world. -Agents can observe different things, act in different ways and also be rewarded differently. -Therefore, many paradigms exist to model multi-agent environments (DecPODPs, Markov Games). -Some of the main differences between these paradigms include: - -- **observation** can be per-agent and also have some shared components -- **reward** can be per-agent or shared -- **done** (and ``"truncated"`` or ``"terminated"``) can be per-agent or shared. - -TorchRL accommodates all these possible paradigms thanks to its :class:`tensordict.TensorDict` data carrier. -In particular, in multi-agent environments, per-agent keys will be carried in a nested "agents" TensorDict. -This TensorDict will have the additional agent dimension and thus group data that is different for each agent. -The shared keys, on the other hand, will be kept in the first level, as in single-agent cases. - -Let's look at an example to understand this better. For this example we are going to use -`VMAS `_, a multi-robot task simulator also -based on PyTorch, which runs parallel batched simulation on device. - -We can create a VMAS environment and look at what the output from a random step looks like: - -.. code-block:: - :caption: Example of multi-agent step tensordict - - >>> from torchrl.envs.libs.vmas import VmasEnv - >>> env = VmasEnv("balance", num_envs=3, n_agents=5) - >>> td = env.rand_step() - >>> td - TensorDict( - fields={ - agents: TensorDict( - fields={ - action: Tensor(shape=torch.Size([3, 5, 2]))}, - batch_size=torch.Size([3, 5])), - next: TensorDict( - fields={ - agents: TensorDict( - fields={ - info: TensorDict( - fields={ - ground_rew: Tensor(shape=torch.Size([3, 5, 1])), - pos_rew: Tensor(shape=torch.Size([3, 5, 1]))}, - batch_size=torch.Size([3, 5])), - observation: Tensor(shape=torch.Size([3, 5, 16])), - reward: Tensor(shape=torch.Size([3, 5, 1]))}, - batch_size=torch.Size([3, 5])), - done: Tensor(shape=torch.Size([3, 1]))}, - batch_size=torch.Size([3]))}, - batch_size=torch.Size([3])) - -We can observe that *keys that are shared by all agents*, such as **done** are present in the root tensordict with -batch size `(num_envs,)`, which represents the number of environments simulated. - -On the other hand, *keys that are different between agents*, such as **action**, **reward**, **observation**, -and **info** are present in the nested "agents" tensordict with batch size `(num_envs, n_agents)`, -which represents the additional agent dimension. - -Multi-agent tensor specs will follow the same style as in tensordicts. -Specs relating to values that vary between agents will need to be nested in the "agents" entry. - -Here is an example of how specs can be created in a multi-agent environment where -only the done flag is shared across agents (as in VMAS): - -.. code-block:: - :caption: Example of multi-agent spec creation - - >>> action_specs = [] - >>> observation_specs = [] - >>> reward_specs = [] - >>> info_specs = [] - >>> for i in range(env.n_agents): - ... action_specs.append(agent_i_action_spec) - ... reward_specs.append(agent_i_reward_spec) - ... observation_specs.append(agent_i_observation_spec) - >>> env.action_spec = Composite( - ... { - ... "agents": Composite( - ... {"action": torch.stack(action_specs)}, shape=(env.n_agents,) - ... ) - ... } - ...) - >>> env.reward_spec = Composite( - ... { - ... "agents": Composite( - ... {"reward": torch.stack(reward_specs)}, shape=(env.n_agents,) - ... ) - ... } - ...) - >>> env.observation_spec = Composite( - ... { - ... "agents": Composite( - ... {"observation": torch.stack(observation_specs)}, shape=(env.n_agents,) - ... ) - ... } - ...) - >>> env.done_spec = Categorical( - ... n=2, - ... shape=torch.Size((1,)), - ... dtype=torch.bool, - ... ) - -As you can see, it is very simple! Per-agent keys will have the nested composite spec and shared keys will follow -single agent standards. - -.. note:: - Since reward, done and action keys may have the additional "agent" prefix (e.g., `("agents","action")`), - the default keys used in the arguments of other TorchRL components (e.g. "action") will not match exactly. - Therefore, TorchRL provides the `env.action_key`, `env.reward_key`, and `env.done_key` attributes, - which will automatically point to the right key to use. Make sure you pass these attributes to the various - components in TorchRL to inform them of the right key (e.g., the `loss.set_keys()` function). - -.. note:: - TorchRL abstracts these nested specs away for ease of use. - This means that accessing `env.reward_spec` will always return the leaf - spec if the accessed spec is Composite. Therefore, if in the example above - we run `env.reward_spec` after env creation, we would get the same output as `torch.stack(reward_specs)}`. - To get the full composite spec with the "agents" key, you can run - `env.output_spec["full_reward_spec"]`. The same is valid for action and done specs. - Note that `env.reward_spec == env.output_spec["full_reward_spec"][env.reward_key]`. - - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - MarlGroupMapType - check_marl_grouping - -Auto-resetting Envs -------------------- - -.. _autoresetting_envs: - -Auto-resetting environments are environments where calls to :meth:`~torchrl.envs.EnvBase.reset` are not expected when -the environment reaches a ``"done"`` state during a rollout, as the reset happens automatically. -Usually, in such cases the observations delivered with the done and reward (which effectively result from performing the -action in the environment) are actually the first observations of a new episode, and not the last observations of the -current episode. - -To handle these cases, torchrl provides a :class:`~torchrl.envs.AutoResetTransform` that will copy the observations -that result from the call to `step` to the next `reset` and skip the calls to `reset` during rollouts (in both -:meth:`~torchrl.envs.EnvBase.rollout` and :class:`~torchrl.collectors.SyncDataCollector` iterations). -This transform class also provides a fine-grained control over the behavior to be adopted for the invalid observations, -which can be masked with `"nan"` or any other values, or not masked at all. - -To tell torchrl that an environment is auto-resetting, it is sufficient to provide an ``auto_reset`` argument -during construction. If provided, an ``auto_reset_replace`` argument can also control whether the values of the last -observation of an episode should be replaced with some placeholder or not. - - >>> from torchrl.envs import GymEnv - >>> from torchrl.envs import set_gym_backend - >>> import torch - >>> torch.manual_seed(0) - >>> - >>> class AutoResettingGymEnv(GymEnv): - ... def _step(self, tensordict): - ... tensordict = super()._step(tensordict) - ... if tensordict["done"].any(): - ... td_reset = super().reset() - ... tensordict.update(td_reset.exclude(*self.done_keys)) - ... return tensordict - ... - ... def _reset(self, tensordict=None): - ... if tensordict is not None and "_reset" in tensordict: - ... return tensordict.copy() - ... return super()._reset(tensordict) - >>> - >>> with set_gym_backend("gym"): - ... env = AutoResettingGymEnv("CartPole-v1", auto_reset=True, auto_reset_replace=True) - ... env.set_seed(0) - ... r = env.rollout(30, break_when_any_done=False) - >>> print(r["next", "done"].squeeze()) - tensor([False, False, False, False, False, False, False, False, False, False, - False, False, False, True, False, False, False, False, False, False, - False, False, False, False, False, True, False, False, False, False]) - >>> print("observation after reset are set as nan", r["next", "observation"]) - observation after reset are set as nan tensor([[-4.3633e-02, -1.4877e-01, 1.2849e-02, 2.7584e-01], - [-4.6609e-02, 4.6166e-02, 1.8366e-02, -1.2761e-02], - [-4.5685e-02, 2.4102e-01, 1.8111e-02, -2.9959e-01], - [-4.0865e-02, 4.5644e-02, 1.2119e-02, -1.2542e-03], - [-3.9952e-02, 2.4059e-01, 1.2094e-02, -2.9009e-01], - [-3.5140e-02, 4.3554e-01, 6.2920e-03, -5.7893e-01], - [-2.6429e-02, 6.3057e-01, -5.2867e-03, -8.6963e-01], - [-1.3818e-02, 8.2576e-01, -2.2679e-02, -1.1640e+00], - [ 2.6972e-03, 1.0212e+00, -4.5959e-02, -1.4637e+00], - [ 2.3121e-02, 1.2168e+00, -7.5232e-02, -1.7704e+00], - [ 4.7457e-02, 1.4127e+00, -1.1064e-01, -2.0854e+00], - [ 7.5712e-02, 1.2189e+00, -1.5235e-01, -1.8289e+00], - [ 1.0009e-01, 1.0257e+00, -1.8893e-01, -1.5872e+00], - [ nan, nan, nan, nan], - [-3.9405e-02, -1.7766e-01, -1.0403e-02, 3.0626e-01], - [-4.2959e-02, -3.7263e-01, -4.2775e-03, 5.9564e-01], - [-5.0411e-02, -5.6769e-01, 7.6354e-03, 8.8698e-01], - [-6.1765e-02, -7.6292e-01, 2.5375e-02, 1.1820e+00], - [-7.7023e-02, -9.5836e-01, 4.9016e-02, 1.4826e+00], - [-9.6191e-02, -7.6387e-01, 7.8667e-02, 1.2056e+00], - [-1.1147e-01, -9.5991e-01, 1.0278e-01, 1.5219e+00], - [-1.3067e-01, -7.6617e-01, 1.3322e-01, 1.2629e+00], - [-1.4599e-01, -5.7298e-01, 1.5848e-01, 1.0148e+00], - [-1.5745e-01, -7.6982e-01, 1.7877e-01, 1.3527e+00], - [-1.7285e-01, -9.6668e-01, 2.0583e-01, 1.6956e+00], - [ nan, nan, nan, nan], - [-4.3962e-02, 1.9845e-01, -4.5015e-02, -2.5903e-01], - [-3.9993e-02, 3.9418e-01, -5.0196e-02, -5.6557e-01], - [-3.2109e-02, 5.8997e-01, -6.1507e-02, -8.7363e-01], - [-2.0310e-02, 3.9574e-01, -7.8980e-02, -6.0090e-01]]) - -Dynamic Specs +Quick Example ------------- -.. _dynamic_envs: - -Running environments in parallel is usually done via the creation of memory buffers used to pass information from one -process to another. In some cases, it may be impossible to forecast whether an environment will or will not have -consistent inputs or outputs during a rollout, as their shape may be variable. We refer to this as dynamic specs. - -TorchRL is capable of handling dynamic specs, but the batched environments and collectors will need to be made -aware of this feature. Note that, in practice, this is detected automatically. - -To indicate that a tensor will have a variable size along a dimension, one can set the size value as ``-1`` for the -desired dimensions. Because the data cannot be stacked contiguously, calls to ``env.rollout`` need to be made with -the ``return_contiguous=False`` argument. -Here is a working example: - - >>> from torchrl.envs import EnvBase - >>> from torchrl.data import Unbounded, Composite, Bounded, Binary - >>> import torch - >>> from tensordict import TensorDict, TensorDictBase - >>> - >>> class EnvWithDynamicSpec(EnvBase): - ... def __init__(self, max_count=5): - ... super().__init__(batch_size=()) - ... self.observation_spec = Composite( - ... observation=Unbounded(shape=(3, -1, 2)), - ... ) - ... self.action_spec = Bounded(low=-1, high=1, shape=(2,)) - ... self.full_done_spec = Composite( - ... done=Binary(1, shape=(1,), dtype=torch.bool), - ... terminated=Binary(1, shape=(1,), dtype=torch.bool), - ... truncated=Binary(1, shape=(1,), dtype=torch.bool), - ... ) - ... self.reward_spec = Unbounded((1,), dtype=torch.float) - ... self.count = 0 - ... self.max_count = max_count - ... - ... def _reset(self, tensordict=None): - ... self.count = 0 - ... data = TensorDict( - ... { - ... "observation": torch.full( - ... (3, self.count + 1, 2), - ... self.count, - ... dtype=self.observation_spec["observation"].dtype, - ... ) - ... } - ... ) - ... data.update(self.done_spec.zero()) - ... return data - ... - ... def _step( - ... self, - ... tensordict: TensorDictBase, - ... ) -> TensorDictBase: - ... self.count += 1 - ... done = self.count >= self.max_count - ... observation = TensorDict( - ... { - ... "observation": torch.full( - ... (3, self.count + 1, 2), - ... self.count, - ... dtype=self.observation_spec["observation"].dtype, - ... ) - ... } - ... ) - ... done = self.full_done_spec.zero() | done - ... reward = self.full_reward_spec.zero() - ... return observation.update(done).update(reward) - ... - ... def _set_seed(self, seed: Optional[int]) -> None: - ... self.manual_seed = seed - ... return seed - >>> env = EnvWithDynamicSpec() - >>> print(env.rollout(5, return_contiguous=False)) - LazyStackedTensorDict( - fields={ - action: Tensor(shape=torch.Size([5, 2]), device=cpu, dtype=torch.float32, is_shared=False), - done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: LazyStackedTensorDict( - fields={ - done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([5, 3, -1, 2]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), - terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), - truncated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - exclusive_fields={ - }, - batch_size=torch.Size([5]), - device=None, - is_shared=False, - stack_dim=0), - observation: Tensor(shape=torch.Size([5, 3, -1, 2]), device=cpu, dtype=torch.float32, is_shared=False), - terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), - truncated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - exclusive_fields={ - }, - batch_size=torch.Size([5]), - device=None, - is_shared=False, - stack_dim=0) - -.. warning:: - The absence of memory buffers in :class:`~torchrl.envs.ParallelEnv` and in - data collectors can impact performance of these classes dramatically. Any - such usage should be carefully benchmarked against a plain execution on a - single process, as serializing and deserializing large numbers of tensors - can be very expensive. - -Currently, :func:`~torchrl.envs.utils.check_env_specs` will pass for dynamic specs where a shape varies along some -dimensions, but not when a key is present during a step and absent during others, or when the number of dimensions -varies. - -Transforms ----------- - -.. _transforms: - -.. currentmodule:: torchrl.envs.transforms - -In most cases, the raw output of an environment must be treated before being passed to another object (such as a -policy or a value operator). To do this, TorchRL provides a set of transforms that aim at reproducing the transform -logic of `torch.distributions.Transform` and `torchvision.transforms`. -Our environment :ref:`tutorial ` -provides more information on how to design a custom transform. - -Transformed environments are build using the :class:`TransformedEnv` primitive. -Composed transforms are built using the :class:`Compose` class: - -.. code-block:: - :caption: Transformed environment - - >>> base_env = GymEnv("Pendulum-v1", from_pixels=True, device="cuda:0") - >>> transform = Compose(ToTensorImage(in_keys=["pixels"]), Resize(64, 64, in_keys=["pixels"])) - >>> env = TransformedEnv(base_env, transform) - -Transforms are usually subclasses of :class:`~torchrl.envs.transforms.Transform`, although any -``Callable[[TensorDictBase], TensorDictBase]``. - -By default, the transformed environment will inherit the device of the -:obj:`base_env` that is passed to it. The transforms will then be executed on that device. -It is now apparent that this can bring a significant speedup depending on the kind of -operations that is to be computed. - -A great advantage of environment wrappers is that one can consult the environment up to that wrapper. -The same can be achieved with TorchRL transformed environments: the ``parent`` attribute will -return a new :class:`TransformedEnv` with all the transforms up to the transform of interest. -Re-using the example above: - -.. code-block:: - :caption: Transform parent - - >>> resize_parent = env.transform[-1].parent # returns the same as TransformedEnv(base_env, transform[:-1]) - - -Transformed environment can be used with vectorized environments. -Since each transform uses a ``"in_keys"``/``"out_keys"`` set of keyword argument, it is -also easy to root the transform graph to each component of the observation data (e.g. -pixels or states etc). - -Forward and inverse transforms -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Transforms also have an :meth:`~torchrl.envs.Transform.inv` method that is called before the action is applied in reverse -order over the composed transform chain. This allows applying transforms to data in the environment before the action is -taken in the environment. The keys to be included in this inverse transform are passed through the `"in_keys_inv"` -keyword argument, and the out-keys default to these values in most cases: - -.. code-block:: - :caption: Inverse transform - - >>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step - -The following paragraphs detail how one can think about what is to be considered `in_` or `out_` features. - -Understanding Transform Keys -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In transforms, `in_keys` and `out_keys` define the interaction between the base environment and the outside world -(e.g., your policy): - -- `in_keys` refers to the base environment's perspective (inner = `base_env` of the - :class:`~torchrl.envs.TransformedEnv`). -- `out_keys` refers to the outside world (outer = `policy`, `agent`, etc.). - -For example, with `in_keys=["obs"]` and `out_keys=["obs_standardized"]`, the policy will "see" a standardized -observation, while the base environment outputs a regular observation. - -Similarly, for inverse keys: - -- `in_keys_inv` refers to entries as seen by the base environment. -- `out_keys_inv` refers to entries as seen or produced by the policy. - -The following figure illustrates this concept for the :class:`~torchrl.envs.RenameTransform` class: the input -`TensorDict` of the `step` function must include the `out_keys_inv` as they are part of the outside world. The -transform changes these names to match the names of the inner, base environment using the `in_keys_inv`. -The inverse process is executed with the output tensordict, where the `in_keys` are mapped to the corresponding -`out_keys`. - -.. figure:: /_static/img/rename_transform.png - - Rename transform logic - -.. note:: During a call to `inv`, the transforms are executed in reversed order (compared to the forward / step mode). - -Transforming Tensors and Specs -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -When transforming actual tensors (coming from the policy), the process is schematically represented as: - - >>> for t in reversed(self.transform): - ... td = t.inv(td) - -This starts with the outermost transform to the innermost transform, ensuring the action value exposed to the policy -is properly transformed. - -For transforming the action spec, the process should go from innermost to outermost (similar to observation specs): - - >>> def transform_action_spec(self, action_spec): - ... for t in self.transform: - ... action_spec = t.transform_action_spec(action_spec) - ... return action_spec - -A pseudocode for a single transform_action_spec could be: - - >>> def transform_action_spec(self, action_spec): - ... return spec_from_random_values(self._apply_transform(action_spec.rand())) - -This approach ensures that the "outside" spec is inferred from the "inside" spec. Note that we did not call -`_inv_apply_transform` but `_apply_transform` on purpose! - -Exposing Specs to the Outside World -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -`TransformedEnv` will expose the specs corresponding to the `out_keys_inv` for actions and states. -For example, with :class:`~torchrl.envs.ActionDiscretizer`, the environment's action (e.g., `"action"`) is a float-valued -tensor that should not be generated when using :meth:`~torchrl.envs.EnvBase.rand_action` with the transformed -environment. Instead, `"action_discrete"` should be generated, and its continuous counterpart obtained from the -transform. Therefore, the user should see the `"action_discrete"` entry being exposed, but not `"action"`. - -Designing your own Transform -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To create a basic, custom transform, you need to subclass the `Transform` class and implement the -:meth:`~torchrl.envs._apply_transform` method. Here's an example of a simple transform that adds 1 to the observation -tensor: - - >>> class AddOneToObs(Transform): - ... """A transform that adds 1 to the observation tensor.""" - ... - ... def __init__(self): - ... super().__init__(in_keys=["observation"], out_keys=["observation"]) - ... - ... def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor: - ... return obs + 1 - - -Tips for subclassing `Transform` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -There are various ways of subclassing a transform. The things to take into considerations are: - -- Is the transform identical for each tensor / item being transformed? Use - :meth:`~torchrl.envs.Transform._apply_transform` and :meth:`~torchrl.envs.Transform._inv_apply_transform`. -- The transform needs access to the input data to env.step as well as output? Rewrite - :meth:`~torchrl.envs.Transform._step`. - Otherwise, rewrite :meth:`~torchrl.envs.Transform._call` (or :meth:`~torchrl.envs.Transform._inv_call`). -- Is the transform to be used within a replay buffer? Overwrite :meth:`~torchrl.envs.Transform.forward`, - :meth:`~torchrl.envs.Transform.inv`, :meth:`~torchrl.envs.Transform._apply_transform` or - :meth:`~torchrl.envs.Transform._inv_apply_transform`. -- Within a transform, you can access (and make calls to) the parent environment using - :attr:`~torchrl.envs.Transform.parent` (the base env + all transforms till this one) or - :meth:`~torchrl.envs.Transform.container` (The object that encapsulates the transform). -- Don't forget to edits the specs if needed: top level: :meth:`~torchrl.envs.Transform.transform_output_spec`, - :meth:`~torchrl.envs.Transform.transform_input_spec`. - Leaf level: :meth:`~torchrl.envs.Transform.transform_observation_spec`, - :meth:`~torchrl.envs.Transform.transform_action_spec`, :meth:`~torchrl.envs.Transform.transform_state_spec`, - :meth:`~torchrl.envs.Transform.transform_reward_spec` and - :meth:`~torchrl.envs.Transform.transform_reward_spec`. - -For practical examples, see the methods listed above. - -You can use a transform in an environment by passing it to the TransformedEnv constructor: - - >>> env = TransformedEnv(GymEnv("Pendulum-v1"), AddOneToObs()) - -You can compose multiple transforms together using the Compose class: - - >>> transform = Compose(AddOneToObs(), RewardSum()) - >>> env = TransformedEnv(GymEnv("Pendulum-v1"), transform) - -Inverse Transforms -^^^^^^^^^^^^^^^^^^ - -Some transforms have an inverse transform that can be used to undo the transformation. For example, the AddOneToAction -transform has an inverse transform that subtracts 1 from the action tensor: - - >>> class AddOneToAction(Transform): - ... """A transform that adds 1 to the action tensor.""" - ... def __init__(self): - ... super().__init__(in_keys=[], out_keys=[], in_keys_inv=["action"], out_keys_inv=["action"]) - ... def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor: - ... return action + 1 - -Using a Transform with a Replay Buffer -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -You can use a transform with a replay buffer by passing it to the ReplayBuffer constructor: - -Cloning transforms -~~~~~~~~~~~~~~~~~~ - -Because transforms appended to an environment are "registered" to this environment -through the ``transform.parent`` property, when manipulating transforms we should keep -in mind that the parent may come and go following what is being done with the transform. -Here are some examples: if we get a single transform from a :class:`Compose` object, -this transform will keep its parent: - - >>> third_transform = env.transform[2] - >>> assert third_transform.parent is not None - -This means that using this transform for another environment is prohibited, as -the other environment would replace the parent and this may lead to unexpected -behviours. Fortunately, the :class:`Transform` class comes with a :func:`clone` -method that will erase the parent while keeping the identity of all the -registered buffers: - - >>> TransformedEnv(base_env, third_transform) # raises an Exception as third_transform already has a parent - >>> TransformedEnv(base_env, third_transform.clone()) # works - -On a single process or if the buffers are placed in shared memory, this will -result in all the clone transforms to keep the same behavior even if the -buffers are changed in place (which is what will happen with the :class:`CatFrames` -transform, for instance). In distributed settings, this may not hold and one -should be careful about the expected behavior of the cloned transforms in this -context. -Finally, notice that indexing multiple transforms from a :class:`Compose` transform -may also result in loss of parenthood for these transforms: the reason is that -indexing a :class:`Compose` transform results in another :class:`Compose` transform -that does not have a parent environment. Hence, we have to clone the sub-transforms -to be able to create this other composition: - - >>> env = TransformedEnv(base_env, Compose(transform1, transform2, transform3)) - >>> last_two = env.transform[-2:] - >>> assert isinstance(last_two, Compose) - >>> assert last_two.parent is None - >>> assert last_two[0] is not transform2 - >>> assert isinstance(last_two[0], type(transform2)) # and the buffers will match - >>> assert last_two[1] is not transform3 - >>> assert isinstance(last_two[1], type(transform3)) # and the buffers will match - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - Transform - TransformedEnv - ActionDiscretizer - ActionMask - AutoResetEnv - AutoResetTransform - BatchSizeTransform - BinarizeReward - BurnInTransform - CatFrames - CatTensors - CenterCrop - ClipTransform - Compose - ConditionalPolicySwitch - ConditionalSkip - Crop - DTypeCastTransform - DeviceCastTransform - DiscreteActionProjection - DoubleToFloat - EndOfLifeTransform - ExcludeTransform - FiniteTensorDictCheck - FlattenObservation - FrameSkipTransform - GrayScale - Hash - InitTracker - KLRewardTransform - LineariseRewards - ModuleTransform - MultiAction - NoopResetEnv - ObservationNorm - ObservationTransform - PermuteTransform - PinMemoryTransform - R3MTransform - RandomCropTensorDict - RemoveEmptySpecs - RenameTransform - Resize - Reward2GoTransform - RewardClipping - RewardScaling - RewardSum - SelectTransform - SignTransform - SqueezeTransform - Stack - StepCounter - TargetReturn - TensorDictPrimer - TimeMaxPool - Timer - Tokenizer - ToTensorImage - TrajCounter - UnaryTransform - UnsqueezeTransform - VC1Transform - VIPRewardTransform - VIPTransform - VecGymEnvTransform - VecNorm - VecNormV2 - gSDENoise - -Environments with masked actions --------------------------------- - -In some environments with discrete actions, the actions available to the agent might change throughout execution. -In such cases the environments will output an action mask (under the ``"action_mask"`` key by default). -This mask needs to be used to filter out unavailable actions for that step. - -If you are using a custom policy you can pass this mask to your probability distribution like so: - -.. code-block:: - :caption: Categorical policy with action mask - - >>> from tensordict.nn import TensorDictModule, ProbabilisticTensorDictModule, TensorDictSequential - >>> import torch.nn as nn - >>> from torchrl.modules import MaskedCategorical - >>> module = TensorDictModule( - >>> nn.Linear(in_feats, out_feats), - >>> in_keys=["observation"], - >>> out_keys=["logits"], - >>> ) - >>> dist = ProbabilisticTensorDictModule( - >>> in_keys={"logits": "logits", "mask": "action_mask"}, - >>> out_keys=["action"], - >>> distribution_class=MaskedCategorical, - >>> ) - >>> actor = TensorDictSequential(module, dist) - -If you want to use a default policy, you will need to wrap your environment in the :class:`~torchrl.envs.transforms.ActionMask` -transform. This transform can take care of updating the action mask in the action spec in order for the default policy -to always know what the latest available actions are. You can do this like so: - -.. code-block:: - :caption: How to use the action mask transform - - >>> from tensordict.nn import TensorDictModule, ProbabilisticTensorDictModule, TensorDictSequential - >>> import torch.nn as nn - >>> from torchrl.envs.transforms import TransformedEnv, ActionMask - >>> env = TransformedEnv( - >>> your_base_env - >>> ActionMask(action_key="action", mask_key="action_mask"), - >>> ) - -.. note:: - In case you are using a parallel environment it is important to add the transform to the parallel environment itself - and not to its sub-environments. - - - -Recorders ---------- - -.. _Environment-Recorders: - -Recording data during environment rollout execution is crucial to keep an eye on the algorithm performance as well as -reporting results after training. - -TorchRL offers several tools to interact with the environment output: first and foremost, a ``callback`` callable -can be passed to the :meth:`~torchrl.envs.EnvBase.rollout` method. This function will be called upon the collected -tensordict at each iteration of the rollout (if some iterations have to be skipped, an internal variable should be added -to keep track of the call count within ``callback``). - -To save collected tensordicts on disk, the :class:`~torchrl.record.TensorDictRecorder` can be used. - -Recording videos -~~~~~~~~~~~~~~~~ - -Several backends offer the possibility of recording rendered images from the environment. -If the pixels are already part of the environment output (e.g. Atari or other game simulators), a -:class:`~torchrl.record.VideoRecorder` can be appended to the environment. This environment transform takes as input -a logger capable of recording videos (e.g. :class:`~torchrl.record.loggers.CSVLogger`, :class:`~torchrl.record.loggers.WandbLogger` -or :class:`~torchrl.record.loggers.TensorBoardLogger`) as well as a tag indicating where the video should be saved. -For instance, to save mp4 videos on disk, one can use :class:`~torchrl.record.loggers.CSVLogger` with a `video_format="mp4"` -argument. - -The :class:`~torchrl.record.VideoRecorder` transform can handle batched images and automatically detects numpy or PyTorch -formatted images (WHC or CWH). - - >>> logger = CSVLogger("dummy-exp", video_format="mp4") - >>> env = GymEnv("ALE/Pong-v5") - >>> env = env.append_transform(VideoRecorder(logger, tag="rendered", in_keys=["pixels"])) - >>> env.rollout(10) - >>> env.transform.dump() # Save the video and clear cache - -Note that the cache of the transform will keep on growing until dump is called. It is the user responsibility to -take care of calling `dump` when needed to avoid OOM issues. - -In some cases, creating a testing environment where images can be collected is tedious or expensive, or simply impossible -(some libraries only allow one environment instance per workspace). -In these cases, assuming that a `render` method is available in the environment, the :class:`~torchrl.record.PixelRenderTransform` -can be used to call `render` on the parent environment and save the images in the rollout data stream. -This class works over single and batched environments alike: - - >>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator - >>> from torchrl.record.loggers import CSVLogger - >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder - >>> - >>> def make_env(): - >>> env = GymEnv("CartPole-v1", render_mode="rgb_array") - >>> # Uncomment this line to execute per-env - >>> # env = env.append_transform(PixelRenderTransform()) - >>> return env - >>> - >>> if __name__ == "__main__": - ... logger = CSVLogger("dummy", video_format="mp4") - ... - ... env = ParallelEnv(16, EnvCreator(make_env)) - ... env.start() - ... # Comment this line to execute per-env - ... env = env.append_transform(PixelRenderTransform()) - ... - ... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) - ... env.rollout(3) - ... - ... check_env_specs(env) - ... - ... r = env.rollout(30) - ... env.transform.dump() - ... env.close() - - -.. currentmodule:: torchrl.record - -Recorders are transforms that register data as they come in, for logging purposes. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - TensorDictRecorder - VideoRecorder - PixelRenderTransform - - -Helpers -------- -.. currentmodule:: torchrl.envs - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - RandomPolicy - check_env_specs - exploration_type - get_available_libraries - make_composite_from_td - set_exploration_type - step_mdp - terminated_or_truncated - -Domain-specific ---------------- -.. currentmodule:: torchrl.envs - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - ModelBasedEnvBase - model_based.dreamer.DreamerEnv - model_based.dreamer.DreamerDecoder - - -Libraries ---------- - -.. currentmodule:: torchrl.envs - -TorchRL's mission is to make the training of control and decision algorithm as -easy as it gets, irrespective of the simulator being used (if any). -Multiple wrappers are available for DMControl, Habitat, Jumanji and, naturally, -for Gym. - -This last library has a special status in the RL community as being the mostly -used framework for coding simulators. Its successful API has been foundational -and inspired many other frameworks, among which TorchRL. -However, Gym has gone through multiple design changes and it is sometimes hard -to accommodate these as an external adoption library: users usually have their -"preferred" version of the library. Moreover, gym is now being maintained -by another group under the "gymnasium" name, which does not facilitate code -compatibility. In practice, we must consider that users may have a version of -gym *and* gymnasium installed in the same virtual environment, and we must -allow both to work concomittantly. -Fortunately, TorchRL provides a solution for this problem: a special decorator -:class:`~.gym.set_gym_backend` allows to control which library will be used -in the relevant functions: - - >>> from torchrl.envs.libs.gym import GymEnv, set_gym_backend, gym_backend - >>> import gymnasium, gym - >>> with set_gym_backend(gymnasium): - ... print(gym_backend()) - ... env1 = GymEnv("Pendulum-v1") - - >>> with set_gym_backend(gym): - ... print(gym_backend()) - ... env2 = GymEnv("Pendulum-v1") - - >>> print(env1._env.env.env) - - >>> print(env2._env.env.env) - - -We can see that the two libraries modify the value returned by :func:`~torchrl.envs.gym.gym_backend()` -which can be further used to indicate which library needs to be used for -the current computation. :class:`~.gym.set_gym_backend` is also a decorator: -we can use it to tell to a specific function what gym backend needs to be used -during its execution. -The :func:`torchrl.envs.libs.gym.gym_backend` function allows you to gather -the current gym backend or any of its modules: - - >>> import mo_gymnasium - >>> with set_gym_backend("gym"): - ... wrappers = gym_backend('wrappers') - ... print(wrappers) - - >>> with set_gym_backend("gymnasium"): - ... wrappers = gym_backend('wrappers') - ... print(wrappers) - - -Another tool that comes in handy with gym and other external dependencies is -the :class:`torchrl._utils.implement_for` class. Decorating a function -with ``@implement_for`` will tell torchrl that, depending on the version -indicated, a specific behavior is to be expected. This allows us to easily -support multiple versions of gym without requiring any effort from the user side. -For example, considering that our virtual environment has the v0.26.2 installed, -the following function will return ``1`` when queried: - - >>> from torchrl._utils import implement_for - >>> @implement_for("gym", None, "0.26.0") - ... def fun(): - ... return 0 - >>> @implement_for("gym", "0.26.0", None) - ... def fun(): - ... return 1 - >>> fun() - 1 - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - BraxEnv - BraxWrapper - DMControlEnv - DMControlWrapper - GymEnv - GymWrapper - HabitatEnv - IsaacGymEnv - IsaacGymWrapper - IsaacLabWrapper - JumanjiEnv - JumanjiWrapper - MeltingpotEnv - MeltingpotWrapper - MOGymEnv - MOGymWrapper - MultiThreadedEnv - MultiThreadedEnvWrapper - OpenMLEnv - OpenSpielWrapper - OpenSpielEnv - PettingZooEnv - PettingZooWrapper - RoboHiveEnv - SMACv2Env - SMACv2Wrapper - UnityMLAgentsEnv - UnityMLAgentsWrapper - VmasEnv - VmasWrapper - gym_backend - set_gym_backend - register_gym_spec_conversion +.. code-block:: python + + from torchrl.envs import GymEnv, ParallelEnv, TransformedEnv + from torchrl.envs.transforms import RewardSum, StepCounter + + # Create a single environment + env = GymEnv("Pendulum-v1") + + # Add transforms + env = TransformedEnv(env, RewardSum()) + + # Create parallel environments + def make_env(): + return TransformedEnv( + GymEnv("Pendulum-v1"), + StepCounter(max_steps=200) + ) + + parallel_env = ParallelEnv(4, make_env) + + # Run a rollout + rollout = parallel_env.rollout(100) + +Documentation Sections +---------------------- + +.. toctree:: + :maxdepth: 2 + + envs_api + envs_vectorized + envs_transforms + envs_multiagent + envs_libraries + envs_recorders diff --git a/docs/source/reference/envs_api.rst b/docs/source/reference/envs_api.rst new file mode 100644 index 00000000000..03cb747ece0 --- /dev/null +++ b/docs/source/reference/envs_api.rst @@ -0,0 +1,208 @@ +.. currentmodule:: torchrl.envs + +.. _Environment-API: + +Environment API +=============== + +TorchRL offers an API to handle environments of different backends, such as gym, +dm-control, dm-lab, model-based environments as well as custom environments. +The goal is to be able to swap environments in an experiment with little or no effort, +even if these environments are simulated using different libraries. +TorchRL offers some out-of-the-box environment wrappers under :mod:`torchrl.envs.libs`, +which we hope can be easily imitated for other libraries. +The parent class :class:`~torchrl.envs.EnvBase` is a :class:`torch.nn.Module` subclass that implements +some typical environment methods using :class:`tensordict.TensorDict` as a data organiser. This allows this +class to be generic and to handle an arbitrary number of input and outputs, as well as +nested or batched data structures. + +Each env will have the following attributes: + +- :attr:`env.batch_size`: a :class:`torch.Size` representing the number of envs + batched together. +- :attr:`env.device`: the device where the input and output tensordict are expected to live. + The environment device does not mean that the actual step operations will be computed on device + (this is the responsibility of the backend, with which TorchRL can do little). The device of + an environment just represents the device where the data is to be expected when input to the + environment or retrieved from it. TorchRL takes care of mapping the data to the desired device. + This is especially useful for transforms (see below). For parametric environments (e.g. + model-based environments), the device does represent the hardware that will be used to + compute the operations. +- :attr:`env.observation_spec`: a :class:`~torchrl.data.Composite` object + containing all the observation key-spec pairs. +- :attr:`env.state_spec`: a :class:`~torchrl.data.Composite` object + containing all the input key-spec pairs (except action). For most stateful + environments, this container will be empty. +- :attr:`env.action_spec`: a :class:`~torchrl.data.TensorSpec` object + representing the action spec. +- :attr:`env.reward_spec`: a :class:`~torchrl.data.TensorSpec` object representing + the reward spec. +- :attr:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing + the done-flag spec. See the section on trajectory termination below. +- :attr:`env.input_spec`: a :class:`~torchrl.data.Composite` object containing + all the input keys (``"full_action_spec"`` and ``"full_state_spec"``). +- :attr:`env.output_spec`: a :class:`~torchrl.data.Composite` object containing + all the output keys (``"full_observation_spec"``, ``"full_reward_spec"`` and ``"full_done_spec"``). + +If the environment carries non-tensor data, a :class:`~torchrl.data.NonTensor` +instance can be used. + +Env specs: locks and batch size +------------------------------- + +.. _Environment-lock: + +Environment specs are locked by default (through a ``spec_locked`` arg passed to the env constructor). +Locking specs means that any modification of the spec (or its children if it is a :class:`~torchrl.data.Composite` +instance) will require to unlock it. This can be done via the :meth:`~torchrl.envs.EnvBase.set_spec_lock_`. +The reason specs are locked by default is that it makes it easy to cache values such as action or reset keys and the +likes. +Unlocking an env should only be done if it expected that the specs will be modified often (which, in principle, should +be avoided). +Modifications of the specs such as `env.observation_spec = new_spec` are allowed: under the hood, TorchRL will erase +the cache, unlock the specs, make the modification and relock the specs if the env was previously locked. + +Importantly, the environment spec shapes should contain the batch size, e.g. +an environment with :attr:`env.batch_size` ``== torch.Size([4])`` should have +an :attr:`env.action_spec` with shape :class:`torch.Size` ``([4, action_size])``. +This is helpful when preallocation tensors, checking shape consistency etc. + +Env methods +----------- + +With these, the following methods are implemented: + +- :meth:`env.reset`: a reset method that may (but not necessarily requires to) take + a :class:`tensordict.TensorDict` input. It return the first tensordict of a rollout, usually + containing a ``"done"`` state and a set of observations. If not present, + a ``"reward"`` key will be instantiated with 0s and the appropriate shape. +- :meth:`env.step`: a step method that takes a :class:`tensordict.TensorDict` input + containing an input action as well as other inputs (for model-based or stateless + environments, for instance). +- :meth:`env.step_and_maybe_reset`: executes a step, and (partially) resets the + environments if it needs to. It returns the updated input with a ``"next"`` + key containing the data of the next step, as well as a tensordict containing + the input data for the next step (ie, reset or result or + :func:`~torchrl.envs.utils.step_mdp`) + This is done by reading the ``done_keys`` and + assigning a ``"_reset"`` signal to each done state. This method allows + to code non-stopping rollout functions with little effort: + + >>> data_ = env.reset() + >>> result = [] + >>> for i in range(N): + ... data, data_ = env.step_and_maybe_reset(data_) + ... result.append(data) + ... + >>> result = torch.stack(result) + +- :meth:`env.set_seed`: a seeding method that will return the next seed + to be used in a multi-env setting. This next seed is deterministically computed + from the preceding one, such that one can seed multiple environments with a different + seed without risking to overlap seeds in consecutive experiments, while still + having reproducible results. +- :meth:`env.rollout`: executes a rollout in the environment for + a maximum number of steps (``max_steps=N``) and using a policy (``policy=model``). + The policy should be coded using a :class:`tensordict.nn.TensorDictModule` + (or any other :class:`tensordict.TensorDict`-compatible module). + The resulting :class:`tensordict.TensorDict` instance will be marked with + a trailing ``"time"`` named dimension that can be used by other modules + to treat this batched dimension as it should. + +The following figure summarizes how a rollout is executed in torchrl. + +.. figure:: /_static/img/rollout.gif + + TorchRL rollouts using TensorDict. + +In brief, a TensorDict is created by the :meth:`~.EnvBase.reset` method, +then populated with an action by the policy before being passed to the +:meth:`~.EnvBase.step` method which writes the observations, done flag(s) and +reward under the ``"next"`` entry. The result of this call is stored for +delivery and the ``"next"`` entry is gathered by the :func:`~.utils.step_mdp` +function. + +.. note:: + In general, all TorchRL environment have a ``"done"`` and ``"terminated"`` + entry in their output tensordict. If they are not present by design, + the :class:`~.EnvBase` metaclass will ensure that every done or terminated + is flanked with its dual. + In TorchRL, ``"done"`` strictly refers to the union of all the end-of-trajectory + signals and should be interpreted as "the last step of a trajectory" or + equivalently "a signal indicating the need to reset". + If the environment provides it (eg, Gymnasium), the truncation entry is also + written in the :meth:`EnvBase.step` output under a ``"truncated"`` entry. + If the environment carries a single value, it will interpreted as a ``"terminated"`` + signal by default. + By default, TorchRL's collectors and rollout methods will be looking for the ``"done"`` + entry to assess if the environment should be reset. + +.. note:: + + The `torchrl.collectors.utils.split_trajectories` function can be used to + slice adjacent trajectories. It relies on a ``"traj_ids"`` entry in the + input tensordict, or to the junction of ``"done"`` and ``"truncated"`` key + if the ``"traj_ids"`` is missing. + + +.. note:: + + In some contexts, it can be useful to mark the first step of a trajectory. + TorchRL provides such functionality through the :class:`~torchrl.envs.InitTracker` + transform. + + +Our environment :ref:`tutorial ` +provides more information on how to design a custom environment from scratch. + +Base classes +------------ + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + EnvBase + GymLikeEnv + EnvMetaData + +Custom native TorchRL environments +---------------------------------- + +TorchRL offers a series of custom built-in environments. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + ChessEnv + PendulumEnv + TicTacToeEnv + LLMHashingEnv + +Domain-specific +--------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + ModelBasedEnvBase + model_based.dreamer.DreamerEnv + model_based.dreamer.DreamerDecoder + +Helpers +------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + RandomPolicy + check_env_specs + exploration_type + get_available_libraries + make_composite_from_td + set_exploration_type + step_mdp + terminated_or_truncated diff --git a/docs/source/reference/envs_libraries.rst b/docs/source/reference/envs_libraries.rst new file mode 100644 index 00000000000..b4b1b78b547 --- /dev/null +++ b/docs/source/reference/envs_libraries.rst @@ -0,0 +1,277 @@ +.. currentmodule:: torchrl.envs + +Library Wrappers +================ + +TorchRL's mission is to make the training of control and decision algorithm as +easy as it gets, irrespective of the simulator being used (if any). +Multiple wrappers are available for DMControl, Habitat, Jumanji and, naturally, +for Gym. + +This last library has a special status in the RL community as being the mostly +used framework for coding simulators. Its successful API has been foundational +and inspired many other frameworks, among which TorchRL. +However, Gym has gone through multiple design changes and it is sometimes hard +to accommodate these as an external adoption library: users usually have their +"preferred" version of the library. Moreover, gym is now being maintained +by another group under the "gymnasium" name, which does not facilitate code +compatibility. In practice, we must consider that users may have a version of +gym *and* gymnasium installed in the same virtual environment, and we must +allow both to work concomittantly. +Fortunately, TorchRL provides a solution for this problem: a special decorator +:class:`~.gym.set_gym_backend` allows to control which library will be used +in the relevant functions: + + >>> from torchrl.envs.libs.gym import GymEnv, set_gym_backend, gym_backend + >>> import gymnasium, gym + >>> with set_gym_backend(gymnasium): + ... print(gym_backend()) + ... env1 = GymEnv("Pendulum-v1") + + >>> with set_gym_backend(gym): + ... print(gym_backend()) + ... env2 = GymEnv("Pendulum-v1") + + >>> print(env1._env.env.env) + + >>> print(env2._env.env.env) + + +We can see that the two libraries modify the value returned by :func:`~torchrl.envs.gym.gym_backend()` +which can be further used to indicate which library needs to be used for +the current computation. :class:`~.gym.set_gym_backend` is also a decorator: +we can use it to tell to a specific function what gym backend needs to be used +during its execution. +The :func:`torchrl.envs.libs.gym.gym_backend` function allows you to gather +the current gym backend or any of its modules: + + >>> import mo_gymnasium + >>> with set_gym_backend("gym"): + ... wrappers = gym_backend('wrappers') + ... print(wrappers) + + >>> with set_gym_backend("gymnasium"): + ... wrappers = gym_backend('wrappers') + ... print(wrappers) + + +Another tool that comes in handy with gym and other external dependencies is +the :class:`torchrl._utils.implement_for` class. Decorating a function +with ``@implement_for`` will tell torchrl that, depending on the version +indicated, a specific behavior is to be expected. This allows us to easily +support multiple versions of gym without requiring any effort from the user side. +For example, considering that our virtual environment has the v0.26.2 installed, +the following function will return ``1`` when queried: + + >>> from torchrl._utils import implement_for + >>> @implement_for("gym", None, "0.26.0") + ... def fun(): + ... return 0 + >>> @implement_for("gym", "0.26.0", None) + ... def fun(): + ... return 1 + >>> fun() + 1 + +Available wrappers +------------------ + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + BraxEnv + BraxWrapper + DMControlEnv + DMControlWrapper + GymEnv + GymWrapper + HabitatEnv + IsaacGymEnv + IsaacGymWrapper + IsaacLabWrapper + JumanjiEnv + JumanjiWrapper + MeltingpotEnv + MeltingpotWrapper + MOGymEnv + MOGymWrapper + MultiThreadedEnv + MultiThreadedEnvWrapper + OpenMLEnv + OpenSpielWrapper + OpenSpielEnv + PettingZooEnv + PettingZooWrapper + RoboHiveEnv + SMACv2Env + SMACv2Wrapper + UnityMLAgentsEnv + UnityMLAgentsWrapper + VmasEnv + VmasWrapper + gym_backend + set_gym_backend + register_gym_spec_conversion + +Auto-resetting Environments +--------------------------- + +.. _autoresetting_envs: + +Auto-resetting environments are environments where calls to :meth:`~torchrl.envs.EnvBase.reset` are not expected when +the environment reaches a ``"done"`` state during a rollout, as the reset happens automatically. +Usually, in such cases the observations delivered with the done and reward (which effectively result from performing the +action in the environment) are actually the first observations of a new episode, and not the last observations of the +current episode. + +To handle these cases, torchrl provides a :class:`~torchrl.envs.AutoResetTransform` that will copy the observations +that result from the call to `step` to the next `reset` and skip the calls to `reset` during rollouts (in both +:meth:`~torchrl.envs.EnvBase.rollout` and :class:`~torchrl.collectors.SyncDataCollector` iterations). +This transform class also provides a fine-grained control over the behavior to be adopted for the invalid observations, +which can be masked with `"nan"` or any other values, or not masked at all. + +To tell torchrl that an environment is auto-resetting, it is sufficient to provide an ``auto_reset`` argument +during construction. If provided, an ``auto_reset_replace`` argument can also control whether the values of the last +observation of an episode should be replaced with some placeholder or not. + + >>> from torchrl.envs import GymEnv + >>> from torchrl.envs import set_gym_backend + >>> import torch + >>> torch.manual_seed(0) + >>> + >>> class AutoResettingGymEnv(GymEnv): + ... def _step(self, tensordict): + ... tensordict = super()._step(tensordict) + ... if tensordict["done"].any(): + ... td_reset = super().reset() + ... tensordict.update(td_reset.exclude(*self.done_keys)) + ... return tensordict + ... + ... def _reset(self, tensordict=None): + ... if tensordict is not None and "_reset" in tensordict: + ... return tensordict.copy() + ... return super()._reset(tensordict) + >>> + >>> with set_gym_backend("gym"): + ... env = AutoResettingGymEnv("CartPole-v1", auto_reset=True, auto_reset_replace=True) + ... env.set_seed(0) + ... r = env.rollout(30, break_when_any_done=False) + >>> print(r["next", "done"].squeeze()) + tensor([False, False, False, False, False, False, False, False, False, False, + False, False, False, True, False, False, False, False, False, False, + False, False, False, False, False, True, False, False, False, False]) + +Dynamic Specs +------------- + +.. _dynamic_envs: + +Running environments in parallel is usually done via the creation of memory buffers used to pass information from one +process to another. In some cases, it may be impossible to forecast whether an environment will or will not have +consistent inputs or outputs during a rollout, as their shape may be variable. We refer to this as dynamic specs. + +TorchRL is capable of handling dynamic specs, but the batched environments and collectors will need to be made +aware of this feature. Note that, in practice, this is detected automatically. + +To indicate that a tensor will have a variable size along a dimension, one can set the size value as ``-1`` for the +desired dimensions. Because the data cannot be stacked contiguously, calls to ``env.rollout`` need to be made with +the ``return_contiguous=False`` argument. +Here is a working example: + + >>> from torchrl.envs import EnvBase + >>> from torchrl.data import Unbounded, Composite, Bounded, Binary + >>> import torch + >>> from tensordict import TensorDict, TensorDictBase + >>> + >>> class EnvWithDynamicSpec(EnvBase): + ... def __init__(self, max_count=5): + ... super().__init__(batch_size=()) + ... self.observation_spec = Composite( + ... observation=Unbounded(shape=(3, -1, 2)), + ... ) + ... self.action_spec = Bounded(low=-1, high=1, shape=(2,)) + ... self.full_done_spec = Composite( + ... done=Binary(1, shape=(1,), dtype=torch.bool), + ... terminated=Binary(1, shape=(1,), dtype=torch.bool), + ... truncated=Binary(1, shape=(1,), dtype=torch.bool), + ... ) + ... self.reward_spec = Unbounded((1,), dtype=torch.float) + ... self.count = 0 + ... self.max_count = max_count + ... + ... def _reset(self, tensordict=None): + ... self.count = 0 + ... data = TensorDict( + ... { + ... "observation": torch.full( + ... (3, self.count + 1, 2), + ... self.count, + ... dtype=self.observation_spec["observation"].dtype, + ... ) + ... } + ... ) + ... data.update(self.done_spec.zero()) + ... return data + ... + ... def _step( + ... self, + ... tensordict: TensorDictBase, + ... ) -> TensorDictBase: + ... self.count += 1 + ... done = self.count >= self.max_count + ... observation = TensorDict( + ... { + ... "observation": torch.full( + ... (3, self.count + 1, 2), + ... self.count, + ... dtype=self.observation_spec["observation"].dtype, + ... ) + ... } + ... ) + ... done = self.full_done_spec.zero() | done + ... reward = self.full_reward_spec.zero() + ... return observation.update(done).update(reward) + ... + ... def _set_seed(self, seed: Optional[int]) -> None: + ... self.manual_seed = seed + ... return seed + >>> env = EnvWithDynamicSpec() + >>> print(env.rollout(5, return_contiguous=False)) + LazyStackedTensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: LazyStackedTensorDict( + fields={ + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([5, 3, -1, 2]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([5]), + device=None, + is_shared=False, + stack_dim=0), + observation: Tensor(shape=torch.Size([5, 3, -1, 2]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([5]), + device=None, + is_shared=False, + stack_dim=0) + +.. warning:: + The absence of memory buffers in :class:`~torchrl.envs.ParallelEnv` and in + data collectors can impact performance of these classes dramatically. Any + such usage should be carefully benchmarked against a plain execution on a + single process, as serializing and deserializing large numbers of tensors + can be very expensive. + +Currently, :func:`~torchrl.envs.utils.check_env_specs` will pass for dynamic specs where a shape varies along some +dimensions, but not when a key is present during a step and absent during others, or when the number of dimensions +varies. diff --git a/docs/source/reference/envs_multiagent.rst b/docs/source/reference/envs_multiagent.rst new file mode 100644 index 00000000000..13f0a7cb9ca --- /dev/null +++ b/docs/source/reference/envs_multiagent.rst @@ -0,0 +1,138 @@ +.. currentmodule:: torchrl.envs + +.. _MARL-environment-API: + +Multi-agent Environments +======================== + +TorchRL supports multi-agent learning out-of-the-box. +*The same classes used in a single-agent learning pipeline can be seamlessly used in multi-agent contexts, +without any modification or dedicated multi-agent infrastructure.* + +In this view, environments play a core role for multi-agent. In multi-agent environments, +many decision-making agents act in a shared world. +Agents can observe different things, act in different ways and also be rewarded differently. +Therefore, many paradigms exist to model multi-agent environments (DecPODPs, Markov Games). +Some of the main differences between these paradigms include: + +- **observation** can be per-agent and also have some shared components +- **reward** can be per-agent or shared +- **done** (and ``"truncated"`` or ``"terminated"``) can be per-agent or shared. + +TorchRL accommodates all these possible paradigms thanks to its :class:`tensordict.TensorDict` data carrier. +In particular, in multi-agent environments, per-agent keys will be carried in a nested "agents" TensorDict. +This TensorDict will have the additional agent dimension and thus group data that is different for each agent. +The shared keys, on the other hand, will be kept in the first level, as in single-agent cases. + +Let's look at an example to understand this better. For this example we are going to use +`VMAS `_, a multi-robot task simulator also +based on PyTorch, which runs parallel batched simulation on device. + +We can create a VMAS environment and look at what the output from a random step looks like: + +.. code-block:: + :caption: Example of multi-agent step tensordict + + >>> from torchrl.envs.libs.vmas import VmasEnv + >>> env = VmasEnv("balance", num_envs=3, n_agents=5) + >>> td = env.rand_step() + >>> td + TensorDict( + fields={ + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 5, 2]))}, + batch_size=torch.Size([3, 5])), + next: TensorDict( + fields={ + agents: TensorDict( + fields={ + info: TensorDict( + fields={ + ground_rew: Tensor(shape=torch.Size([3, 5, 1])), + pos_rew: Tensor(shape=torch.Size([3, 5, 1]))}, + batch_size=torch.Size([3, 5])), + observation: Tensor(shape=torch.Size([3, 5, 16])), + reward: Tensor(shape=torch.Size([3, 5, 1]))}, + batch_size=torch.Size([3, 5])), + done: Tensor(shape=torch.Size([3, 1]))}, + batch_size=torch.Size([3]))}, + batch_size=torch.Size([3])) + +We can observe that *keys that are shared by all agents*, such as **done** are present in the root tensordict with +batch size `(num_envs,)`, which represents the number of environments simulated. + +On the other hand, *keys that are different between agents*, such as **action**, **reward**, **observation**, +and **info** are present in the nested "agents" tensordict with batch size `(num_envs, n_agents)`, +which represents the additional agent dimension. + +Multi-agent tensor specs will follow the same style as in tensordicts. +Specs relating to values that vary between agents will need to be nested in the "agents" entry. + +Here is an example of how specs can be created in a multi-agent environment where +only the done flag is shared across agents (as in VMAS): + +.. code-block:: + :caption: Example of multi-agent spec creation + + >>> action_specs = [] + >>> observation_specs = [] + >>> reward_specs = [] + >>> info_specs = [] + >>> for i in range(env.n_agents): + ... action_specs.append(agent_i_action_spec) + ... reward_specs.append(agent_i_reward_spec) + ... observation_specs.append(agent_i_observation_spec) + >>> env.action_spec = Composite( + ... { + ... "agents": Composite( + ... {"action": torch.stack(action_specs)}, shape=(env.n_agents,) + ... ) + ... } + ...) + >>> env.reward_spec = Composite( + ... { + ... "agents": Composite( + ... {"reward": torch.stack(reward_specs)}, shape=(env.n_agents,) + ... ) + ... } + ...) + >>> env.observation_spec = Composite( + ... { + ... "agents": Composite( + ... {"observation": torch.stack(observation_specs)}, shape=(env.n_agents,) + ... ) + ... } + ...) + >>> env.done_spec = Categorical( + ... n=2, + ... shape=torch.Size((1,)), + ... dtype=torch.bool, + ... ) + +As you can see, it is very simple! Per-agent keys will have the nested composite spec and shared keys will follow +single agent standards. + +.. note:: + Since reward, done and action keys may have the additional "agent" prefix (e.g., `("agents","action")`), + the default keys used in the arguments of other TorchRL components (e.g. "action") will not match exactly. + Therefore, TorchRL provides the `env.action_key`, `env.reward_key`, and `env.done_key` attributes, + which will automatically point to the right key to use. Make sure you pass these attributes to the various + components in TorchRL to inform them of the right key (e.g., the `loss.set_keys()` function). + +.. note:: + TorchRL abstracts these nested specs away for ease of use. + This means that accessing `env.reward_spec` will always return the leaf + spec if the accessed spec is Composite. Therefore, if in the example above + we run `env.reward_spec` after env creation, we would get the same output as `torch.stack(reward_specs)}`. + To get the full composite spec with the "agents" key, you can run + `env.output_spec["full_reward_spec"]`. The same is valid for action and done specs. + Note that `env.reward_spec == env.output_spec["full_reward_spec"][env.reward_key]`. + + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + MarlGroupMapType + check_marl_grouping diff --git a/docs/source/reference/envs_recorders.rst b/docs/source/reference/envs_recorders.rst new file mode 100644 index 00000000000..143a8f19be2 --- /dev/null +++ b/docs/source/reference/envs_recorders.rst @@ -0,0 +1,83 @@ +.. currentmodule:: torchrl.record + +.. _Environment-Recorders: + +Recorders +========= + +Recording data during environment rollout execution is crucial to keep an eye on the algorithm performance as well as +reporting results after training. + +TorchRL offers several tools to interact with the environment output: first and foremost, a ``callback`` callable +can be passed to the :meth:`~torchrl.envs.EnvBase.rollout` method. This function will be called upon the collected +tensordict at each iteration of the rollout (if some iterations have to be skipped, an internal variable should be added +to keep track of the call count within ``callback``). + +To save collected tensordicts on disk, the :class:`~torchrl.record.TensorDictRecorder` can be used. + +Recording videos +---------------- + +Several backends offer the possibility of recording rendered images from the environment. +If the pixels are already part of the environment output (e.g. Atari or other game simulators), a +:class:`~torchrl.record.VideoRecorder` can be appended to the environment. This environment transform takes as input +a logger capable of recording videos (e.g. :class:`~torchrl.record.loggers.CSVLogger`, :class:`~torchrl.record.loggers.WandbLogger` +or :class:`~torchrl.record.loggers.TensorBoardLogger`) as well as a tag indicating where the video should be saved. +For instance, to save mp4 videos on disk, one can use :class:`~torchrl.record.loggers.CSVLogger` with a `video_format="mp4"` +argument. + +The :class:`~torchrl.record.VideoRecorder` transform can handle batched images and automatically detects numpy or PyTorch +formatted images (WHC or CWH). + + >>> logger = CSVLogger("dummy-exp", video_format="mp4") + >>> env = GymEnv("ALE/Pong-v5") + >>> env = env.append_transform(VideoRecorder(logger, tag="rendered", in_keys=["pixels"])) + >>> env.rollout(10) + >>> env.transform.dump() # Save the video and clear cache + +Note that the cache of the transform will keep on growing until dump is called. It is the user responsibility to +take care of calling `dump` when needed to avoid OOM issues. + +In some cases, creating a testing environment where images can be collected is tedious or expensive, or simply impossible +(some libraries only allow one environment instance per workspace). +In these cases, assuming that a `render` method is available in the environment, the :class:`~torchrl.record.PixelRenderTransform` +can be used to call `render` on the parent environment and save the images in the rollout data stream. +This class works over single and batched environments alike: + + >>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator + >>> from torchrl.record.loggers import CSVLogger + >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder + >>> + >>> def make_env(): + >>> env = GymEnv("CartPole-v1", render_mode="rgb_array") + >>> # Uncomment this line to execute per-env + >>> # env = env.append_transform(PixelRenderTransform()) + >>> return env + >>> + >>> if __name__ == "__main__": + ... logger = CSVLogger("dummy", video_format="mp4") + ... + ... env = ParallelEnv(16, EnvCreator(make_env)) + ... env.start() + ... # Comment this line to execute per-env + ... env = env.append_transform(PixelRenderTransform()) + ... + ... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) + ... env.rollout(3) + ... + ... check_env_specs(env) + ... + ... r = env.rollout(30) + ... env.transform.dump() + ... env.close() + + +Recorders are transforms that register data as they come in, for logging purposes. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + TensorDictRecorder + VideoRecorder + PixelRenderTransform diff --git a/docs/source/reference/envs_transforms.rst b/docs/source/reference/envs_transforms.rst new file mode 100644 index 00000000000..d68d3e115f4 --- /dev/null +++ b/docs/source/reference/envs_transforms.rst @@ -0,0 +1,359 @@ +.. currentmodule:: torchrl.envs.transforms + +.. _transforms: + +Transforms +========== + +In most cases, the raw output of an environment must be treated before being passed to another object (such as a +policy or a value operator). To do this, TorchRL provides a set of transforms that aim at reproducing the transform +logic of `torch.distributions.Transform` and `torchvision.transforms`. +Our environment :ref:`tutorial ` +provides more information on how to design a custom transform. + +Transformed environments are build using the :class:`TransformedEnv` primitive. +Composed transforms are built using the :class:`Compose` class: + +.. code-block:: + :caption: Transformed environment + + >>> base_env = GymEnv("Pendulum-v1", from_pixels=True, device="cuda:0") + >>> transform = Compose(ToTensorImage(in_keys=["pixels"]), Resize(64, 64, in_keys=["pixels"])) + >>> env = TransformedEnv(base_env, transform) + +Transforms are usually subclasses of :class:`~torchrl.envs.transforms.Transform`, although any +``Callable[[TensorDictBase], TensorDictBase]``. + +By default, the transformed environment will inherit the device of the +``base_env`` that is passed to it. The transforms will then be executed on that device. +It is now apparent that this can bring a significant speedup depending on the kind of +operations that is to be computed. + +A great advantage of environment wrappers is that one can consult the environment up to that wrapper. +The same can be achieved with TorchRL transformed environments: the ``parent`` attribute will +return a new :class:`TransformedEnv` with all the transforms up to the transform of interest. +Re-using the example above: + +.. code-block:: + :caption: Transform parent + + >>> resize_parent = env.transform[-1].parent # returns the same as TransformedEnv(base_env, transform[:-1]) + + +Transformed environment can be used with vectorized environments. +Since each transform uses a ``"in_keys"``/``"out_keys"`` set of keyword argument, it is +also easy to root the transform graph to each component of the observation data (e.g. +pixels or states etc). + +Forward and inverse transforms +------------------------------ + +Transforms also have an :meth:`~torchrl.envs.Transform.inv` method that is called before the action is applied in reverse +order over the composed transform chain. This allows applying transforms to data in the environment before the action is +taken in the environment. The keys to be included in this inverse transform are passed through the `"in_keys_inv"` +keyword argument, and the out-keys default to these values in most cases: + +.. code-block:: + :caption: Inverse transform + + >>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step + +The following paragraphs detail how one can think about what is to be considered `in_` or `out_` features. + +Understanding Transform Keys +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In transforms, `in_keys` and `out_keys` define the interaction between the base environment and the outside world +(e.g., your policy): + +- `in_keys` refers to the base environment's perspective (inner = `base_env` of the + :class:`~torchrl.envs.TransformedEnv`). +- `out_keys` refers to the outside world (outer = `policy`, `agent`, etc.). + +For example, with `in_keys=["obs"]` and `out_keys=["obs_standardized"]`, the policy will "see" a standardized +observation, while the base environment outputs a regular observation. + +Similarly, for inverse keys: + +- `in_keys_inv` refers to entries as seen by the base environment. +- `out_keys_inv` refers to entries as seen or produced by the policy. + +The following figure illustrates this concept for the :class:`~torchrl.envs.RenameTransform` class: the input +`TensorDict` of the `step` function must include the `out_keys_inv` as they are part of the outside world. The +transform changes these names to match the names of the inner, base environment using the `in_keys_inv`. +The inverse process is executed with the output tensordict, where the `in_keys` are mapped to the corresponding +`out_keys`. + +.. figure:: /_static/img/rename_transform.png + + Rename transform logic + +.. note:: During a call to `inv`, the transforms are executed in reversed order (compared to the forward / step mode). + +Transforming Tensors and Specs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When transforming actual tensors (coming from the policy), the process is schematically represented as: + + >>> for t in reversed(self.transform): + ... td = t.inv(td) + +This starts with the outermost transform to the innermost transform, ensuring the action value exposed to the policy +is properly transformed. + +For transforming the action spec, the process should go from innermost to outermost (similar to observation specs): + + >>> def transform_action_spec(self, action_spec): + ... for t in self.transform: + ... action_spec = t.transform_action_spec(action_spec) + ... return action_spec + +A pseudocode for a single transform_action_spec could be: + + >>> def transform_action_spec(self, action_spec): + ... return spec_from_random_values(self._apply_transform(action_spec.rand())) + +This approach ensures that the "outside" spec is inferred from the "inside" spec. Note that we did not call +`_inv_apply_transform` but `_apply_transform` on purpose! + +Exposing Specs to the Outside World +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +`TransformedEnv` will expose the specs corresponding to the `out_keys_inv` for actions and states. +For example, with :class:`~torchrl.envs.ActionDiscretizer`, the environment's action (e.g., `"action"`) is a float-valued +tensor that should not be generated when using :meth:`~torchrl.envs.EnvBase.rand_action` with the transformed +environment. Instead, `"action_discrete"` should be generated, and its continuous counterpart obtained from the +transform. Therefore, the user should see the `"action_discrete"` entry being exposed, but not `"action"`. + +Designing your own Transform +---------------------------- + +To create a basic, custom transform, you need to subclass the `Transform` class and implement the +:meth:`~torchrl.envs._apply_transform` method. Here's an example of a simple transform that adds 1 to the observation +tensor: + + >>> class AddOneToObs(Transform): + ... """A transform that adds 1 to the observation tensor.""" + ... + ... def __init__(self): + ... super().__init__(in_keys=["observation"], out_keys=["observation"]) + ... + ... def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor: + ... return obs + 1 + + +Tips for subclassing `Transform` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +There are various ways of subclassing a transform. The things to take into considerations are: + +- Is the transform identical for each tensor / item being transformed? Use + :meth:`~torchrl.envs.Transform._apply_transform` and :meth:`~torchrl.envs.Transform._inv_apply_transform`. +- The transform needs access to the input data to env.step as well as output? Rewrite + :meth:`~torchrl.envs.Transform._step`. + Otherwise, rewrite :meth:`~torchrl.envs.Transform._call` (or :meth:`~torchrl.envs.Transform._inv_call`). +- Is the transform to be used within a replay buffer? Overwrite :meth:`~torchrl.envs.Transform.forward`, + :meth:`~torchrl.envs.Transform.inv`, :meth:`~torchrl.envs.Transform._apply_transform` or + :meth:`~torchrl.envs.Transform._inv_apply_transform`. +- Within a transform, you can access (and make calls to) the parent environment using + :attr:`~torchrl.envs.Transform.parent` (the base env + all transforms till this one) or + :meth:`~torchrl.envs.Transform.container` (The object that encapsulates the transform). +- Don't forget to edits the specs if needed: top level: :meth:`~torchrl.envs.Transform.transform_output_spec`, + :meth:`~torchrl.envs.Transform.transform_input_spec`. + Leaf level: :meth:`~torchrl.envs.Transform.transform_observation_spec`, + :meth:`~torchrl.envs.Transform.transform_action_spec`, :meth:`~torchrl.envs.Transform.transform_state_spec`, + :meth:`~torchrl.envs.Transform.transform_reward_spec` and + :meth:`~torchrl.envs.Transform.transform_reward_spec`. + +For practical examples, see the methods listed above. + +You can use a transform in an environment by passing it to the TransformedEnv constructor: + + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), AddOneToObs()) + +You can compose multiple transforms together using the Compose class: + + >>> transform = Compose(AddOneToObs(), RewardSum()) + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), transform) + +Inverse Transforms +~~~~~~~~~~~~~~~~~~ + +Some transforms have an inverse transform that can be used to undo the transformation. For example, the AddOneToAction +transform has an inverse transform that subtracts 1 from the action tensor: + + >>> class AddOneToAction(Transform): + ... """A transform that adds 1 to the action tensor.""" + ... def __init__(self): + ... super().__init__(in_keys=[], out_keys=[], in_keys_inv=["action"], out_keys_inv=["action"]) + ... def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor: + ... return action + 1 + +Using a Transform with a Replay Buffer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can use a transform with a replay buffer by passing it to the ReplayBuffer constructor: + +Cloning transforms +~~~~~~~~~~~~~~~~~~ + +Because transforms appended to an environment are "registered" to this environment +through the ``transform.parent`` property, when manipulating transforms we should keep +in mind that the parent may come and go following what is being done with the transform. +Here are some examples: if we get a single transform from a :class:`Compose` object, +this transform will keep its parent: + + >>> third_transform = env.transform[2] + >>> assert third_transform.parent is not None + +This means that using this transform for another environment is prohibited, as +the other environment would replace the parent and this may lead to unexpected +behviours. Fortunately, the :class:`Transform` class comes with a :func:`clone` +method that will erase the parent while keeping the identity of all the +registered buffers: + + >>> TransformedEnv(base_env, third_transform) # raises an Exception as third_transform already has a parent + >>> TransformedEnv(base_env, third_transform.clone()) # works + +On a single process or if the buffers are placed in shared memory, this will +result in all the clone transforms to keep the same behavior even if the +buffers are changed in place (which is what will happen with the :class:`CatFrames` +transform, for instance). In distributed settings, this may not hold and one +should be careful about the expected behavior of the cloned transforms in this +context. +Finally, notice that indexing multiple transforms from a :class:`Compose` transform +may also result in loss of parenthood for these transforms: the reason is that +indexing a :class:`Compose` transform results in another :class:`Compose` transform +that does not have a parent environment. Hence, we have to clone the sub-transforms +to be able to create this other composition: + + >>> env = TransformedEnv(base_env, Compose(transform1, transform2, transform3)) + >>> last_two = env.transform[-2:] + >>> assert isinstance(last_two, Compose) + >>> assert last_two.parent is None + >>> assert last_two[0] is not transform2 + >>> assert isinstance(last_two[0], type(transform2)) # and the buffers will match + >>> assert last_two[1] is not transform3 + >>> assert isinstance(last_two[1], type(transform3)) # and the buffers will match + +Available Transforms +-------------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + Transform + TransformedEnv + ActionDiscretizer + ActionMask + AutoResetEnv + AutoResetTransform + BatchSizeTransform + BinarizeReward + BurnInTransform + CatFrames + CatTensors + CenterCrop + ClipTransform + Compose + ConditionalPolicySwitch + ConditionalSkip + Crop + DTypeCastTransform + DeviceCastTransform + DiscreteActionProjection + DoubleToFloat + EndOfLifeTransform + ExcludeTransform + FiniteTensorDictCheck + FlattenObservation + FrameSkipTransform + GrayScale + Hash + InitTracker + KLRewardTransform + LineariseRewards + ModuleTransform + MultiAction + NoopResetEnv + ObservationNorm + ObservationTransform + PermuteTransform + PinMemoryTransform + R3MTransform + RandomCropTensorDict + RemoveEmptySpecs + RenameTransform + Resize + Reward2GoTransform + RewardClipping + RewardScaling + RewardSum + SelectTransform + SignTransform + SqueezeTransform + Stack + StepCounter + TargetReturn + TensorDictPrimer + TimeMaxPool + Timer + Tokenizer + ToTensorImage + TrajCounter + UnaryTransform + UnsqueezeTransform + VC1Transform + VIPRewardTransform + VIPTransform + VecGymEnvTransform + VecNorm + VecNormV2 + gSDENoise + +Environments with masked actions +-------------------------------- + +In some environments with discrete actions, the actions available to the agent might change throughout execution. +In such cases the environments will output an action mask (under the ``"action_mask"`` key by default). +This mask needs to be used to filter out unavailable actions for that step. + +If you are using a custom policy you can pass this mask to your probability distribution like so: + +.. code-block:: + :caption: Categorical policy with action mask + + >>> from tensordict.nn import TensorDictModule, ProbabilisticTensorDictModule, TensorDictSequential + >>> import torch.nn as nn + >>> from torchrl.modules import MaskedCategorical + >>> module = TensorDictModule( + >>> nn.Linear(in_feats, out_feats), + >>> in_keys=["observation"], + >>> out_keys=["logits"], + >>> ) + >>> dist = ProbabilisticTensorDictModule( + >>> in_keys={"logits": "logits", "mask": "action_mask"}, + >>> out_keys=["action"], + >>> distribution_class=MaskedCategorical, + >>> ) + >>> actor = TensorDictSequential(module, dist) + +If you want to use a default policy, you will need to wrap your environment in the :class:`~torchrl.envs.transforms.ActionMask` +transform. This transform can take care of updating the action mask in the action spec in order for the default policy +to always know what the latest available actions are. You can do this like so: + +.. code-block:: + :caption: How to use the action mask transform + + >>> from tensordict.nn import TensorDictModule, ProbabilisticTensorDictModule, TensorDictSequential + >>> import torch.nn as nn + >>> from torchrl.envs.transforms import TransformedEnv, ActionMask + >>> env = TransformedEnv( + >>> your_base_env + >>> ActionMask(action_key="action", mask_key="action_mask"), + >>> ) + +.. note:: + In case you are using a parallel environment it is important to add the transform to the parallel environment itself + and not to its sub-environments. diff --git a/docs/source/reference/envs_vectorized.rst b/docs/source/reference/envs_vectorized.rst new file mode 100644 index 00000000000..29a7d67e7a2 --- /dev/null +++ b/docs/source/reference/envs_vectorized.rst @@ -0,0 +1,351 @@ +.. currentmodule:: torchrl.envs + +Vectorized and Parallel Environments +==================================== + +Vectorized (or better: parallel) environments is a common feature in Reinforcement Learning +where executing the environment step can be cpu-intensive. +Some libraries such as `gym3 `_ or `EnvPool `_ +offer interfaces to execute batches of environments simultaneously. +While they often offer a very competitive computational advantage, they do not +necessarily scale to the wide variety of environment libraries supported by TorchRL. +Therefore, TorchRL offers its own, generic :class:`ParallelEnv` class to run multiple +environments in parallel. +As this class inherits from :class:`SerialEnv`, it enjoys the exact same API as other environment. +Of course, a :class:`ParallelEnv` will have a batch size that corresponds to its environment count: + +.. note:: + Given the library's many optional dependencies (eg, Gym, Gymnasium, and many others) + warnings can quickly become quite annoying in multiprocessed / distributed settings. + By default, TorchRL filters out these warnings in sub-processes. If one still wishes to + see these warnings, they can be displayed by setting ``torchrl.filter_warnings_subprocess=False``. + +It is important that your environment specs match the input and output that it sends and receives, as +:class:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes. +Check the :func:`~torchrl.envs.utils.check_env_specs` method for a sanity check. + +.. code-block:: + :caption: Parallel environment + + >>> def make_env(): + ... return GymEnv("Pendulum-v1", from_pixels=True, g=9.81, device="cuda:0") + >>> check_env_specs(env) # this must pass for ParallelEnv to work + >>> env = ParallelEnv(4, make_env) + >>> print(env.batch_size) + torch.Size([4]) + +:class:`ParallelEnv` allows to retrieve the attributes from its contained environments: +one can simply call: + +.. code-block:: + :caption: Parallel environment attributes + + >>> a, b, c, d = env.g # gets the g-force of the various envs, which we set to 9.81 before + >>> print(a) + 9.81 + +.. note:: + + *A note on performance*: launching a :class:`~.ParallelEnv` can take quite some time + as it requires to launch as many python instances as there are processes. Due to + the time that it takes to run ``import torch`` (and other imports), starting the + parallel env can be a bottleneck. This is why, for instance, TorchRL tests are so slow. + Once the environment is launched, a great speedup should be observed. + +.. note:: + + *TorchRL requires precise specs*: Another thing to take in consideration is + that :class:`ParallelEnv` (as well as data collectors) + will create data buffers based on the environment specs to pass data from one process + to another. This means that a misspecified spec (input, observation or reward) will + cause a breakage at runtime as the data can't be written on the preallocated buffer. + In general, an environment should be tested using the :func:`~.utils.check_env_specs` + test function before being used in a :class:`ParallelEnv`. This function will raise + an assertion error whenever the preallocated buffer and the collected data mismatch. + +We also offer the :class:`~.SerialEnv` class that enjoys the exact same API but is executed +serially. This is mostly useful for testing purposes, when one wants to assess the +behavior of a :class:`~.ParallelEnv` without launching the subprocesses. + +In addition to :class:`~.ParallelEnv`, which offers process-based parallelism, we also provide a way to create +multithreaded environments with :class:`~.MultiThreadedEnv`. This class uses `EnvPool `_ +library underneath, which allows for higher performance, but at the same time restricts flexibility - one can only +create environments implemented in ``EnvPool``. This covers many popular RL environments types (Atari, Classic Control, +etc.), but one can not use an arbitrary TorchRL environment, as it is possible with :class:`~.ParallelEnv`. Run +`benchmarks/benchmark_batched_envs.py` to compare performance of different ways to parallelize batched environments. + +Vectorized environment classes +------------------------------ + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + SerialEnv + ParallelEnv + EnvCreator + +Partial steps and partial resets +-------------------------------- + +TorchRL allows environments to reset some but not all the environments, or run a step in one but not all environments. +If there is only one environment in the batch, then a partial reset / step is also allowed with the behavior detailed +below. + +Batching environments and locking the batch +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. _ref_batch_locked: + +Before detailing what partial resets and partial steps do, we must distinguish cases where an environment has +a batch size of its own (mostly stateful environments) or when the environment is just a mere module that, given an +input of arbitrary size, batches the operations over all elements (mostly stateless environments). + +This is controlled via the :attr:`~torchrl.envs.batch_locked` attribute: a batch-locked environment requires all input +tensordicts to have the same batch-size as the env's. Typical examples of these environments are +:class:`~torchrl.envs.GymEnv` and related. Batch-unlocked envs are by contrast allowed to work with any input size. +Notable examples are :class:`~torchrl.envs.BraxEnv` or :class:`~torchrl.envs.JumanjiEnv`. + +Executing partial steps in a batch-unlocked environment is straightforward: one just needs to mask the part of the +tensordict that does not need to be executed, pass the other part to `step` and merge the results with the previous +input. + +Batched environments (:class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv`) can also deal with +partial steps easily, they just pass the actions to the sub-environments that are required to be executed. + +In all other cases, TorchRL assumes that the environment handles the partial steps correctly. + +.. warning:: This means that custom environments may silently run the non-required steps as there is no way for torchrl + to control what happens within the `_step` method! + +Partial Steps +~~~~~~~~~~~~~ + +.. _ref_partial_steps: + +Partial steps are controlled via the temporary key `"_step"` which points to a boolean mask of the +size of the tensordict that holds it. The classes armed to deal with this are: + +- Batched environments: :class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv` will dispatch the + action to and only to the environments where `"_step"` is `True`; +- Batch-unlocked environments; +- Unbatched environments (i.e., environments without batch size). In these environments, the :meth:`~torchrl.envs.EnvBase.step` + method will first look for a `"_step"` entry and, if present, act accordingly. + If a :class:`~torchrl.envs.Transform` instance passes a `"_step"` entry to the tensordict, it is also captured by + :class:`~torchrl.envs.TransformedEnv`'s own `_step` method which will skip the `base_env.step` as well as any further + transformation. + +When dealing with partial steps, the strategy is always to use the step output and mask missing values with the previous +content of the input tensordict, if present, or a `0`-valued tensor if the tensor cannot be found. This means that +if the input tensordict does not contain all the previous observations, then the output tensordict will be 0-valued for +all the non-stepped elements. Within batched environments, data collectors and rollouts utils, this is an issue that +is not observed because these classes handle the passing of data properly. + +Partial steps are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_all_done` is `True`, +as the environments with a `True` done state will need to be skipped during calls to `_step`. + +The :class:`~torchrl.envs.ConditionalSkip` transform allows you to programmatically ask for (partial) step skips. + +Partial Resets +~~~~~~~~~~~~~~ + +.. _ref_partial_resets: + +Partial resets work pretty much like partial steps, but with the `"_reset"` entry. + +The same restrictions of partial steps apply to partial resets. + +Likewise, partial resets are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_any_done` is `True`, +as the environments with a `True` done state will need to be reset, but not others. + +See te following paragraph for a deep dive in partial resets within batched and vectorized environments. + +Partial resets in detail +~~~~~~~~~~~~~~~~~~~~~~~~ + +TorchRL uses a private ``"_reset"`` key to indicate to the environment which +component (sub-environments or agents) should be reset. +This allows to reset some but not all of the components. + +The ``"_reset"`` key has two distinct functionalities: + +1. During a call to :meth:`~.EnvBase._reset`, the ``"_reset"`` key may or may + not be present in the input tensordict. TorchRL's convention is that the + absence of the ``"_reset"`` key at a given ``"done"`` level indicates + a total reset of that level (unless a ``"_reset"`` key was found at a level + above, see details below). + If it is present, it is expected that those entries and only those components + where the ``"_reset"`` entry is ``True`` (along key and shape dimension) will be reset. + + The way an environment deals with the ``"_reset"`` keys in its :meth:`~.EnvBase._reset` + method is proper to its class. + Designing an environment that behaves according to ``"_reset"`` inputs is the + developer's responsibility, as TorchRL has no control over the inner logic + of :meth:`~.EnvBase._reset`. Nevertheless, the following point should be + kept in mind when designing that method. + +2. After a call to :meth:`~.EnvBase._reset`, the output will be masked with the + ``"_reset"`` entries and the output of the previous :meth:`~.EnvBase.step` + will be written wherever the ``"_reset"`` was ``False``. In practice, this + means that if a ``"_reset"`` modifies data that isn't exposed by it, this + modification will be lost. After this masking operation, the ``"_reset"`` + entries will be erased from the :meth:`~.EnvBase.reset` outputs. + +It must be pointed out that ``"_reset"`` is a private key, and it should only be +used when coding specific environment features that are internal facing. +In other words, this should NOT be used outside of the library, and developers +will keep the right to modify the logic of partial resets through ``"_reset"`` +setting without preliminary warranty, as long as they don't affect TorchRL +internal tests. + +Finally, the following assumptions are made and should be kept in mind when +designing reset functionalities: + +- Each ``"_reset"`` is paired with a ``"done"`` entry (+ ``"terminated"`` and, + possibly, ``"truncated"``). This means that the following structure is not + allowed: ``TensorDict({"done": done, "nested": {"_reset": reset}}, [])``, as + the ``"_reset"`` lives at a different nesting level than the ``"done"``. +- A reset at one level does not preclude the presence of a ``"_reset"`` at lower + levels, but it annihilates its effects. The reason is simply that + whether the ``"_reset"`` at the root level corresponds to an ``all()``, ``any()`` + or custom call to the nested ``"done"`` entries cannot be known in advance, + and it is explicitly assumed that the ``"_reset"`` at the root was placed + there to supersede the nested values (for an example, have a look at + :class:`~.PettingZooWrapper` implementation where each group has one or more + ``"done"`` entries associated which is aggregated at the root level with a + ``any`` or ``all`` logic depending on the task). +- When calling :meth:`env.reset(tensordict)` with a partial ``"_reset"`` entry + that will reset some but not all the done sub-environments, the input data + should contain the data of the sub-environments that are __not__ being reset. + The reason for this constrain lies in the fact that the output of the + ``env._reset(data)`` can only be predicted for the entries that are reset. + For the others, TorchRL cannot know in advance if they will be meaningful or + not. For instance, one could perfectly just pad the values of the non-reset + components, in which case the non-reset data will be meaningless and should + be discarded. + +Below, we give some examples of the expected effect that ``"_reset"`` keys will +have on an environment returning zeros after reset: + + >>> # single reset at the root + >>> data = TensorDict({"val": [1, 1], "_reset": [False, True]}, []) + >>> env.reset(data) + >>> print(data.get("val")) # only the second value is 0 + tensor([1, 0]) + >>> # nested resets + >>> data = TensorDict({ + ... ("agent0", "val"): [1, 1], ("agent0", "_reset"): [False, True], + ... ("agent1", "val"): [2, 2], ("agent1", "_reset"): [True, False], + ... }, []) + >>> env.reset(data) + >>> print(data.get(("agent0", "val"))) # only the second value is 0 + tensor([1, 0]) + >>> print(data.get(("agent1", "val"))) # only the first value is 0 + tensor([0, 2]) + >>> # nested resets are overridden by a "_reset" at the root + >>> data = TensorDict({ + ... "_reset": [True, True], + ... ("agent0", "val"): [1, 1], ("agent0", "_reset"): [False, True], + ... ("agent1", "val"): [2, 2], ("agent1", "_reset"): [True, False], + ... }, []) + >>> env.reset(data) + >>> print(data.get(("agent0", "val"))) # reset at the root overrides nested + tensor([0, 0]) + >>> print(data.get(("agent1", "val"))) # reset at the root overrides nested + tensor([0, 0]) + +.. code-block:: + :caption: Parallel environment reset + + >>> tensordict = TensorDict({"_reset": [[True], [False], [True], [True]]}, [4]) + >>> env.reset(tensordict) # eliminates the "_reset" entry + TensorDict( + fields={ + terminated: Tensor(torch.Size([4, 1]), dtype=torch.bool), + done: Tensor(torch.Size([4, 1]), dtype=torch.bool), + pixels: Tensor(torch.Size([4, 500, 500, 3]), dtype=torch.uint8), + truncated: Tensor(torch.Size([4, 1]), dtype=torch.bool), + batch_size=torch.Size([4]), + device=None, + is_shared=True) + +Async environments +------------------ + +Asynchronous environments allow for parallel execution of multiple environments, which can significantly speed up the +data collection process in reinforcement learning. + +The `AsyncEnvPool` class and its subclasses provide a flexible interface for managing these environments using different +backends, such as threading and multiprocessing. + +The `AsyncEnvPool` class serves as a base class for asynchronous environment pools, providing a common interface for +managing multiple environments concurrently. It supports different backends for parallel execution, such as threading +and multiprocessing, and provides methods for asynchronous stepping and resetting of environments. + +Contrary to :class:`~torchrl.envs.ParallelEnv`, :class:`~torchrl.envs.AsyncEnvPool` and its subclasses permit the +execution of a given set of sub-environments while another task performed, allowing for complex asynchronous jobs to be +run at the same time. For instance, it is possible to execute some environments while the policy is running based on +the output of others. + +This family of classes is particularly interesting when dealing with environments that have a high (and/or variable) +latency. + +.. note:: This class and its subclasses should work when nested in with :class:`~torchrl.envs.TransformedEnv` and + batched environments, but users won't currently be able to use the async features of the base environment when + it's nested in these classes. One should prefer nested transformed envs within an `AsyncEnvPool` instead. + If this is not possible, please raise an issue. + +Classes +~~~~~~~ + +- :class:`~torchrl.envs.AsyncEnvPool`: A base class for asynchronous environment pools. It determines the backend + implementation to use based on the provided arguments and manages the lifecycle of the environments. +- :class:`~torchrl.envs.ProcessorAsyncEnvPool`: An implementation of :class:`~torchrl.envs.AsyncEnvPool` using + multiprocessing for parallel execution of environments. This class manages a pool of environments, each running in + its own process, and provides methods for asynchronous stepping and resetting of environments using inter-process + communication. It is automatically instantiated when `"multiprocessing"` is passed as a backend during the + :class:`~torchrl.envs.AsyncEnvPool` instantiation. +- :class:`~torchrl.envs.ThreadingAsyncEnvPool`: An implementation of :class:`~torchrl.envs.AsyncEnvPool` using + threading for parallel execution of environments. This class manages a pool of environments, each running in its own + thread, and provides methods for asynchronous stepping and resetting of environments using a thread pool executor. + It is automatically instantiated when `"threading"` is passed as a backend during the + :class:`~torchrl.envs.AsyncEnvPool` instantiation. + +Example +~~~~~~~ + + >>> from functools import partial + >>> from torchrl.envs import AsyncEnvPool, GymEnv + >>> import torch + >>> # Choose backend + >>> backend = "threading" + >>> env = AsyncEnvPool( + >>> [partial(GymEnv, "Pendulum-v1"), partial(GymEnv, "CartPole-v1")], + >>> stack="lazy", + >>> backend=backend + >>> ) + >>> # Execute a synchronous reset + >>> reset = env.reset() + >>> print(reset) + >>> # Execute a synchronous step + >>> s = env.rand_step(reset) + >>> print(s) + >>> # Execute an asynchronous step in env 0 + >>> s0 = s[0] + >>> s0["action"] = torch.randn(1).clamp(-1, 1) + >>> s0["env_index"] = 0 + >>> env.async_step_send(s0) + >>> # Receive data + >>> s0_result = env.async_step_recv() + >>> print('result', s0_result) + >>> # Close env + >>> env.close() + + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + AsyncEnvPool + ProcessorAsyncEnvPool + ThreadingAsyncEnvPool diff --git a/docs/source/reference/llms.rst b/docs/source/reference/llms.rst index 3a04bc5c95c..bdaba6ff15d 100644 --- a/docs/source/reference/llms.rst +++ b/docs/source/reference/llms.rst @@ -6,1187 +6,48 @@ LLM Interface .. _ref_llms: TorchRL provides a comprehensive framework for LLM post-training and fine-tuning. The LLM API is built around five core concepts that work -together to create a complete reinforcement learning pipeline for language models: +together to create a complete reinforcement learning pipeline for language models. -1. **Data Representation** (`Data Structures`_): The foundation for handling conversations, text parsing, and LLM - output classes. This includes the :class:`~torchrl.data.llm.History` class for managing conversation context and structured output classes for - tokens, log-probabilities, and text. +Key Components +-------------- -2. **LLM Wrapper API** (`Modules`_): Unified interfaces for different LLM backends, including :class:`~torchrl.modules.llm.TransformersWrapper` for - Hugging Face models, :class:`~torchrl.modules.llm.vLLMWrapper` for vLLM inference, and :class:`~torchrl.modules.llm.AsyncVLLM` for high-performance - distributed vLLM inference (recommended). These wrappers provide consistent input/output formats across different backends and an integrated - interface for loss computation, data storage, grading, weight synchronization, etc. +1. **Data Structures**: History class for conversation management, structured output classes +2. **LLM Wrappers**: Unified interfaces for Transformers, vLLM, and AsyncVLLM +3. **Environments**: ChatEnv, task-specific environments, and transforms +4. **Collectors**: LLMCollector and RayLLMCollector for data collection +5. **Objectives**: GRPOLoss, SFTLoss for training -3. **Environments** (`Environments`_): The orchestration layer that manages data loading, tool execution, reward computation, and formatting. This includes - :class:`~torchrl.envs.llm.ChatEnv` for conversation management, dataset environments, and various transforms for tool integration. - -4. **Objectives** (`Objectives`_): Specialized loss functions for LLM training, including :class:`~torchrl.objectives.llm.GRPOLoss` for Group Relative - Policy Optimization and :class:`~torchrl.objectives.llm.SFTLoss` for supervised fine-tuning. - -5. **Collectors** (`Collectors`_): Collectors are used to collect data from the environment and store it in a format that can be used for training. This includes - :class:`~torchrl.collectors.llm.LLMCollector` for collecting data from the environment and :class:`~torchrl.collectors.llm.RayLLMCollector` for collecting - data in distributed settings using Ray. - -These components work together to create a complete pipeline: environments load and format data, LLM wrappers handle inference, data structures maintain -conversation context, and objectives compute training losses. The modular design allows you to mix and match components based on your specific use case. - -A complete example of how to use the LLM API can be found in the `sota-implementations/grpo/` directory. The training orchestration involves three main components: - -- The Data Collector: holds a reference to the environment and the inference model or engine. It collects data, puts it in the buffer, and handles weight updates. -- The Replay Buffer: stores the collected data and executes any pre or post-processing steps. These may include: - - Advantage estimation with Monte-Carlo based method (using the :class:`~torchrl.objectives.llm.MCAdvantage` transform); - - Grading of the outputs; - - Logging etc. -- The trainer: handles the training loop, including the optimization step, serialization, logging and weight updates initialization. - -.. warning:: The LLM API is still under development and may change in the future. Feedback, issues and PRs are welcome! - -Data Structures ---------------- - -The data representation layer provides the foundation for handling conversations and LLM outputs in a structured way. - -History Class -~~~~~~~~~~~~~ - -The :class:`~torchrl.data.llm.History` class is a TensorClass version of the chat format usually found in transformers -(see `Hugging Face chat documentation `_). -It provides a comprehensive API for managing conversation data with features including: - -- **Text parsing and formatting**: Convert between text and structured conversation format using :meth:`~torchrl.data.llm.chat.History.from_text` - and :meth:`~torchrl.data.llm.chat.History.apply_chat_template` -- **Dynamic conversation building**: Append and extend conversations with :meth:`~torchrl.data.llm.chat.History.append` and - :meth:`~torchrl.data.llm.chat.History.extend` methods -- **Multi-model support**: Automatic template detection for various model families (Qwen, DialoGPT, Falcon, DeepSeek, etc.) -- **Assistant token masking**: Identify which tokens were generated by the assistant for reinforcement learning applications -- **Tool calling support**: Handle function calls and tool responses in conversations -- **Batch operations**: Efficient tensor operations for processing multiple conversations simultaneously. - -.. currentmodule:: torchrl.data.llm - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - History - ContentBase - -Supported Model Families -^^^^^^^^^^^^^^^^^^^^^^^^ - -We currently support the following model families for string to History parsing or assistant token masking: - -- **Qwen family** (e.g., `Qwen/Qwen2.5-0.5B`): Custom template with full tool calling support -- **DialoGPT family** (e.g., `microsoft/DialoGPT-medium`): Custom template for conversation format -- **Falcon family** (e.g., `tiiuae/falcon-7b-instruct`): Custom template for instruction format -- **DeepSeek family** (e.g., `deepseek-ai/deepseek-coder-6.7b-base`): Custom template with native format - -Other models are supported, but you will need to provide a custom template for them. -LLAMA, Mistral, OPT, GPT, MPT, BLOOM, Pythia, Phi, etc. will use the default `chatml_format` template. - -Usage -^^^^^ - -.. code-block:: python - - >>> from torchrl.data.llm.chat import History - >>> from transformers import AutoTokenizer - >>> - >>> # Create a conversation history - >>> history = History.from_chats([[ - ... {"role": "user", "content": "Hello"}, - ... {"role": "assistant", "content": "Hi there!"}, - ... {"role": "user", "content": "How are you?"}, - ... {"role": "assistant", "content": "I'm doing well, thanks!"} - ... ]]) - >>> - >>> # Load any supported tokenizer - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") - >>> - >>> # Apply chat template with assistant token masking - >>> result = history.apply_chat_template( - ... chat_template_name="qwen", - ... add_generation_prompt=False, - ... return_dict=True, - ... return_assistant_tokens_mask=True, - ... ) - >>> - >>> # The result contains an assistant_masks tensor - >>> assistant_masks = result["assistant_masks"] - >>> print(f"Assistant tokens: {assistant_masks.sum().item()}") - -Adding Custom Templates -^^^^^^^^^^^^^^^^^^^^^^^ - -You can add custom chat templates for new model families using the :func:`torchrl.data.llm.add_chat_template` function. - -.. autofunction:: torchrl.data.llm.add_chat_template - -Usage Examples -^^^^^^^^^^^^^^ - -Adding a Llama Template -""""""""""""""""""""""" - -.. code-block:: python - - >>> from torchrl.data.llm import add_chat_template, History - >>> from transformers import AutoTokenizer - >>> - >>> # Define the Llama chat template - >>> llama_template = ''' - ... {% for message in messages %} - ... {%- if message['role'] == 'user' %} - ... {{ '[INST] ' + message['content'] + ' [/INST]' }} - ... {%- elif message['role'] == 'assistant' %} - ... {% generation %}{{ message['content'] + '' }}{% endgeneration %} - ... {%- endif %} - ... {% endfor %} - ... {%- if add_generation_prompt %} - ... {% generation %}{{ ' ' }}{% endgeneration %} - ... {%- endif %} - ... ''' - >>> - >>> # Define the inverse parser for Llama format - >>> def parse_llama_text(text: str) -> History: - ... import re - ... pattern = r'\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?)' - ... matches = re.findall(pattern, text, re.DOTALL) - ... messages = [] - ... for user_content, assistant_content in matches: - ... messages.append(History(role="user", content=user_content.strip())) - ... messages.append(History(role="assistant", content=assistant_content.strip())) - ... return lazy_stack(messages) - >>> - >>> # Add the template with auto-detection - >>> add_chat_template( - ... template_name="llama", - ... template=llama_template, - ... inverse_parser=parse_llama_text, - ... model_family_keywords=["llama", "meta-llama"] - ... ) - >>> - >>> # Now you can use it with auto-detection - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - >>> history = History.from_chats([[ - ... {"role": "user", "content": "Hello"}, - ... {"role": "assistant", "content": "Hi there!"} - ... ]]) - >>> - >>> # Auto-detection will use the llama template - >>> result = history.apply_chat_template( - ... tokenizer=tokenizer, - ... add_generation_prompt=False, - ... return_dict=True, - ... return_assistant_tokens_mask=True, - ... ) - -Testing Your Custom Templates -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -When adding custom templates, you should test them to ensure they work correctly. Here are the recommended tests: - -Assistant Token Masking Test -"""""""""""""""""""""""""""" - -Test that your template supports assistant token masking: +Quick Example +------------- .. code-block:: python - import pytest - from torchrl.data.llm.chat import History, add_chat_template - from transformers import AutoTokenizer - - def test_my_model_assistant_masking(): - """Test that your model supports assistant token masking.""" - # Add your template first - add_chat_template( - template_name="my_model", - template="your_template_here", - model_family_keywords=["my_model"] - ) - - tokenizer = AutoTokenizer.from_pretrained("your/model/name") - history = History.from_chats([[ - {'role': 'user', 'content': 'Hello'}, - {'role': 'assistant', 'content': 'Hi there!'} - ]]) - - result = history.apply_chat_template( - tokenizer=tokenizer, - chat_template_name="my_model", - add_generation_prompt=False, - return_dict=True, - return_assistant_tokens_mask=True, - ) - - # Verify assistant mask is present - assert 'assistant_masks' in result - assert result['assistant_masks'].shape[0] == 1, "Should have batch dimension of 1" - assert result['assistant_masks'].shape[1] > 0, "Should have sequence length > 0" - - # Verify some assistant tokens are masked - assistant_token_count = result['assistant_masks'].sum().item() - assert assistant_token_count > 0, "Should have assistant tokens masked" - print(f"✓ {assistant_token_count} assistant tokens masked") - -Template Equivalence Test -""""""""""""""""""""""""" - -Test that your custom template produces the same output as the model's default template (except for masking): - -.. code-block:: python - - def test_my_model_template_equivalence(): - """Test that your template matches the model's default template.""" - tokenizer = AutoTokenizer.from_pretrained("your/model/name") - history = History.from_chats([[ - {'role': 'user', 'content': 'Hello'}, - {'role': 'assistant', 'content': 'Hi there!'}, - {'role': 'user', 'content': 'How are you?'}, - {'role': 'assistant', 'content': 'I\'m good, thanks!'}, - ]]) - - # Get output with model's default template - try: - default_out = history.apply_chat_template( - tokenizer=tokenizer, - add_generation_prompt=False, - chat_template=tokenizer.chat_template, - tokenize=False, - ) - except Exception as e: - default_out = None - print(f"[WARN] Could not get default template: {e}") - - # Get output with your custom template - custom_out = history.apply_chat_template( - tokenizer=tokenizer, - add_generation_prompt=False, - chat_template_name="my_model", - tokenize=False, - ) - - if default_out is not None: - # Normalize whitespace for comparison - import re - def norm(s): - return re.sub(r"\s+", " ", s.strip()) - - assert norm(default_out) == norm(custom_out), ( - f"Custom template does not match default!\n" - f"Default: {default_out}\nCustom: {custom_out}" - ) - print("✓ Template equivalence verified") - else: - print("[INFO] Skipped equivalence check (no default template available)") - -Inverse Parsing Test -"""""""""""""""""""" - -If you provided an inverse parser, test that it works correctly: - -.. code-block:: python - - def test_my_model_inverse_parsing(): - """Test that your inverse parser works correctly.""" - history = History.from_chats([[ - {'role': 'user', 'content': 'Hello'}, - {'role': 'assistant', 'content': 'Hi there!'} - ]]) - - # Format using your template - formatted = history.apply_chat_template( - tokenizer=tokenizer, - chat_template_name="my_model", - add_generation_prompt=False, - tokenize=False, - ) - - # Parse back using your inverse parser - parsed = History.from_text(formatted, chat_template_name="my_model") - - # Verify the parsing worked - assert parsed.role == history.role - assert parsed.content == history.content - print("✓ Inverse parsing verified") - -LLM Wrapper API -~~~~~~~~~~~~~~~ - -The LLM wrapper API provides unified interfaces for different LLM backends, ensuring consistent input/output formats across training and inference pipelines. The main wrappers are :class:`~torchrl.modules.llm.TransformersWrapper` for Hugging Face models and :class:`~torchrl.modules.llm.vLLMWrapper` for vLLM inference. - -**Data Structure Classes** - -The wrappers use structured :class:`~tensordict.TensorClass` objects to represent different aspects of LLM data: - -- **:class:`~torchrl.modules.llm.policies.Text`**: Contains text data with `prompt`, `response`, and `full` fields -- **:class:`~torchrl.modules.llm.policies.ChatHistory`**: Contains :class:`~torchrl.data.llm.History` objects with `prompt`, `response`, and `full` fields -- **:class:`~torchrl.modules.llm.policies.Tokens`**: Contains tokenized data with `prompt`, `response`, and `full` fields -- **:class:`~torchrl.modules.llm.policies.LogProbs`**: Contains log probabilities with `prompt`, `response`, and `full` fields -- **:class:`~torchrl.modules.llm.policies.Masks`**: Contains attention and assistant masks - -**API Flow** - -The wrappers operate in two distinct modes: - -**Generation Mode (`generate=True`)**: -- **Input**: Reads from `prompt` fields (e.g., `history.prompt`, `text.prompt`, `tokens.prompt`) -- **Output**: Writes to both `response` and `full` fields - - `response`: Contains only the newly generated content - - `full`: Contains the complete sequence (prompt + response) - -**Log-Probability Mode (`generate=False`)**: -- **Input**: Reads from `full` fields (e.g., `history.full`, `text.full`, `tokens.full`) -- **Output**: Writes log probabilities to the corresponding `full` fields - -**LLM-Environment Interaction Loop** - -.. figure:: /_static/img/llm-env.png - :alt: LLM-Environment interaction loop - :align: center - :width: 80% - - LLM-Environment interaction: the LLM generates a response, the environment updates the conversation, and transforms can inject new messages or tools. - -In a typical RL or tool-augmented setting, the LLM and environment interact in a loop: - -1. **LLM Generation**: The LLM wrapper receives a `prompt` (the current conversation history), generates a `response`, and outputs a `full` field - containing the concatenation of the prompt and response. -2. **Environment Step**: The environment takes the `full` field and makes it the next `prompt` for the LLM. This ensures that the conversation - context grows with each turn. See :ref:`ref_env_llm_step` for more details. -3. **Transforms**: Before the next LLM step, transforms can modify the conversation—for example, by inserting a new user message, a tool call, - or a reward annotation. -4. **Repeat**: This process repeats for as many turns as needed, enabling multi-turn dialogue, tool use, and RL training. - -This design allows for flexible augmentation of the conversation at each step, supporting advanced RL and tool-use scenarios. - -A typical pseudocode loop: - -.. code-block:: python - - # Get the first prompt out of an initial query - obs = env.reset(TensorDict({"query": "Hello!"}, batch_size=env.batch_size, device=env.device)) - while not done: - # LLM generates a response given the current prompt - llm_output = llm(obs) - # Environment steps: creates a ("next", "history") field with the new prompt (from the previous `"full"` field) - obs = env.step(llm_output) - -**Integration with History** - -When using `input_mode="history"`, the wrapper integrates seamlessly with the :class:`~torchrl.data.llm.History` class: - -- **Input**: Takes a :class:`~torchrl.modules.llm.policies.ChatHistory` object containing a History in the `prompt` field -- **Generation**: Applies chat templates to convert History to tokens, generates response, then parses the full text back into a History object -- **Output**: Returns a ChatHistory with: - - `prompt`: Original conversation history - - `response`: New History object containing only the assistant's response - - `full`: Complete conversation history with the new response appended - -This design allows for natural conversation flow where each generation step extends the conversation history, making it ideal for multi-turn -dialogue systems. - - -Prompt vs. Response and padding -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. figure:: /_static/img/llm-data.svg - :alt: LLM output data format (Tokens, Masks, Padded vs. Sparse) - :align: center - :width: 80% - - Structure of LLM outputs: padded vs. sparse representations for Tokens, LogProbs, and Masks. - -The diagram above illustrates the structure of the main output classes used in TorchRL's LLM API: - -- **Tokens** (and by extension, **LogProbs**): - - *Padded* format: All sequences in a batch are padded to the same length (with a special pad token), making them suitable for tensor operations. The prompt and response are concatenated to form `tokens.full`, and masks indicate valid vs. padded positions. - - *Sparse* format: Each sequence retains its original length (no padding), represented as lists of tensors. This is more memory-efficient for variable-length data. -- **Masks**: Two main masks are shown: - - `mask.attention_mask_all` marks valid (non-pad) tokens. - - `mask.assistant_mask_all` marks which tokens were generated by the assistant (useful for RLHF and SFT training). -- **Text**: Not shown in detail, as it is simply the decoded string representation of the prompt, response, or full sequence. - -This format ensures that all LLM outputs (Tokens, LogProbs, Masks, Text) are consistent and easy to manipulate, regardless of whether you use padded or sparse batching. - -In general, we recommend working with unpadded data, as it is more memory-efficient and easier to manipulate. -For instance, when collecting multiple padded elements from the buffer, it may be hard to clearly understand how to re-pad them -to combine them in a cohesive batch. Working with unpadded data is more straightforward. - -Modules -------- - -The LLM wrapper API provides unified interfaces for different LLM backends, ensuring consistent input/output formats across training and inference pipelines. - -Wrappers -~~~~~~~~ - -The main goal of these primitives is to: - -- Unify the input/output data format across training and inference pipelines -- Unify the input/output data format across backends (to be able to use different backends across losses and collectors) -- Provide appropriate tooling to construct these objects in typical RL settings (resource allocation, async execution, weight update, etc.) - -.. currentmodule:: torchrl.modules.llm - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - LLMWrapperBase - TransformersWrapper - vLLMWrapper - RemoteTransformersWrapper - AsyncVLLM - ChatHistory - Text - LogProbs - Masks - Tokens - -Async vLLM Engine (Recommended) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -:class:`~torchrl.modules.llm.AsyncVLLM` is the recommended approach for high-performance vLLM inference in TorchRL. -It provides a distributed, async-capable inference service built on Ray that offers superior performance and resource utilization compared to synchronous vLLM engines. - -**Key Features:** - -- **Distributed Architecture**: Runs multiple vLLM engine replicas as Ray actors for horizontal scaling -- **Load Balancing**: Automatically distributes requests across available replicas -- **Native vLLM Batching**: Leverages vLLM's optimized batching for maximum throughput. - Every thread or actor in your code will be able to make requests to the vLLM engine(s), put the query in - the queue and let the engine handle the batching. -- **Resource Management**: Automatic GPU allocation and cleanup through Ray placement groups -- **Simple API**: Single-import convenience with :meth:`~torchrl.modules.llm.AsyncVLLM.from_pretrained` - -**Basic Usage:** - -.. code-block:: python - - from torchrl.modules.llm import AsyncVLLM, vLLMWrapper - from vllm import SamplingParams - - # Create async vLLM service (recommended) - async_engine = AsyncVLLM.from_pretrained( - "Qwen/Qwen2.5-7B", - num_devices=2, # Use 2 GPUs per replica (tensor parallel) - num_replicas=2, # Create 2 replicas for higher throughput - max_model_len=4096 - ) - - # Use with vLLMWrapper for TorchRL integration - wrapper = vLLMWrapper(async_engine, input_mode="history", generate=True) - - # Direct generation (also supported) - sampling_params = SamplingParams(temperature=0.7, max_tokens=100) - result = async_engine.generate("Hello, world!", sampling_params) - - # Cleanup when done - async_engine.shutdown() - -These objects (AsyncVLLM and vLLMWrapper) can be shared across multiple collectors, environments, or workers efficiently. -They can be directly passed from one worker to another: under the hood, Ray will handle the handler sharing and remote execution. - -**Performance Benefits:** - -- **Higher Throughput**: Multiple replicas process requests concurrently -- **Better GPU Utilization**: Ray ensures optimal GPU allocation and co-location -- **Reduced Latency**: Native batching reduces per-request overhead -- **Fault Tolerance**: Ray provides automatic error recovery and resource management - -**Resource Sharing:** - -AsyncVLLM instances can be shared across multiple collectors, environments, or workers efficiently: - -.. code-block:: python - - from torchrl.modules.llm import AsyncVLLM, vLLMWrapper + from torchrl.modules.llm import vLLMWrapper, AsyncVLLM + from torchrl.envs.llm import ChatEnv from torchrl.collectors.llm import LLMCollector - # Create a shared AsyncVLLM service - shared_async_engine = AsyncVLLM.from_pretrained( - "Qwen/Qwen2.5-7B", - num_devices=2, - num_replicas=4, # High throughput for multiple consumers - max_model_len=4096 - ) - - # Multiple wrappers can use the same AsyncVLLM service - wrapper1 = vLLMWrapper(shared_async_engine, input_mode="history") - wrapper2 = vLLMWrapper(shared_async_engine, input_mode="text") - - # Multiple collectors can share the same service - collector1 = LLMCollector(env1, policy=wrapper1) - collector2 = LLMCollector(env2, policy=wrapper2) - - # The AsyncVLLM service automatically load-balances across replicas - # No additional coordination needed between consumers - -This approach is more efficient than creating separate vLLM instances for each consumer, as it: - -- **Reduces Memory Usage**: Single model loading shared across consumers -- **Automatic Load Balancing**: Requests are distributed across replicas -- **Better Resource Utilization**: GPUs are used more efficiently -- **Simplified Management**: Single service to monitor and manage - -.. note:: - **AsyncVLLM vs. Traditional Actor Sharing** - - Unlike traditional Ray actor sharing patterns where you manually create named actors and share references, - AsyncVLLM handles the distributed architecture internally. You simply create one AsyncVLLM service and - pass it to multiple consumers. The service automatically: - - - Creates and manages multiple Ray actors (replicas) internally - - Load-balances requests across replicas without manual coordination - - Handles actor lifecycle and resource cleanup - - This eliminates the need for manual actor name management and reference sharing that was required - with the previous `RemotevLLMWrapper` approach. - -Remote Wrappers -^^^^^^^^^^^^^^^ - -TorchRL provides remote wrapper classes that enable distributed execution of LLM wrappers using Ray. These wrappers provide a simplified interface that doesn't require explicit `remote()` and `get()` calls, making them easy to use in distributed settings. - -.. note:: - **For vLLM: Use AsyncVLLM Instead** - - For vLLM-based inference, we recommend using :class:`~torchrl.modules.llm.AsyncVLLM` directly rather than - remote wrappers. AsyncVLLM provides better performance, resource utilization, and built-in load balancing. - See the `Async vLLM Engine (Recommended)`_ section above for details. - - Remote wrappers are primarily intended for Transformers-based models or other use cases where AsyncVLLM - is not applicable. - -**Key Features:** - -- **Simplified Interface**: No need to call `remote()` and `get()` explicitly -- **Full API Compatibility**: Exposes all public methods from the base `LLMWrapperBase` class -- **Automatic Ray Management**: Handles Ray initialization and remote execution internally -- **Property Access**: All properties are accessible through the remote wrapper -- **Error Handling**: Proper error propagation from remote actors -- **Resource Management**: Context manager support for automatic cleanup - -**Model Parameter Requirements:** - -- **RemoteTransformersWrapper**: Only accepts string model names/paths. Transformers models are not serializable. - -**Supported Backends:** - -Currently, only Transformers-based models are supported through remote wrappers. For vLLM models, use :class:`~torchrl.modules.llm.AsyncVLLM` instead. - -**Usage Examples:** - -.. code-block:: python - - import ray - from torchrl.modules.llm.policies import RemoteTransformersWrapper - from torchrl.data.llm import History - from torchrl.modules.llm.policies import ChatHistory, Text - from tensordict import TensorDict - - # Initialize Ray (if not already done) - if not ray.is_initialized(): - ray.init() - - # Transformers wrapper (only string models supported) - # The remote wrappers implement context managers for proper resource cleanup: - with RemoteTransformersWrapper( - model="gpt2", - max_concurrency=16, - input_mode="text", - generate=True, - generate_kwargs={"max_new_tokens": 30} - ) as remote_transformers: - - text_input = TensorDict({"text": Text(prompt="Hello world")}, batch_size=(1,)) - result = remote_transformers(text_input) - print(result["text"].response) -**Performance Considerations:** - -- **Network Overhead**: Remote execution adds network communication overhead -- **Serialization**: Data is serialized when sent to remote actors -- **Memory**: Each remote actor maintains its own copy of the model -- **Concurrency**: Multiple remote wrappers can run concurrently -- **Max Concurrency**: Use the `max_concurrency` parameter to control the number of concurrent calls to each remote actor -- **Cleanup**: Always use context managers or call `cleanup_batching()` to prevent hanging due to batching locks - -Utils -^^^^^ - -.. currentmodule:: torchrl.modules.llm - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - make_async_vllm_engine - stateless_init_process_group_async - make_vllm_worker - stateless_init_process_group - -Collectors ----------- - -.. _Collectors: - -TorchRL offers specialized collector classes (:class:`~torchrl.collectors.llm.LLMCollector` and :class:`~torchrl.collectors.llm.RayLLMCollector`) -that are tailored for LLM use cases. We also provide weight synchronization schemes for vLLM inference engines. - -See :ref:`ref_collectors` for more details on the collector API. In brief, the idea of a collector is to isolate the inference part of the pipeline -in a dedicated class. -A collector usually takes as input a policy and an environment, and alternate between running one and the other. -In "classical" settings, the policy is similar to the policy being trained (with some optional extra-exploration). In the context of LLM fine-tuning, -the policy will usually be a specialized inference engine, such as a vLLM server. -Collectors are defined by the following parameters and features: - -- **Sync/Async**: Whether the collector should run in sync or async mode. - In sync mode, the collector will run the inference step in alternate with the optimization/training step. - In async mode, the collector will run the inference step in parallel with the optimization/training step. - A replay buffer can be passed to the collector, in such a way that the collector can directly write to it. - In other cases, the collector can be iterated over to collect data. -- **Steps**: A collector is built with a certain number of steps budget, as well as a number of steps to be - included in each batch yield during collection. -- **Weight Synchronization Schemes**: Weight sync schemes handle the synchronization of weights between the training model - and the inference engine. The new scheme-based approach provides flexible, high-performance weight updates for vLLM and - other inference backends. - -vLLM Weight Synchronization Schemes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -TorchRL provides two weight synchronization schemes for vLLM engines, offering different trade-offs between -performance and simplicity: - -**1. NCCL-Based Synchronization** (:class:`~torchrl.weight_update.llm.VLLMWeightSyncScheme`) - -Uses NCCL collectives for high-bandwidth GPU-to-GPU weight transfers. Best for: - -- High-frequency weight updates -- Large models where transfer speed is critical -- Setups with GPU interconnect (NVLink, InfiniBand) - -**2. Double-Buffer Synchronization** (:class:`~torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme`) - -Uses memory-mapped file storage for asynchronous weight transfers. Best for: - -- Simpler setup without NCCL coordination -- Distributed setups with shared filesystems (NFS) -- Cases where update frequency is lower - -**Usage Example with NCCL:** - -.. code-block:: python - - from torchrl.collectors.llm import RayLLMCollector - from torchrl.weight_update.llm import VLLMWeightSyncScheme - from torchrl.modules.llm import AsyncVLLM, vLLMWrapper - # Create vLLM engine - vllm_engine = AsyncVLLM.from_pretrained( - "Qwen/Qwen2.5-7B", - num_devices=2, - num_replicas=2, - ) - policy = vLLMWrapper(vllm_engine, input_mode="history") - - # Create NCCL weight sync scheme - weight_sync_scheme = VLLMWeightSyncScheme( - master_address="localhost", - master_port=29500, - gpus_per_replica=2, # tp_size × dp_size × pp_size - num_replicas=2, - strategy="state_dict" - ) - - # Create collector with weight sync scheme - collector = RayLLMCollector( - env=make_env, - policy=policy, - dialog_turns_per_batch=256, - total_dialog_turns=10000, - weight_sync_schemes={"policy": weight_sync_scheme}, - track_policy_version=True, - ) - - # During training, get the sender and update weights - sender = collector._weight_senders["policy"] - sender.register_model(training_model) - - # Initialize collective group (must be called before first update) - metadata = get_model_metadata(training_model) - sender.init_all_workers_group(metadata, vllm_engine=vllm_engine) + engine = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-7B", num_replicas=2) + policy = vLLMWrapper(engine, input_mode="history") - # Update weights during training - for i, data in enumerate(collector): - # ... training step ... - if i % 10 == 0: - sender.update_weights() # Broadcasts via NCCL - -**Usage Example with Double-Buffer:** - -.. code-block:: python - - from torchrl.collectors.llm import RayLLMCollector - from torchrl.weight_update.llm import VLLMDoubleBufferSyncScheme - from torchrl.modules.llm import AsyncVLLM, vLLMWrapper + # Create environment + env = ChatEnv(tokenizer=tokenizer) - # Create vLLM engine - vllm_engine = AsyncVLLM.from_pretrained( - "Qwen/Qwen2.5-7B", - num_devices=2, - num_replicas=1, - ) - policy = vLLMWrapper(vllm_engine, input_mode="history") - - # Create double-buffer weight sync scheme - weight_sync_scheme = VLLMDoubleBufferSyncScheme( - remote_addr="/tmp/weights", # Or "/mnt/shared/weights" for NFS - num_threads=128, - strategy="state_dict" - ) - - # Create collector with weight sync scheme - collector = RayLLMCollector( - env=make_env, - policy=policy, - dialog_turns_per_batch=256, - total_dialog_turns=10000, - weight_sync_schemes={"policy": weight_sync_scheme}, - track_policy_version=True, - ) - - # During training, get the sender and receiver - sender = collector._weight_senders["policy"] - sender.register_model(training_model) - - # No initialization needed for double-buffer scheme! - - # Update weights during training - for i, data in enumerate(collector): - # ... training step ... - if i % 10 == 0: - sender.update_weights() # Writes to shared storage - # vLLM workers can poll and apply: receiver.poll_and_apply() - -Policy Version Tracking -~~~~~~~~~~~~~~~~~~~~~~~ - -LLM Collectors also allow to track the version of the policy, which is useful for some use cases. -This is done by adding a :class:`~torchrl.envs.llm.transforms.PolicyVersion` transform to the environment, which is -then incremented by the collector after each weight update. To do this, one either provides the stateful version of the -transform, or a boolean to the collector constructor. - - >>> from torchrl.envs.llm.transforms import PolicyVersion - >>> from torchrl.collectors.llm import LLMCollector - >>> from torchrl.weight_update.llm import VLLMWeightSyncScheme, get_model_metadata - >>> env = make_env() # place your code here - >>> policy = make_policy() # place your code here - >>> scheme = VLLMWeightSyncScheme(master_port=29500, gpus_per_replica=1, num_replicas=1) - >>> collector = LLMCollector(env, policy=policy, weight_sync_schemes={"policy": scheme}, track_policy_version=True) - >>> # Get the sender and register model - >>> sender = collector._weight_senders["policy"] - >>> sender.register_model(training_model) - >>> # Initialize the collective group - >>> metadata = get_model_metadata(training_model) - >>> sender.init_all_workers_group(metadata, vllm_engine=policy.model) - >>> # Update weights - >>> sender.update_weights() - >>> print(collector.policy_version_tracker.version) - >>> # the policy version is written in the data - >>> for data in collector: - ... print(data["policy_version"]) - -.. currentmodule:: torchrl.weight_update.llm - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - VLLMWeightSyncScheme - VLLMWeightSender - VLLMWeightReceiver - VLLMCollectiveTransport - VLLMDoubleBufferSyncScheme - VLLMDoubleBufferWeightSender - VLLMDoubleBufferWeightReceiver - VLLMDoubleBufferTransport - get_model_metadata - -Legacy Weight Updaters (Deprecated) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. deprecated:: 0.11 - The `vLLMUpdater` and `vLLMUpdaterV2` classes are deprecated in favor of the new weight synchronization schemes - (:class:`~torchrl.weight_update.llm.VLLMWeightSyncScheme` and :class:`~torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme`). - These schemes provide better performance, more flexibility, and cleaner integration with collectors. - The legacy updaters will be removed in a future release. - - The legacy weight updaters (`vLLMUpdater` and `vLLMUpdaterV2`) are still available but are no longer recommended. - Please migrate to the new weight synchronization schemes shown above. - -.. currentmodule:: torchrl.collectors.llm - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - vLLMUpdater - vLLMUpdaterV2 - LLMCollector - RayLLMCollector - -Environments ------------- - -The environment layer orchestrates data loading, tool execution, reward computation, and formatting. When fine-tuning an LLM using TorchRL, the environment is a -crucial component of the inference pipeline, alongside the policy and collector. - -ChatEnv -~~~~~~~ - -:class:`~torchrl.envs.llm.ChatEnv` serves as a blank canvas for LLM environments - it's a basic tool designed to be extended with transforms that add -specific functionality. The base ChatEnv provides the fundamental structure for managing conversation state using the -:class:`~torchrl.data.llm.History` format, but it's intentionally minimal to allow maximum flexibility. - -Core Functionality -^^^^^^^^^^^^^^^^^^ - -ChatEnv operates in three main modes: -- **History mode**: Uses :class:`~torchrl.data.llm.History` objects for conversation management -- **Text mode**: Uses simple text strings for input/output -- **Tokens mode**: Uses tokenized data for input/output - -The environment maintains conversation state by: -- **Reset**: Initializes a new conversation with an optional system prompt -- **Step**: Takes the LLM's response and updates the conversation history, preparing the next prompt - -Transform-Based Architecture -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Transforms are the main way to extend ChatEnv with specific capabilities: - -- **Reward computation**: :class:`~torchrl.envs.llm.transforms.KLRewardTransform` for KL divergence rewards -- **Tool execution**: :class:`~torchrl.envs.llm.transforms.PythonInterpreter` for Python code - execution, :class:`~torchrl.envs.llm.transforms.MCPToolTransform` for general tool calling. -- **Data loading**: :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer` for loading prompts from datasets -- **Thinking prompts**: :class:`~torchrl.envs.llm.transforms.AddThinkingPrompt` for chain-of-thought reasoning -- **Policy tracking**: :class:`~torchrl.envs.llm.transforms.PolicyVersion` for version control -- **Step counting**: Built-in step tracking and reset management using :class:`~torchrl.envs.transforms.StepCounter`. - -Integration with LLM Wrappers -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. _ref_env_llm_step: - -ChatEnv is designed to work seamlessly with both :class:`~torchrl.modules.llm.TransformersWrapper` and :class:`~torchrl.modules.llm.vLLMWrapper`. -The environment handles the conversation state management while the wrapper handles the actual LLM inference, creating a clean separation of concerns. - -On each call to `step`, the environment: - -- Takes the LLM's output, specifically the `full` field, which contains the entire conversation so far, including the new response (e.g., `history.full`, `text.full`, `tokens.full`). -- Sets this `full` field as the new `prompt` for the next LLM step (e.g., `td["next", "history"].prompt`, `td["next", "text"].prompt`, `td["next", "tokens"].prompt`). -- Optionally, applies transforms to insert new user messages, tool calls, or other modifications to the conversation before the next LLM step to refine the prompt. - -This mechanism enables seamless multi-turn interactions and supports complex workflows such as tool use and reward shaping. - -Task-Specific Environments -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -We provide a few task-specific environments, such as :class:`~torchrl.envs.llm.GSM8KEnv` for the GSM8K dataset, -:class:`~torchrl.envs.llm.IFEvalEnv` for the IFEval dataset, and :class:`~torchrl.envs.llm.MLGymEnv` for MLGym integration. - -These environments wrap a :class:`~torchrl.envs.llm.ChatEnv` and add a :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer` transform -(plus an optional reward parsing transform) in a :class:`~torchrl.envs.TransformedEnv` class. - - - -.. currentmodule:: torchrl.envs.llm - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - ChatEnv - DatasetChatEnv - GSM8KEnv - make_gsm8k_env - GSM8KPrepareQuestion - IFEvalEnv - IfEvalScorer - IFEvalScoreData - LLMEnv - LLMHashingEnv - make_mlgym - MLGymWrapper - GSM8KRewardParser - -Transforms -~~~~~~~~~~ - -Transforms are used to modify the data before it is passed to the LLM. -Tools are usually implemented as transforms, and appended to a base environment -such as :class:`~torchrl.envs.llm.ChatEnv`. - -An example of a tool transform is the :class:`~torchrl.envs.llm.transforms.PythonInterpreter` transform, which is used -to execute Python code in the context of the LLM. - - >>> from torchrl.envs.llm.transforms import PythonInterpreter - >>> from torchrl.envs.llm import ChatEnv - >>> from tensordict import TensorDict, set_list_to_stack - >>> from transformers import AutoTokenizer - >>> from pprint import pprint - >>> set_list_to_stack(True).set() - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct") - >>> base_env = ChatEnv( - ... tokenizer=tokenizer, - ... system_prompt="You are an assistant that can execute Python code. Decorate your code with ```python``` tags.", - ... user_role="user", - ... system_role="system", - ... batch_size=[1], - ... ) - >>> env = base_env.append_transform(PythonInterpreter()) - >>> env.set_seed(0) - >>> # Pass the reset data - the prompt - to the environment - >>> reset_data = env.reset(TensorDict( - ... text="Let's write a Python function that returns the square of a number.", - ... batch_size=[1]) - ... ) - >>> # Simulate an action - i.e., a response from the LLM (as if we were an LLM) - >>> action = """Here is a block of code to be executed in python: - ... ```python - ... def square(x): - ... return x * x - ... print('testing the square function with input 2:', square(2)) - ... ``` - ... <|im_end|> - ... """ - >>> step_data = reset_data.set("text_response", [action]) - >>> s, s_ = env.step_and_maybe_reset(reset_data) - >>> # The history is a stack of chat messages. - >>> # The python interpreter transform has executed the code in the last message. - >>> pprint(s_["history"].apply_chat_template(tokenizer=tokenizer)) - ['<|im_start|>system\n' - 'You are an assistant that can execute Python code. Decorate your code with ' - '```python``` tags.<|im_end|>\n' - '<|im_start|>user\n' - "Let's write a Python function that returns the square of a " - 'number.<|im_end|>\n' - '<|im_start|>assistant\n' - 'Here is a block of code to be executed in python:\n' - '```python\n' - 'def square(x):\n' - ' return x * x\n' - "print('testing the square function with input 2:', square(2))\n" - '```<|im_end|>\n' - '<|im_start|>user\n' - '\n' - 'Code block 1 executed successfully:\n' - 'testing the square function with input 2: 4\n' - '\n' - '<|im_end|>\n' - '<|im_start|>assistant\n'] - -Similarly, environments that load data from a dataset are just special instances of the :class:`~torchrl.envs.llm.ChatEnv` -augmented with a :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer` transforms (and some dedicated reward parsing -transforms). - -Designing Reward Transforms -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -When designing reward transforms for LLM environments, several key considerations must be -addressed to ensure proper integration with the training pipeline. -The examples of :class:`~torchrl.envs.llm.GSM8KRewardParser` and -:class:`~torchrl.envs.llm.IfEvalScorer` provide excellent templates for reward transform design. - -**Reward Shape Requirements** - -The reward tensor must have the same number of dimensions as the logits, which is typically -two more dimensions than the environment batch size: - -- **Sparse rewards**: Shape ``(*bsz, 1, 1)`` - single reward per sequence -- **Dense rewards**: Shape ``(*bsz, num_tokens, 1)`` - per-token rewards - -This shape requirement ensures compatibility with the loss computation pipeline. -For example, in the GSM8K reward parser: - -.. code-block:: python - - # Rewards need to have shape broadcastable to [batch x tokens x 1] - tds = tds.apply(lambda t: t.unsqueeze(-1).unsqueeze(-1)) - -**Done State Management** - -It is crucial to properly manage the done state to prevent endless generation. Common strategies include: - -1. **Completion-based termination**: Set done when the response is complete (e.g., ``History.complete=True``) -2. **Content-based termination**: Set done when specific content is detected (e.g., ```` blocks) -3. **Step-based termination**: Use :class:`~torchrl.envs.transforms.StepCounter` for predetermined step limits - -Example from IFEvalScorer: - -.. code-block:: python - - if self.set_done_if_answer and bool(answer_blocks): - next_tensordict.set("done", torch.ones(...)) - next_tensordict.set("terminated", torch.ones(...)) - -**Input Mode Handling** - -Reward transforms must handle different input modes correctly: - -- **History mode**: Extract text from ``("history", "full")`` or ``("history", "response")`` -- **Text mode**: Use text directly from ``("text", "full")`` or ``("text", "response")`` -- **Tokens mode**: Decode tokens from ``("tokens", "full")`` or ``("tokens", "response")`` - -The GSM8K reward parser demonstrates this pattern: - -.. code-block:: python - - if input_mode == "history": - responses = lazy_stack([r[..., -1] for r in responses.unbind(0)]) - if hasattr(responses, "content"): - text_completion = responses.content - elif input_mode == "text": - text_completion = responses - elif input_mode == "tokens": - text_completion = self.tokenizer.decode(responses.flatten(0, 1).tolist()) - -**Specification Management** - -Accurate specification of reward and observation specs is essential for proper environment initialization. Both GSM8K and IFEval provide good examples: - -.. code-block:: python - - def transform_reward_spec(self, reward_spec: Composite) -> Composite: - shape = reward_spec.shape + (1, 1) - reward_spec.update( - Composite( - reward_answer=Unbounded(shape), - reward_think=Unbounded(shape), - reward_right=Unbounded(shape), - reward_contained=Unbounded(shape), - reward=Unbounded(shape), - success=Unbounded(shape, dtype=torch.bool), - ) - ) - return reward_spec - -**Batch Processing Considerations** - -For efficient processing, handle batched data appropriately: - -1. **Flatten batch dimensions**: Use ``tensordict.view(-1)`` for processing -2. **Reshape results**: Restore original batch structure after processing -3. **Handle variable-length sequences**: Use proper padding and masking - -**Reward Aggregation Strategies** - -Consider different reward aggregation approaches: - -1. **Simple aggregation**: Sum or average multiple reward components -2. **Weighted aggregation**: Apply different weights to different components -3. **Conditional rewards**: Base rewards on specific conditions or thresholds - -The IFEvalScorer demonstrates a sophisticated aggregation strategy: - -.. code-block:: python - - def default_reward_aggregator(self, score: IFEvalScoreData, ...): - # Format score (max 1.0) - format_score = (format_components * weights).sum(dim=-1, keepdim=True) - - # Structure score (max 1.0) - structure_score = think_score + answer_score - - # Completion bonus (max 0.2) - completion_bonus = float(complete) * 0.2 - - return format_score + structure_score + completion_bonus - -**Post-Processing in Replay Buffers** - -Rewards can also be computed after the fact by appending transforms to the replay buffer. However, done state capture must remain in the environment transform since it needs to occur on-the-fly during data collection. - -**Error Handling and Robustness** - -Implement robust error handling for parsing failures: - -.. code-block:: python - - try: - cot, potential_answer = self.extract_tags(compl) - except ET.ParseError: - cot, potential_answer = ("", "") - -**Performance Considerations** - -1. **Avoid redundant computations**: Cache parsed results when possible -2. **Use efficient text processing**: Leverage regex or XML parsing as appropriate -3. **Minimize memory allocations**: Reuse tensors and avoid unnecessary copies - -By following these design principles, reward transforms can be effectively integrated into the LLM training pipeline while maintaining performance and reliability. - -.. currentmodule:: torchrl.envs.llm.transforms - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - AddThinkingPrompt - BrowserTransform - DataLoadingPrimer - KLComputation - KLRewardTransform - MCPToolTransform - PolicyVersion - PythonInterpreter - RayDataLoadingPrimer - RetrieveKL - RetrieveLogProb - TemplateTransform - Tokenizer - as_nested_tensor - as_padded_tensor - -Objectives ----------- - -LLM post-training requires specialized loss functions that are adapted to the unique characteristics of language models. - -GRPO -~~~~ - -The :class:`~torchrl.objectives.llm.GRPOLoss` class is a thin wrapper around the :class:`~torchrl.objectives.PPOLoss` class -that codes the LLM-specific functionalities. - -.. currentmodule:: torchrl.objectives.llm - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - GRPOLoss - GRPOLossOutput - MCAdvantage - -SFT -^^^ - -.. currentmodule:: torchrl.objectives.llm - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst + # Create collector + collector = LLMCollector(env, policy, dialog_turns_per_batch=256) - SFTLoss - SFTLossOutput +.. warning:: The LLM API is still under development and may change in the future. + Feedback, issues and PRs are welcome! -.. currentmodule:: torchrl.data.llm +Documentation Sections +---------------------- -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst +.. toctree:: + :maxdepth: 2 - TopKRewardSelector + llms_data + llms_modules + llms_envs + llms_transforms + llms_collectors + llms_objectives diff --git a/docs/source/reference/llms_collectors.rst b/docs/source/reference/llms_collectors.rst new file mode 100644 index 00000000000..314b0afb4a4 --- /dev/null +++ b/docs/source/reference/llms_collectors.rst @@ -0,0 +1,34 @@ +.. currentmodule:: torchrl.collectors.llm + +LLM Collectors +============== + +Specialized collector classes for LLM use cases. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + LLMCollector + RayLLMCollector + vLLMUpdater + vLLMUpdaterV2 + +Weight Synchronization Schemes +------------------------------ + +.. currentmodule:: torchrl.weight_update.llm + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + VLLMWeightSyncScheme + VLLMWeightSender + VLLMWeightReceiver + VLLMCollectiveTransport + VLLMDoubleBufferSyncScheme + VLLMDoubleBufferWeightSender + VLLMDoubleBufferWeightReceiver + VLLMDoubleBufferTransport + get_model_metadata diff --git a/docs/source/reference/llms_data.rst b/docs/source/reference/llms_data.rst new file mode 100644 index 00000000000..382caf68489 --- /dev/null +++ b/docs/source/reference/llms_data.rst @@ -0,0 +1,36 @@ +.. currentmodule:: torchrl.data.llm + +Data Structures +=============== + +The data representation layer provides the foundation for handling conversations and LLM outputs. + +History Class +------------- + +The :class:`~torchrl.data.llm.History` class is a TensorClass version of the chat format usually found in transformers. +It provides a comprehensive API for managing conversation data with features including: + +- **Text parsing and formatting**: Convert between text and structured conversation format +- **Dynamic conversation building**: Append and extend conversations +- **Multi-model support**: Automatic template detection for various model families +- **Assistant token masking**: Identify which tokens were generated by the assistant +- **Tool calling support**: Handle function calls and tool responses +- **Batch operations**: Efficient tensor operations for multiple conversations + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + History + ContentBase + add_chat_template + +TopK Reward Selector +-------------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + TopKRewardSelector diff --git a/docs/source/reference/llms_envs.rst b/docs/source/reference/llms_envs.rst new file mode 100644 index 00000000000..e04b42a6f7b --- /dev/null +++ b/docs/source/reference/llms_envs.rst @@ -0,0 +1,24 @@ +.. currentmodule:: torchrl.envs.llm + +LLM Environments +================ + +The environment layer orchestrates data loading, tool execution, reward computation, and formatting. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + ChatEnv + DatasetChatEnv + GSM8KEnv + make_gsm8k_env + GSM8KPrepareQuestion + GSM8KRewardParser + IFEvalEnv + IfEvalScorer + IFEvalScoreData + LLMEnv + LLMHashingEnv + make_mlgym + MLGymWrapper diff --git a/docs/source/reference/llms_modules.rst b/docs/source/reference/llms_modules.rst new file mode 100644 index 00000000000..53779611a0e --- /dev/null +++ b/docs/source/reference/llms_modules.rst @@ -0,0 +1,45 @@ +.. currentmodule:: torchrl.modules.llm + +LLM Wrappers +============ + +The LLM wrapper API provides unified interfaces for different LLM backends, ensuring consistent +input/output formats across training and inference pipelines. + +Wrappers +-------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + LLMWrapperBase + TransformersWrapper + vLLMWrapper + RemoteTransformersWrapper + AsyncVLLM + +Data Structure Classes +---------------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + ChatHistory + Text + LogProbs + Masks + Tokens + +Utilities +--------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + make_async_vllm_engine + stateless_init_process_group_async + make_vllm_worker + stateless_init_process_group diff --git a/docs/source/reference/llms_objectives.rst b/docs/source/reference/llms_objectives.rst new file mode 100644 index 00000000000..e7863a081b8 --- /dev/null +++ b/docs/source/reference/llms_objectives.rst @@ -0,0 +1,27 @@ +.. currentmodule:: torchrl.objectives.llm + +LLM Objectives +============== + +Specialized loss functions for LLM training. + +GRPO +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + GRPOLoss + GRPOLossOutput + MCAdvantage + +SFT +--- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + SFTLoss + SFTLossOutput diff --git a/docs/source/reference/llms_transforms.rst b/docs/source/reference/llms_transforms.rst new file mode 100644 index 00000000000..4ab638b73d3 --- /dev/null +++ b/docs/source/reference/llms_transforms.rst @@ -0,0 +1,26 @@ +.. currentmodule:: torchrl.envs.llm.transforms + +LLM Transforms +============== + +Transforms for LLM environments, including tools and utilities. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + AddThinkingPrompt + BrowserTransform + DataLoadingPrimer + KLComputation + KLRewardTransform + MCPToolTransform + PolicyVersion + PythonInterpreter + RayDataLoadingPrimer + RetrieveKL + RetrieveLogProb + TemplateTransform + Tokenizer + as_nested_tensor + as_padded_tensor diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 28c97982b4d..2e1e932829e 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -5,467 +5,53 @@ torchrl.modules package .. _ref_modules: -TensorDict modules: Actors, exploration, value models and generative models ---------------------------------------------------------------------------- +TorchRL offers a comprehensive collection of RL-specific neural network modules built on top of +:class:`tensordict.nn.TensorDictModule`. These modules are designed to work seamlessly with +tensordict data structures, making it easy to build and compose RL models. -.. _tdmodules: +Key Features +------------ -TorchRL offers a series of module wrappers aimed at making it easy to build -RL models from the ground up. These wrappers are exclusively based on -:class:`tensordict.nn.TensorDictModule` and :class:`tensordict.nn.TensorDictSequential`. -They can loosely be split in three categories: -policies (actors), including exploration strategies, -value model and simulation models (in model-based contexts). +- **Spec-based construction**: Automatically configure output layers based on action specs +- **Probabilistic modules**: Built-in support for stochastic policies +- **Exploration strategies**: Modular exploration wrappers (ε-greedy, Ornstein-Uhlenbeck, etc.) +- **Value networks**: Q-value, distributional, and dueling architectures +- **Safe modules**: Automatic projection to satisfy action constraints +- **Model-based RL**: World model and dynamics modules -The main features are: - -- Integration of the specs in your model to ensure that the model output matches - what your environment expects as input; -- Probabilistic modules that can automatically sample from a chosen distribution - and/or return the distribution of interest; -- Custom containers for Q-Value learning, model-based agents and others. - -TensorDictModules and SafeModules -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -TorchRL :class:`~torchrl.modules.tensordict_module.SafeModule` allows you to -check the you model output matches what is to be expected for the environment. -This should be used whenever your model is to be recycled across multiple -environments for instance, and when you want to make sure that the outputs -(e.g. the action) always satisfies the bounds imposed by the environment. -Here is an example of how to use that feature with the -:class:`~torchrl.modules.tensordict_module.Actor` class: - - >>> env = GymEnv("Pendulum-v1") - >>> action_spec = env.action_spec - >>> model = nn.LazyLinear(action_spec.shape[-1]) - >>> policy = Actor(model, in_keys=["observation"], spec=action_spec, safe=True) - -The ``safe`` flag ensures that the output is always within the bounds of the -``action_spec`` domain: if the network output violates these bounds it will be -projected (in a L1-manner) into the desired domain. - -.. currentmodule:: torchrl.modules.tensordict_module - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - Actor - MultiStepActorWrapper - SafeModule - SafeSequential - TanhModule - -Exploration wrappers and modules -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To efficiently explore the environment, TorchRL proposes a series of modules -that will override the action sampled by the policy by a noisier version. -Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_type`: -if the exploration is set to ``ExplorationType.RANDOM``, the exploration is active. In all -other cases, the action written in the tensordict is simply the network output. - -.. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule` - uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch. - The :func:`~torchrl.envs.utils.set_exploration_type` context manager will have no effect on - this module. - -.. currentmodule:: torchrl.modules - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - AdditiveGaussianModule - ConsistentDropoutModule - EGreedyModule - OrnsteinUhlenbeckProcessModule - -Probabilistic actors -~~~~~~~~~~~~~~~~~~~~ - -Some algorithms such as PPO require a probabilistic policy to be implemented. -In TorchRL, these policies take the form of a model, followed by a distribution -constructor. - - .. note:: The choice of a probabilistic or regular actor class depends on the algorithm - that is being implemented. On-policy algorithms usually require a probabilistic - actor, off-policy usually have a deterministic actor with an extra exploration - strategy. There are, however, many exceptions to this rule. - -The model reads an input (typically some observation from the environment) -and outputs the parameters of a distribution, while the distribution constructor -reads these parameters and gets a random sample from the distribution and/or -provides a :class:`torch.distributions.Distribution` object. - - >>> from tensordict.nn import NormalParamExtractor, TensorDictSequential, TensorDictModule - >>> from torchrl.modules import SafeProbabilisticModule - >>> from torchrl.envs import GymEnv - >>> from torch.distributions import Normal - >>> from torch import nn - >>> - >>> env = GymEnv("Pendulum-v1") - >>> action_spec = env.action_spec - >>> model = nn.Sequential(nn.LazyLinear(action_spec.shape[-1] * 2), NormalParamExtractor()) - >>> # build the first module, which maps the observation on the mean and sd of the normal distribution - >>> model = TensorDictModule(model, in_keys=["observation"], out_keys=["loc", "scale"]) - >>> # build the distribution constructor - >>> prob_module = SafeProbabilisticModule( - ... in_keys=["loc", "scale"], - ... out_keys=["action"], - ... distribution_class=Normal, - ... return_log_prob=True, - ... spec=action_spec, - ... ) - >>> policy = TensorDictSequential(model, prob_module) - >>> # execute a rollout - >>> env.rollout(3, policy) - -To facilitate the construction of probabilistic policies, we provide a dedicated -:class:`~torchrl.modules.tensordict_module.ProbabilisticActor`: - - >>> from torchrl.modules import ProbabilisticActor - >>> policy = ProbabilisticActor( - ... model, - ... in_keys=["loc", "scale"], - ... out_keys=["action"], - ... distribution_class=Normal, - ... return_log_prob=True, - ... spec=action_spec, - ... ) - -which alleviates the need to specify a constructor and putting it with the -module in a sequence. - -Outputs of this policy will contain a ``"loc"`` and ``"scale"`` entries, an -``"action"`` sampled according to the normal distribution and the log-probability -of this action. - -.. currentmodule:: torchrl.modules.tensordict_module - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - ProbabilisticActor - SafeProbabilisticModule - SafeProbabilisticTensorDictSequential - -Q-Value actors -~~~~~~~~~~~~~~ - -Q-Value actors are a type of policy that selects actions based on the maximum value -(or "quality") of a state-action pair. This value can be represented as a table or a -function. For discrete action spaces with continuous states, it's common to use a non-linear -model like a neural network to represent this function. - -QValueActor -^^^^^^^^^^^ - -The :class:`~torchrl.modules.QValueActor` class takes in a module and an action -specification, and outputs the selected action and its corresponding value. - - >>> import torch - >>> from tensordict import TensorDict - >>> from torch import nn - >>> from torchrl.data import OneHot - >>> from torchrl.modules.tensordict_module.actors import QValueActor - >>> # Create a tensor dict with an observation - >>> td = TensorDict({'observation': torch.randn(5, 3)}, [5]) - >>> # Define the action space - >>> action_spec = OneHot(4) - >>> # Create a linear module to output action values - >>> module = nn.Linear(3, 4) - >>> # Create a QValueActor instance - >>> qvalue_actor = QValueActor(module=module, spec=action_spec) - >>> # Run the actor on the tensor dict - >>> qvalue_actor(td) - >>> print(td) - TensorDict( - fields={ - action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), - action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), - chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), - observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([5]), - device=None, - is_shared=False) - -This will output a tensor dict with the selected action and its corresponding value. - -Distributional Q-Learning -^^^^^^^^^^^^^^^^^^^^^^^^^ - -Distributional Q-learning is a variant of Q-learning that represents the value function as a -probability distribution over possible values, rather than a single scalar value. -This allows the agent to learn about the uncertainty in the environment and make more informed -decisions. -In TorchRL, distributional Q-learning is implemented using the :class:`~torchrl.modules.DistributionalQValueActor` -class. This class takes in a module, an action specification, and a support vector, and outputs the selected -action and its corresponding value distribution. - - - >>> import torch - >>> from tensordict import TensorDict - >>> from torch import nn - >>> from torchrl.data import OneHot - >>> from torchrl.modules import DistributionalQValueActor, MLP - >>> # Create a tensor dict with an observation - >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) - >>> # Define the action space - >>> action_spec = OneHot(4) - >>> # Define the number of bins for the value distribution - >>> nbins = 3 - >>> # Create an MLP module to output logits for the value distribution - >>> module = MLP(out_features=(nbins, 4), depth=2) - >>> # Create a DistributionalQValueActor instance - >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins)) - >>> # Run the actor on the tensor dict - >>> td = qvalue_actor(td) - >>> print(td) - TensorDict( - fields={ - action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), - action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), - observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([5]), - device=None, - is_shared=False) - -This will output a tensor dict with the selected action and its corresponding value distribution. - -.. currentmodule:: torchrl.modules.tensordict_module - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - QValueActor - QValueModule - DistributionalQValueActor - DistributionalQValueModule - -Value operators and joined models -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torchrl.modules.tensordict_module - -TorchRL provides a series of value operators that wrap value networks to -soften the interface with the rest of the library. -The basic building block is :class:`torchrl.modules.tensordict_module.ValueOperator`: -given an input state (and possibly action), it will automatically write a ``"state_value"`` -(or ``"state_action_value"``) in the tensordict, depending on what the input is. -As such, this class accounts for both value and quality networks. -Three classes are also proposed to group together a policy and a value network. -The :class:`~.ActorCriticOperator` is an joined actor-quality network with shared parameters: -it reads an observation, pass it through a -common backbone, writes a hidden state, feeds this hidden state to the policy, -then takes the hidden state and the action and provides the quality of the state-action -pair. -The :class:`~.ActorValueOperator` is a joined actor-value network with shared parameters: -it reads an observation, pass it through a -common backbone, writes a hidden state, feeds this hidden state to the policy -and value modules to output an action and a state value. -Finally, the :class:`~.ActorCriticWrapper` is a joined actor and value network -without shared parameters. It is mainly intended as a replacement for -:class:`~.ActorValueOperator` when a script needs to account for both options. - - >>> actor = make_actor() - >>> value = make_value() - >>> if shared_params: - ... common = make_common() - ... model = ActorValueOperator(common, actor, value) - ... else: - ... model = ActorValueOperator(actor, value) - >>> policy = model.get_policy_operator() # will work in both cases - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - ActorCriticOperator - ActorCriticWrapper - ActorValueOperator - ValueOperator - DecisionTransformerInferenceWrapper - -Domain-specific TensorDict modules -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. currentmodule:: torchrl.modules.tensordict_module - -These modules include dedicated solutions for MBRL or RLHF pipelines. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - LMHeadActorValueOperator - WorldModelWrapper - -Hooks ------ -.. currentmodule:: torchrl.modules - -The Q-value hooks are used by the :class:`~.QValueActor` and :class:`~.DistributionalQValueActor` -modules and those should be preferred in general as they are easier to create -and use. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - QValueHook - DistributionalQValueHook - -Models ------- -.. currentmodule:: torchrl.modules - -TorchRL provides a series of useful "regular" (ie non-tensordict) nn.Module -classes for RL usage. - -Regular modules -~~~~~~~~~~~~~~~ - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - BatchRenorm1d - ConsistentDropout - Conv3dNet - ConvNet - MLP - Squeeze2dLayer - SqueezeLayer - -Algorithm-specific modules -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -These networks implement sub-networks that have shown to be useful for specific -algorithms, such as DQN, DDPG or Dreamer. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - DTActor - DdpgCnnActor - DdpgCnnQNet - DdpgMlpActor - DdpgMlpQNet - DecisionTransformer - DistributionalDQNnet - DreamerActor - DuelingCnnDQNet - GRUCell - GRU - GRUModule - LSTMCell - LSTM - LSTMModule - ObsDecoder - ObsEncoder - OnlineDTActor - RSSMPosterior - RSSMPrior - set_recurrent_mode - recurrent_mode - -Multi-agent-specific modules -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -These networks implement models that can be used in multi-agent contexts. -They use :func:`~torch.vmap` to execute multiple networks all at once on the -network inputs. Because the parameters are batched, initialization may differ -from what is usually done with other PyTorch modules, see -:meth:`~torchrl.modules.MultiAgentNetBase.get_stateful_net` -for more information. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - MultiAgentNetBase - MultiAgentMLP - MultiAgentConvNet - QMixer - VDNMixer - - -Exploration ------------ -.. currentmodule:: torchrl.modules - -Noisy linear layers are a popular way of exploring the environment without -altering the actions, but by integrating the stochasticity in the weight -configuration. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - NoisyLinear - NoisyLazyLinear - reset_noise - - -Planners --------- -.. currentmodule:: torchrl.modules - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - CEMPlanner - MPCPlannerBase - MPPIPlanner - - -Distributions +Quick Example ------------- -.. currentmodule:: torchrl.modules - -Some distributions are typically used in RL scripts. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - Delta - IndependentNormal - TanhNormal - TruncatedNormal - TanhDelta - OneHotCategorical - LLMMaskedCategorical - MaskedCategorical - MaskedOneHotCategorical - Ordinal - OneHotOrdinal - -Utils ------ -.. currentmodule:: torchrl.modules.utils - -The module utils include functionals used to do some custom mappings as well as a tool to -build :class:`~torchrl.envs.TensorDictPrimer` instances from a given module. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - mappings - inv_softplus - biased_softplus - get_primers_from_module - -.. currentmodule:: torchrl.modules - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - VmapModule +.. code-block:: python + + from torchrl.modules import ProbabilisticActor, TanhNormal + from torchrl.envs import GymEnv + from tensordict.nn import TensorDictModule + import torch.nn as nn + + env = GymEnv("Pendulum-v1") + + # Create a probabilistic actor + actor = ProbabilisticActor( + module=TensorDictModule( + nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 2)), + in_keys=["observation"], + out_keys=["loc", "scale"], + ), + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + spec=env.action_spec, + ) + +Documentation Sections +---------------------- + +.. toctree:: + :maxdepth: 2 + + modules_actors + modules_exploration + modules_critics + modules_models + modules_distributions + modules_utils diff --git a/docs/source/reference/modules_actors.rst b/docs/source/reference/modules_actors.rst new file mode 100644 index 00000000000..afbf90ba702 --- /dev/null +++ b/docs/source/reference/modules_actors.rst @@ -0,0 +1,47 @@ +.. currentmodule:: torchrl.modules + +Actor Modules +============= + +Actor modules represent policies in RL. They map observations to actions, either deterministically +or stochastically. + +TensorDictModules and SafeModules +--------------------------------- + +.. currentmodule:: torchrl.modules.tensordict_module + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + Actor + MultiStepActorWrapper + SafeModule + SafeSequential + TanhModule + +Probabilistic actors +-------------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + ProbabilisticActor + SafeProbabilisticModule + SafeProbabilisticTensorDictSequential + +Q-Value actors +-------------- + +.. currentmodule:: torchrl.modules + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + QValueActor + DistributionalQValueActor + QValueModule + DistributionalQValueModule diff --git a/docs/source/reference/modules_critics.rst b/docs/source/reference/modules_critics.rst new file mode 100644 index 00000000000..ba2f7110bf5 --- /dev/null +++ b/docs/source/reference/modules_critics.rst @@ -0,0 +1,25 @@ +.. currentmodule:: torchrl.modules + +Value Networks and Critics +========================== + +Value networks estimate the value of states or state-action pairs. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + ValueOperator + DuelingCnnDQNet + DistributionalDQNnet + ConvNet + MLP + DdpgCnnActor + DdpgCnnQNet + DdpgMlpActor + DdpgMlpQNet + LSTMModule + GRUModule + OnlineDTActor + DTActor + DecisionTransformer diff --git a/docs/source/reference/modules_distributions.rst b/docs/source/reference/modules_distributions.rst new file mode 100644 index 00000000000..65ba204cee0 --- /dev/null +++ b/docs/source/reference/modules_distributions.rst @@ -0,0 +1,20 @@ +.. currentmodule:: torchrl.modules + +Distribution Classes +==================== + +Custom distribution classes for RL, extending PyTorch distributions. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + Delta + IndependentNormal + MaskedCategorical + NormalParamExtractor + OneHotCategorical + ReparamGradientStrategy + TanhDelta + TanhNormal + TruncatedNormal diff --git a/docs/source/reference/modules_exploration.rst b/docs/source/reference/modules_exploration.rst new file mode 100644 index 00000000000..01b25ef05cf --- /dev/null +++ b/docs/source/reference/modules_exploration.rst @@ -0,0 +1,15 @@ +.. currentmodule:: torchrl.modules + +Exploration Strategies +====================== + +Exploration modules add noise to actions to enable exploration during training. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + AdditiveGaussianModule + ConsistentDropoutModule + EGreedyModule + OrnsteinUhlenbeckProcessModule diff --git a/docs/source/reference/modules_models.rst b/docs/source/reference/modules_models.rst new file mode 100644 index 00000000000..be3e74ef0c7 --- /dev/null +++ b/docs/source/reference/modules_models.rst @@ -0,0 +1,18 @@ +.. currentmodule:: torchrl.modules + +World Models and Model-Based RL +=============================== + +Modules for model-based reinforcement learning, including world models and dynamics models. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + WorldModelWrapper + DreamerActor + ObsEncoder + ObsDecoder + RSSMPosterior + RSSMPrior + RSSMRollout diff --git a/docs/source/reference/modules_utils.rst b/docs/source/reference/modules_utils.rst new file mode 100644 index 00000000000..e202d8af458 --- /dev/null +++ b/docs/source/reference/modules_utils.rst @@ -0,0 +1,16 @@ +.. currentmodule:: torchrl.modules + +Utilities and Helpers +===================== + +Utility modules and helper functions for building RL networks. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + ActorValueOperator + ActorCriticOperator + ActorCriticWrapper + Shift + SquashDims diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index f2741809bd3..08c832a4cc3 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -5,413 +5,49 @@ torchrl.objectives package .. _ref_objectives: -TorchRL provides a series of losses to use in your training scripts. -The aim is to have losses that are easily reusable/swappable and that have -a simple signature. - -The main characteristics of TorchRL losses are: - -- they are stateful objects: they contain a copy of the trainable parameters - such that ``loss_module.parameters()`` gives whatever is needed to train the - algorithm. -- They follow the ``tensordict`` convention: the :meth:`torch.nn.Module.forward` - method will receive a tensordict as input that contains all the necessary - information to return a loss value. -- They output a :class:`tensordict.TensorDict` instance with the loss values - written under a ``"loss_"`` where ``smth`` is a string describing the - loss. Additional keys in the tensordict may be useful metrics to log during - training time. - -.. note:: - The reason we return independent losses is to let the user use a different - optimizer for different sets of parameters for instance. Summing the losses - can be simply done via - - >>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_")) - -.. note:: - Initializing parameters in losses can be done via a query to :meth:`~torchrl.objectives.LossModule.get_stateful_net` - which will return a stateful version of the network that can be initialized like any other module. - If the modification is done in-place, it will be downstreamed to any other module that uses the same parameter - set (within and outside of the loss): for instance, modifying the ``actor_network`` parameters from the loss - will also modify the actor in the collector. - If the parameters are modified out-of-place, :meth:`~torchrl.objectives.LossModule.from_stateful_net` can be - used to reset the parameters in the loss to the new value. - -torch.vmap and randomness -------------------------- - -TorchRL loss modules have plenty of calls to :func:`~torch.vmap` to amortize the cost of calling multiple similar models -in a loop, and instead vectorize these operations. `vmap` needs to be told explicitly what to do when random numbers -need to be generated within the call. To do this, a randomness mode need to be set and must be one of `"error"` (default, -errors when dealing with pseudo-random functions), `"same"` (replicates the results across the batch) or `"different"` -(each element of the batch is treated separately). -Relying on the default will typically result in an error such as this one: - - >>> RuntimeError: vmap: called random operation while in randomness error mode. - -Since the calls to `vmap` are buried down the loss modules, TorchRL -provides an interface to set that vmap mode from the outside through `loss.vmap_randomness = str_value`, see -:meth:`~torchrl.objectives.LossModule.vmap_randomness` for more information. - -``LossModule.vmap_randomness`` defaults to `"error"` if no random module is detected, and to `"different"` in -other cases. By default, only a limited number of modules are listed as random, but the list can be extended -using the :func:`~torchrl.objectives.common.add_random_module` function. - -Training value functions ------------------------- - -TorchRL provides a range of **value estimators** such as TD(0), TD(1), TD(:math:`\lambda`) -and GAE. -In a nutshell, a value estimator is a function of data (mostly -rewards and done states) and a state value (ie. the value -returned by a function that is fit to estimate state-values). -To learn more about value estimators, check the introduction to RL from `Sutton -and Barto `_, -in particular the chapters about value iteration and TD learning. -It gives a somewhat biased estimation of the discounted return following a state -or a state-action pair based on data and proxy maps. These estimators are -used in two contexts: - -- To train the value network to learn the "true" state value (or state-action - value) map, one needs a target value to fit it to. The better (less bias, - less variance) the estimator, the better the value network will be, which in - turn can speed up the policy training significantly. Typically, the value - network loss will look like: - - >>> value = value_network(states) - >>> target_value = value_estimator(rewards, done, value_network(next_state)) - >>> value_net_loss = (value - target_value).pow(2).mean() - -- Computing an "advantage" signal for policy-optimization. The advantage is - the delta between the value estimate (from the estimator, ie from "real" data) - and the output of the value network (ie the proxy to this value). A positive - advantage can be seen as a signal that the policy actually performed better - than expected, thereby signaling that there is room for improvement if that - trajectory is to be taken as example. Conversely, a negative advantage signifies - that the policy underperformed compared to what was to be expected. - -Thins are not always as easy as in the example above and the formula to compute -the value estimator or the advantage may be slightly more intricate than this. -To help users flexibly use one or another value estimator, we provide a simple -API to change it on-the-fly. Here is an example with DQN, but all modules will -follow a similar structure: - - >>> from torchrl.objectives import DQNLoss, ValueEstimators - >>> loss_module = DQNLoss(actor) - >>> kwargs = {"gamma": 0.9, "lmbda": 0.9} - >>> loss_module.make_value_estimator(ValueEstimators.TDLambda, **kwargs) - -The :class:`~torchrl.objectives.ValueEstimators` class enumerates the value -estimators to choose from. This makes it easy for the users to rely on -auto-completion to make their choice. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - LossModule - add_random_module - -DQN ---- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - DQNLoss - DistributionalDQNLoss - -DDPG ----- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - DDPGLoss - -SAC ---- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - SACLoss - DiscreteSACLoss - -REDQ ----- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - REDQLoss - -CrossQ ------- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - CrossQLoss - -IQL ---- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - IQLLoss - DiscreteIQLLoss - -CQL ---- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - CQLLoss - DiscreteCQLLoss - -GAIL ----- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - GAILLoss - -DT --- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - DTLoss - OnlineDTLoss - -TD3 ---- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - TD3Loss - -TD3+BC ------- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - TD3BCLoss - -PPO ---- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - PPOLoss - ClipPPOLoss - KLPENPPOLoss - -Using PPO with multi-head action policies -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. note:: The main tools to consider when building multi-head policies are: :class:`~tensordict.nn.CompositeDistribution`, - :class:`~tensordict.nn.ProbabilisticTensorDictModule` and :class:`~tensordict.nn.ProbabilisticTensorDictSequential`. - When dealing with these, it is recommended to call `tensordict.nn.set_composite_lp_aggregate(False).set()` at the - beginning of the script to instruct :class:`~tensordict.nn.CompositeDistribution` that log-probabilities should not - be aggregated but rather written as leaves in the tensordict. - -In some cases, we have a single advantage value but more than one action undertaken. Each action has its own -log-probability, and shape. For instance, it can be that the action space is structured as follows: - - >>> action_td = TensorDict( - ... agents=TensorDict( - ... action0=Tensor(batch, n_agents, f0), - ... action1=Tensor(batch, n_agents, f1, f2), - ... batch_size=torch.Size((batch, n_agents)) - ... ), - ... batch_size=torch.Size((batch,)) - ... ) - -where `f0`, `f1` and `f2` are some arbitrary integers. - -Note that, in TorchRL, the root tensordict has the shape of the environment (if the environment is batch-locked, otherwise it -has the shape of the number of batched environments being run). If the tensordict is sampled from the buffer, it will -also have the shape of the replay buffer `batch_size`. The `n_agent` dimension, although common to each action, does not -in general appear in the root tensordict's batch-size (although it appears in the sub-tensordict containing the -agent-specific data according to the :ref:`MARL API `). - -There is a legitimate reason why this is the case: the number of agent may condition some but not all the specs of the -environment. For example, some environments have a shared done state among all agents. A more complete tensordict -would in this case look like - - >>> action_td = TensorDict( - ... agents=TensorDict( - ... action0=Tensor(batch, n_agents, f0), - ... action1=Tensor(batch, n_agents, f1, f2), - ... observation=Tensor(batch, n_agents, f3), - ... batch_size=torch.Size((batch, n_agents)) - ... ), - ... done=Tensor(batch, 1), - ... [...] # etc - ... batch_size=torch.Size((batch,)) - ... ) - -Notice that `done` states and `reward` are usually flanked by a rightmost singleton dimension. See this :ref:`part of the doc ` -to learn more about this restriction. - -The log-probability of our actions given their respective distributions may look like anything like - - >>> action_td = TensorDict( - ... agents=TensorDict( - ... action0_log_prob=Tensor(batch, n_agents), - ... action1_log_prob=Tensor(batch, n_agents, f1), - ... batch_size=torch.Size((batch, n_agents)) - ... ), - ... batch_size=torch.Size((batch,)) - ... ) - -or - - >>> action_td = TensorDict( - ... agents=TensorDict( - ... action0_log_prob=Tensor(batch, n_agents), - ... action1_log_prob=Tensor(batch, n_agents), - ... batch_size=torch.Size((batch, n_agents)) - ... ), - ... batch_size=torch.Size((batch,)) - ... ) - -ie, the number of dimensions of distributions log-probabilities generally varies from the sample's dimensionality to -anything inferior to that, e.g. if the distribution is multivariate -- :class:`~torch.distributions.Dirichlet` for -instance -- or an :class:`~torch.distributions.Independent` instance. -The dimension of the tensordict, on the contrary, still matches the env's / replay-buffer's batch-size. - -During a call to the PPO loss, the loss module will schematically execute the following set of operations: - - >>> def ppo(tensordict): - ... prev_log_prob = tensordict.select(*log_prob_keys) - ... action = tensordict.select(*action_keys) - ... new_log_prob = dist.log_prob(action) - ... log_weight = new_log_prob - prev_log_prob - ... advantage = tensordict.get("advantage") # computed by GAE earlier - ... # attempt to map shape - ... log_weight.batch_size = advantage.batch_size[:-1] - ... log_weight = sum(log_weight.sum(dim="feature").values(True, True)) # get a single tensor of log_weights - ... return minimum(log_weight.exp() * advantage, log_weight.exp().clamp(1-eps, 1+eps) * advantage) - -To appreciate what a PPO pipeline looks like with multihead policies, an example can be found in the library's -`example directory `__. - - -A2C ---- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - A2CLoss - -Reinforce ---------- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - ReinforceLoss - -Dreamer -------- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - DreamerActorLoss - DreamerModelLoss - DreamerValueLoss - -Multi-agent objectives +TorchRL provides a comprehensive collection of loss modules for reinforcement learning algorithms. +These losses are designed to be stateful, reusable, and follow the tensordict convention. + +Key Features +------------ + +- **Stateful objects**: Contain trainable parameters accessible via ``loss_module.parameters()`` +- **TensorDict convention**: Input and output use TensorDict format +- **Structured output**: Loss values returned with ``"loss_"`` keys +- **Value estimators**: Support for TD(0), TD(λ), GAE, and more +- **Vmap support**: Efficient batched operations with customizable randomness modes + +Quick Example +------------- + +.. code-block:: python + + from torchrl.objectives import DDPGLoss + from torchrl.modules import Actor, ValueOperator + + # Create loss module + loss = DDPGLoss( + actor_network=actor, + value_network=value, + gamma=0.99, + ) + + # Compute loss + td = collector.rollout() + loss_vals = loss(td) + + # Get total loss + total_loss = sum(v for k, v in loss_vals.items() if k.startswith("loss_")) + +Documentation Sections ---------------------- -.. currentmodule:: torchrl.objectives.multiagent - -These objectives are specific to multi-agent algorithms. - -QMixer -~~~~~~ - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - QMixerLoss - - -Returns -------- - -.. _ref_returns: - -.. currentmodule:: torchrl.objectives.value - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - ValueEstimatorBase - TD0Estimator - TD1Estimator - TDLambdaEstimator - GAE - functional.td0_return_estimate - functional.td0_advantage_estimate - functional.td1_return_estimate - functional.vec_td1_return_estimate - functional.td1_advantage_estimate - functional.vec_td1_advantage_estimate - functional.td_lambda_return_estimate - functional.vec_td_lambda_return_estimate - functional.td_lambda_advantage_estimate - functional.vec_td_lambda_advantage_estimate - functional.generalized_advantage_estimate - functional.vec_generalized_advantage_estimate - functional.reward2go - - -Utils ------ - -.. currentmodule:: torchrl.objectives - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst +.. toctree:: + :maxdepth: 2 - HardUpdate - SoftUpdate - ValueEstimators - default_value_kwargs - distance_loss - group_optimizers - hold_out_net - hold_out_params - next_state_value + objectives_common + objectives_value + objectives_policy + objectives_actorcritic + objectives_offline + objectives_other diff --git a/docs/source/reference/objectives_actorcritic.rst b/docs/source/reference/objectives_actorcritic.rst new file mode 100644 index 00000000000..7ba0690f185 --- /dev/null +++ b/docs/source/reference/objectives_actorcritic.rst @@ -0,0 +1,17 @@ +.. currentmodule:: torchrl.objectives + +Actor-Critic Methods +==================== + +Loss modules for actor-critic algorithms. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + DDPGLoss + SACLoss + DiscreteSACLoss + TD3Loss + REDQLoss + CrossQLoss diff --git a/docs/source/reference/objectives_common.rst b/docs/source/reference/objectives_common.rst new file mode 100644 index 00000000000..34cd00be168 --- /dev/null +++ b/docs/source/reference/objectives_common.rst @@ -0,0 +1,27 @@ +.. currentmodule:: torchrl.objectives + +Common Components +================= + +Base classes and common utilities for all loss modules. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + LossModule + add_random_module + +Value Estimators +---------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + ValueEstimatorBase + ValueEstimators + TD0Estimator + TD1Estimator + TDLambdaEstimator + GAE diff --git a/docs/source/reference/objectives_offline.rst b/docs/source/reference/objectives_offline.rst new file mode 100644 index 00000000000..44a160cc4d2 --- /dev/null +++ b/docs/source/reference/objectives_offline.rst @@ -0,0 +1,16 @@ +.. currentmodule:: torchrl.objectives + +Offline RL Methods +================== + +Loss modules for offline reinforcement learning. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + CQLLoss + DiscreteCQLLoss + IQLLoss + DiscreteIQLLoss + TD3BCLoss diff --git a/docs/source/reference/objectives_other.rst b/docs/source/reference/objectives_other.rst new file mode 100644 index 00000000000..018268ed7f6 --- /dev/null +++ b/docs/source/reference/objectives_other.rst @@ -0,0 +1,17 @@ +.. currentmodule:: torchrl.objectives + +Other Loss Modules +================== + +Additional loss modules for specialized algorithms. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + GAILLoss + DTLoss + OnlineDTLoss + DreamerActorLoss + DreamerModelLoss + DreamerValueLoss diff --git a/docs/source/reference/objectives_policy.rst b/docs/source/reference/objectives_policy.rst new file mode 100644 index 00000000000..5a50bbd9c3a --- /dev/null +++ b/docs/source/reference/objectives_policy.rst @@ -0,0 +1,16 @@ +.. currentmodule:: torchrl.objectives + +Policy Gradient Methods +======================= + +Loss modules for policy gradient algorithms. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + PPOLoss + ClipPPOLoss + KLPENPPOLoss + A2CLoss + ReinforceLoss diff --git a/docs/source/reference/objectives_value.rst b/docs/source/reference/objectives_value.rst new file mode 100644 index 00000000000..63f9a37c394 --- /dev/null +++ b/docs/source/reference/objectives_value.rst @@ -0,0 +1,17 @@ +.. currentmodule:: torchrl.objectives + +Value-Based Methods +=================== + +Loss modules for value-based RL algorithms. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + DQNLoss + DistributionalDQNLoss + IQLLoss + DiscreteIQLLoss + CQLLoss + DiscreteCQLLoss diff --git a/docs/source/reference/services.rst b/docs/source/reference/services.rst new file mode 100644 index 00000000000..7ed25012abe --- /dev/null +++ b/docs/source/reference/services.rst @@ -0,0 +1,609 @@ +.. currentmodule:: torchrl + +Service Registry +================ + +.. _ref_services: + +TorchRL provides a service registry system for managing distributed services across workers in distributed applications. +This is particularly useful for sharing resources like tokenizers, replay buffers, or Python executor pools across +multiple environments or collectors. + +The service registry provides a **backend-agnostic API** for distributed service management. While the current +implementation focuses on Ray as the primary backend, the design allows for future backends (e.g., Monarch, +local multiprocessing) without changing the core API. + +Overview +-------- + +The service registry provides a centralized way to register and access distributed services that can be shared across +different parts of your application. Services are registered once and can be accessed by any worker, with the underlying +backend handling the distributed communication and resource management. + +**Current Backend Support:** + +- **Ray**: Full support for Ray-based distributed services (recommended for production use) +- **Other backends**: Planned for future releases (e.g., Monarch, local multiprocessing) + +Key Features +~~~~~~~~~~~~ + +- **Centralized Management**: Register services once and access them from anywhere in your distributed system +- **Namespace Isolation**: Services are isolated within namespaces for multi-tenant support +- **Type Safety**: Dict-like access with ``services["name"]`` syntax +- **Automatic Cleanup**: Reset all services in a namespace with a single call +- **Backend Flexibility**: Designed to support multiple distributed backends (currently Ray) + +Basic Usage +----------- + +Getting Started +~~~~~~~~~~~~~~~ + +The service registry API is backend-agnostic, but you need to specify which backend to use when getting the registry. +Currently, Ray is the only supported backend. + +.. code-block:: python + + import ray + from torchrl.services import get_services + + # Initialize your backend (Ray in this example) + ray.init() + + # Get the service registry for your chosen backend + services = get_services(backend="ray", namespace="my_namespace") + + # Register a service (the class will become a distributed service) + services.register( + "tokenizer", + TokenizerService, + vocab_size=50000, + num_cpus=1, # Backend-specific option (Ray) + ) + + # Access the service from any worker + # (other workers just need to call get_services with the same backend and namespace) + services = get_services(backend="ray", namespace="my_namespace") + tokenizer = services["tokenizer"] + + # Call the service (syntax depends on backend) + # For Ray, you need to use .remote() and ray.get() + result = ray.get(tokenizer.encode.remote("Hello world")) + + # Cleanup when done + services.reset() + ray.shutdown() + +Service Registration +~~~~~~~~~~~~~~~~~~~~ + +Services are registered by providing a name, a class (that will become a distributed service), and any initialization arguments. +The exact behavior depends on the backend being used. + +**Basic Registration (Backend-Agnostic):** + +.. code-block:: python + + # Register a service with constructor arguments + services.register( + "my_service", + MyServiceClass, + arg1="value1", + arg2="value2", + ) + +The ``register`` method accepts: + +- **name** (str): Unique identifier for the service +- **service_factory** (type): Class to instantiate as a distributed service +- **kwargs**: Arguments passed to the service constructor and/or backend-specific options + +**Backend-Specific Options (Ray):** + +When using the Ray backend, you can pass Ray actor options alongside constructor arguments: + +.. code-block:: python + + # Ray-specific: Mix actor options and constructor arguments + services.register( + "gpu_service", + GPUService, + model_name="gpt2", # Constructor argument + num_cpus=4, # Ray actor option + num_gpus=1, # Ray actor option + max_concurrency=16, # Ray actor option + ) + +For more explicit separation of backend options and constructor arguments, the Ray backend provides +``register_with_options`` (note that options are expected not to collide with constructor arguments): + +.. code-block:: python + + # Ray-specific: Explicit separation of options + services.register_with_options( + "my_service", + MyServiceClass, + actor_options={ + "num_cpus": 4, + "num_gpus": 1, + "max_concurrency": 16, + }, + model_name="gpt2", # Constructor argument + batch_size=32, # Constructor argument + ) + +.. note:: + The ``register_with_options`` method is specific to the Ray backend. Other backends may have + different mechanisms for separating backend options from constructor arguments. + +Service Access +~~~~~~~~~~~~~~ + +Services can be accessed using dict-like syntax: + +.. code-block:: python + + # Check if service exists + if "tokenizer" in services: + tokenizer = services["tokenizer"] + + # Get service (raises KeyError if not found) + tokenizer = services["tokenizer"] + + # Alternative: use get() method + tokenizer = services.get("tokenizer") + + # List all services + service_names = services.list() + print(f"Available services: {service_names}") + +Cross-Worker Visibility +~~~~~~~~~~~~~~~~~~~~~~~ + +Services registered by one worker are immediately visible to all other workers in the same namespace. +This is a core feature of the service registry, enabled by the underlying distributed backend. + +**Example with Ray Backend:** + +.. code-block:: python + + import ray + from torchrl.services import get_services + + @ray.remote + class Worker: + def register_service(self): + # Worker 1: Register a service + services = get_services(backend="ray", namespace="shared") + services.register("shared_tokenizer", TokenizerService, vocab_size=50000) + return "registered" + + def use_service(self): + # Worker 2: Use the service registered by Worker 1 + services = get_services(backend="ray", namespace="shared") + tokenizer = services["shared_tokenizer"] + return ray.get(tokenizer.encode.remote("Hello")) + + # Worker 1 registers the service + worker1 = Worker.remote() + ray.get(worker1.register_service.remote()) + + # Worker 2 can immediately use it + worker2 = Worker.remote() + result = ray.get(worker2.use_service.remote()) + +The key insight is that both workers access the same service registry by using the same ``backend`` and +``namespace`` parameters in ``get_services()``. The backend handles the distributed coordination. + +Namespace Isolation +~~~~~~~~~~~~~~~~~~~ + +Different namespaces provide complete isolation between service registries: + +.. code-block:: python + + # Training namespace + train_services = get_services(backend="ray", namespace="training") + train_services.register("tokenizer", TokenizerService, vocab_size=50000) + + # Evaluation namespace + eval_services = get_services(backend="ray", namespace="evaluation") + eval_services.register("tokenizer", TokenizerService, vocab_size=30000) + + # These are completely independent services + assert "tokenizer" in train_services + assert "tokenizer" in eval_services + # But they have different configurations + +Cleanup +~~~~~~~ + +Always clean up services when done to free resources: + +.. code-block:: python + + # Reset all services in a namespace + services.reset() + + # This terminates all service actors and clears the registry + # After reset(), the registry is empty + assert services.list() == [] + +Python Executor Service +----------------------- + +One of the most useful built-in services is the :class:`~torchrl.envs.llm.transforms.PythonExecutorService`, +which provides a shared pool of Python interpreter processes for executing code across multiple environments. +This service is designed to work with any backend, though it's currently optimized for Ray. + +Overview +~~~~~~~~ + +The Python Executor Service allows you to share a fixed pool of Python interpreters (e.g., 32 processes) across +many environments (e.g., 128 environments). This provides significant resource savings compared to giving each +environment its own interpreter process. The service is registered through the service registry and can be +accessed by any worker using the :class:`~torchrl.envs.llm.transforms.PythonInterpreter` transform. + +**Resource Efficiency:** + ++---------------------------+---------------+------------+------------------+ +| Configuration | Environments | Processes | Resource Usage | ++===========================+===============+============+==================+ +| Local (persistent) | 128 | 128 | 100% | ++---------------------------+---------------+------------+------------------+ +| Service (pool=32) | 128 | 32 | **25%** | ++---------------------------+---------------+------------+------------------+ +| Service (pool=64) | 128 | 64 | **50%** | ++---------------------------+---------------+------------+------------------+ + +Basic Usage +~~~~~~~~~~~ + +The Python Executor Service is registered like any other service, then accessed through the +:class:`~torchrl.envs.llm.transforms.PythonInterpreter` transform by specifying ``services="ray"`` +(or the appropriate backend name). + +**Example with Ray Backend:** + +.. code-block:: python + + import ray + from torchrl.services import get_services + from torchrl.envs.llm.transforms import PythonExecutorService, PythonInterpreter + from torchrl.envs.llm import ChatEnv + + # Initialize your backend + ray.init() + + # Register the Python executor service + services = get_services(backend="ray", namespace="my_namespace") + services.register( + "python_executor", + PythonExecutorService, + pool_size=32, # 32 interpreter processes + timeout=10.0, # 10 second timeout + num_cpus=32, # Ray-specific: Allocate 32 CPUs + max_concurrency=32, # Ray-specific: Allow 32 concurrent executions + ) + + # Create environments that use the service + env = ChatEnv( + batch_size=(128,), # 128 parallel environments + system_prompt="Execute Python code when requested.", + ) + + # Add PythonInterpreter transform configured to use the service + env = env.append_transform( + PythonInterpreter( + services="ray", # Use Ray backend + namespace="my_namespace", # Same namespace as registration + ) + ) + + # All 128 environments now share the 32 interpreters! + # The backend (Ray) automatically queues requests when all interpreters are busy + +Optional Service Usage +~~~~~~~~~~~~~~~~~~~~~~ + +The :class:`~torchrl.envs.llm.transforms.PythonInterpreter` transform supports optional service usage. +You can easily switch between using a shared service or local processes: + +.. code-block:: python + + # Option 1: Use shared Ray service (recommended for many envs) + env = env.append_transform( + PythonInterpreter( + services="ray", + namespace="my_namespace", + ) + ) + + # Option 2: Use local persistent processes (good for few envs) + env = env.append_transform( + PythonInterpreter( + services=None, + persistent=True, + ) + ) + + # Option 3: Use temporary processes (good for infrequent use) + env = env.append_transform( + PythonInterpreter( + services=None, + persistent=False, + ) + ) + +Conditional Usage Pattern +~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can decide at runtime whether to use a distributed service based on your configuration: + +.. code-block:: python + + import ray + from torchrl.services import get_services + from torchrl.envs.llm.transforms import PythonExecutorService, PythonInterpreter + + num_envs = 128 + use_distributed_service = ray.is_initialized() and num_envs > 16 + + if use_distributed_service: + # Use distributed service for efficient resource usage + services = get_services(backend="ray") # Could be other backends in the future + if "python_executor" not in services: + services.register( + "python_executor", + PythonExecutorService, + pool_size=32, + timeout=10.0, + num_cpus=32, # Backend-specific option + max_concurrency=32, # Backend-specific option + ) + + # Configure transform to use the service + interpreter = PythonInterpreter(services="ray") + else: + # Use local processes (no distributed backend) + interpreter = PythonInterpreter(services=None, persistent=True) + + env = env.append_transform(interpreter) + +How It Works +~~~~~~~~~~~~ + +The Python Executor Service uses a simple round-robin assignment strategy to distribute work across +a pool of interpreter processes. The backend handles concurrency control and request queuing. + +**Architecture:** + +1. **Pool of Interpreters**: The service maintains a fixed pool of ``PersistentPythonProcess`` instances +2. **Round-Robin Assignment**: Each request is assigned to the next interpreter in the pool +3. **Backend Queuing**: When all interpreters are busy, the backend queues additional requests +4. **Concurrent Execution**: The backend controls how many requests can execute simultaneously + +.. code-block:: python + + # Inside PythonExecutorService + def execute(self, code: str) -> dict: + # Simple round-robin assignment + with self._lock: + process = self.processes[self.next_idx] + self.next_idx = (self.next_idx + 1) % self.pool_size + + # Backend handles queuing (e.g., Ray's max_concurrency parameter) + return process.execute(code) + +**Backend-Specific Behavior:** + +- **Ray**: Uses the ``max_concurrency`` parameter to control concurrent executions. Requests beyond + this limit are automatically queued by Ray's actor system. +- **Other backends**: Will have their own mechanisms for concurrency control and queuing. + +Performance Considerations +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +**When to Use Service Mode (Distributed):** + +- Running > 16 parallel environments +- Resource efficiency is important +- Code execution is frequent +- Have a distributed backend available (e.g., Ray) + +**When to Use Local Persistent Mode:** + +- Running < 16 environments +- Need strict isolation between environments +- Latency is critical +- Don't want distributed backend dependency + +**When to Use Local Temp File Mode:** + +- Code execution is infrequent +- Don't want persistent processes +- Memory is more important than speed + +Advanced Usage +-------------- + +Multiple Service Configurations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can register multiple services with different configurations: + +.. code-block:: python + + services = get_services(backend="ray") + + # Fast service for simple code + services.register( + "python_executor_fast", + PythonExecutorService, + pool_size=16, + timeout=5.0, + num_cpus=16, + max_concurrency=16, + ) + + # Heavy service for complex code + services.register( + "python_executor_heavy", + PythonExecutorService, + pool_size=64, + timeout=30.0, + num_cpus=64, + max_concurrency=64, + ) + + # Use different services for different environments + fast_env = env.append_transform( + PythonInterpreter(services="ray", service_name="python_executor_fast") + ) + heavy_env = env.append_transform( + PythonInterpreter(services="ray", service_name="python_executor_heavy") + ) + +Custom Services +~~~~~~~~~~~~~~~ + +You can create your own services by defining a class and registering it: + +.. code-block:: python + + class MyCustomService: + """A custom service for your application.""" + + def __init__(self, config: dict): + self.config = config + # Initialize your service + + def process(self, data: str) -> dict: + # Process data and return results + return {"result": f"Processed: {data}"} + + # Register the custom service + services = get_services(backend="ray") + services.register( + "my_service", + MyCustomService, + config={"param1": "value1"}, + num_cpus=2, + ) + + # Use the service + my_service = services["my_service"] + result = ray.get(my_service.process.remote("Hello")) + +API Reference +------------- + +Service Registry +~~~~~~~~~~~~~~~~ + +.. currentmodule:: torchrl.services + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + get_services + reset_services + ServiceBase + RayService + +Python Executor Service +~~~~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torchrl.envs.llm.transforms + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + PythonExecutorService + PythonInterpreter + +Best Practices +-------------- + +1. **Specify Backend and Namespace**: Always explicitly specify both the backend and namespace when calling + ``get_services()`` to ensure services are registered and accessed from the correct location. + +2. **Clean Up**: Always call ``services.reset()`` when done to free resources and terminate distributed services. + +3. **Service Naming**: Use descriptive names that indicate the service's purpose (e.g., ``"python_executor"``, + ``"tokenizer_service"``). + +4. **Backend-Specific Options**: Understand which options are backend-specific (e.g., ``num_cpus``, ``num_gpus``, + ``max_concurrency`` for Ray) and which are constructor arguments for your service class. + +5. **Error Handling**: Check if services exist before accessing them: + + .. code-block:: python + + if "my_service" in services: + service = services["my_service"] + else: + # Register or handle missing service + +6. **Conditional Registration**: Only register services if they don't already exist: + + .. code-block:: python + + if "python_executor" not in services: + services.register("python_executor", PythonExecutorService, ...) + +7. **Context Managers**: Consider using context managers for automatic cleanup: + + .. code-block:: python + + class ServiceContext: + def __init__(self, backend, namespace): + self.services = get_services(backend=backend, namespace=namespace) + + def __enter__(self): + return self.services + + def __exit__(self, *args): + self.services.reset() + + with ServiceContext("ray", "my_namespace") as services: + services.register("my_service", MyService) + # Use services... + # Automatic cleanup + +8. **Backend Portability**: When writing code that should work with multiple backends, avoid using + backend-specific methods like ``register_with_options()`` (Ray-only). Stick to the common ``register()`` + API for maximum portability. + +Examples +-------- + +For complete examples, see: + +- ``examples/services/distributed_services.py`` - Basic service registry usage +- ``examples/llm/python_executor_service.py`` - Python executor service examples +- ``test/test_services.py`` - Comprehensive test suite +- ``test/test_python_executor_service.py`` - Python executor service tests + +See Also +-------- + +- :ref:`ref_llms` - LLM API documentation +- :ref:`ref_collectors` - Collector documentation +- `Ray Documentation `_ - Ray distributed framework documentation + +.. note:: + **Future Backend Support** + + The service registry is designed to be backend-agnostic. While Ray is currently the only supported + backend, the API is structured to easily accommodate additional backends in the future, such as: + + - **Monarch**: For specialized distributed computing scenarios + - **Local Multiprocessing**: For single-node parallelism without external dependencies + - **Custom Backends**: You can implement your own backend by subclassing :class:`~torchrl.services.ServiceBase` + + The core API (``get_services()``, ``register()``, ``get()``, ``list()``, ``reset()``) will remain + consistent across all backends, ensuring your code remains portable. diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index c47436d11a8..adb98e78445 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -7,417 +7,45 @@ torchrl.trainers package The trainer package provides utilities to write re-usable training scripts. The core idea is to use a trainer that implements a nested loop, where the outer loop runs the data collection steps and the inner -loop the optimization steps. We believe this fits multiple RL training schemes, such as -on-policy, off-policy, model-based and model-free solutions, offline RL and others. -More particular cases, such as meta-RL algorithms may have training schemes that differ substantially. +loop the optimization steps. -The ``trainer.train()`` method can be sketched as follows: +Key Features +------------ -.. code-block:: - :caption: Trainer loops +- **Modular hook system**: Customize training at 10 different points in the loop +- **Checkpointing support**: Save and restore training state with torch or torchsnapshot +- **Algorithm trainers**: High-level trainers for PPO, SAC with Hydra configuration +- **Builder helpers**: Utilities for constructing collectors, losses, and replay buffers - >>> for batch in collector: - ... batch = self._process_batch_hook(batch) # "batch_process" - ... self._pre_steps_log_hook(batch) # "pre_steps_log" - ... self._pre_optim_hook() # "pre_optim_steps" - ... for j in range(self.optim_steps_per_batch): - ... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch" - ... losses = self.loss_module(sub_batch) - ... self._post_loss_hook(sub_batch) # "post_loss" - ... self.optimizer.step() - ... self.optimizer.zero_grad() - ... self._post_optim_hook() # "post_optim" - ... self._post_optim_log(sub_batch) # "post_optim_log" - ... self._post_steps_hook() # "post_steps" - ... self._post_steps_log_hook(batch) # "post_steps_log" - - There are 10 hooks that can be used in a trainer loop: - - >>> for batch in collector: - ... batch = self._process_batch_hook(batch) # "batch_process" - ... self._pre_steps_log_hook(batch) # "pre_steps_log" - ... self._pre_optim_hook() # "pre_optim_steps" - ... for j in range(self.optim_steps_per_batch): - ... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch" - ... losses = self.loss_module(sub_batch) - ... self._post_loss_hook(sub_batch) # "post_loss" - ... self.optimizer.step() - ... self.optimizer.zero_grad() - ... self._post_optim_hook() # "post_optim" - ... self._post_optim_log(sub_batch) # "post_optim_log" - ... self._post_steps_hook() # "post_steps" - ... self._post_steps_log_hook(batch) # "post_steps_log" - - There are 10 hooks that can be used in a trainer loop: - - >>> for batch in collector: - ... batch = self._process_batch_hook(batch) # "batch_process" - ... self._pre_steps_log_hook(batch) # "pre_steps_log" - ... self._pre_optim_hook() # "pre_optim_steps" - ... for j in range(self.optim_steps_per_batch): - ... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch" - ... losses = self.loss_module(sub_batch) - ... self._post_loss_hook(sub_batch) # "post_loss" - ... self.optimizer.step() - ... self.optimizer.zero_grad() - ... self._post_optim_hook() # "post_optim" - ... self._post_optim_log(sub_batch) # "post_optim_log" - ... self._post_steps_hook() # "post_steps" - ... self._post_steps_log_hook(batch) # "post_steps_log" - -There are 10 hooks that can be used in a trainer loop: ``"batch_process"``, ``"pre_optim_steps"``, -``"process_optim_batch"``, ``"post_loss"``, ``"post_steps"``, ``"post_optim"``, ``"pre_steps_log"``, -``"post_steps_log"``, ``"post_optim_log"`` and ``"optimizer"``. They are indicated in the comments where they are applied. -Hooks can be split into 3 categories: **data processing** (``"batch_process"`` and ``"process_optim_batch"``), -**logging** (``"pre_steps_log"``, ``"post_optim_log"`` and ``"post_steps_log"``) and **operations** hook -(``"pre_optim_steps"``, ``"post_loss"``, ``"post_optim"`` and ``"post_steps"``). - -- **Data processing** hooks update a tensordict of data. Hooks ``__call__`` method should accept - a ``TensorDict`` object as input and update it given some strategy. - Examples of such hooks include Replay Buffer extension (``ReplayBufferTrainer.extend``), data normalization (including normalization - constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such. - -- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger - some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward - logger (``LogScalar``) and such. Hooks should return a dictionary (or a None value) containing the - data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value - should be displayed on the progression bar printed on the training log. - -- **Operation** hooks are hooks that execute specific operations over the models, data collectors, - target network updates and such. For instance, syncing the weights of the collectors using ``UpdateWeights`` - or update the priority of the replay buffer using ``ReplayBufferTrainer.update_priority`` are examples - of operation hooks. They are data-independent (they do not require a ``TensorDict`` - input), they are just supposed to be executed once at every iteration (or every N iterations). - -The hooks provided by TorchRL usually inherit from a common abstract class ``TrainerHookBase``, -and all implement three base methods: a ``state_dict`` and ``load_state_dict`` method for -checkpointing and a ``register`` method that registers the hook at the default value in the -trainer. This method takes a trainer and a module name as input. For instance, the following logging -hook is executed every 10 calls to ``"post_optim_log"``: - -.. code-block:: - - >>> class LoggingHook(TrainerHookBase): - ... def __init__(self): - ... self.counter = 0 - ... - ... def register(self, trainer, name): - ... trainer.register_module(self, "logging_hook") - ... trainer.register_op("post_optim_log", self) - ... - ... def save_dict(self): - ... return {"counter": self.counter} - ... - ... def load_state_dict(self, state_dict): - ... self.counter = state_dict["counter"] - ... - ... def __call__(self, batch): - ... if self.counter % 10 == 0: - ... self.counter += 1 - ... out = {"some_value": batch["some_value"].item(), "log_pbar": False} - ... else: - ... out = None - ... self.counter += 1 - ... return out - -Checkpointing +Quick Example ------------- -The trainer class and hooks support checkpointing, which can be achieved either -using the `torchsnapshot `_ backend or -the regular torch backend. This can be controlled via the global variable ``CKPT_BACKEND``: - -.. code-block:: - - $ CKPT_BACKEND=torchsnapshot python script.py - -``CKPT_BACKEND`` defaults to ``torch``. The advantage of torchsnapshot over pytorch -is that it is a more flexible API, which supports distributed checkpointing and -also allows users to load tensors from a file stored on disk to a tensor with a -physical storage (which pytorch currently does not support). This allows, for instance, -to load tensors from and to a replay buffer that would otherwise not fit in memory. - -When building a trainer, one can provide a path where the checkpoints are to -be written. With the ``torchsnapshot`` backend, a directory path is expected, -whereas the ``torch`` backend expects a file path (typically a ``.pt`` file). - -.. code-block:: - - >>> filepath = "path/to/dir/or/file" - >>> trainer = Trainer( - ... collector=collector, - ... total_frames=total_frames, - ... frame_skip=frame_skip, - ... loss_module=loss_module, - ... optimizer=optimizer, - ... save_trainer_file=filepath, - ... ) - >>> select_keys = SelectKeys(["action", "observation"]) - >>> select_keys.register(trainer) - >>> # to save to a path - >>> trainer.save_trainer(True) - >>> # to load from a path - >>> trainer.load_from_file(filepath) - -The ``Trainer.train()`` method can be used to execute the above loop with all of -its hooks, although using the :obj:`Trainer` class for its checkpointing capability -only is also a perfectly valid use. - - -Trainer and hooks ------------------ - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - BatchSubSampler - ClearCudaCache - CountFramesLog - LogScalar - OptimizerHook - LogValidationReward - ReplayBufferTrainer - RewardNormalizer - SelectKeys - Trainer - TrainerHookBase - UpdateWeights - TargetNetUpdaterHook - UTDRHook - - -Algorithm-specific trainers (Experimental) ------------------------------------------- - -.. warning:: - The following trainers are experimental/prototype features. The API may change in future versions. - Please report any issues or feedback to help improve these implementations. - -TorchRL provides high-level, algorithm-specific trainers that combine the modular components -into complete training solutions with sensible defaults and comprehensive configuration options. - -.. currentmodule:: torchrl.trainers.algorithms - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - PPOTrainer - SACTrainer - -Algorithm Trainers -~~~~~~~~~~~~~~~~~~ - -TorchRL provides high-level algorithm trainers that offer complete training solutions with minimal code. -These trainers feature comprehensive configuration systems built on Hydra, enabling both simple usage -and sophisticated customization. - -**Currently Available:** - -- :class:`~torchrl.trainers.algorithms.PPOTrainer` - Proximal Policy Optimization -- :class:`~torchrl.trainers.algorithms.SACTrainer` - Soft Actor-Critic - -**Key Features:** - -- **Complete pipeline**: Environment setup, data collection, and optimization -- **Hydra configuration**: Extensive dataclass-based configuration system -- **Built-in logging**: Rewards, actions, and algorithm-specific metrics -- **Modular design**: Built on existing TorchRL components -- **Minimal code**: Complete SOTA implementations in ~20 lines! - -.. warning:: - Algorithm trainers are experimental features. The API may change in future versions. - We welcome feedback and contributions to help improve these implementations! - -Quick Start Examples -^^^^^^^^^^^^^^^^^^^^ - -**PPO Training:** - -.. code-block:: bash - - # Train PPO on Pendulum-v1 with default settings - python sota-implementations/ppo_trainer/train.py - -**SAC Training:** - -.. code-block:: bash - - # Train SAC on a continuous control task - python sota-implementations/sac_trainer/train.py - -**Custom Configuration:** - -.. code-block:: bash - - # Override parameters for any algorithm - python sota-implementations/ppo_trainer/train.py \ - trainer.total_frames=2000000 \ - training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \ - networks.policy_network.num_cells=[256,256] \ - optimizer.lr=0.0003 - -**Environment Switching:** - -.. code-block:: bash - - # Switch environment and logger for any trainer - python sota-implementations/sac_trainer/train.py \ - training_env.create_env_fn.base_env.env_name=Walker2d-v4 \ - logger=tensorboard \ - logger.exp_name=sac_walker2d - -**View Configuration Options:** - -.. code-block:: bash - - # See all available options for any trainer - python sota-implementations/ppo_trainer/train.py --help - python sota-implementations/sac_trainer/train.py --help - -Universal Configuration System -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -All algorithm trainers share a unified configuration architecture organized into logical groups: - -- **Environment**: ``training_env.create_env_fn.base_env.env_name``, ``training_env.num_workers`` -- **Networks**: ``networks.policy_network.num_cells``, ``networks.value_network.num_cells`` -- **Training**: ``trainer.total_frames``, ``trainer.clip_norm``, ``optimizer.lr`` -- **Data**: ``collector.frames_per_batch``, ``replay_buffer.batch_size``, ``replay_buffer.storage.max_size`` -- **Logging**: ``logger.exp_name``, ``logger.project``, ``trainer.log_interval`` - -**Working Example:** - -All trainer implementations follow the same simple pattern: - .. code-block:: python - import hydra - from torchrl.trainers.algorithms.configs import * - - @hydra.main(config_path="config", config_name="config", version_base="1.1") - def main(cfg): - trainer = hydra.utils.instantiate(cfg.trainer) - trainer.train() - - if __name__ == "__main__": - main() - -*Complete algorithm training with full configurability in ~20 lines!* - -Configuration Classes -^^^^^^^^^^^^^^^^^^^^^ - -The trainer system uses a hierarchical configuration system with shared components. - -.. note:: - The configuration system requires Python 3.10+ due to its use of modern type annotation syntax. - -**Algorithm-Specific Trainers:** - -- **PPO**: :class:`~torchrl.trainers.algorithms.configs.trainers.PPOTrainerConfig` -- **SAC**: :class:`~torchrl.trainers.algorithms.configs.trainers.SACTrainerConfig` - -**Shared Configuration Components:** - -- **Environment**: :class:`~torchrl.trainers.algorithms.configs.envs_libs.GymEnvConfig`, :class:`~torchrl.trainers.algorithms.configs.envs.BatchedEnvConfig` -- **Networks**: :class:`~torchrl.trainers.algorithms.configs.modules.MLPConfig`, :class:`~torchrl.trainers.algorithms.configs.modules.TanhNormalModelConfig` -- **Data**: :class:`~torchrl.trainers.algorithms.configs.data.TensorDictReplayBufferConfig`, :class:`~torchrl.trainers.algorithms.configs.collectors.MultiaSyncDataCollectorConfig` -- **Objectives**: :class:`~torchrl.trainers.algorithms.configs.objectives.PPOLossConfig`, :class:`~torchrl.trainers.algorithms.configs.objectives.SACLossConfig` -- **Optimizers**: :class:`~torchrl.trainers.algorithms.configs.utils.AdamConfig`, :class:`~torchrl.trainers.algorithms.configs.utils.AdamWConfig` -- **Logging**: :class:`~torchrl.trainers.algorithms.configs.logging.WandbLoggerConfig`, :class:`~torchrl.trainers.algorithms.configs.logging.TensorboardLoggerConfig` - -Algorithm-Specific Features -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -**PPOTrainer:** - -- On-policy learning with advantage estimation -- Policy clipping and value function optimization -- Configurable number of epochs per batch -- Built-in GAE (Generalized Advantage Estimation) - -**SACTrainer:** - -- Off-policy learning with replay buffer -- Entropy-regularized policy optimization -- Target network soft updates -- Continuous action space optimization - -**Future Development:** - -The trainer system is actively expanding. Upcoming features include: - -- Additional algorithms: TD3, DQN, A2C, DDPG, and more -- Enhanced distributed training support -- Advanced configuration validation and error reporting -- Integration with more TorchRL ecosystem components - -See the complete `configuration system documentation `_ for all available options and examples. - - -Builders --------- - -.. currentmodule:: torchrl.trainers.helpers - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - make_collector_offpolicy - make_collector_onpolicy - make_dqn_loss - make_replay_buffer - make_target_updater - make_trainer - parallel_env_constructor - sync_async_collector - sync_sync_collector - transformed_env_constructor - -Utils ------ - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - correct_for_frame_skip - get_stats_random_rollout - -Loggers -------- - -.. _ref_loggers: - -.. currentmodule:: torchrl.record.loggers - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - Logger - csv.CSVLogger - mlflow.MLFlowLogger - tensorboard.TensorboardLogger - wandb.WandbLogger - get_logger - generate_exp_name - - -Recording utils ---------------- - -Recording utils are detailed :ref:`here `. - -.. currentmodule:: torchrl.record - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - VideoRecorder - TensorDictRecorder - PixelRenderTransform + from torchrl.trainers import Trainer + from torchrl.trainers import UpdateWeights, LogScalar + + # Create trainer + trainer = Trainer( + collector=collector, + total_frames=1000000, + loss_module=loss, + optimizer=optimizer, + ) + + # Register hooks + UpdateWeights(collector, 10).register(trainer) + LogScalar("reward").register(trainer) + + # Train + trainer.train() + +Documentation Sections +---------------------- + +.. toctree:: + :maxdepth: 2 + + trainers_basics + trainers_loggers + trainers_hooks diff --git a/docs/source/reference/trainers_basics.rst b/docs/source/reference/trainers_basics.rst new file mode 100644 index 00000000000..217ad92b2ec --- /dev/null +++ b/docs/source/reference/trainers_basics.rst @@ -0,0 +1,58 @@ +.. currentmodule:: torchrl.trainers + +Trainer Basics +============== + +Core trainer classes and builder utilities. + +Trainer and hooks +----------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + Trainer + TrainerHookBase + +Algorithm-specific trainers +--------------------------- + +.. currentmodule:: torchrl.trainers.algorithms + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + PPOTrainer + SACTrainer + +Builders +-------- + +.. currentmodule:: torchrl.trainers.helpers + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + make_collector_offpolicy + make_collector_onpolicy + make_dqn_loss + make_replay_buffer + make_target_updater + make_trainer + parallel_env_constructor + sync_async_collector + sync_sync_collector + transformed_env_constructor + +Utils +----- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + correct_for_frame_skip + get_stats_random_rollout diff --git a/docs/source/reference/trainers_hooks.rst b/docs/source/reference/trainers_hooks.rst new file mode 100644 index 00000000000..2a6ed8ba7e8 --- /dev/null +++ b/docs/source/reference/trainers_hooks.rst @@ -0,0 +1,23 @@ +.. currentmodule:: torchrl.trainers + +Training Hooks +============== + +Hooks for customizing the training loop at various points. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + BatchSubSampler + ClearCudaCache + CountFramesLog + LogScalar + OptimizerHook + LogValidationReward + ReplayBufferTrainer + RewardNormalizer + SelectKeys + UpdateWeights + TargetNetUpdaterHook + UTDRHook diff --git a/docs/source/reference/trainers_loggers.rst b/docs/source/reference/trainers_loggers.rst new file mode 100644 index 00000000000..d1bf25c48ba --- /dev/null +++ b/docs/source/reference/trainers_loggers.rst @@ -0,0 +1,33 @@ +.. currentmodule:: torchrl.record.loggers + +.. _ref_loggers: + +Loggers +======= + +Logger classes for experiment tracking and visualization. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + Logger + csv.CSVLogger + mlflow.MLFlowLogger + tensorboard.TensorboardLogger + wandb.WandbLogger + get_logger + generate_exp_name + +Recording utils +--------------- + +.. currentmodule:: torchrl.record + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + VideoRecorder + TensorDictRecorder + PixelRenderTransform From 95a8461983650f00d0ccaea23fe07da06758680d Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 6 Nov 2025 13:49:47 +0000 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- .pre-commit-config.yaml | 2 +- docs/source/reference/data_datasets.rst | 2 +- docs/source/reference/modules_utils.rst | 8 +- docs/source/reference/objectives_common.rst | 11 +- scripts/check-sphinx-section-underline | 127 ++++++++++++++++++++ 5 files changed, 146 insertions(+), 4 deletions(-) create mode 100755 scripts/check-sphinx-section-underline diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b83882b0f54..17ad88a8413 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,7 +52,7 @@ repos: types: [python] - id: check-sphinx-section-underline name: Check Sphinx section underline lengths - entry: ./check-sphinx-section-underline --fix + entry: ./scripts/check-sphinx-section-underline --fix language: script files: ^docs/.*\.rst$ pass_filenames: true diff --git a/docs/source/reference/data_datasets.rst b/docs/source/reference/data_datasets.rst index 5946c0ac4df..e7fc8038a04 100644 --- a/docs/source/reference/data_datasets.rst +++ b/docs/source/reference/data_datasets.rst @@ -11,7 +11,7 @@ TorchRL provides dataset utilities for offline RL and data management. datasets.AtariDQNExperienceReplay datasets.D4RLExperienceReplay - datasets.Gen_DGRLExperienceReplay + datasets.GenDGRLExperienceReplay datasets.MinariExperienceReplay datasets.OpenMLExperienceReplay datasets.OpenXExperienceReplay diff --git a/docs/source/reference/modules_utils.rst b/docs/source/reference/modules_utils.rst index e202d8af458..92ec06b2645 100644 --- a/docs/source/reference/modules_utils.rst +++ b/docs/source/reference/modules_utils.rst @@ -12,5 +12,11 @@ Utility modules and helper functions for building RL networks. ActorValueOperator ActorCriticOperator ActorCriticWrapper - Shift + +.. currentmodule:: torchrl.modules.models.utils + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + SquashDims diff --git a/docs/source/reference/objectives_common.rst b/docs/source/reference/objectives_common.rst index 34cd00be168..9ed8197715a 100644 --- a/docs/source/reference/objectives_common.rst +++ b/docs/source/reference/objectives_common.rst @@ -15,13 +15,22 @@ Base classes and common utilities for all loss modules. Value Estimators ---------------- +.. currentmodule:: torchrl.objectives.value + .. autosummary:: :toctree: generated/ :template: rl_template_noinherit.rst ValueEstimatorBase - ValueEstimators TD0Estimator TD1Estimator TDLambdaEstimator GAE + +.. currentmodule:: torchrl.objectives + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + ValueEstimators diff --git a/scripts/check-sphinx-section-underline b/scripts/check-sphinx-section-underline new file mode 100755 index 00000000000..cb0ca1be9c2 --- /dev/null +++ b/scripts/check-sphinx-section-underline @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +"""Check that Sphinx/ReST section underlines match title lengths.""" +import sys +import re +from pathlib import Path + +# Pattern to match Sphinx section titles followed by underlines +# Supports: = - ` : . ' " ~ ^ _ * + # < > +SECTION_PATTERN = re.compile(r"^([^\n]+)\n([~=\-^`':\"#\*_\+.<>]+)\n", re.MULTILINE) + + +def fix_file(path): + """Fix underline length mismatches in a file.""" + try: + text = Path(path).read_text(encoding="utf-8") + except Exception as e: + print(f"Warning: Could not read {path}: {e}") + return False + + original_text = text + fixed_count = 0 + + def replace_underline(match): + nonlocal fixed_count + title, underline = match.groups() + title_stripped = title.strip() + underline_stripped = underline.strip() + + # Skip if title is empty or looks like it might be a code block or other content + if not title_stripped or title_stripped.startswith('..'): + return match.group(0) + + if len(title_stripped) != len(underline_stripped): + # Get the underline character and create correct length underline + underline_char = underline_stripped[0] + correct_underline = underline_char * len(title_stripped) + fixed_count += 1 + return f"{title}\n{correct_underline}\n" + + return match.group(0) + + text = SECTION_PATTERN.sub(replace_underline, text) + + if text != original_text: + Path(path).write_text(text, encoding="utf-8") + return fixed_count + + return 0 + + +def check_file(path): + """Check a single file for underline length mismatches.""" + try: + text = Path(path).read_text(encoding="utf-8") + except Exception as e: + print(f"Warning: Could not read {path}: {e}") + return [] + + errors = [] + + for match in SECTION_PATTERN.finditer(text): + title, underline = match.groups() + title_stripped = title.strip() + underline_stripped = underline.strip() + + # Skip if title is empty or looks like it might be a code block or other content + if not title_stripped or title_stripped.startswith('..'): + continue + + if len(title_stripped) != len(underline_stripped): + # Calculate line number + line_num = text.count('\n', 0, match.start()) + 1 + errors.append( + f"{path}:{line_num}: " + f"title '{title_stripped}' length {len(title_stripped)}, " + f"underline length {len(underline_stripped)}" + ) + + return errors + + +def main(argv): + """Main entry point for the hook.""" + # Check for --fix flag + fix_mode = "--fix" in argv + if fix_mode: + argv = [arg for arg in argv if arg != "--fix"] + + if len(argv) < 2: + print("✅ Sphinx section underline check: no files to check.") + sys.exit(0) + + if fix_mode: + total_fixed = 0 + for path in argv[1:]: + fixed_count = fix_file(path) + if fixed_count: + print(f"✏️ Fixed {fixed_count} section(s) in {path}") + total_fixed += fixed_count + + if total_fixed: + print(f"\n✅ Fixed {total_fixed} section underline(s) total.") + sys.exit(0) + else: + print("✅ Sphinx section underline check: no fixes needed.") + sys.exit(0) + else: + all_errors = [] + for path in argv[1:]: + all_errors.extend(check_file(path)) + + if all_errors: + print("❌ Sphinx section underline length errors:\n") + for e in all_errors: + print(" ", e) + print("\nFix underline lengths to match title text.") + print("Or run with --fix flag to automatically fix them:") + print(f" python3 scripts/check-sphinx-section-underline --fix {' '.join(argv[1:])}") + sys.exit(1) + else: + print("✅ Sphinx section underline check passed.") + sys.exit(0) + + +if __name__ == "__main__": + main(sys.argv) + From 54917c1e669bcc45a6b03bca8d1d5b6466254609 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 6 Nov 2025 14:11:18 +0000 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- docs/source/reference/llms.rst | 2 +- scripts/check-sphinx-section-underline | 47 ++++++++++---------- setup.cfg | 1 + torchrl/weight_update/weight_sync_schemes.py | 2 +- 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/docs/source/reference/llms.rst b/docs/source/reference/llms.rst index dc1f748ea42..9f7f111c46d 100644 --- a/docs/source/reference/llms.rst +++ b/docs/source/reference/llms.rst @@ -471,7 +471,7 @@ SFT :template: rl_template.rst TopKRewardSelector -======= +================== llms_data llms_modules llms_envs diff --git a/scripts/check-sphinx-section-underline b/scripts/check-sphinx-section-underline index cb0ca1be9c2..9d9de2524c9 100755 --- a/scripts/check-sphinx-section-underline +++ b/scripts/check-sphinx-section-underline @@ -1,7 +1,7 @@ #!/usr/bin/env python3 """Check that Sphinx/ReST section underlines match title lengths.""" -import sys import re +import sys from pathlib import Path # Pattern to match Sphinx section titles followed by underlines @@ -16,35 +16,35 @@ def fix_file(path): except Exception as e: print(f"Warning: Could not read {path}: {e}") return False - + original_text = text fixed_count = 0 - + def replace_underline(match): nonlocal fixed_count title, underline = match.groups() title_stripped = title.strip() underline_stripped = underline.strip() - + # Skip if title is empty or looks like it might be a code block or other content - if not title_stripped or title_stripped.startswith('..'): + if not title_stripped or title_stripped.startswith(".."): return match.group(0) - + if len(title_stripped) != len(underline_stripped): # Get the underline character and create correct length underline underline_char = underline_stripped[0] correct_underline = underline_char * len(title_stripped) fixed_count += 1 return f"{title}\n{correct_underline}\n" - + return match.group(0) - + text = SECTION_PATTERN.sub(replace_underline, text) - + if text != original_text: Path(path).write_text(text, encoding="utf-8") return fixed_count - + return 0 @@ -55,27 +55,27 @@ def check_file(path): except Exception as e: print(f"Warning: Could not read {path}: {e}") return [] - + errors = [] - + for match in SECTION_PATTERN.finditer(text): title, underline = match.groups() title_stripped = title.strip() underline_stripped = underline.strip() - + # Skip if title is empty or looks like it might be a code block or other content - if not title_stripped or title_stripped.startswith('..'): + if not title_stripped or title_stripped.startswith(".."): continue - + if len(title_stripped) != len(underline_stripped): # Calculate line number - line_num = text.count('\n', 0, match.start()) + 1 + line_num = text.count("\n", 0, match.start()) + 1 errors.append( f"{path}:{line_num}: " f"title '{title_stripped}' length {len(title_stripped)}, " f"underline length {len(underline_stripped)}" ) - + return errors @@ -85,11 +85,11 @@ def main(argv): fix_mode = "--fix" in argv if fix_mode: argv = [arg for arg in argv if arg != "--fix"] - + if len(argv) < 2: print("✅ Sphinx section underline check: no files to check.") sys.exit(0) - + if fix_mode: total_fixed = 0 for path in argv[1:]: @@ -97,7 +97,7 @@ def main(argv): if fixed_count: print(f"✏️ Fixed {fixed_count} section(s) in {path}") total_fixed += fixed_count - + if total_fixed: print(f"\n✅ Fixed {total_fixed} section underline(s) total.") sys.exit(0) @@ -108,14 +108,16 @@ def main(argv): all_errors = [] for path in argv[1:]: all_errors.extend(check_file(path)) - + if all_errors: print("❌ Sphinx section underline length errors:\n") for e in all_errors: print(" ", e) print("\nFix underline lengths to match title text.") print("Or run with --fix flag to automatically fix them:") - print(f" python3 scripts/check-sphinx-section-underline --fix {' '.join(argv[1:])}") + print( + f" python3 scripts/check-sphinx-section-underline --fix {' '.join(argv[1:])}" + ) sys.exit(1) else: print("✅ Sphinx section underline check passed.") @@ -124,4 +126,3 @@ def main(argv): if __name__ == "__main__": main(sys.argv) - diff --git a/setup.cfg b/setup.cfg index 2c7173763fb..525b66f6394 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,6 +25,7 @@ per-file-ignores = torchrl/objectives/td3.py: TOR101 torchrl/objectives/value/advantages.py: TOR101 tutorials/*/**.py: T001, T201 + scripts/*: T001, T201 examples/*.py: T001, T201 packaging/verify_nightly_version.py: T001, T201 test/opengl_rendering.py: T001, T201 diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 34b35da9446..42d13108a0f 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -222,7 +222,7 @@ def _send_buffer_to_workers( # Wait for acknowledgments from all workers for pipe in self._pipes: if not pipe.poll(timeout): - raise TimeoutError(f"Timeout waiting for acknowledgment from worker") + raise TimeoutError("Timeout waiting for acknowledgment from worker") _, msg = pipe.recv() if msg != "registered": raise RuntimeError(f"Expected 'registered' acknowledgment, got '{msg}'")