diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 04c0f40c2aa..17ad88a8413 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,3 +50,10 @@ repos: entry: autoflake --in-place --remove-unused-variables --remove-all-unused-imports language: system types: [python] + - id: check-sphinx-section-underline + name: Check Sphinx section underline lengths + entry: ./scripts/check-sphinx-section-underline --fix + language: script + files: ^docs/.*\.rst$ + pass_filenames: true + description: Ensure Sphinx section underline lengths match section titles. diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index 315e17e082f..6cbfde29b30 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -5,592 +5,67 @@ torchrl.collectors package .. _ref_collectors: -Data collectors are somewhat equivalent to pytorch dataloaders, except that (1) they -collect data over non-static data sources and (2) the data is collected using a model -(likely a version of the model that is being trained). +Data collectors are the bridge between your environments and training loop, managing the process of gathering +experience data using your policy. They handle environment resets, policy execution, and data aggregation, +making it easy to collect high-quality training data efficiently. -TorchRL's data collectors accept two main arguments: an environment (or a list of -environment constructors) and a policy. They will iteratively execute an environment -step and a policy query over a defined number of steps before delivering a stack of -the data collected to the user. Environments will be reset whenever they reach a done -state, and/or after a predefined number of steps. +TorchRL provides several collector implementations optimized for different scenarios: -Because data collection is a potentially compute heavy process, it is crucial to -configure the execution hyperparameters appropriately. -The first parameter to take into consideration is whether the data collection should -occur serially with the optimization step or in parallel. The :obj:`SyncDataCollector` -class will execute the data collection on the training worker. The :obj:`MultiSyncDataCollector` -will split the workload across an number of workers and aggregate the results that -will be delivered to the training worker. Finally, the :obj:`MultiaSyncDataCollector` will -execute the data collection on several workers and deliver the first batch of results -that it can gather. This execution will occur continuously and concomitantly with -the training of the networks: this implies that the weights of the policy that -is used for the data collection may slightly lag the configuration of the policy -on the training worker. Therefore, although this class may be the fastest to collect -data, it comes at the price of being suitable only in settings where it is acceptable -to gather data asynchronously (e.g. off-policy RL or curriculum RL). -For remotely executed rollouts (:obj:`MultiSyncDataCollector` or :obj:`MultiaSyncDataCollector`) -it is necessary to synchronise the weights of the remote policy with the weights -from the training worker using either the `collector.update_policy_weights_()` or -by setting `update_at_each_batch=True` in the constructor. +- **SyncDataCollector**: Single-process collection on the training worker +- **MultiSyncDataCollector**: Parallel collection across multiple workers with synchronized delivery +- **MultiaSyncDataCollector**: Asynchronous collection with first-come-first-serve delivery +- **Distributed collectors**: For multi-node setups using Ray, RPC, or distributed backends -The second parameter to consider (in the remote settings) is the device where the -data will be collected and the device where the environment and policy operations -will be executed. For instance, a policy executed on CPU may be slower than one -executed on CUDA. When multiple inference workers run concomitantly, dispatching -the compute workload across the available devices may speed up the collection or -avoid OOM errors. Finally, the choice of the batch size and passing device (ie the -device where the data will be stored while waiting to be passed to the collection -worker) may also impact the memory management. The key parameters to control are -:obj:`devices` which controls the execution devices (ie the device of the policy) -and :obj:`storing_device` which will control the device where the environment and -data are stored during a rollout. A good heuristic is usually to use the same device -for storage and compute, which is the default behavior when only the `devices` argument -is being passed. +Key Features +------------ -Besides those compute parameters, users may choose to configure the following parameters: +- **Flexible execution**: Choose between sync, async, and distributed collection +- **Device management**: Control where environments and policies execute +- **Weight synchronization**: Keep inference policies up-to-date with training weights +- **Replay buffer integration**: Seamless compatibility with TorchRL's replay buffers +- **Batching strategies**: Multiple ways to organize collected data -- max_frames_per_traj: the number of frames after which a :obj:`env.reset()` is called -- frames_per_batch: the number of frames delivered at each iteration over the collector -- init_random_frames: the number of random steps (steps where :obj:`env.rand_step()` is being called) -- reset_at_each_iter: if :obj:`True`, the environment(s) will be reset after each batch collection -- split_trajs: if :obj:`True`, the trajectories will be split and delivered in a padded tensordict - along with a :obj:`"mask"` key that will point to a boolean mask representing the valid values. -- exploration_type: the exploration strategy to be used with the policy. -- reset_when_done: whether environments should be reset when reaching a done state. - -Collectors and batch size -------------------------- - -Because each collector has its own way of organizing the environments that are -run within, the data will come with different batch-size depending on how -the specificities of the collector. The following table summarizes what is to -be expected when collecting data: - - -+--------------------+---------------------+--------------------------------------------+------------------------------+ -| | SyncDataCollector | MultiSyncDataCollector (n=B) |MultiaSyncDataCollector (n=B) | -+====================+=====================+=============+==============+===============+==============================+ -| `cat_results` | NA | `"stack"` | `0` | `-1` | NA | -+--------------------+---------------------+-------------+--------------+---------------+------------------------------+ -| Single env | [T] | `[B, T]` | `[B*(T//B)` | `[B*(T//B)]` | [T] | -+--------------------+---------------------+-------------+--------------+---------------+------------------------------+ -| Batched env (n=P) | [P, T] | `[B, P, T]` | `[B * P, T]`| `[P, T * B]` | [P, T] | -+--------------------+---------------------+-------------+--------------+---------------+------------------------------+ - -In each of these cases, the last dimension (``T`` for ``time``) is adapted such -that the batch size equals the ``frames_per_batch`` argument passed to the -collector. - -.. warning:: :class:`~torchrl.collectors.MultiSyncDataCollector` should not be - used with ``cat_results=0``, as the data will be stacked along the batch - dimension with batched environment, or the time dimension for single environments, - which can introduce some confusion when swapping one with the other. - ``cat_results="stack"`` is a better and more consistent way of interacting - with the environments as it will keep each dimension separate, and provide - better interchangeability between configurations, collector classes and other - components. - -Whereas :class:`~torchrl.collectors.MultiSyncDataCollector` -has a dimension corresponding to the number of sub-collectors being run (``B``), -:class:`~torchrl.collectors.MultiaSyncDataCollector` doesn't. This -is easily understood when considering that :class:`~torchrl.collectors.MultiaSyncDataCollector` -delivers batches of data on a first-come, first-serve basis, whereas -:class:`~torchrl.collectors.MultiSyncDataCollector` gathers data from -each sub-collector before delivering it. - -Collectors and policy copies ----------------------------- - -When passing a policy to a collector, we can choose the device on which this policy will be run. This can be used to -keep the training version of the policy on a device and the inference version on another. For example, if you have two -CUDA devices, it may be wise to train on one device and execute the policy for inference on the other. If that is the -case, a :meth:`~torchrl.collectors.DataCollector.update_policy_weights_` can be used to copy the parameters from one -device to the other (if no copy is required, this method is a no-op). - -Since the goal is to avoid calling `policy.to(policy_device)` explicitly, the collector will do a deepcopy of the -policy structure and copy the parameters placed on the new device during instantiation if necessary. -Since not all policies support deepcopies (e.g., policies using CUDA graphs or relying on third-party libraries), we -try to limit the cases where a deepcopy will be executed. The following chart shows when this will occur. - -.. figure:: /_static/img/collector-copy.png - - Policy copy decision tree in Collectors. - -Weight Synchronization using Weight Update Schemes --------------------------------------------------- - -RL pipelines are typically split in two big computational buckets: training, and inference. -While the inference pipeline sends data to the training one, the training pipeline needs to occasionally -synchronize its weights with the inference one. -In the most basic setting (fully synchronized data collection with traditional neural networks), the same weights are -used in both instances. From there, anything can happen: - -- In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named - `DataCollectors` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights - for his instance of the policy. -- In some cases, the environment or the postprocessing hooks can rely on the usage of a model which itself needs - synchronization. This means that there can be multiple ends in the data transfer API and one needs to think beyond - policy-to-policy weight synchronization strategies. -- In the LLM world, the inference engine and the training one are very different: they will use different libraries, - kernels and calling APIs (e.g., `generate` vs. `forward`). The weight format can also be drastically different (quantized - vs non-quantized). - This makes the weight synchronization much more complex, as one cannot simply dump and load a state dict on both ends. -- One typically also has to choose who instantiates a transfer: should this come from the inference engine who actively - asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach - is to store the weights on some intermediary server and let the workers fetch them when necessary. - -TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight -transfer: - -- A `Sender` class that somehow gets the weights (or a reference to them) and initializes the transfer; -- A `Receiver` class that casts the weights to the destination module (policy or other utility module); -- A `Transport` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else). -- A Scheme that defines what sender, receiver and transport have to be used and how to initialize them. - -Each of these classes is detailed below. - -Usage Examples -~~~~~~~~~~~~~~ - -.. note:: - **Runnable versions** of these examples are available in the repository: - - - `examples/collectors/weight_sync_standalone.py `_: Standalone weight synchronization - - `examples/collectors/weight_sync_collectors.py `_: Collector integration - -Using Weight Update Schemes Independently -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Weight update schemes can be used outside of collectors for custom synchronization scenarios. -The new simplified API provides four core methods for weight synchronization: - -- ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side -- ``init_on_worker(model_id, **kwargs)`` - Initialize on worker process side -- ``get_sender()`` - Get the configured sender instance -- ``get_receiver()`` - Get the configured receiver instance - -Here's a basic example: +Quick Example +------------- .. code-block:: python - import torch - import torch.nn as nn - from torch import multiprocessing as mp - from tensordict import TensorDict - from torchrl.weight_update import ( - MultiProcessWeightSyncScheme, - SharedMemWeightSyncScheme, - ) - - # Create a simple policy - policy = nn.Linear(4, 2) - - # Example 1: Multiprocess weight synchronization with state_dict - # -------------------------------------------------------------- - # On the main process side (trainer): - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + from torchrl.collectors import SyncDataCollector + from torchrl.envs import GymEnv, ParallelEnv - # Initialize scheme with pipes - parent_pipe, child_pipe = mp.Pipe() - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) + # Create a batched environment + def make_env(): + return GymEnv("Pendulum-v1") - # Get the sender and send weights - sender = scheme.get_sender() - weights = policy.state_dict() - sender.send(weights) # Synchronous send - # or sender.send_async(weights); sender.wait_async() # Asynchronous send - - # On the worker process side: - # scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy) - # receiver = scheme.get_receiver() - # # Non-blocking check for new weights - # if receiver.receive(timeout=0.001): - # # Weights were received and applied - - # Example 2: Shared memory weight synchronization - # ------------------------------------------------ - # Create shared memory scheme with auto-registration - shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - - # Initialize with pipes for lazy registration - parent_pipe2, child_pipe2 = mp.Pipe() - shared_scheme.init_on_sender(model_id="policy", pipes=[parent_pipe2]) + env = ParallelEnv(4, make_env) - # Get sender and send weights (automatically creates shared buffer on first send) - shared_sender = shared_scheme.get_sender() - weights_td = TensorDict.from_module(policy) - shared_sender.send(weights_td) - - # Workers automatically see updates via shared memory! - -Using Weight Update Schemes with Collectors -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Weight update schemes integrate seamlessly with TorchRL collectors, enabling efficient weight synchronization -across multiple inference workers: - -.. code-block:: python - - import torch.nn as nn - from tensordict.nn import TensorDictModule - from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector - from torchrl.envs import GymEnv - from torchrl.weight_update import ( - MultiProcessWeightSyncScheme, - SharedMemWeightSyncScheme, - ) - - # Create environment and policy - env = GymEnv("CartPole-v1") - policy = TensorDictModule( - nn.Linear(env.observation_spec["observation"].shape[-1], - env.action_spec.shape[-1]), - in_keys=["observation"], - out_keys=["action"], - ) - - # Example 1: Single collector with multiprocess scheme - # ----------------------------------------------------- - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - + # Create collector collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), - policy=policy, - frames_per_batch=64, - total_frames=1000, - weight_sync_schemes={"policy": scheme}, - ) - - # Collect data and update weights periodically - for i, data in enumerate(collector): - # ... training step with data ... - - # Update policy weights every N iterations - if i % 10 == 0: - new_weights = policy.state_dict() - collector.update_policy_weights_(new_weights) - - collector.shutdown() - - # Example 2: Multiple collectors with shared memory - # -------------------------------------------------- - # Shared memory is more efficient for frequent updates - shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - - collector = MultiSyncDataCollector( - create_env_fn=[ - lambda: GymEnv("CartPole-v1"), - lambda: GymEnv("CartPole-v1"), - lambda: GymEnv("CartPole-v1"), - ], - policy=policy, - frames_per_batch=192, + env, + policy=my_policy, + frames_per_batch=200, total_frames=10000, - weight_sync_schemes={"policy": shared_scheme}, ) - - # Workers automatically see weight updates via shared memory + + # Collect data for data in collector: - # ... training ... - collector.update_policy_weights_(TensorDict.from_module(policy)) - + # data is a TensorDict with shape [4, 50] (4 envs, 50 steps each) + # Use data for training... + + # Update policy weights periodically + if should_update: + collector.update_policy_weights_() + collector.shutdown() -.. note:: - When using ``SharedMemWeightSyncScheme``, weight updates are zero-copy and extremely fast since all - processes share the same memory buffers. This is ideal for frequent weight updates but requires all - processes to be on the same machine. - -.. note:: - The ``strategy`` parameter determines the weight format: ``"state_dict"`` uses PyTorch's native state - dictionaries, while ``"tensordict"`` uses TensorDict format which can be more efficient for structured - models and supports advanced features like lazy initialization. - -Weight Senders -~~~~~~~~~~~~~~ - -.. currentmodule:: torchrl.weight_update - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - WeightSender - RayModuleTransformSender - -Weight Receivers -~~~~~~~~~~~~~~~~ - -.. currentmodule:: torchrl.weight_update - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - WeightReceiver - RayModuleTransformReceiver - -Transports -~~~~~~~~~~ - -.. currentmodule:: torchrl.weight_update - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - TransportBackend - MPTransport - SharedMemTransport - RayTransport - RayActorTransport - RPCTransport - DistributedTransport - -Schemes -~~~~~~~ - -.. currentmodule:: torchrl.weight_update - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - WeightSyncScheme - MultiProcessWeightSyncScheme - SharedMemWeightSyncScheme - NoWeightSyncScheme - RayWeightSyncScheme - RayModuleTransformScheme - RPCWeightSyncScheme - DistributedWeightSyncScheme - -Legacy: Weight Synchronization in Distributed Environments ----------------------------------------------------------- - -.. warning:: - The `WeightUpdater` is considered legacy as per the 0.11 release and will be deprecated soon. - The Weight update schemes, which provides more flexibility and a better compatibility with heavy - weight transfers (e.g., LLMs) is to be preferred. - -In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the -latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible -mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios. - -Sending and receiving model weights with WeightUpdaters -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The weight synchronization process is facilitated by one dedicated extension point: -:class:`~torchrl.collectors.WeightUpdaterBase`. These base class provides a structured interface for -implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs. - -:class:`~torchrl.collectors.WeightUpdaterBase` handles the distribution of policy weights to -the policy or to remote inference workers, as well as formatting / gathering the weights from a server if necessary. -Every collector -- server or worker -- should have a `WeightUpdaterBase` instance to handle the -weight synchronization with the policy. -Even the simplest collectors use a :class:`~torchrl.collectors.VanillaWeightUpdater` instance to update the policy -state-dict (assuming it is a :class:`~torch.nn.Module` instance). - -Extending the Updater Class -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To accommodate diverse use cases, the API allows users to extend the updater classes with custom implementations. -The goal is to be able to customize the weight sync strategy while leaving the collector and policy implementation -untouched. -This flexibility is particularly beneficial in scenarios involving complex network architectures or specialized hardware -setups. -By implementing the abstract methods in these base classes, users can define how weights are retrieved, -transformed, and applied, ensuring seamless integration with their existing infrastructure. - -.. currentmodule:: torchrl.collectors - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - WeightUpdaterBase - VanillaWeightUpdater - MultiProcessedWeightUpdater - RayWeightUpdater - -.. currentmodule:: torchrl.collectors.distributed - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - RPCWeightUpdater - DistributedWeightUpdater - -Collectors and replay buffers interoperability ----------------------------------------------- - -In the simplest scenario where single transitions have to be sampled -from the replay buffer, little attention has to be given to the way -the collector is built. Flattening the data after collection will -be a sufficient preprocessing step before populating the storage: - - >>> memory = ReplayBuffer( - ... storage=LazyTensorStorage(N), - ... transform=lambda data: data.reshape(-1)) - >>> for data in collector: - ... memory.extend(data) - -If trajectory slices have to be collected, the recommended way to achieve this is to create -a multidimensional buffer and sample using the :class:`~torchrl.data.replay_buffers.SliceSampler` -sampler class. One must ensure that the data passed to the buffer is properly shaped, with the -``time`` and ``batch`` dimensions clearly separated. In practice, the following configurations -will work: - - >>> # Single environment: no need for a multi-dimensional buffer - >>> memory = ReplayBuffer( - ... storage=LazyTensorStorage(N), - ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) - ... ) - >>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1) - >>> for data in collector: - ... memory.extend(data) - >>> # Batched environments: a multi-dim buffer is required - >>> memory = ReplayBuffer( - ... storage=LazyTensorStorage(N, ndim=2), - ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) - ... ) - >>> env = ParallelEnv(4, make_env) - >>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1) - >>> for data in collector: - ... memory.extend(data) - >>> # MultiSyncDataCollector + regular env: behaves like a ParallelEnv if cat_results="stack" - >>> memory = ReplayBuffer( - ... storage=LazyTensorStorage(N, ndim=2), - ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) - ... ) - >>> collector = MultiSyncDataCollector([make_env] * 4, - ... policy, - ... frames_per_batch=N, - ... total_frames=-1, - ... cat_results="stack") - >>> for data in collector: - ... memory.extend(data) - >>> # MultiSyncDataCollector + parallel env: the ndim must be adapted accordingly - >>> memory = ReplayBuffer( - ... storage=LazyTensorStorage(N, ndim=3), - ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) - ... ) - >>> collector = MultiSyncDataCollector([ParallelEnv(2, make_env)] * 4, - ... policy, - ... frames_per_batch=N, - ... total_frames=-1, - ... cat_results="stack") - >>> for data in collector: - ... memory.extend(data) - -Using replay buffers that sample trajectories with :class:`~torchrl.collectors.MultiSyncDataCollector` -isn't currently fully supported as the data batches can come from any worker and in most cases consecutive -batches written in the buffer won't come from the same source (thereby interrupting the trajectories). - -Running the Collector Asynchronously ------------------------------------- - -Passing replay buffers to a collector allows us to start the collection and get rid of the iterative nature of the -collector. -If you want to run a data collector in the background, simply run :meth:`~torchrl.DataCollectorBase.start`: - - >>> collector = SyncDataCollector(..., replay_buffer=rb) # pass your replay buffer - >>> collector.start() - >>> # little pause - >>> time.sleep(10) - >>> # Start training - >>> for i in range(optim_steps): - ... data = rb.sample() # Sampling from the replay buffer - ... # rest of the training loop - -Single-process collectors (:class:`~torchrl.collectors.SyncDataCollector`) will run the process using multithreading, -so be mindful of Python's GIL and related multithreading restrictions. - -Multiprocessed collectors will on the other hand let the child processes handle the filling of the buffer on their own, -which truly decouples the data collection and training. - -Data collectors that have been started with `start()` should be shut down using -:meth:`~torchrl.DataCollectorBase.async_shutdown`. - -.. warning:: Running a collector asynchronously decouples the collection from training, which means that the training - performance may be drastically different depending on the hardware, load and other factors (although it is generally - expected to provide significant speed-ups). Make sure you understand how this may affect your algorithm and if it - is a legitimate thing to do! (For example, on-policy algorithms such as PPO should not be run asynchronously - unless properly benchmarked). - -Single node data collectors ---------------------------- -.. currentmodule:: torchrl.collectors - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - DataCollectorBase - SyncDataCollector - MultiSyncDataCollector - MultiaSyncDataCollector - aSyncDataCollector - - -Distributed data collectors ---------------------------- -TorchRL provides a set of distributed data collectors. These tools support -multiple backends (``'gloo'``, ``'nccl'``, ``'mpi'`` with the :class:`~.DistributedDataCollector` -or PyTorch RPC with :class:`~.RPCDataCollector`) and launchers (``'ray'``, -``submitit`` or ``torch.multiprocessing``). -They can be efficiently used in synchronous or asynchronous mode, on a single -node or across multiple nodes. - -*Resources*: Find examples for these collectors in the -`dedicated folder `_. - -.. note:: - *Choosing the sub-collector*: All distributed collectors support the various single machine collectors. - One may wonder why using a :class:`MultiSyncDataCollector` or a :class:`~torchrl.envs.ParallelEnv` - instead. In general, multiprocessed collectors have a lower IO footprint than - parallel environments which need to communicate at each step. Yet, the model specs - play a role in the opposite direction, since using parallel environments will - result in a faster execution of the policy (and/or transforms) since these - operations will be vectorized. - -.. note:: - *Choosing the device of a collector (or a parallel environment)*: Sharing data - among processes is achieved via shared-memory buffers with parallel environment - and multiprocessed environments executed on CPU. Depending on the capabilities - of the machine being used, this may be prohibitively slow compared to sharing - data on GPU which is natively supported by cuda drivers. - In practice, this means that using the ``device="cpu"`` keyword argument when - building a parallel environment or collector can result in a slower collection - than using ``device="cuda"`` when available. - -.. note:: - Given the library's many optional dependencies (eg, Gym, Gymnasium, and many others) - warnings can quickly become quite annoying in multiprocessed / distributed settings. - By default, TorchRL filters out these warnings in sub-processes. If one still wishes to - see these warnings, they can be displayed by setting ``torchrl.filter_warnings_subprocess=False``. - -.. currentmodule:: torchrl.collectors.distributed - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - DistributedDataCollector - RPCDataCollector - DistributedSyncDataCollector - submitit_delayed_launcher - RayCollector - -Helper functions ----------------- - -.. currentmodule:: torchrl.collectors.utils +Documentation Sections +---------------------- -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst +.. toctree:: + :maxdepth: 2 - split_trajectories + collectors_basics + collectors_single + collectors_distributed + collectors_weightsync + collectors_replay diff --git a/docs/source/reference/collectors_basics.rst b/docs/source/reference/collectors_basics.rst new file mode 100644 index 00000000000..dd78cd61eb6 --- /dev/null +++ b/docs/source/reference/collectors_basics.rst @@ -0,0 +1,118 @@ +.. currentmodule:: torchrl.collectors + +.. _ref_collectors: + +Collector Basics +================ + +Data collectors are somewhat equivalent to pytorch dataloaders, except that (1) they +collect data over non-static data sources and (2) the data is collected using a model +(likely a version of the model that is being trained). + +TorchRL's data collectors accept two main arguments: an environment (or a list of +environment constructors) and a policy. They will iteratively execute an environment +step and a policy query over a defined number of steps before delivering a stack of +the data collected to the user. Environments will be reset whenever they reach a done +state, and/or after a predefined number of steps. + +Because data collection is a potentially compute heavy process, it is crucial to +configure the execution hyperparameters appropriately. +The first parameter to take into consideration is whether the data collection should +occur serially with the optimization step or in parallel. The :class:`SyncDataCollector` +class will execute the data collection on the training worker. The :class:`MultiSyncDataCollector` +will split the workload across an number of workers and aggregate the results that +will be delivered to the training worker. Finally, the :class:`MultiaSyncDataCollector` will +execute the data collection on several workers and deliver the first batch of results +that it can gather. This execution will occur continuously and concomitantly with +the training of the networks: this implies that the weights of the policy that +is used for the data collection may slightly lag the configuration of the policy +on the training worker. Therefore, although this class may be the fastest to collect +data, it comes at the price of being suitable only in settings where it is acceptable +to gather data asynchronously (e.g. off-policy RL or curriculum RL). +For remotely executed rollouts (:class:`MultiSyncDataCollector` or :class:`MultiaSyncDataCollector`) +it is necessary to synchronise the weights of the remote policy with the weights +from the training worker using either the :meth:`collector.update_policy_weights_` or +by setting ``update_at_each_batch=True`` in the constructor. + +The second parameter to consider (in the remote settings) is the device where the +data will be collected and the device where the environment and policy operations +will be executed. For instance, a policy executed on CPU may be slower than one +executed on CUDA. When multiple inference workers run concomitantly, dispatching +the compute workload across the available devices may speed up the collection or +avoid OOM errors. Finally, the choice of the batch size and passing device (ie the +device where the data will be stored while waiting to be passed to the collection +worker) may also impact the memory management. The key parameters to control are +``devices`` which controls the execution devices (ie the device of the policy) +and ``storing_device`` which will control the device where the environment and +data are stored during a rollout. A good heuristic is usually to use the same device +for storage and compute, which is the default behavior when only the ``devices`` argument +is being passed. + +Besides those compute parameters, users may choose to configure the following parameters: + +- max_frames_per_traj: the number of frames after which a :meth:`env.reset` is called +- frames_per_batch: the number of frames delivered at each iteration over the collector +- init_random_frames: the number of random steps (steps where :meth:`env.rand_step` is being called) +- reset_at_each_iter: if ``True``, the environment(s) will be reset after each batch collection +- split_trajs: if ``True``, the trajectories will be split and delivered in a padded tensordict + along with a ``"mask"`` key that will point to a boolean mask representing the valid values. +- exploration_type: the exploration strategy to be used with the policy. +- reset_when_done: whether environments should be reset when reaching a done state. + +Collectors and batch size +------------------------- + +Because each collector has its own way of organizing the environments that are +run within, the data will come with different batch-size depending on how +the specificities of the collector. The following table summarizes what is to +be expected when collecting data: + + ++--------------------+---------------------+--------------------------------------------+------------------------------+ +| | SyncDataCollector | MultiSyncDataCollector (n=B) |MultiaSyncDataCollector (n=B) | ++====================+=====================+=============+==============+===============+==============================+ +| `cat_results` | NA | `"stack"` | `0` | `-1` | NA | ++--------------------+---------------------+-------------+--------------+---------------+------------------------------+ +| Single env | [T] | `[B, T]` | `[B*(T//B)` | `[B*(T//B)]` | [T] | ++--------------------+---------------------+-------------+--------------+---------------+------------------------------+ +| Batched env (n=P) | [P, T] | `[B, P, T]` | `[B * P, T]`| `[P, T * B]` | [P, T] | ++--------------------+---------------------+-------------+--------------+---------------+------------------------------+ + +In each of these cases, the last dimension (``T`` for ``time``) is adapted such +that the batch size equals the ``frames_per_batch`` argument passed to the +collector. + +.. warning:: :class:`~torchrl.collectors.MultiSyncDataCollector` should not be + used with ``cat_results=0``, as the data will be stacked along the batch + dimension with batched environment, or the time dimension for single environments, + which can introduce some confusion when swapping one with the other. + ``cat_results="stack"`` is a better and more consistent way of interacting + with the environments as it will keep each dimension separate, and provide + better interchangeability between configurations, collector classes and other + components. + +Whereas :class:`~torchrl.collectors.MultiSyncDataCollector` +has a dimension corresponding to the number of sub-collectors being run (``B``), +:class:`~torchrl.collectors.MultiaSyncDataCollector` doesn't. This +is easily understood when considering that :class:`~torchrl.collectors.MultiaSyncDataCollector` +delivers batches of data on a first-come, first-serve basis, whereas +:class:`~torchrl.collectors.MultiSyncDataCollector` gathers data from +each sub-collector before delivering it. + +Collectors and policy copies +---------------------------- + +When passing a policy to a collector, we can choose the device on which this policy will be run. This can be used to +keep the training version of the policy on a device and the inference version on another. For example, if you have two +CUDA devices, it may be wise to train on one device and execute the policy for inference on the other. If that is the +case, a :meth:`~torchrl.collectors.DataCollector.update_policy_weights_` can be used to copy the parameters from one +device to the other (if no copy is required, this method is a no-op). + +Since the goal is to avoid calling `policy.to(policy_device)` explicitly, the collector will do a deepcopy of the +policy structure and copy the parameters placed on the new device during instantiation if necessary. +Since not all policies support deepcopies (e.g., policies using CUDA graphs or relying on third-party libraries), we +try to limit the cases where a deepcopy will be executed. The following chart shows when this will occur. + +.. figure:: /_static/img/collector-copy.png + + Policy copy decision tree in Collectors. diff --git a/docs/source/reference/collectors_distributed.rst b/docs/source/reference/collectors_distributed.rst new file mode 100644 index 00000000000..acc4c7af35b --- /dev/null +++ b/docs/source/reference/collectors_distributed.rst @@ -0,0 +1,49 @@ +.. currentmodule:: torchrl.collectors.distributed + +Distributed Collectors +====================== + +TorchRL provides a set of distributed data collectors. These tools support +multiple backends (``'gloo'``, ``'nccl'``, ``'mpi'`` with the :class:`~.DistributedDataCollector` +or PyTorch RPC with :class:`~.RPCDataCollector`) and launchers (``'ray'``, +``submitit`` or ``torch.multiprocessing``). +They can be efficiently used in synchronous or asynchronous mode, on a single +node or across multiple nodes. + +*Resources*: Find examples for these collectors in the +`dedicated folder `_. + +.. note:: + *Choosing the sub-collector*: All distributed collectors support the various single machine collectors. + One may wonder why using a :class:`MultiSyncDataCollector` or a :class:`~torchrl.envs.ParallelEnv` + instead. In general, multiprocessed collectors have a lower IO footprint than + parallel environments which need to communicate at each step. Yet, the model specs + play a role in the opposite direction, since using parallel environments will + result in a faster execution of the policy (and/or transforms) since these + operations will be vectorized. + +.. note:: + *Choosing the device of a collector (or a parallel environment)*: Sharing data + among processes is achieved via shared-memory buffers with parallel environment + and multiprocessed environments executed on CPU. Depending on the capabilities + of the machine being used, this may be prohibitively slow compared to sharing + data on GPU which is natively supported by cuda drivers. + In practice, this means that using the ``device="cpu"`` keyword argument when + building a parallel environment or collector can result in a slower collection + than using ``device="cuda"`` when available. + +.. note:: + Given the library's many optional dependencies (eg, Gym, Gymnasium, and many others) + warnings can quickly become quite annoying in multiprocessed / distributed settings. + By default, TorchRL filters out these warnings in sub-processes. If one still wishes to + see these warnings, they can be displayed by setting ``torchrl.filter_warnings_subprocess=False``. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + DistributedDataCollector + RPCDataCollector + DistributedSyncDataCollector + submitit_delayed_launcher + RayCollector diff --git a/docs/source/reference/collectors_replay.rst b/docs/source/reference/collectors_replay.rst new file mode 100644 index 00000000000..fb0a776fdbf --- /dev/null +++ b/docs/source/reference/collectors_replay.rst @@ -0,0 +1,81 @@ +.. currentmodule:: torchrl.collectors + +Collectors and Replay Buffers +============================= + +Collectors and replay buffers interoperability +---------------------------------------------- + +In the simplest scenario where single transitions have to be sampled +from the replay buffer, little attention has to be given to the way +the collector is built. Flattening the data after collection will +be a sufficient preprocessing step before populating the storage: + + >>> memory = ReplayBuffer( + ... storage=LazyTensorStorage(N), + ... transform=lambda data: data.reshape(-1)) + >>> for data in collector: + ... memory.extend(data) + +If trajectory slices have to be collected, the recommended way to achieve this is to create +a multidimensional buffer and sample using the :class:`~torchrl.data.replay_buffers.SliceSampler` +sampler class. One must ensure that the data passed to the buffer is properly shaped, with the +``time`` and ``batch`` dimensions clearly separated. In practice, the following configurations +will work: + + >>> # Single environment: no need for a multi-dimensional buffer + >>> memory = ReplayBuffer( + ... storage=LazyTensorStorage(N), + ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) + ... ) + >>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1) + >>> for data in collector: + ... memory.extend(data) + >>> # Batched environments: a multi-dim buffer is required + >>> memory = ReplayBuffer( + ... storage=LazyTensorStorage(N, ndim=2), + ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) + ... ) + >>> env = ParallelEnv(4, make_env) + >>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1) + >>> for data in collector: + ... memory.extend(data) + >>> # MultiSyncDataCollector + regular env: behaves like a ParallelEnv if cat_results="stack" + >>> memory = ReplayBuffer( + ... storage=LazyTensorStorage(N, ndim=2), + ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) + ... ) + >>> collector = MultiSyncDataCollector([make_env] * 4, + ... policy, + ... frames_per_batch=N, + ... total_frames=-1, + ... cat_results="stack") + >>> for data in collector: + ... memory.extend(data) + >>> # MultiSyncDataCollector + parallel env: the ndim must be adapted accordingly + >>> memory = ReplayBuffer( + ... storage=LazyTensorStorage(N, ndim=3), + ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) + ... ) + >>> collector = MultiSyncDataCollector([ParallelEnv(2, make_env)] * 4, + ... policy, + ... frames_per_batch=N, + ... total_frames=-1, + ... cat_results="stack") + >>> for data in collector: + ... memory.extend(data) + +Using replay buffers that sample trajectories with :class:`~torchrl.collectors.MultiSyncDataCollector` +isn't currently fully supported as the data batches can come from any worker and in most cases consecutive +batches written in the buffer won't come from the same source (thereby interrupting the trajectories). + +Helper functions +---------------- + +.. currentmodule:: torchrl.collectors.utils + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + split_trajectories diff --git a/docs/source/reference/collectors_single.rst b/docs/source/reference/collectors_single.rst new file mode 100644 index 00000000000..529b03ee4e2 --- /dev/null +++ b/docs/source/reference/collectors_single.rst @@ -0,0 +1,50 @@ +.. currentmodule:: torchrl.collectors + +Single Node Collectors +====================== + +TorchRL provides several collector classes for single-node data collection, each with different execution strategies. + +Single node data collectors +--------------------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + DataCollectorBase + SyncDataCollector + MultiSyncDataCollector + MultiaSyncDataCollector + aSyncDataCollector + +Running the Collector Asynchronously +------------------------------------ + +Passing replay buffers to a collector allows us to start the collection and get rid of the iterative nature of the +collector. +If you want to run a data collector in the background, simply run :meth:`~torchrl.DataCollectorBase.start`: + + >>> collector = SyncDataCollector(..., replay_buffer=rb) # pass your replay buffer + >>> collector.start() + >>> # little pause + >>> time.sleep(10) + >>> # Start training + >>> for i in range(optim_steps): + ... data = rb.sample() # Sampling from the replay buffer + ... # rest of the training loop + +Single-process collectors (:class:`~torchrl.collectors.SyncDataCollector`) will run the process using multithreading, +so be mindful of Python's GIL and related multithreading restrictions. + +Multiprocessed collectors will on the other hand let the child processes handle the filling of the buffer on their own, +which truly decouples the data collection and training. + +Data collectors that have been started with `start()` should be shut down using +:meth:`~torchrl.DataCollectorBase.async_shutdown`. + +.. warning:: Running a collector asynchronously decouples the collection from training, which means that the training + performance may be drastically different depending on the hardware, load and other factors (although it is generally + expected to provide significant speed-ups). Make sure you understand how this may affect your algorithm and if it + is a legitimate thing to do! (For example, on-policy algorithms such as PPO should not be run asynchronously + unless properly benchmarked). diff --git a/docs/source/reference/collectors_weightsync.rst b/docs/source/reference/collectors_weightsync.rst new file mode 100644 index 00000000000..0fcf174f3c1 --- /dev/null +++ b/docs/source/reference/collectors_weightsync.rst @@ -0,0 +1,299 @@ +.. currentmodule:: torchrl.weight_update + +Weight Synchronization +====================== + +RL pipelines are typically split in two big computational buckets: training, and inference. +While the inference pipeline sends data to the training one, the training pipeline needs to occasionally +synchronize its weights with the inference one. +In the most basic setting (fully synchronized data collection with traditional neural networks), the same weights are +used in both instances. From there, anything can happen: + +- In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named + `DataCollectors` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights + for his instance of the policy. +- In some cases, the environment or the postprocessing hooks can rely on the usage of a model which itself needs + synchronization. This means that there can be multiple ends in the data transfer API and one needs to think beyond + policy-to-policy weight synchronization strategies. +- In the LLM world, the inference engine and the training one are very different: they will use different libraries, + kernels and calling APIs (e.g., `generate` vs. `forward`). The weight format can also be drastically different (quantized + vs non-quantized). + This makes the weight synchronization much more complex, as one cannot simply dump and load a state dict on both ends. +- One typically also has to choose who instantiates a transfer: should this come from the inference engine who actively + asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach + is to store the weights on some intermediary server and let the workers fetch them when necessary. + +TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight +transfer: + +- A `Sender` class that somehow gets the weights (or a reference to them) and initializes the transfer; +- A `Receiver` class that casts the weights to the destination module (policy or other utility module); +- A `Transport` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else). +- A Scheme that defines what sender, receiver and transport have to be used and how to initialize them. + +Each of these classes is detailed below. + +Usage Examples +-------------- + +.. note:: + **Runnable versions** of these examples are available in the repository: + + - `examples/collectors/weight_sync_standalone.py `_: Standalone weight synchronization + - `examples/collectors/weight_sync_collectors.py `_: Collector integration + +Using Weight Update Schemes Independently +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Weight update schemes can be used outside of collectors for custom synchronization scenarios. +The new simplified API provides four core methods for weight synchronization: + +- ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side +- ``init_on_worker(model_id, **kwargs)`` - Initialize on worker process side +- ``get_sender()`` - Get the configured sender instance +- ``get_receiver()`` - Get the configured receiver instance + +Here's a basic example: + +.. code-block:: python + + import torch + import torch.nn as nn + from torch import multiprocessing as mp + from tensordict import TensorDict + from torchrl.weight_update import ( + MultiProcessWeightSyncScheme, + SharedMemWeightSyncScheme, + ) + + # Create a simple policy + policy = nn.Linear(4, 2) + + # Example 1: Multiprocess weight synchronization with state_dict + # -------------------------------------------------------------- + # On the main process side (trainer): + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + + # Initialize scheme with pipes + parent_pipe, child_pipe = mp.Pipe() + scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) + + # Get the sender and send weights + sender = scheme.get_sender() + weights = policy.state_dict() + sender.send(weights) # Synchronous send + # or sender.send_async(weights); sender.wait_async() # Asynchronous send + + # On the worker process side: + # scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy) + # receiver = scheme.get_receiver() + # # Non-blocking check for new weights + # if receiver.receive(timeout=0.001): + # # Weights were received and applied + + # Example 2: Shared memory weight synchronization + # ------------------------------------------------ + # Create shared memory scheme with auto-registration + shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + + # Initialize with pipes for lazy registration + parent_pipe2, child_pipe2 = mp.Pipe() + shared_scheme.init_on_sender(model_id="policy", pipes=[parent_pipe2]) + + # Get sender and send weights (automatically creates shared buffer on first send) + shared_sender = shared_scheme.get_sender() + weights_td = TensorDict.from_module(policy) + shared_sender.send(weights_td) + + # Workers automatically see updates via shared memory! + +Using Weight Update Schemes with Collectors +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Weight update schemes integrate seamlessly with TorchRL collectors, enabling efficient weight synchronization +across multiple inference workers: + +.. code-block:: python + + import torch.nn as nn + from tensordict.nn import TensorDictModule + from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector + from torchrl.envs import GymEnv + from torchrl.weight_update import ( + MultiProcessWeightSyncScheme, + SharedMemWeightSyncScheme, + ) + + # Create environment and policy + env = GymEnv("CartPole-v1") + policy = TensorDictModule( + nn.Linear(env.observation_spec["observation"].shape[-1], + env.action_spec.shape[-1]), + in_keys=["observation"], + out_keys=["action"], + ) + + # Example 1: Single collector with multiprocess scheme + # ----------------------------------------------------- + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=policy, + frames_per_batch=64, + total_frames=1000, + weight_sync_schemes={"policy": scheme}, + ) + + # Collect data and update weights periodically + for i, data in enumerate(collector): + # ... training step with data ... + + # Update policy weights every N iterations + if i % 10 == 0: + new_weights = policy.state_dict() + collector.update_policy_weights_(new_weights) + + collector.shutdown() + + # Example 2: Multiple collectors with shared memory + # -------------------------------------------------- + # Shared memory is more efficient for frequent updates + shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + + collector = MultiSyncDataCollector( + create_env_fn=[ + lambda: GymEnv("CartPole-v1"), + lambda: GymEnv("CartPole-v1"), + lambda: GymEnv("CartPole-v1"), + ], + policy=policy, + frames_per_batch=192, + total_frames=10000, + weight_sync_schemes={"policy": shared_scheme}, + ) + + # Workers automatically see weight updates via shared memory + for data in collector: + # ... training ... + collector.update_policy_weights_(TensorDict.from_module(policy)) + + collector.shutdown() + +.. note:: + When using ``SharedMemWeightSyncScheme``, weight updates are zero-copy and extremely fast since all + processes share the same memory buffers. This is ideal for frequent weight updates but requires all + processes to be on the same machine. + +.. note:: + The ``strategy`` parameter determines the weight format: ``"state_dict"`` uses PyTorch's native state + dictionaries, while ``"tensordict"`` uses TensorDict format which can be more efficient for structured + models and supports advanced features like lazy initialization. + +Weight Senders +-------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + WeightSender + RayModuleTransformSender + +Weight Receivers +---------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + WeightReceiver + RayModuleTransformReceiver + +Transports +---------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + TransportBackend + MPTransport + SharedMemTransport + RayTransport + RayActorTransport + RPCTransport + DistributedTransport + +Schemes +------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + WeightSyncScheme + MultiProcessWeightSyncScheme + SharedMemWeightSyncScheme + NoWeightSyncScheme + RayWeightSyncScheme + RayModuleTransformScheme + RPCWeightSyncScheme + DistributedWeightSyncScheme + +Legacy: Weight Updaters +----------------------- + +.. warning:: + The `WeightUpdater` is considered legacy as per the 0.11 release and will be deprecated soon. + The Weight update schemes, which provides more flexibility and a better compatibility with heavy + weight transfers (e.g., LLMs) is to be preferred. + +In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the +latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible +mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios. + +Sending and receiving model weights with WeightUpdaters +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The weight synchronization process is facilitated by one dedicated extension point: +:class:`~torchrl.collectors.WeightUpdaterBase`. These base class provides a structured interface for +implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs. + +:class:`~torchrl.collectors.WeightUpdaterBase` handles the distribution of policy weights to +the policy or to remote inference workers, as well as formatting / gathering the weights from a server if necessary. +Every collector -- server or worker -- should have a `WeightUpdaterBase` instance to handle the +weight synchronization with the policy. +Even the simplest collectors use a :class:`~torchrl.collectors.VanillaWeightUpdater` instance to update the policy +state-dict (assuming it is a :class:`~torch.nn.Module` instance). + +Extending the Updater Class +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To accommodate diverse use cases, the API allows users to extend the updater classes with custom implementations. +The goal is to be able to customize the weight sync strategy while leaving the collector and policy implementation +untouched. +This flexibility is particularly beneficial in scenarios involving complex network architectures or specialized hardware +setups. +By implementing the abstract methods in these base classes, users can define how weights are retrieved, +transformed, and applied, ensuring seamless integration with their existing infrastructure. + +.. currentmodule:: torchrl.collectors + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + WeightUpdaterBase + VanillaWeightUpdater + MultiProcessedWeightUpdater + RayWeightUpdater + +.. currentmodule:: torchrl.collectors.distributed + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + RPCWeightUpdater + DistributedWeightUpdater diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index f7372a38f6f..78035570a10 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -14,7 +14,7 @@ The advantages of using a configuration system are: - Easy to version control: you can easily version control your configuration file Quick Start with a Simple Example ----------------------------------- +--------------------------------- Let's start with a simple example that creates a Gym environment. Here's a minimal configuration file: @@ -65,7 +65,7 @@ TorchRL organizes configurations into several categories using the ``@`` syntax The ``@`` syntax allows you to assign configurations to specific locations in your config structure. More Complex Example: Parallel Environment with Transforms ------------------------------------------------------------ +---------------------------------------------------------- Here's a more complex example that creates a parallel environment with multiple transforms applied to each worker: @@ -127,7 +127,7 @@ This configuration builds a **parallel environment with 4 workers**, where each 4. **Variable interpolation**: ``${transform0}`` and ``${transform1}`` reference the separately defined transform configurations Getting Available Options --------------------------- +------------------------- To explore all available configurations and their parameters, one can use the ``--help`` flag with any TorchRL script: @@ -141,7 +141,7 @@ This shows all configuration groups and their options, making it easy to discove Complete Training Example --------------------------- +------------------------- Here's a complete configuration for PPO training: @@ -219,7 +219,7 @@ Here's a complete configuration for PPO training: exp_name: my_experiment Running Experiments --------------------- +------------------- Basic Usage ~~~~~~~~~~~ @@ -252,7 +252,7 @@ Hyperparameter Sweeps training_env.num_workers=2,4,8 Custom Configuration Files -~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: bash diff --git a/docs/source/reference/cudnn_persistent_rnn.rst b/docs/source/reference/cudnn_persistent_rnn.rst new file mode 100644 index 00000000000..d6a1e44380c --- /dev/null +++ b/docs/source/reference/cudnn_persistent_rnn.rst @@ -0,0 +1,4 @@ +.. note:: + In some circumstances when using the CUDNN backend with CuDNN 7.2.1, the backward + can be up to 5x slower when called with a batch_first input. This is expected to be fixed + in CuDNN 7.2.5. diff --git a/docs/source/reference/cudnn_rnn_determinism.rst b/docs/source/reference/cudnn_rnn_determinism.rst new file mode 100644 index 00000000000..92eff7c0596 --- /dev/null +++ b/docs/source/reference/cudnn_rnn_determinism.rst @@ -0,0 +1,8 @@ +.. note:: + If the following conditions are not met, the backward pass will use a slower but more + memory efficient implementation: + + * The input is a :class:`~torch.nn.utils.rnn.PackedSequence` + * The input is not batch first + * ``dropout != 0`` + * ``training == True`` diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 319eb19d0dc..46bfd25e2bd 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -5,1234 +5,54 @@ torchrl.data package .. _ref_data: -Replay Buffers --------------- - -Replay buffers are a central part of off-policy RL algorithms. TorchRL provides an efficient implementation of a few, -widely used replay buffers: - - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - ReplayBuffer - PrioritizedReplayBuffer - TensorDictReplayBuffer - TensorDictPrioritizedReplayBuffer - RayReplayBuffer - RemoteTensorDictReplayBuffer - -Composable Replay Buffers -------------------------- - -.. _ref_buffers: - -We also give users the ability to compose a replay buffer. -We provide a wide panel of solutions for replay buffer usage, including support for -almost any data type; storage in memory, on device or on physical memory; -several sampling strategies; usage of transforms etc. - -Supported data types and choosing a storage -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -In theory, replay buffers support any data type but we can't guarantee that each -component will support any data type. The most crude replay buffer implementation -is made of a :class:`~torchrl.data.replay_buffers.ReplayBuffer` base with a -:class:`~torchrl.data.replay_buffers.ListStorage` storage. This is very inefficient -but it will allow you to store complex data structures with non-tensor data. -Storages in contiguous memory include :class:`~torchrl.data.replay_buffers.TensorStorage`, -:class:`~torchrl.data.replay_buffers.LazyTensorStorage` and -:class:`~torchrl.data.replay_buffers.LazyMemmapStorage`. -These classes support :class:`~tensordict.TensorDict` data as first-class citizens, but also -any PyTree data structure (eg, tuples, lists, dictionaries and nested versions -of these). The :class:`~torchrl.data.replay_buffers.TensorStorage` storage requires -you to provide the storage at construction time, whereas :class:`~torchrl.data.replay_buffers.TensorStorage` -(RAM, CUDA) and :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` (physical memory) -will preallocate the storage for you after they've been extended the first time. - -Here are a few examples, starting with the generic :class:`~torchrl.data.replay_buffers.ListStorage`: - - >>> from torchrl.data.replay_buffers import ReplayBuffer, ListStorage - >>> rb = ReplayBuffer(storage=ListStorage(10)) - >>> rb.add("a string!") # first element will be a string - >>> rb.extend([30, None]) # element [1] is an int, [2] is None - -The main entry points to write onto a buffer are :meth:`~torchrl.data.ReplayBuffer.add` and -:meth:`~torchrl.data.ReplayBuffer.extend`. -One can also use :meth:`~torchrl.data.ReplayBuffer.__setitem__`, in which case the data is written -where indicated without updating the length or cursor of the buffer. This can be useful when sampling -items from the buffer and them updating their values in-place afterwards. - -Using a :class:`~torchrl.data.replay_buffers.TensorStorage` we tell our RB that -we want the storage to be contiguous, which is by far more efficient but also -more restrictive: - - >>> import torch - >>> from torchrl.data.replay_buffers import ReplayBuffer, TensorStorage - >>> container = torch.empty(10, 3, 64, 64, dtype=torch.unit8) - >>> rb = ReplayBuffer(storage=TensorStorage(container)) - >>> img = torch.randint(255, (3, 64, 64), dtype=torch.uint8) - >>> rb.add(img) - -Next we can avoid creating the container and ask the storage to do it automatically. -This is very useful when using PyTrees and tensordicts! For PyTrees as other data -structures, :meth:`~torchrl.data.replay_buffers.ReplayBuffer.add` considers the sampled -passed to it as a single instance of the type. :meth:`~torchrl.data.replay_buffers.ReplayBuffer.extend` -on the other hand will consider that the data is an iterable. For tensors, tensordicts -and lists (see below), the iterable is looked for at the root level. For PyTrees, -we assume that the leading dimension of all the leaves (tensors) in the tree -match. If they don't, ``extend`` will throw an exception. - - >>> import torch - >>> from tensordict import TensorDict - >>> from torchrl.data.replay_buffers import ReplayBuffer, LazyMemmapStorage - >>> rb_td = ReplayBuffer(storage=LazyMemmapStorage(10), batch_size=1) # max 10 elements stored - >>> rb_td.add(TensorDict({"img": torch.randint(255, (3, 64, 64), dtype=torch.unit8), - ... "labels": torch.randint(100, ())}, batch_size=[])) - >>> rb_pytree = ReplayBuffer(storage=LazyMemmapStorage(10)) # max 10 elements stored - >>> # extend with a PyTree where all tensors have the same leading dim (3) - >>> rb_pytree.extend({"a": {"b": torch.randn(3), "c": [torch.zeros(3, 2), (torch.ones(3, 10),)]}}) - >>> assert len(rb_pytree) == 3 # the replay buffer has 3 elements! - -.. note:: :meth:`~torchrl.data.replay_buffers.ReplayBuffer.extend` can have an - ambiguous signature when dealing with lists of values, which should be interpreted - either as PyTree (in which case all elements in the list will be put in a slice - in the stored PyTree in the storage) or a list of values to add one at a time. - To solve this, TorchRL makes the clear-cut distinction between list and tuple: - a tuple will be viewed as a PyTree, a list (at the root level) will be interpreted - as a stack of values to add one at a time to the buffer. - -Sampling and indexing -~~~~~~~~~~~~~~~~~~~~~ - -Replay buffers can be indexed and sampled. -Indexing and sampling collect data at given indices in the storage and then process them -through a series of transforms and ``collate_fn`` that can be passed to the `__init__` -function of the replay buffer. ``collate_fn`` comes with default values that should -match user expectations in the majority of cases, such that you should not have -to worry about it most of the time. Transforms are usually instances of :class:`~torchrl.envs.transforms.Transform` -even though regular functions will work too (in the latter case, the :meth:`~torchrl.envs.transforms.Transform.inv` -method will obviously be ignored, whereas in the first case it can be used to -preprocess the data before it is passed to the buffer). -Finally, sampling can be achieved using multithreading by passing the number of threads -to the constructor through the ``prefetch`` keyword argument. We advise users to -benchmark this technique in real life settings before adopting it, as there is -no guarantee that it will lead to a faster throughput in practice depending on -the machine and setting where it is used. - -When sampling, the ``batch_size`` can be either passed during construction -(e.g., if it's constant throughout training) or -to the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method. - -To further refine the sampling strategy, we advise you to look into our samplers! - -Here are a couple of examples of how to get data out of a replay buffer: - - >>> first_elt = rb_td[0] - >>> storage = rb_td[:] # returns all valid elements from the buffer - >>> sample = rb_td.sample(128) - >>> for data in rb_td: # iterate over the buffer using the sampler -- batch-size was set in the constructor to 1 - ... print(data) - -using the following components: - -.. currentmodule:: torchrl.data.replay_buffers - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - - CompressedListStorage - CompressedListStorageCheckpointer - FlatStorageCheckpointer - H5StorageCheckpointer - ImmutableDatasetWriter - LazyMemmapStorage - LazyTensorStorage - ListStorage - LazyStackStorage - ListStorageCheckpointer - NestedStorageCheckpointer - PrioritizedSampler - PrioritizedSliceSampler - RandomSampler - RoundRobinWriter - Sampler - SamplerWithoutReplacement - SliceSampler - SliceSamplerWithoutReplacement - Storage - StorageCheckpointerBase - StorageEnsembleCheckpointer - TensorDictMaxValueWriter - TensorDictRoundRobinWriter - TensorStorage - TensorStorageCheckpointer - Writer - - -Storage choice is very influential on replay buffer sampling latency, especially -in distributed reinforcement learning settings with larger data volumes. -:class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage` is highly -advised in distributed settings with shared storage due to the lower serialization -cost of MemoryMappedTensors as well as the ability to specify file storage locations -for improved node failure recovery. -The following mean sampling latency improvements over using :class:`~torchrl.data.replay_buffers.ListStorage` -were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/benchmarks/storage. - -+-------------------------------+-----------+ -| Storage Type | Speed up | -| | | -+===============================+===========+ -| :class:`ListStorage` | 1x | -+-------------------------------+-----------+ -| :class:`LazyTensorStorage` | 1.83x | -+-------------------------------+-----------+ -| :class:`LazyMemmapStorage` | 3.44x | -+-------------------------------+-----------+ - -Compressed Storage for Memory Efficiency -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -For applications where memory usage or memory bandwidth is a primary concern—especially when storing or transferring large sensory observations such as images, audio, or text—the :class:`~torchrl.data.replay_buffers.storages.CompressedListStorage` provides significant memory savings through compression. - -**Key features:** - -- **Memory Efficiency:** Achieves substantial memory savings via compression. -- **Data Integrity:** Maintains full data fidelity through lossless compression. -- **Flexible Compression:** Uses zstd compression by default, with support for custom compression algorithms. -- **TensorDict Support:** Seamlessly integrates with TensorDict structures. -- **Checkpointing:** Fully supports saving and loading compressed data. -- **Batched GPU Compression/Decompression:** Enables efficient replay buffer sampling directly from VRAM. - -The `CompressedListStorage` compresses data when storing and decompresses when retrieving, achieving compression ratios of 95x–122x for Atari images while maintaining full data fidelity. -We see these results in the Atari Learning Environment (ALE) from a rollout in Pong with a random policy for an episode at each compression level: - -+-------------------------------+--------+--------+--------+--------+--------+ -| Compression level of zstd | 1 | 3 | 8 | 12 | 22 | -+===============================+========+========+========+========+========+ -| Compression ratio in ALE Pong | 95x | 99x | 106x | 111x | 122x | -+-------------------------------+--------+--------+--------+--------+--------+ - -Example usage: - - >>> import torch - >>> from torchrl.data import ReplayBuffer, CompressedListStorage - >>> from tensordict import TensorDict - >>> - >>> # Create a compressed storage for image data - >>> storage = CompressedListStorage(max_size=1000, compression_level=3) - >>> rb = ReplayBuffer(storage=storage, batch_size=32) - >>> - >>> # Add image data - >>> images = torch.randn(100, 3, 84, 84) # Atari-like frames - >>> data = TensorDict({"obs": images}, batch_size=[100]) - >>> rb.extend(data) - >>> - >>> # Sample data (automatically decompressed) - >>> sample = rb.sample(32) - >>> print(sample["obs"].shape) # torch.Size([32, 3, 84, 84]) - -The compression level can be adjusted from 1 (fast, less compression) to 22 (slow, more compression), -with level 3 being a good default for most use cases. - -For custom compression algorithms: - - >>> def my_compress(tensor): - ... return tensor.to(torch.uint8) # Simple example - >>> - >>> def my_decompress(compressed_tensor, metadata): - ... return compressed_tensor.to(metadata["dtype"]) - >>> - >>> storage = CompressedListStorage( - ... max_size=1000, - ... compression_fn=my_compress, - ... decompression_fn=my_decompress - ... ) - -.. note:: The CompressedListStorage uses `zstd` for python versions of at least 3.14 and defaults to zlib otherwise. - -.. note:: Batched GPU compression relies on `nvidia.nvcomp`, please see example code - `examples/replay-buffers/compressed_replay_buffer.py `_. - -Sharing replay buffers across processes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Replay buffers can be shared between processes as long as their components are -sharable. This feature allows for multiple processes to collect data and populate a shared -replay buffer collaboratively, rather than centralizing the data on the main process -which can incur some data transmission overhead. - -Sharable storages include :class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage` -or any subclass of :class:`~torchrl.data.replay_buffers.storages.TensorStorage` -as long as they are instantiated and their content is stored as memory-mapped -tensors. Stateful writers such as :class:`~torchrl.data.replay_buffers.writers.TensorDictRoundRobinWriter` -are currently not sharable, and the same goes for stateful samplers such as -:class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`. - -A shared replay buffer can be read and extended on any process that has access -to it, as the following example shows: - - >>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage - >>> import torch - >>> from torch import multiprocessing as mp - >>> from tensordict import TensorDict - >>> - >>> def worker(rb): - ... # Updates the replay buffer with new data - ... td = TensorDict({"a": torch.ones(10)}, [10]) - ... rb.extend(td) - ... - >>> if __name__ == "__main__": - ... rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(21)) - ... td = TensorDict({"a": torch.zeros(10)}, [10]) - ... rb.extend(td) - ... - ... proc = mp.Process(target=worker, args=(rb,)) - ... proc.start() - ... proc.join() - ... # the replay buffer now has a length of 20, since the worker updated it - ... assert len(rb) == 20 - ... assert (rb["_data", "a"][:10] == 0).all() # data from main process - ... assert (rb["_data", "a"][10:20] == 1).all() # data from remote process - - -Storing trajectories -~~~~~~~~~~~~~~~~~~~~ - -It is not too difficult to store trajectories in the replay buffer. -One element to pay attention to is that the size of the replay buffer is by default -the size of the leading dimension of the storage: in other words, creating a -replay buffer with a storage of size 1M when storing multidimensional data -does not mean storing 1M frames but 1M trajectories. However, if trajectories -(or episodes/rollouts) are flattened before being stored, the capacity will still -be 1M steps. - -There is a way to circumvent this by telling the storage how many dimensions -it should take into account when saving data. This can be done through the ``ndim`` -keyword argument which is accepted by all contiguous storages such as -:class:`~torchrl.data.replay_buffers.TensorStorage` and the likes. When a -multidimensional storage is passed to a buffer, the buffer will automatically -consider the last dimension as the "time" dimension, as it is conventional in -TorchRL. This can be overridden through the ``dim_extend`` keyword argument -in :class:`~torchrl.data.ReplayBuffer`. -This is the recommended way to save trajectories that are obtained through -:class:`~torchrl.envs.ParallelEnv` or its serial counterpart, as we will see -below. - -When sampling trajectories, it may be desirable to sample sub-trajectories -to diversify learning or make the sampling more efficient. -TorchRL offers two distinctive ways of accomplishing this: - -- The :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` allows to - sample a given number of slices of trajectories stored one after another - along the leading dimension of the :class:`~torchrl.data.replay_buffers.samplers.TensorStorage`. - This is the recommended way of sampling sub-trajectories in TorchRL __especially__ - when using offline datasets (which are stored using that convention). - This strategy requires to flatten the trajectories before extending the replay - buffer and reshaping them after sampling. - The :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` class docstrings - gives extensive details about this storage and sampling strategy. - Note that :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` - is compatible with multidimensional storages. The following examples show - how to use this feature with and without flattening of the tensordict. - In the first scenario, we are collecting data from a single environment. In - that case, we are happy with a storage that concatenates the data coming in - along the first dimension, since there will be no interruption introduced - by the collection schedule: - - >>> from torchrl.envs import TransformedEnv, StepCounter, GymEnv - >>> from torchrl.collectors import SyncDataCollector, RandomPolicy - >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler - >>> env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter()) - >>> collector = SyncDataCollector(env, - ... RandomPolicy(env.action_spec), - ... frames_per_batch=10, total_frames=-1) - >>> rb = ReplayBuffer( - ... storage=LazyTensorStorage(100), - ... sampler=SliceSampler(num_slices=8, traj_key=("collector", "traj_ids"), - ... truncated_key=None, strict_length=False), - ... batch_size=64) - >>> for i, data in enumerate(collector): - ... rb.extend(data) - ... if i == 10: - ... break - >>> assert len(rb) == 100, len(rb) - >>> print(rb[:]["next", "step_count"]) - tensor([[32], - [33], - [34], - [35], - [36], - [37], - [38], - [39], - [40], - [41], - [11], - [12], - [13], - [14], - [15], - [16], - [17], - [... - - If there are more than one environment run in a batch, we could still store - the data in the same buffer as before by calling ``data.reshape(-1)`` which - will flatten the ``[B, T]`` size into ``[B * T]`` but that means that the - trajectories of, say, the first environment of the batch will be interleaved - by trajectories of the other environments, a scenario that ``SliceSampler`` - cannot handle. To solve this, we suggest to use the ``ndim`` argument in the - storage constructor: - - >>> env = TransformedEnv(SerialEnv(2, - ... lambda: GymEnv("CartPole-v1")), StepCounter()) - >>> collector = SyncDataCollector(env, - ... RandomPolicy(env.action_spec), - ... frames_per_batch=1, total_frames=-1) - >>> rb = ReplayBuffer( - ... storage=LazyTensorStorage(100, ndim=2), - ... sampler=SliceSampler(num_slices=8, traj_key=("collector", "traj_ids"), - ... truncated_key=None, strict_length=False), - ... batch_size=64) - >>> for i, data in enumerate(collector): - ... rb.extend(data) - ... if i == 100: - ... break - >>> assert len(rb) == 100, len(rb) - >>> print(rb[:]["next", "step_count"].squeeze()) - tensor([[ 6, 5], - [ 2, 2], - [ 3, 3], - [ 4, 4], - [ 5, 5], - [ 6, 6], - [ 7, 7], - [ 8, 8], - [ 9, 9], - [10, 10], - [11, 11], - [12, 12], - [13, 13], - [14, 14], - [15, 15], - [16, 16], - [17, 17], - [18, 1], - [19, 2], - [... - - -- Trajectories can also be stored independently, with the each element of the - leading dimension pointing to a different trajectory. This requires - for the trajectories to have a congruent shape (or to be padded). - We provide a custom :class:`~torchrl.envs.Transform` class named - :class:`~torchrl.envs.RandomCropTensorDict` that allows to sample - sub-trajectories in the buffer. Note that, unlike the :class:`~torchrl.data.replay_buffers.samplers.SliceSampler`-based - strategy, here having an ``"episode"`` or ``"done"`` key pointing at the - start and stop signals isn't required. - Here is an example of how this class can be used: - - .. code-block::Python - - >>> import torch - >>> from tensordict import TensorDict - >>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer - >>> from torchrl.envs import RandomCropTensorDict - >>> - >>> obs = torch.randn(100, 50, 1) - >>> data = TensorDict({"obs": obs[:-1], "next": {"obs": obs[1:]}}, [99]) - >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000)) - >>> rb.extend(data) - >>> # subsample trajectories of length 10 - >>> rb.append_transform(RandomCropTensorDict(sub_seq_len=10)) - >>> print(rb.sample(128)) - TensorDict( - fields={ - index: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int32, is_shared=False), - next: TensorDict( - fields={ - obs: Tensor(shape=torch.Size([10, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([10]), - device=None, - is_shared=False), - obs: Tensor(shape=torch.Size([10, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([10]), - device=None, - is_shared=False) - -Checkpointing Replay Buffers -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. _checkpoint-rb: - -Each component of the replay buffer can potentially be stateful and, as such, -require a dedicated way of being serialized. -Our replay buffer enjoys two separate APIs for saving their state on disk: -:meth:`~torchrl.data.ReplayBuffer.dumps` and :meth:`~torchrl.data.ReplayBuffer.loads` will save the -data of each component except transforms (storage, writer, sampler) using memory-mapped -tensors and json files for the metadata. - -This will work across all classes except -:class:`~torchrl.data.replay_buffers.storages.ListStorage`, which content -cannot be anticipated (and as such does not comply with memory-mapped data -structures such as those that can be found in the tensordict library). - -This API guarantees that a buffer that is saved and then loaded back will be in -the exact same state, whether we look at the status of its sampler (eg, priority trees) -its writer (eg, max writer heaps) or its storage. - -Under the hood, a naive call to :meth:`~torchrl.data.ReplayBuffer.dumps` will just call the public -`dumps` method in a specific folder for each of its components (except transforms -which we don't assume to be serializable using memory-mapped tensors in general). - -Saving data in :ref:`TED-format ` may however consume much more memory than required. If continuous -trajectories are stored in a buffer, we can avoid saving duplicated observations by saving all the -observations at the root plus only the last element of the `"next"` sub-tensordict's observations, which -can reduce the storage consumption up to two times. To enable this, three checkpointer classes are available: -:class:`~torchrl.data.FlatStorageCheckpointer` will discard duplicated observations to compress the TED format. At -load time, this class will re-write the observations in the correct format. If the buffer is saved on disk, -the operations executed by this checkpointer will not require any additional RAM. -The :class:`~torchrl.data.NestedStorageCheckpointer` will save the trajectories using nested tensors to make the data -representation more apparent (each item along the first dimension representing a distinct trajectory). -Finally, the :class:`~torchrl.data.H5StorageCheckpointer` will save the buffer in an H5DB format, enabling users to -compress the data and save some more space. - -.. warning:: The checkpointers make some restrictive assumption about the replay buffers. First, it is assumed that - the ``done`` state accurately represents the end of a trajectory (except for the last trajectory which was written - for which the writer cursor indicates where to place the truncated signal). For MARL usage, one should note that - only done states that have as many elements as the root tensordict are allowed: - if the done state has extra elements that are not represented in - the batch-size of the storage, these checkpointers will fail. For example, a done state with shape ``torch.Size([3, 4, 5])`` - within a storage of shape ``torch.Size([3, 4])`` is not allowed. - -Here is a concrete example of how an H5DB checkpointer could be used in practice: - - >>> from torchrl.data import ReplayBuffer, H5StorageCheckpointer, LazyMemmapStorage - >>> from torchrl.collectors import SyncDataCollector - >>> from torchrl.envs import GymEnv, SerialEnv - >>> import torch - >>> env = SerialEnv(3, lambda: GymEnv("CartPole-v1", device=None)) - >>> env.set_seed(0) - >>> torch.manual_seed(0) - >>> collector = SyncDataCollector( - >>> env, policy=env.rand_step, total_frames=200, frames_per_batch=22 - >>> ) - >>> rb = ReplayBuffer(storage=LazyMemmapStorage(100, ndim=2)) - >>> rb_test = ReplayBuffer(storage=LazyMemmapStorage(100, ndim=2)) - >>> rb.storage.checkpointer = H5StorageCheckpointer() - >>> rb_test.storage.checkpointer = H5StorageCheckpointer() - >>> for i, data in enumerate(collector): - ... rb.extend(data) - ... assert rb._storage.max_size == 102 - ... rb.dumps(path_to_save_dir) - ... rb_test.loads(path_to_save_dir) - ... assert_allclose_td(rb_test[:], rb[:]) - - -Whenever saving data using :meth:`~torchrl.data.ReplayBuffer.dumps` is not possible, an -alternative way is to use :meth:`~torchrl.data.ReplayBuffer.state_dict`, which returns a data -structure that can be saved using :func:`torch.save` and loaded using :func:`torch.load` -before calling :meth:`~torchrl.data.ReplayBuffer.load_state_dict`. The drawback -of this method is that it will struggle to save big data structures, which is a -common setting when using replay buffers. - -TorchRL Episode Data Format (TED) ---------------------------------- - -.. _TED-format: - -In TorchRL, sequential data is consistently presented in a specific format, known -as the TorchRL Episode Data Format (TED). This format is crucial for the seamless -integration and functioning of various components within TorchRL. - -Some components, such as replay buffers, are somewhat indifferent to the data -format. However, others, particularly environments, heavily depend on it for smooth operation. - -Therefore, it's essential to understand the TED, its purpose, and how to interact -with it. This guide will provide a clear explanation of the TED, why it's used, -and how to effectively work with it. - -The Rationale Behind TED -~~~~~~~~~~~~~~~~~~~~~~~~ - -Formatting sequential data can be a complex task, especially in the realm of -Reinforcement Learning (RL). As practitioners, we often encounter situations -where data is delivered at the reset time (though not always), and sometimes data -is provided or discarded at the final step of the trajectory. - -This variability means that we can observe data of different lengths in a dataset, -and it's not always immediately clear how to match each time step across the -various elements of this dataset. Consider the following ambiguous dataset structure: - - >>> observation.shape - [200, 3] - >>> action.shape - [199, 4] - >>> info.shape - [200, 3] - -At first glance, it seems that the info and observation were delivered -together (one of each at reset + one of each at each step call), as suggested by -the action having one less element. However, if info has one less element, we -must assume that it was either omitted at reset time or not delivered or recorded -for the last step of the trajectory. Without proper documentation of the data -structure, it's impossible to determine which info corresponds to which time step. - -Complicating matters further, some datasets provide inconsistent data formats, -where ``observations`` or ``infos`` are missing at the start or end of the -rollout, and this behavior is often not documented. -The primary aim of TED is to eliminate these ambiguities by providing a clear -and consistent data representation. - -The structure of TED -~~~~~~~~~~~~~~~~~~~~ - -TED is built upon the canonical definition of a Markov Decision Process (MDP) in RL contexts. -At each step, an observation conditions an action that results in (1) a new -observation, (2) an indicator of task completion (terminated, truncated, done), -and (3) a reward signal. - -Some elements may be missing (for example, the reward is optional in imitation -learning contexts), or additional information may be passed through a state or -info container. In some cases, additional information is required to get the -observation during a call to ``step`` (for instance, in stateless environment simulators). Furthermore, -in certain scenarios, an "action" (or any other data) cannot be represented as a -single tensor and needs to be organized differently. For example, in Multi-Agent RL -settings, actions, observations, rewards, and completion signals may be composite. - -TED accommodates all these scenarios with a single, uniform, and unambiguous -format. We distinguish what happens at time step ``t`` and ``t+1`` by setting a -limit at the time the action is executed. In other words, everything that was -present before ``env.step`` was called belongs to ``t``, and everything that -comes after belongs to ``t+1``. - -The general rule is that everything that belongs to time step ``t`` is stored -at the root of the tensordict, while everything that belongs to ``t+1`` is stored -in the ``"next"`` entry of the tensordict. Here's an example: - - >>> data = env.reset() - >>> data = policy(data) - >>> print(env.step(data)) - TensorDict( - fields={ - action: Tensor(...), # The action taken at time t - done: Tensor(...), # The done state when the action was taken (at reset) - next: TensorDict( # all of this content comes from the call to `step` - fields={ - done: Tensor(...), # The done state after the action has been taken - observation: Tensor(...), # The observation resulting from the action - reward: Tensor(...), # The reward resulting from the action - terminated: Tensor(...), # The terminated state after the action has been taken - truncated: Tensor(...), # The truncated state after the action has been taken - batch_size=torch.Size([]), - device=cpu, - is_shared=False), - observation: Tensor(...), # the observation at reset - terminated: Tensor(...), # the terminated at reset - truncated: Tensor(...), # the truncated at reset - batch_size=torch.Size([]), - device=cpu, - is_shared=False) - -During a rollout (either using :class:`~torchrl.envs.EnvBase` or -:class:`~torchrl.collectors.SyncDataCollector`), the content of the ``"next"`` -tensordict is brought to the root through the :func:`~torchrl.envs.utils.step_mdp` -function when the agent resets its step count: ``t <- t+1``. You can read more -about the environment API :ref:`here `. - -In most cases, there is no `True`-valued ``"done"`` state at the root since any -done state will trigger a (partial) reset which will turn the ``"done"`` to ``False``. -However, this is only true as long as resets are automatically performed. In some -cases, partial resets will not trigger a reset, so we retain these data, which -should have a considerably lower memory footprint than observations, for instance. - -This format eliminates any ambiguity regarding the matching of an observation with -its action, info, or done state. - -A note on singleton dimensions in TED -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. _reward_done_singleton: - -In TorchRL, the standard practice is that `done` states (including terminated and truncated) and rewards should have a -dimension that can be expanded to match the shape of observations, states, and actions without recurring to anything -else than repetition (i.e., the reward must have as many dimensions as the observation and/or action, or their -embeddings). - -Essentially, this format is acceptable (though not strictly enforced): - - >>> print(rollout[t]) - ... TensorDict( - ... fields={ - ... action: Tensor(n_action), - ... done: Tensor(1), # The done state has a rightmost singleton dimension - ... next: TensorDict( - ... fields={ - ... done: Tensor(1), - ... observation: Tensor(n_obs), - ... reward: Tensor(1), # The reward has a rightmost singleton dimension - ... terminated: Tensor(1), - ... truncated: Tensor(1), - ... batch_size=torch.Size([]), - ... device=cpu, - ... is_shared=False), - ... observation: Tensor(n_obs), # the observation at reset - ... terminated: Tensor(1), # the terminated at reset - ... truncated: Tensor(1), # the truncated at reset - ... batch_size=torch.Size([]), - ... device=cpu, - ... is_shared=False) - -The rationale behind this is to ensure that the results of operations (such as value estimation) on observations and/or -actions have the same number of dimensions as the reward and `done` state. This consistency allows subsequent operations -to proceed without issues: - - >>> state_value = f(observation) - >>> next_state_value = state_value + reward - -Without this singleton dimension at the end of the reward, broadcasting rules (which only work when tensors can be -expanded from the left) would try to expand the reward on the left. This could lead to failures (at best) or introduce -bugs (at worst). - -Flattening TED to reduce memory consumption -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -TED copies the observations twice in the memory, which can impact the feasibility of using this format -in practice. Since it is being used mostly for ease of representation, one can store the data -in a flat manner but represent it as TED during training. - -This is particularly useful when serializing replay buffers: -For instance, the :class:`~torchrl.data.TED2Flat` class ensures that a TED-formatted data -structure is flattened before being written to disk, whereas the :class:`~torchrl.data.Flat2TED` -load hook will unflatten this structure during deserialization. - - -Dimensionality of the Tensordict -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -During a rollout, all collected tensordicts will be stacked along a new dimension -positioned at the end. Both collectors and environments will label this dimension -with the ``"time"`` name. Here's an example: - - >>> rollout = env.rollout(10, policy) - >>> assert rollout.shape[-1] == 10 - >>> assert rollout.names[-1] == "time" - -This ensures that the time dimension is clearly marked and easily identifiable -in the data structure. - -Special cases and footnotes -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Multi-Agent data presentation -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The multi-agent data formatting documentation can be accessed in the :ref:`MARL environment API ` section. - -Memory-based policies (RNNs and Transformers) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In the examples provided above, only ``env.step(data)`` generates data that -needs to be read in the next step. However, in some cases, the policy also -outputs information that will be required in the next step. This is typically -the case for RNN-based policies, which output an action as well as a recurrent -state that needs to be used in the next step. -To accommodate this, we recommend users to adjust their RNN policy to write this -data under the ``"next"`` entry of the tensordict. This ensures that this content -will be brought to the root in the next step. More information can be found in -:class:`~torchrl.modules.GRUModule` and :class:`~torchrl.modules.LSTMModule`. - -Multi-step -^^^^^^^^^^ - -Collectors allow users to skip steps when reading the data, accumulating reward -for the upcoming n steps. This technique is popular in DQN-like algorithms like Rainbow. -The :class:`~torchrl.data.postprocs.MultiStep` class performs this data transformation -on batches coming out of collectors. In these cases, a check like the following -will fail since the next observation is shifted by n steps: - - >>> assert (data[..., 1:]["observation"] == data[..., :-1]["next", "observation"]).all() - -What about memory requirements? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Implemented naively, this data format consumes approximately twice the memory -that a flat representation would. In some memory-intensive settings -(for example, in the :class:`~torchrl.data.datasets.AtariDQNExperienceReplay` dataset), -we store only the ``T+1`` observation on disk and perform the formatting online at get time. -In other cases, we assume that the 2x memory cost is a small price to pay for a -clearer representation. However, generalizing the lazy representation for offline -datasets would certainly be a beneficial feature to have, and we welcome -contributions in this direction! - -Datasets --------- - -TorchRL provides wrappers around offline RL datasets. -These data are presented as :class:`~torchrl.data.ReplayBuffer` instances, which -means that they can be customized at will with transforms, samplers and storages. -For instance, entries can be filtered in or out of a dataset with :class:`~torchrl.envs.SelectTransform` -or :class:`~torchrl.envs.ExcludeTransform`. - -By default, datasets are stored as memory mapped tensors, allowing them to be -promptly sampled with virtually no memory footprint. - -Here's an example: - -.. code::Python - - >>> from torchrl.data.datasets import D4RLExperienceReplay - >>> from torchrl.data.replay_buffers import SamplerWithoutReplacement - >>> from torchrl.envs.transforms import RenameTransform - >>> dataset = D4RLExperienceReplay('kitchen-complete-v0', split_trajs=True, batch_size=10) - >>> print(dataset.sample()) # will sample 10 trajectories since split_trajs is set to True - TensorDict( - fields={ - action: Tensor(shape=torch.Size([10, 207, 9]), device=cpu, dtype=torch.float32, is_shared=False), - done: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - index: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.int32, is_shared=False), - infos: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.int64, is_shared=False), - mask: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([10, 207, 60]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([10, 207]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([10, 207, 60]), device=cpu, dtype=torch.float32, is_shared=False), - timeouts: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - traj_ids: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([10, 207]), - device=cpu, - is_shared=False) - >>> dataset.append_transform(RenameTransform(["done", ("next", "done")], ["terminal", ("next", "terminal")])) - >>> print(dataset.sample()) # The "done" has been renamed to "terminal" - TensorDict( - fields={ - action: Tensor(shape=torch.Size([10, 207, 9]), device=cpu, dtype=torch.float32, is_shared=False), - terminal: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - index: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.int32, is_shared=False), - infos: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.int64, is_shared=False), - mask: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - terminal: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([10, 207, 60]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([10, 207]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([10, 207, 60]), device=cpu, dtype=torch.float32, is_shared=False), - timeouts: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False), - traj_ids: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([10, 207]), - device=cpu, - is_shared=False) - >>> # we can also use a `SamplerWithoutReplacement` to iterate over the dataset with random samples: - >>> dataset = D4RLExperienceReplay( - ... 'kitchen-complete-v0', - ... sampler=SamplerWithoutReplacement(drop_last=True), - ... split_trajs=True, - ... batch_size=3) - >>> for data in dataset: - ... print(data) - ... - -.. note:: - - Installing dependencies is the responsibility of the user. For D4RL, a clone of - `the repository `_ is needed as - the latest wheels are not published on PyPI. For OpenML, `scikit-learn `_ and - `pandas `_ are required. - -Transforming datasets -~~~~~~~~~~~~~~~~~~~~~ - -In many instances, the raw data isn't going to be used as-is. -The natural solution could be to pass a :class:`~torchrl.envs.transforms.Transform` -instance to the dataset constructor and modify the sample on-the-fly. This will -work but it will incur an extra runtime for the transform. -If the transformations can be (at least a part) pre-applied to the dataset, -a conisderable disk space and some incurred overhead at sampling time can be -saved. To do this, the -:meth:`~torchrl.data.datasets.BaseDatasetExperienceReplay.preprocess` can be -used. This method will run a per-sample preprocessing pipeline on each element -of the dataset, and replace the existing dataset by its transformed version. - -Once transformed, re-creating the same dataset will produce another object with -the same transformed storage (unless ``download="force"`` is being used): - - >>> dataset = RobosetExperienceReplay( - ... "FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4", batch_size=32, download="force" - ... ) - >>> - >>> def func(data): - ... return data.set("obs_norm", data.get("observation").norm(dim=-1)) - ... - >>> dataset.preprocess( - ... func, - ... num_workers=max(1, os.cpu_count() - 2), - ... num_chunks=1000, - ... mp_start_method="fork", - ... ) - >>> sample = dataset.sample() - >>> assert "obs_norm" in sample.keys() - >>> # re-recreating the dataset gives us the transformed version back. - >>> dataset = RobosetExperienceReplay( - ... "FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4", batch_size=32 - ... ) - >>> sample = dataset.sample() - >>> assert "obs_norm" in sample.keys() - - -.. currentmodule:: torchrl.data.datasets - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - BaseDatasetExperienceReplay - AtariDQNExperienceReplay - D4RLExperienceReplay - GenDGRLExperienceReplay - MinariExperienceReplay - OpenMLExperienceReplay - OpenXExperienceReplay - RobosetExperienceReplay - VD4RLExperienceReplay - -Composing datasets -~~~~~~~~~~~~~~~~~~ - -In offline RL, it is customary to work with more than one dataset at the same time. -Moreover, TorchRL usually has a fine-grained dataset nomenclature, where -each task is represented separately when other libraries will represent these -datasets in a more compact way. To allow users to compose multiple datasets -together, we propose a :class:`~torchrl.data.replay_buffers.ReplayBufferEnsemble` -primitive that allows users to sample from multiple datasets at once. - -If the individual dataset formats differ, :class:`~torchrl.envs.Transform` instances -can be used. In the following example, we create two dummy datasets with semantically -identical entries that differ in names (``("some", "key")`` and ``"another_key"``) -and show how they can be renamed to have a matching name. We also resize images -such that they can be stacked together during sampling. - - >>> from torchrl.envs import Comopse, ToTensorImage, Resize, RenameTransform - >>> from torchrl.data import TensorDictReplayBuffer, ReplayBufferEnsemble, LazyMemmapStorage - >>> from tensordict import TensorDict - >>> import torch - >>> rb0 = TensorDictReplayBuffer( - ... storage=LazyMemmapStorage(10), - ... transform=Compose( - ... ToTensorImage(in_keys=["pixels", ("next", "pixels")]), - ... Resize(32, in_keys=["pixels", ("next", "pixels")]), - ... RenameTransform([("some", "key")], ["renamed"]), - ... ), - ... ) - >>> rb1 = TensorDictReplayBuffer( - ... storage=LazyMemmapStorage(10), - ... transform=Compose( - ... ToTensorImage(in_keys=["pixels", ("next", "pixels")]), - ... Resize(32, in_keys=["pixels", ("next", "pixels")]), - ... RenameTransform(["another_key"], ["renamed"]), - ... ), - ... ) - >>> rb = ReplayBufferEnsemble( - ... rb0, - ... rb1, - ... p=[0.5, 0.5], - ... transform=Resize(33, in_keys=["pixels"], out_keys=["pixels33"]), - ... ) - >>> data0 = TensorDict( - ... { - ... "pixels": torch.randint(255, (10, 244, 244, 3)), - ... ("next", "pixels"): torch.randint(255, (10, 244, 244, 3)), - ... ("some", "key"): torch.randn(10), - ... }, - ... batch_size=[10], - ... ) - >>> data1 = TensorDict( - ... { - ... "pixels": torch.randint(255, (10, 64, 64, 3)), - ... ("next", "pixels"): torch.randint(255, (10, 64, 64, 3)), - ... "another_key": torch.randn(10), - ... }, - ... batch_size=[10], - ... ) - >>> rb[0].extend(data0) - >>> rb[1].extend(data1) - >>> for _ in range(2): - ... sample = rb.sample(10) - ... assert sample["next", "pixels"].shape == torch.Size([2, 5, 3, 32, 32]) - ... assert sample["pixels"].shape == torch.Size([2, 5, 3, 32, 32]) - ... assert sample["pixels33"].shape == torch.Size([2, 5, 3, 33, 33]) - ... assert sample["renamed"].shape == torch.Size([2, 5]) - -.. currentmodule:: torchrl.data.replay_buffers - - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - ReplayBufferEnsemble - SamplerEnsemble - StorageEnsemble - WriterEnsemble - -TensorSpec ----------- - -.. _ref_specs: - -The :class:`~torchrl.data.TensorSpec` parent class and subclasses define the basic properties of state, observations -actions, rewards and done status in TorchRL, such as their shape, device, dtype and domain. - -It is important that your environment specs match the input and output that it sends and receives, as -:class:`~torchrl.envs.ParallelEnv` will create buffers from these specs to communicate with the spawn processes. -Check the :func:`torchrl.envs.utils.check_env_specs` method for a sanity check. - -If needed, specs can be automatically generated from data using the :func:`~torchrl.envs.utils.make_composite_from_td` -function. - -Specs fall in two main categories, numerical and categorical. - -.. table:: Numerical TensorSpec subclasses. - - +-------------------------------------------------------------------------------+ - | Numerical | - +=====================================+=========================================+ - | Bounded | Unbounded | - +-----------------+-------------------+-------------------+---------------------+ - | BoundedDiscrete | BoundedContinuous | UnboundedDiscrete | UnboundedContinuous | - +-----------------+-------------------+-------------------+---------------------+ - -Whenever a :class:`~torchrl.data.Bounded` instance is created, its domain (defined either implicitly by its dtype or -explicitly by the `"domain"` keyword argument) will determine if the instantiated class will be of :class:`~torchrl.data.BoundedContinuous` -or :class:`~torchrl.data.BoundedDiscrete` type. The same applies to the :class:`~torchrl.data.Unbounded` class. -See these classes for further information. - -.. table:: Categorical TensorSpec subclasses. - - +------------------------------------------------------------------+ - | Categorical | - +========+=============+=============+==================+==========+ - | OneHot | MultiOneHot | Categorical | MultiCategorical | Binary | - +--------+-------------+-------------+------------------+----------+ - -Unlike ``gymnasium``, TorchRL does not have the concept of an arbitrary list of specs. If multiple specs have to be -combined together, TorchRL assumes that the data will be presented as dictionaries (more specifically, as -:class:`~tensordict.TensorDict` or related formats). The corresponding :class:`~torchrl.data.TensorSpec` class in these -cases is the :class:`~torchrl.data.Composite` spec. - -Nevertheless, specs can be stacked together using :func:`~torch.stack`: if they are identical, their shape will be -expanded accordingly. -Otherwise, a lazy stack will be created through the :class:`~torchrl.data.Stacked` class. - -Similarly, ``TensorSpecs`` possess some common behavior with :class:`~torch.Tensor` and -:class:`~tensordict.TensorDict`: they can be reshaped, indexed, squeezed, unsqueezed, moved to another device (``to``) -or unbound (``unbind``) as regular :class:`~torch.Tensor` instances would be. - -Specs where some dimensions are ``-1`` are said to be "dynamic" and the negative dimensions indicate that the corresponding -data has an inconsistent shape. When seen by an optimizer or an environment (e.g., batched environment such as -:class:`~torchrl.envs.ParallelEnv`), these negative shapes tell TorchRL to avoid using buffers as the tensor shapes are -not predictable. - -.. currentmodule:: torchrl.data - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - TensorSpec - Binary - Bounded - Categorical - Composite - MultiCategorical - MultiOneHot - NonTensor - OneHot - Stacked - StackedComposite - Unbounded - UnboundedContinuous - UnboundedDiscrete - -The following classes are deprecated and just point to the classes above: - -.. currentmodule:: torchrl.data - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - BinaryDiscreteTensorSpec - BoundedTensorSpec - CompositeSpec - DiscreteTensorSpec - LazyStackedCompositeSpec - LazyStackedTensorSpec - MultiDiscreteTensorSpec - MultiOneHotDiscreteTensorSpec - NonTensorSpec - OneHotDiscreteTensorSpec - UnboundedContinuousTensorSpec - UnboundedDiscreteTensorSpec - -Trees and Forests ------------------ - -TorchRL offers a set of classes and functions that can be used to represent trees and forests efficiently, -which is particularly useful for Monte Carlo Tree Search (MCTS) algorithms. - -TensorDictMap -~~~~~~~~~~~~~ - -At its core, the MCTS API relies on the :class:`~torchrl.data.TensorDictMap` which acts like a storage where indices can -be any numerical object. In traditional storages (e.g., :class:`~torchrl.data.TensorStorage`), only integer indices -are allowed: - - >>> storage = TensorStorage(...) - >>> data = storage[3] - -:class:`~torchrl.data.TensorDictMap` allows us to make more advanced queries in the storage. The typical example is -when we have a storage containing a set of MDPs and we want to rebuild a trajectory given its initial observation, action -pair. In tensor terms, this could be written with the following pseudocode: - - >>> next_state = storage[observation, action] - -(if there is more than one next state associated with this pair one could return a stack of ``next_states`` instead). -This API would make sense but it would be restrictive: allowing observations or actions that are composed of -multiple tensors may be hard to implement. Instead, we provide a tensordict containing these values and let the storage -know what ``in_keys`` to look at to query the next state: - - >>> td = TensorDict(observation=observation, action=action) - >>> next_td = storage[td] - -Of course, this class also allows us to extend the storage with new data: - - >>> storage[td] = next_state - -This comes in handy because it allows us to represent complex rollout structures where different actions are undertaken -at a given node (ie, for a given observation). All `(observation, action)` pairs that have been observed may lead us to -a (set of) rollout that we can use further. - -MCTSForest -~~~~~~~~~~ - -Building a tree from an initial observation then becomes just a matter of organizing data efficiently. -The :class:`~torchrl.data.MCTSForest` has at its core two storages: a first storage links observations to hashes and -indices of actions encountered in the past in the dataset: - - >>> data = TensorDict(observation=observation) - >>> metadata = forest.node_map[data] - >>> index = metadata["_index"] - -where ``forest`` is a :class:`~torchrl.data.MCTSForest` instance. -Then, a second storage keeps track of the actions and results associated with the observation: - - >>> next_data = forest.data_map[index] - -The ``next_data`` entry can have any shape, but it will usually match the shape of ``index`` (since at each index -corresponds one action). Once ``next_data`` is obtained, it can be put together with ``data`` to form a set of nodes, -and the tree can be expanded for each of these. The following figure shows how this is done. - -.. figure:: /_static/img/collector-copy.png - - Building a :class:`~torchrl.data.Tree` from a :class:`~torchrl.data.MCTSForest` object. - The flowchart represents a tree being built from an initial observation `o`. The :class:`~torchrl.data.MCTSForest.get_tree` - method passed the input data structure (the root node) to the ``node_map`` :class:`~torchrl.data.TensorDictMap` instance - that returns a set of hashes and indices. These indices are then used to query the corresponding tuples of - actions, next observations, rewards etc. that are associated with the root node. - A vertex is created from each of them (possibly with a longer rollout when a compact representation is asked). - The stack of vertices is then used to build up the tree further, and these vertices are stacked together and make - up the branches of the tree at the root. This process is repeated for a given depth or until the tree cannot be - expanded anymore. - -.. currentmodule:: torchrl.data - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - BinaryToDecimal - HashToInt - MCTSForest - QueryModule - RandomProjectionHash - SipHash - TensorDictMap - TensorMap - Tree - - -Large language models and Reinforcement Learning From Human Feedback (RLHF) ---------------------------------------------------------------------------- - -.. warning:: - These APIs are deprecated and will be removed in the future. - Use the :mod:`torchrl.data.llm` module instead. - See the full :ref:`LLM documentation ` for more information. - -Data is of utmost importance in LLM post-training (e.g., GRPO or Reinforcement Learning from Human Feedback (RLHF)). -Given that these techniques are commonly employed in the realm of language, -which is scarcely addressed in other subdomains of RL within the library, -we offer specific utilities to facilitate interaction with external libraries -like datasets. These utilities consist of tools for tokenizing data, formatting -it in a manner suitable for TorchRL modules, and optimizing storage for -efficient sampling. - -.. currentmodule:: torchrl.data - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - PairwiseDataset - PromptData - PromptTensorDictTokenizer - RewardData - RolloutFromModel - TensorDictTokenizer - TokenizedDatasetLoader - create_infinite_iterator - get_dataloader - ConstantKLController - AdaptiveKLController - - -Utils ------ - -.. currentmodule:: torchrl.data - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - DensifyReward - Flat2TED - H5Combine - H5Split - MultiStep - Nested2TED - TED2Flat - TED2Nested - check_no_exclusive_keys - consolidate_spec - contains_lazy_spec - -.. currentmodule:: torchrl.envs.transforms.rb_transforms - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - MultiStepTransform +TorchRL provides a comprehensive data management system built around replay buffers, which are central to +off-policy RL algorithms. The library offers efficient implementations of various replay buffers with +composable components for storage, sampling, and data transformation. + +Key Features +------------ + +- **Flexible storage backends**: Memory, memmap, and compressed storage options +- **Advanced sampling strategies**: Prioritized, slice-based, and custom samplers +- **Composable design**: Mix and match storage, samplers, and writers +- **Type flexibility**: Support for tensors, tensordicts, and arbitrary data types +- **Efficient transforms**: Apply preprocessing during sampling +- **Distributed support**: Ray-based and remote replay buffers + +Quick Example +------------- + +.. code-block:: python + + from torchrl.data import ReplayBuffer, LazyMemmapStorage, PrioritizedSampler + from tensordict import TensorDict + + # Create a replay buffer with memmap storage and prioritized sampling + buffer = ReplayBuffer( + storage=LazyMemmapStorage(max_size=1000000), + sampler=PrioritizedSampler(max_capacity=1000000, alpha=0.7, beta=0.5), + batch_size=256, + ) + + # Add data + data = TensorDict({ + "observation": torch.randn(32, 4), + "action": torch.randn(32, 2), + "reward": torch.randn(32, 1), + }, batch_size=[32]) + buffer.extend(data) + + # Sample + sample = buffer.sample() # Returns batch_size=256 + +Documentation Sections +---------------------- + +.. toctree:: + :maxdepth: 2 + + data_replaybuffers + data_storage + data_samplers + data_datasets + data_specs diff --git a/docs/source/reference/data_datasets.rst b/docs/source/reference/data_datasets.rst new file mode 100644 index 00000000000..e7fc8038a04 --- /dev/null +++ b/docs/source/reference/data_datasets.rst @@ -0,0 +1,19 @@ +.. currentmodule:: torchrl.data + +Datasets +======== + +TorchRL provides dataset utilities for offline RL and data management. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + datasets.AtariDQNExperienceReplay + datasets.D4RLExperienceReplay + datasets.GenDGRLExperienceReplay + datasets.MinariExperienceReplay + datasets.OpenMLExperienceReplay + datasets.OpenXExperienceReplay + datasets.RobosetExperienceReplay + datasets.VD4RLExperienceReplay diff --git a/docs/source/reference/data_replaybuffers.rst b/docs/source/reference/data_replaybuffers.rst new file mode 100644 index 00000000000..375a93a7798 --- /dev/null +++ b/docs/source/reference/data_replaybuffers.rst @@ -0,0 +1,52 @@ +.. currentmodule:: torchrl.data + +Replay Buffers +============== + +Replay buffers are a central part of off-policy RL algorithms. TorchRL provides an efficient implementation of a few, +widely used replay buffers: + +Core Replay Buffer Classes +-------------------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + ReplayBuffer + ReplayBufferEnsemble + PrioritizedReplayBuffer + TensorDictReplayBuffer + TensorDictPrioritizedReplayBuffer + RayReplayBuffer + RemoteTensorDictReplayBuffer + +Composable Replay Buffers +------------------------- + +.. _ref_buffers: + +We also give users the ability to compose a replay buffer. +We provide a wide panel of solutions for replay buffer usage, including support for +almost any data type; storage in memory, on device or on physical memory; +several sampling strategies; usage of transforms etc. + +Supported data types and choosing a storage +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In theory, replay buffers support any data type but we can't guarantee that each +component will support any data type. The most crude replay buffer implementation +is made of a :class:`~torchrl.data.replay_buffers.ReplayBuffer` base with a +:class:`~torchrl.data.replay_buffers.ListStorage` storage. This is very inefficient +but it will allow you to store complex data structures with non-tensor data. +Storages in contiguous memory include :class:`~torchrl.data.replay_buffers.TensorStorage`, +:class:`~torchrl.data.replay_buffers.LazyTensorStorage` and +:class:`~torchrl.data.replay_buffers.LazyMemmapStorage`. + +Sampling and indexing +~~~~~~~~~~~~~~~~~~~~~ + +Replay buffers can be indexed and sampled. +Indexing and sampling collect data at given indices in the storage and then process them +through a series of transforms and ``collate_fn`` that can be passed to the `__init__` +function of the replay buffer. diff --git a/docs/source/reference/data_samplers.rst b/docs/source/reference/data_samplers.rst new file mode 100644 index 00000000000..50bcd5f19f5 --- /dev/null +++ b/docs/source/reference/data_samplers.rst @@ -0,0 +1,34 @@ +.. currentmodule:: torchrl.data.replay_buffers + +Sampling Strategies +=================== + +Samplers control how data is retrieved from the replay buffer storage. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + PrioritizedSampler + PrioritizedSliceSampler + RandomSampler + Sampler + SamplerEnsemble + SamplerWithoutReplacement + SliceSampler + SliceSamplerWithoutReplacement + +Writers +------- + +Writers control how data is written to the storage. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + RoundRobinWriter + TensorDictMaxValueWriter + TensorDictRoundRobinWriter + Writer + WriterEnsemble diff --git a/docs/source/reference/data_specs.rst b/docs/source/reference/data_specs.rst new file mode 100644 index 00000000000..f1ca8a38c1f --- /dev/null +++ b/docs/source/reference/data_specs.rst @@ -0,0 +1,26 @@ +.. currentmodule:: torchrl.data + +TensorSpec System +================= + +TensorSpec classes define the shape, dtype, and domain of tensors in TorchRL. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + Binary + Bounded + Categorical + Composite + DiscreteTensorSpec + LazyStackedCompositeSpec + MultiCategorical + MultiDiscreteTensorSpec + MultiOneHot + NonTensor + OneHot + TensorSpec + Unbounded + UnboundedContinuous + UnboundedDiscrete diff --git a/docs/source/reference/data_storage.rst b/docs/source/reference/data_storage.rst new file mode 100644 index 00000000000..9225c9351e5 --- /dev/null +++ b/docs/source/reference/data_storage.rst @@ -0,0 +1,38 @@ +.. currentmodule:: torchrl.data.replay_buffers + +Storage Backends +================ + +TorchRL provides various storage backends for replay buffers, each optimized for different use cases. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + CompressedListStorage + CompressedListStorageCheckpointer + FlatStorageCheckpointer + H5StorageCheckpointer + ImmutableDatasetWriter + LazyMemmapStorage + LazyTensorStorage + ListStorage + LazyStackStorage + ListStorageCheckpointer + NestedStorageCheckpointer + Storage + StorageCheckpointerBase + StorageEnsemble + StorageEnsembleCheckpointer + TensorStorage + TensorStorageCheckpointer + +Storage Performance +------------------- + +Storage choice is very influential on replay buffer sampling latency, especially +in distributed reinforcement learning settings with larger data volumes. +:class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage` is highly +advised in distributed settings with shared storage due to the lower serialization +cost of MemoryMappedTensors as well as the ability to specify file storage locations +for improved node failure recovery. diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 3e807cc2f93..a5cd92e1a9a 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -3,1442 +3,60 @@ torchrl.envs package ==================== -.. _Environment-API: +.. _ref_envs: -TorchRL offers an API to handle environments of different backends, such as gym, -dm-control, dm-lab, model-based environments as well as custom environments. -The goal is to be able to swap environments in an experiment with little or no effort, -even if these environments are simulated using different libraries. -TorchRL offers some out-of-the-box environment wrappers under :obj:`torchrl.envs.libs`, -which we hope can be easily imitated for other libraries. -The parent class :class:`~torchrl.envs.EnvBase` is a :class:`torch.nn.Module` subclass that implements -some typical environment methods using :class:`tensordict.TensorDict` as a data organiser. This allows this -class to be generic and to handle an arbitrary number of input and outputs, as well as -nested or batched data structures. +TorchRL offers a comprehensive API to handle environments of different backends, making it easy to swap +environments in an experiment with minimal effort. The library provides wrappers for popular RL frameworks +including Gym, DMControl, Brax, Jumanji, and many others. -Each env will have the following attributes: +The :class:`~torchrl.envs.EnvBase` class serves as the foundation, providing a unified interface that uses +:class:`tensordict.TensorDict` for data organization. This design allows the framework to be generic and +handle an arbitrary number of inputs and outputs, as well as nested or batched data structures. -- :obj:`env.batch_size`: a :obj:`torch.Size` representing the number of envs - batched together. -- :obj:`env.device`: the device where the input and output tensordict are expected to live. - The environment device does not mean that the actual step operations will be computed on device - (this is the responsibility of the backend, with which TorchRL can do little). The device of - an environment just represents the device where the data is to be expected when input to the - environment or retrieved from it. TorchRL takes care of mapping the data to the desired device. - This is especially useful for transforms (see below). For parametric environments (e.g. - model-based environments), the device does represent the hardware that will be used to - compute the operations. -- :obj:`env.observation_spec`: a :class:`~torchrl.data.Composite` object - containing all the observation key-spec pairs. -- :obj:`env.state_spec`: a :class:`~torchrl.data.Composite` object - containing all the input key-spec pairs (except action). For most stateful - environments, this container will be empty. -- :obj:`env.action_spec`: a :class:`~torchrl.data.TensorSpec` object - representing the action spec. -- :obj:`env.reward_spec`: a :class:`~torchrl.data.TensorSpec` object representing - the reward spec. -- :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing - the done-flag spec. See the section on trajectory termination below. -- :obj:`env.input_spec`: a :class:`~torchrl.data.Composite` object containing - all the input keys (:obj:`"full_action_spec"` and :obj:`"full_state_spec"`). -- :obj:`env.output_spec`: a :class:`~torchrl.data.Composite` object containing - all the output keys (:obj:`"full_observation_spec"`, :obj:`"full_reward_spec"` and :obj:`"full_done_spec"`). +Key Features +------------ -If the environment carries non-tensor data, a :class:`~torchrl.data.NonTensor` -instance can be used. +- **Unified API**: Consistent interface across different environment backends +- **Vectorization**: Built-in support for parallel and batched environments +- **Transforms**: Powerful transform system for preprocessing observations and actions +- **Multi-agent**: Native support for multi-agent RL with no additional infrastructure +- **Flexible backends**: Easy integration with Gym, DMControl, Brax, and custom environments -Env specs: locks and batch size -------------------------------- - -.. _Environment-lock: - -Environment specs are locked by default (through a ``spec_locked`` arg passed to the env constructor). -Locking specs means that any modification of the spec (or its children if it is a :class:`~torchrl.data.Composite` -instance) will require to unlock it. This can be done via the :meth:`~torchrl.envs.EnvBase.set_spec_lock_`. -The reason specs are locked by default is that it makes it easy to cache values such as action or reset keys and the -likes. -Unlocking an env should only be done if it expected that the specs will be modified often (which, in principle, should -be avoided). -Modifications of the specs such as `env.observation_spec = new_spec` are allowed: under the hood, TorchRL will erase -the cache, unlock the specs, make the modification and relock the specs if the env was previously locked. - -Importantly, the environment spec shapes should contain the batch size, e.g. -an environment with :obj:`env.batch_size == torch.Size([4])` should have -an :obj:`env.action_spec` with shape :obj:`torch.Size([4, action_size])`. -This is helpful when preallocation tensors, checking shape consistency etc. - -Env methods ------------ - -With these, the following methods are implemented: - -- :meth:`env.reset`: a reset method that may (but not necessarily requires to) take - a :class:`tensordict.TensorDict` input. It return the first tensordict of a rollout, usually - containing a :obj:`"done"` state and a set of observations. If not present, - a `"reward"` key will be instantiated with 0s and the appropriate shape. -- :meth:`env.step`: a step method that takes a :class:`tensordict.TensorDict` input - containing an input action as well as other inputs (for model-based or stateless - environments, for instance). -- :meth:`env.step_and_maybe_reset`: executes a step, and (partially) resets the - environments if it needs to. It returns the updated input with a ``"next"`` - key containing the data of the next step, as well as a tensordict containing - the input data for the next step (ie, reset or result or - :func:`~torchrl.envs.utils.step_mdp`) - This is done by reading the ``done_keys`` and - assigning a ``"_reset"`` signal to each done state. This method allows - to code non-stopping rollout functions with little effort: - - >>> data_ = env.reset() - >>> result = [] - >>> for i in range(N): - ... data, data_ = env.step_and_maybe_reset(data_) - ... result.append(data) - ... - >>> result = torch.stack(result) - -- :meth:`env.set_seed`: a seeding method that will return the next seed - to be used in a multi-env setting. This next seed is deterministically computed - from the preceding one, such that one can seed multiple environments with a different - seed without risking to overlap seeds in consecutive experiments, while still - having reproducible results. -- :meth:`env.rollout`: executes a rollout in the environment for - a maximum number of steps (``max_steps=N``) and using a policy (``policy=model``). - The policy should be coded using a :class:`tensordict.nn.TensorDictModule` - (or any other :class:`tensordict.TensorDict`-compatible module). - The resulting :class:`tensordict.TensorDict` instance will be marked with - a trailing ``"time"`` named dimension that can be used by other modules - to treat this batched dimension as it should. - -The following figure summarizes how a rollout is executed in torchrl. - -.. figure:: /_static/img/rollout.gif - - TorchRL rollouts using TensorDict. - -In brief, a TensorDict is created by the :meth:`~.EnvBase.reset` method, -then populated with an action by the policy before being passed to the -:meth:`~.EnvBase.step` method which writes the observations, done flag(s) and -reward under the ``"next"`` entry. The result of this call is stored for -delivery and the ``"next"`` entry is gathered by the :func:`~.utils.step_mdp` -function. - -.. note:: - In general, all TorchRL environment have a ``"done"`` and ``"terminated"`` - entry in their output tensordict. If they are not present by design, - the :class:`~.EnvBase` metaclass will ensure that every done or terminated - is flanked with its dual. - In TorchRL, ``"done"`` strictly refers to the union of all the end-of-trajectory - signals and should be interpreted as "the last step of a trajectory" or - equivalently "a signal indicating the need to reset". - If the environment provides it (eg, Gymnasium), the truncation entry is also - written in the :meth:`EnvBase.step` output under a ``"truncated"`` entry. - If the environment carries a single value, it will interpreted as a ``"terminated"`` - signal by default. - By default, TorchRL's collectors and rollout methods will be looking for the ``"done"`` - entry to assess if the environment should be reset. - -.. note:: - - The `torchrl.collectors.utils.split_trajectories` function can be used to - slice adjacent trajectories. It relies on a ``"traj_ids"`` entry in the - input tensordict, or to the junction of ``"done"`` and ``"truncated"`` key - if the ``"traj_ids"`` is missing. - - -.. note:: - - In some contexts, it can be useful to mark the first step of a trajectory. - TorchRL provides such functionality through the :class:`~torchrl.envs.InitTracker` - transform. - - -Our environment :ref:`tutorial ` -provides more information on how to design a custom environment from scratch. - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - EnvBase - GymLikeEnv - EnvMetaData - -Partial steps and partial resets --------------------------------- - -TorchRL allows environments to reset some but not all the environments, or run a step in one but not all environments. -If there is only one environment in the batch, then a partial reset / step is also allowed with the behavior detailed -below. - -Batching environments and locking the batch -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. _ref_batch_locked: - -Before detailing what partial resets and partial steps do, we must distinguish cases where an environment has -a batch size of its own (mostly stateful environments) or when the environment is just a mere module that, given an -input of arbitrary size, batches the operations over all elements (mostly stateless environments). - -This is controlled via the :attr:`~torchrl.envs.batch_locked` attribute: a batch-locked environment requires all input -tensordicts to have the same batch-size as the env's. Typical examples of these environments are -:class:`~torchrl.envs.GymEnv` and related. Batch-unlocked envs are by contrast allowed to work with any input size. -Notable examples are :class:`~torchrl.envs.BraxEnv` or :class:`~torchrl.envs.JumanjiEnv`. - -Executing partial steps in a batch-unlocked environment is straightforward: one just needs to mask the part of the -tensordict that does not need to be executed, pass the other part to `step` and merge the results with the previous -input. - -Batched environments (:class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv`) can also deal with -partial steps easily, they just pass the actions to the sub-environments that are required to be executed. - -In all other cases, TorchRL assumes that the environment handles the partial steps correctly. - -.. warning:: This means that custom environments may silently run the non-required steps as there is no way for torchrl - to control what happens within the `_step` method! - -Partial Steps -~~~~~~~~~~~~~ - -.. _ref_partial_steps: - -Partial steps are controlled via the temporary key `"_step"` which points to a boolean mask of the -size of the tensordict that holds it. The classes armed to deal with this are: - -- Batched environments: :class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv` will dispatch the - action to and only to the environments where `"_step"` is `True`; -- Batch-unlocked environments; -- Unbatched environments (i.e., environments without batch size). In these environments, the :meth:`~torchrl.envs.EnvBase.step` - method will first look for a `"_step"` entry and, if present, act accordingly. - If a :class:`~torchrl.envs.Transform` instance passes a `"_step"` entry to the tensordict, it is also captured by - :class:`~torchrl.envs.TransformedEnv`'s own `_step` method which will skip the `base_env.step` as well as any further - transformation. - -When dealing with partial steps, the strategy is always to use the step output and mask missing values with the previous -content of the input tensordict, if present, or a `0`-valued tensor if the tensor cannot be found. This means that -if the input tensordict does not contain all the previous observations, then the output tensordict will be 0-valued for -all the non-stepped elements. Within batched environments, data collectors and rollouts utils, this is an issue that -is not observed because these classes handle the passing of data properly. - -Partial steps are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_all_done` is `True`, -as the environments with a `True` done state will need to be skipped during calls to `_step`. - -The :class:`~torchrl.envs.ConditionalSkip` transform allows you to programmatically ask for (partial) step skips. - -Partial Resets -~~~~~~~~~~~~~~ - -.. _ref_partial_resets: - -Partial resets work pretty much like partial steps, but with the `"_reset"` entry. - -The same restrictions of partial steps apply to partial resets. - -Likewise, partial resets are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_any_done` is `True`, -as the environments with a `True` done state will need to be reset, but not others. - -See te following paragraph for a deep dive in partial resets within batched and vectorized environments. - -Vectorized envs ---------------- - -Vectorized (or better: parallel) environments is a common feature in Reinforcement Learning -where executing the environment step can be cpu-intensive. -Some libraries such as `gym3 `_ or `EnvPool `_ -offer interfaces to execute batches of environments simultaneously. -While they often offer a very competitive computational advantage, they do not -necessarily scale to the wide variety of environment libraries supported by TorchRL. -Therefore, TorchRL offers its own, generic :class:`ParallelEnv` class to run multiple -environments in parallel. -As this class inherits from :class:`SerialEnv`, it enjoys the exact same API as other environment. -Of course, a :class:`ParallelEnv` will have a batch size that corresponds to its environment count: - -.. note:: - Given the library's many optional dependencies (eg, Gym, Gymnasium, and many others) - warnings can quickly become quite annoying in multiprocessed / distributed settings. - By default, TorchRL filters out these warnings in sub-processes. If one still wishes to - see these warnings, they can be displayed by setting ``torchrl.filter_warnings_subprocess=False``. - -It is important that your environment specs match the input and output that it sends and receives, as -:class:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes. -Check the :func:`~torchrl.envs.utils.check_env_specs` method for a sanity check. - -.. code-block:: - :caption: Parallel environment - - >>> def make_env(): - ... return GymEnv("Pendulum-v1", from_pixels=True, g=9.81, device="cuda:0") - >>> check_env_specs(env) # this must pass for ParallelEnv to work - >>> env = ParallelEnv(4, make_env) - >>> print(env.batch_size) - torch.Size([4]) - -:class:`ParallelEnv` allows to retrieve the attributes from its contained environments: -one can simply call: - -.. code-block:: - :caption: Parallel environment attributes - - >>> a, b, c, d = env.g # gets the g-force of the various envs, which we set to 9.81 before - >>> print(a) - 9.81 - -TorchRL uses a private ``"_reset"`` key to indicate to the environment which -component (sub-environments or agents) should be reset. -This allows to reset some but not all of the components. - -The ``"_reset"`` key has two distinct functionalities: - -1. During a call to :meth:`~.EnvBase._reset`, the ``"_reset"`` key may or may - not be present in the input tensordict. TorchRL's convention is that the - absence of the ``"_reset"`` key at a given ``"done"`` level indicates - a total reset of that level (unless a ``"_reset"`` key was found at a level - above, see details below). - If it is present, it is expected that those entries and only those components - where the ``"_reset"`` entry is ``True`` (along key and shape dimension) will be reset. - - The way an environment deals with the ``"_reset"`` keys in its :meth:`~.EnvBase._reset` - method is proper to its class. - Designing an environment that behaves according to ``"_reset"`` inputs is the - developer's responsibility, as TorchRL has no control over the inner logic - of :meth:`~.EnvBase._reset`. Nevertheless, the following point should be - kept in mind when designing that method. - -2. After a call to :meth:`~.EnvBase._reset`, the output will be masked with the - ``"_reset"`` entries and the output of the previous :meth:`~.EnvBase.step` - will be written wherever the ``"_reset"`` was ``False``. In practice, this - means that if a ``"_reset"`` modifies data that isn't exposed by it, this - modification will be lost. After this masking operation, the ``"_reset"`` - entries will be erased from the :meth:`~.EnvBase.reset` outputs. - -It must be pointed out that ``"_reset"`` is a private key, and it should only be -used when coding specific environment features that are internal facing. -In other words, this should NOT be used outside of the library, and developers -will keep the right to modify the logic of partial resets through ``"_reset"`` -setting without preliminary warranty, as long as they don't affect TorchRL -internal tests. - -Finally, the following assumptions are made and should be kept in mind when -designing reset functionalities: - -- Each ``"_reset"`` is paired with a ``"done"`` entry (+ ``"terminated"`` and, - possibly, ``"truncated"``). This means that the following structure is not - allowed: ``TensorDict({"done": done, "nested": {"_reset": reset}}, [])``, as - the ``"_reset"`` lives at a different nesting level than the ``"done"``. -- A reset at one level does not preclude the presence of a ``"_reset"`` at lower - levels, but it annihilates its effects. The reason is simply that - whether the ``"_reset"`` at the root level corresponds to an ``all()``, ``any()`` - or custom call to the nested ``"done"`` entries cannot be known in advance, - and it is explicitly assumed that the ``"_reset"`` at the root was placed - there to supersede the nested values (for an example, have a look at - :class:`~.PettingZooWrapper` implementation where each group has one or more - ``"done"`` entries associated which is aggregated at the root level with a - ``any`` or ``all`` logic depending on the task). -- When calling :meth:`env.reset(tensordict)` with a partial ``"_reset"`` entry - that will reset some but not all the done sub-environments, the input data - should contain the data of the sub-environments that are __not__ being reset. - The reason for this constrain lies in the fact that the output of the - ``env._reset(data)`` can only be predicted for the entries that are reset. - For the others, TorchRL cannot know in advance if they will be meaningful or - not. For instance, one could perfectly just pad the values of the non-reset - components, in which case the non-reset data will be meaningless and should - be discarded. - -Below, we give some examples of the expected effect that ``"_reset"`` keys will -have on an environment returning zeros after reset: - - >>> # single reset at the root - >>> data = TensorDict({"val": [1, 1], "_reset": [False, True]}, []) - >>> env.reset(data) - >>> print(data.get("val")) # only the second value is 0 - tensor([1, 0]) - >>> # nested resets - >>> data = TensorDict({ - ... ("agent0", "val"): [1, 1], ("agent0", "_reset"): [False, True], - ... ("agent1", "val"): [2, 2], ("agent1", "_reset"): [True, False], - ... }, []) - >>> env.reset(data) - >>> print(data.get(("agent0", "val"))) # only the second value is 0 - tensor([1, 0]) - >>> print(data.get(("agent1", "val"))) # only the first value is 0 - tensor([0, 2]) - >>> # nested resets are overridden by a "_reset" at the root - >>> data = TensorDict({ - ... "_reset": [True, True], - ... ("agent0", "val"): [1, 1], ("agent0", "_reset"): [False, True], - ... ("agent1", "val"): [2, 2], ("agent1", "_reset"): [True, False], - ... }, []) - >>> env.reset(data) - >>> print(data.get(("agent0", "val"))) # reset at the root overrides nested - tensor([0, 0]) - >>> print(data.get(("agent1", "val"))) # reset at the root overrides nested - tensor([0, 0]) - -.. code-block:: - :caption: Parallel environment reset - - >>> tensordict = TensorDict({"_reset": [[True], [False], [True], [True]]}, [4]) - >>> env.reset(tensordict) # eliminates the "_reset" entry - TensorDict( - fields={ - terminated: Tensor(torch.Size([4, 1]), dtype=torch.bool), - done: Tensor(torch.Size([4, 1]), dtype=torch.bool), - pixels: Tensor(torch.Size([4, 500, 500, 3]), dtype=torch.uint8), - truncated: Tensor(torch.Size([4, 1]), dtype=torch.bool), - batch_size=torch.Size([4]), - device=None, - is_shared=True) - - -.. note:: - - *A note on performance*: launching a :class:`~.ParallelEnv` can take quite some time - as it requires to launch as many python instances as there are processes. Due to - the time that it takes to run ``import torch`` (and other imports), starting the - parallel env can be a bottleneck. This is why, for instance, TorchRL tests are so slow. - Once the environment is launched, a great speedup should be observed. - -.. note:: - - *TorchRL requires precise specs*: Another thing to take in consideration is - that :class:`ParallelEnv` (as well as data collectors) - will create data buffers based on the environment specs to pass data from one process - to another. This means that a misspecified spec (input, observation or reward) will - cause a breakage at runtime as the data can't be written on the preallocated buffer. - In general, an environment should be tested using the :func:`~.utils.check_env_specs` - test function before being used in a :class:`ParallelEnv`. This function will raise - an assertion error whenever the preallocated buffer and the collected data mismatch. - -We also offer the :class:`~.SerialEnv` class that enjoys the exact same API but is executed -serially. This is mostly useful for testing purposes, when one wants to assess the -behavior of a :class:`~.ParallelEnv` without launching the subprocesses. - -In addition to :class:`~.ParallelEnv`, which offers process-based parallelism, we also provide a way to create -multithreaded environments with :obj:`~.MultiThreadedEnv`. This class uses `EnvPool `_ -library underneath, which allows for higher performance, but at the same time restricts flexibility - one can only -create environments implemented in ``EnvPool``. This covers many popular RL environments types (Atari, Classic Control, -etc.), but one can not use an arbitrary TorchRL environment, as it is possible with :class:`~.ParallelEnv`. Run -`benchmarks/benchmark_batched_envs.py` to compare performance of different ways to parallelize batched environments. - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - SerialEnv - ParallelEnv - EnvCreator - -Async environments ------------------- - -Asynchronous environments allow for parallel execution of multiple environments, which can significantly speed up the -data collection process in reinforcement learning. - -The `AsyncEnvPool` class and its subclasses provide a flexible interface for managing these environments using different -backends, such as threading and multiprocessing. - -The `AsyncEnvPool` class serves as a base class for asynchronous environment pools, providing a common interface for -managing multiple environments concurrently. It supports different backends for parallel execution, such as threading -and multiprocessing, and provides methods for asynchronous stepping and resetting of environments. - -Contrary to :class:`~torchrl.envs.ParallelEnv`, :class:`~torchrl.envs.AsyncEnvPool` and its subclasses permit the -execution of a given set of sub-environments while another task performed, allowing for complex asynchronous jobs to be -run at the same time. For instance, it is possible to execute some environments while the policy is running based on -the output of others. - -This family of classes is particularly interesting when dealing with environments that have a high (and/or variable) -latency. - -.. note:: This class and its subclasses should work when nested in with :class:`~torchrl.envs.TransformedEnv` and - batched environments, but users won't currently be able to use the async features of the base environment when - it's nested in these classes. One should prefer nested transformed envs within an `AsyncEnvPool` instead. - If this is not possible, please raise an issue. - -Classes -~~~~~~~ - -- :class:`~torchrl.envs.AsyncEnvPool`: A base class for asynchronous environment pools. It determines the backend - implementation to use based on the provided arguments and manages the lifecycle of the environments. -- :class:`~torchrl.envs.ProcessorAsyncEnvPool`: An implementation of :class:`~torchrl.envs.AsyncEnvPool` using - multiprocessing for parallel execution of environments. This class manages a pool of environments, each running in - its own process, and provides methods for asynchronous stepping and resetting of environments using inter-process - communication. It is automatically instantiated when `"multiprocessing"` is passed as a backend during the - :class:`~torchrl.envs.AsyncEnvPool` instantiation. -- :class:`~torchrl.envs.ThreadingAsyncEnvPool`: An implementation of :class:`~torchrl.envs.AsyncEnvPool` using - threading for parallel execution of environments. This class manages a pool of environments, each running in its own - thread, and provides methods for asynchronous stepping and resetting of environments using a thread pool executor. - It is automatically instantiated when `"threading"` is passed as a backend during the - :class:`~torchrl.envs.AsyncEnvPool` instantiation. - -Example -~~~~~~~ - - >>> from functools import partial - >>> from torchrl.envs import AsyncEnvPool, GymEnv - >>> import torch - >>> # Choose backend - >>> backend = "threading" - >>> env = AsyncEnvPool( - >>> [partial(GymEnv, "Pendulum-v1"), partial(GymEnv, "CartPole-v1")], - >>> stack="lazy", - >>> backend=backend - >>> ) - >>> # Execute a synchronous reset - >>> reset = env.reset() - >>> print(reset) - >>> # Execute a synchronous step - >>> s = env.rand_step(reset) - >>> print(s) - >>> # Execute an asynchronous step in env 0 - >>> s0 = s[0] - >>> s0["action"] = torch.randn(1).clamp(-1, 1) - >>> s0["env_index"] = 0 - >>> env.async_step_send(s0) - >>> # Receive data - >>> s0_result = env.async_step_recv() - >>> print('result', s0_result) - >>> # Close env - >>> env.close() - - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - AsyncEnvPool - ProcessorAsyncEnvPool - ThreadingAsyncEnvPool - - -Custom native TorchRL environments ----------------------------------- - -TorchRL offers a series of custom built-in environments. - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - ChessEnv - PendulumEnv - TicTacToeEnv - LLMHashingEnv - - -Multi-agent environments ------------------------- - -.. _MARL-environment-API: - -.. currentmodule:: torchrl.envs - -TorchRL supports multi-agent learning out-of-the-box. -*The same classes used in a single-agent learning pipeline can be seamlessly used in multi-agent contexts, -without any modification or dedicated multi-agent infrastructure.* - -In this view, environments play a core role for multi-agent. In multi-agent environments, -many decision-making agents act in a shared world. -Agents can observe different things, act in different ways and also be rewarded differently. -Therefore, many paradigms exist to model multi-agent environments (DecPODPs, Markov Games). -Some of the main differences between these paradigms include: - -- **observation** can be per-agent and also have some shared components -- **reward** can be per-agent or shared -- **done** (and ``"truncated"`` or ``"terminated"``) can be per-agent or shared. - -TorchRL accommodates all these possible paradigms thanks to its :class:`tensordict.TensorDict` data carrier. -In particular, in multi-agent environments, per-agent keys will be carried in a nested "agents" TensorDict. -This TensorDict will have the additional agent dimension and thus group data that is different for each agent. -The shared keys, on the other hand, will be kept in the first level, as in single-agent cases. - -Let's look at an example to understand this better. For this example we are going to use -`VMAS `_, a multi-robot task simulator also -based on PyTorch, which runs parallel batched simulation on device. - -We can create a VMAS environment and look at what the output from a random step looks like: - -.. code-block:: - :caption: Example of multi-agent step tensordict - - >>> from torchrl.envs.libs.vmas import VmasEnv - >>> env = VmasEnv("balance", num_envs=3, n_agents=5) - >>> td = env.rand_step() - >>> td - TensorDict( - fields={ - agents: TensorDict( - fields={ - action: Tensor(shape=torch.Size([3, 5, 2]))}, - batch_size=torch.Size([3, 5])), - next: TensorDict( - fields={ - agents: TensorDict( - fields={ - info: TensorDict( - fields={ - ground_rew: Tensor(shape=torch.Size([3, 5, 1])), - pos_rew: Tensor(shape=torch.Size([3, 5, 1]))}, - batch_size=torch.Size([3, 5])), - observation: Tensor(shape=torch.Size([3, 5, 16])), - reward: Tensor(shape=torch.Size([3, 5, 1]))}, - batch_size=torch.Size([3, 5])), - done: Tensor(shape=torch.Size([3, 1]))}, - batch_size=torch.Size([3]))}, - batch_size=torch.Size([3])) - -We can observe that *keys that are shared by all agents*, such as **done** are present in the root tensordict with -batch size `(num_envs,)`, which represents the number of environments simulated. - -On the other hand, *keys that are different between agents*, such as **action**, **reward**, **observation**, -and **info** are present in the nested "agents" tensordict with batch size `(num_envs, n_agents)`, -which represents the additional agent dimension. - -Multi-agent tensor specs will follow the same style as in tensordicts. -Specs relating to values that vary between agents will need to be nested in the "agents" entry. - -Here is an example of how specs can be created in a multi-agent environment where -only the done flag is shared across agents (as in VMAS): - -.. code-block:: - :caption: Example of multi-agent spec creation - - >>> action_specs = [] - >>> observation_specs = [] - >>> reward_specs = [] - >>> info_specs = [] - >>> for i in range(env.n_agents): - ... action_specs.append(agent_i_action_spec) - ... reward_specs.append(agent_i_reward_spec) - ... observation_specs.append(agent_i_observation_spec) - >>> env.action_spec = Composite( - ... { - ... "agents": Composite( - ... {"action": torch.stack(action_specs)}, shape=(env.n_agents,) - ... ) - ... } - ...) - >>> env.reward_spec = Composite( - ... { - ... "agents": Composite( - ... {"reward": torch.stack(reward_specs)}, shape=(env.n_agents,) - ... ) - ... } - ...) - >>> env.observation_spec = Composite( - ... { - ... "agents": Composite( - ... {"observation": torch.stack(observation_specs)}, shape=(env.n_agents,) - ... ) - ... } - ...) - >>> env.done_spec = Categorical( - ... n=2, - ... shape=torch.Size((1,)), - ... dtype=torch.bool, - ... ) - -As you can see, it is very simple! Per-agent keys will have the nested composite spec and shared keys will follow -single agent standards. - -.. note:: - Since reward, done and action keys may have the additional "agent" prefix (e.g., `("agents","action")`), - the default keys used in the arguments of other TorchRL components (e.g. "action") will not match exactly. - Therefore, TorchRL provides the `env.action_key`, `env.reward_key`, and `env.done_key` attributes, - which will automatically point to the right key to use. Make sure you pass these attributes to the various - components in TorchRL to inform them of the right key (e.g., the `loss.set_keys()` function). - -.. note:: - TorchRL abstracts these nested specs away for ease of use. - This means that accessing `env.reward_spec` will always return the leaf - spec if the accessed spec is Composite. Therefore, if in the example above - we run `env.reward_spec` after env creation, we would get the same output as `torch.stack(reward_specs)}`. - To get the full composite spec with the "agents" key, you can run - `env.output_spec["full_reward_spec"]`. The same is valid for action and done specs. - Note that `env.reward_spec == env.output_spec["full_reward_spec"][env.reward_key]`. - - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - MarlGroupMapType - check_marl_grouping - -Auto-resetting Envs -------------------- - -.. _autoresetting_envs: - -Auto-resetting environments are environments where calls to :meth:`~torchrl.envs.EnvBase.reset` are not expected when -the environment reaches a ``"done"`` state during a rollout, as the reset happens automatically. -Usually, in such cases the observations delivered with the done and reward (which effectively result from performing the -action in the environment) are actually the first observations of a new episode, and not the last observations of the -current episode. - -To handle these cases, torchrl provides a :class:`~torchrl.envs.AutoResetTransform` that will copy the observations -that result from the call to `step` to the next `reset` and skip the calls to `reset` during rollouts (in both -:meth:`~torchrl.envs.EnvBase.rollout` and :class:`~torchrl.collectors.SyncDataCollector` iterations). -This transform class also provides a fine-grained control over the behavior to be adopted for the invalid observations, -which can be masked with `"nan"` or any other values, or not masked at all. - -To tell torchrl that an environment is auto-resetting, it is sufficient to provide an ``auto_reset`` argument -during construction. If provided, an ``auto_reset_replace`` argument can also control whether the values of the last -observation of an episode should be replaced with some placeholder or not. - - >>> from torchrl.envs import GymEnv - >>> from torchrl.envs import set_gym_backend - >>> import torch - >>> torch.manual_seed(0) - >>> - >>> class AutoResettingGymEnv(GymEnv): - ... def _step(self, tensordict): - ... tensordict = super()._step(tensordict) - ... if tensordict["done"].any(): - ... td_reset = super().reset() - ... tensordict.update(td_reset.exclude(*self.done_keys)) - ... return tensordict - ... - ... def _reset(self, tensordict=None): - ... if tensordict is not None and "_reset" in tensordict: - ... return tensordict.copy() - ... return super()._reset(tensordict) - >>> - >>> with set_gym_backend("gym"): - ... env = AutoResettingGymEnv("CartPole-v1", auto_reset=True, auto_reset_replace=True) - ... env.set_seed(0) - ... r = env.rollout(30, break_when_any_done=False) - >>> print(r["next", "done"].squeeze()) - tensor([False, False, False, False, False, False, False, False, False, False, - False, False, False, True, False, False, False, False, False, False, - False, False, False, False, False, True, False, False, False, False]) - >>> print("observation after reset are set as nan", r["next", "observation"]) - observation after reset are set as nan tensor([[-4.3633e-02, -1.4877e-01, 1.2849e-02, 2.7584e-01], - [-4.6609e-02, 4.6166e-02, 1.8366e-02, -1.2761e-02], - [-4.5685e-02, 2.4102e-01, 1.8111e-02, -2.9959e-01], - [-4.0865e-02, 4.5644e-02, 1.2119e-02, -1.2542e-03], - [-3.9952e-02, 2.4059e-01, 1.2094e-02, -2.9009e-01], - [-3.5140e-02, 4.3554e-01, 6.2920e-03, -5.7893e-01], - [-2.6429e-02, 6.3057e-01, -5.2867e-03, -8.6963e-01], - [-1.3818e-02, 8.2576e-01, -2.2679e-02, -1.1640e+00], - [ 2.6972e-03, 1.0212e+00, -4.5959e-02, -1.4637e+00], - [ 2.3121e-02, 1.2168e+00, -7.5232e-02, -1.7704e+00], - [ 4.7457e-02, 1.4127e+00, -1.1064e-01, -2.0854e+00], - [ 7.5712e-02, 1.2189e+00, -1.5235e-01, -1.8289e+00], - [ 1.0009e-01, 1.0257e+00, -1.8893e-01, -1.5872e+00], - [ nan, nan, nan, nan], - [-3.9405e-02, -1.7766e-01, -1.0403e-02, 3.0626e-01], - [-4.2959e-02, -3.7263e-01, -4.2775e-03, 5.9564e-01], - [-5.0411e-02, -5.6769e-01, 7.6354e-03, 8.8698e-01], - [-6.1765e-02, -7.6292e-01, 2.5375e-02, 1.1820e+00], - [-7.7023e-02, -9.5836e-01, 4.9016e-02, 1.4826e+00], - [-9.6191e-02, -7.6387e-01, 7.8667e-02, 1.2056e+00], - [-1.1147e-01, -9.5991e-01, 1.0278e-01, 1.5219e+00], - [-1.3067e-01, -7.6617e-01, 1.3322e-01, 1.2629e+00], - [-1.4599e-01, -5.7298e-01, 1.5848e-01, 1.0148e+00], - [-1.5745e-01, -7.6982e-01, 1.7877e-01, 1.3527e+00], - [-1.7285e-01, -9.6668e-01, 2.0583e-01, 1.6956e+00], - [ nan, nan, nan, nan], - [-4.3962e-02, 1.9845e-01, -4.5015e-02, -2.5903e-01], - [-3.9993e-02, 3.9418e-01, -5.0196e-02, -5.6557e-01], - [-3.2109e-02, 5.8997e-01, -6.1507e-02, -8.7363e-01], - [-2.0310e-02, 3.9574e-01, -7.8980e-02, -6.0090e-01]]) - -Dynamic Specs +Quick Example ------------- -.. _dynamic_envs: - -Running environments in parallel is usually done via the creation of memory buffers used to pass information from one -process to another. In some cases, it may be impossible to forecast whether an environment will or will not have -consistent inputs or outputs during a rollout, as their shape may be variable. We refer to this as dynamic specs. - -TorchRL is capable of handling dynamic specs, but the batched environments and collectors will need to be made -aware of this feature. Note that, in practice, this is detected automatically. - -To indicate that a tensor will have a variable size along a dimension, one can set the size value as ``-1`` for the -desired dimensions. Because the data cannot be stacked contiguously, calls to ``env.rollout`` need to be made with -the ``return_contiguous=False`` argument. -Here is a working example: - - >>> from torchrl.envs import EnvBase - >>> from torchrl.data import Unbounded, Composite, Bounded, Binary - >>> import torch - >>> from tensordict import TensorDict, TensorDictBase - >>> - >>> class EnvWithDynamicSpec(EnvBase): - ... def __init__(self, max_count=5): - ... super().__init__(batch_size=()) - ... self.observation_spec = Composite( - ... observation=Unbounded(shape=(3, -1, 2)), - ... ) - ... self.action_spec = Bounded(low=-1, high=1, shape=(2,)) - ... self.full_done_spec = Composite( - ... done=Binary(1, shape=(1,), dtype=torch.bool), - ... terminated=Binary(1, shape=(1,), dtype=torch.bool), - ... truncated=Binary(1, shape=(1,), dtype=torch.bool), - ... ) - ... self.reward_spec = Unbounded((1,), dtype=torch.float) - ... self.count = 0 - ... self.max_count = max_count - ... - ... def _reset(self, tensordict=None): - ... self.count = 0 - ... data = TensorDict( - ... { - ... "observation": torch.full( - ... (3, self.count + 1, 2), - ... self.count, - ... dtype=self.observation_spec["observation"].dtype, - ... ) - ... } - ... ) - ... data.update(self.done_spec.zero()) - ... return data - ... - ... def _step( - ... self, - ... tensordict: TensorDictBase, - ... ) -> TensorDictBase: - ... self.count += 1 - ... done = self.count >= self.max_count - ... observation = TensorDict( - ... { - ... "observation": torch.full( - ... (3, self.count + 1, 2), - ... self.count, - ... dtype=self.observation_spec["observation"].dtype, - ... ) - ... } - ... ) - ... done = self.full_done_spec.zero() | done - ... reward = self.full_reward_spec.zero() - ... return observation.update(done).update(reward) - ... - ... def _set_seed(self, seed: Optional[int]) -> None: - ... self.manual_seed = seed - ... return seed - >>> env = EnvWithDynamicSpec() - >>> print(env.rollout(5, return_contiguous=False)) - LazyStackedTensorDict( - fields={ - action: Tensor(shape=torch.Size([5, 2]), device=cpu, dtype=torch.float32, is_shared=False), - done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: LazyStackedTensorDict( - fields={ - done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([5, 3, -1, 2]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), - terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), - truncated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - exclusive_fields={ - }, - batch_size=torch.Size([5]), - device=None, - is_shared=False, - stack_dim=0), - observation: Tensor(shape=torch.Size([5, 3, -1, 2]), device=cpu, dtype=torch.float32, is_shared=False), - terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), - truncated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - exclusive_fields={ - }, - batch_size=torch.Size([5]), - device=None, - is_shared=False, - stack_dim=0) - -.. warning:: - The absence of memory buffers in :class:`~torchrl.envs.ParallelEnv` and in - data collectors can impact performance of these classes dramatically. Any - such usage should be carefully benchmarked against a plain execution on a - single process, as serializing and deserializing large numbers of tensors - can be very expensive. - -Currently, :func:`~torchrl.envs.utils.check_env_specs` will pass for dynamic specs where a shape varies along some -dimensions, but not when a key is present during a step and absent during others, or when the number of dimensions -varies. - -Transforms ----------- - -.. _transforms: - -.. currentmodule:: torchrl.envs.transforms - -In most cases, the raw output of an environment must be treated before being passed to another object (such as a -policy or a value operator). To do this, TorchRL provides a set of transforms that aim at reproducing the transform -logic of `torch.distributions.Transform` and `torchvision.transforms`. -Our environment :ref:`tutorial ` -provides more information on how to design a custom transform. - -Transformed environments are build using the :class:`TransformedEnv` primitive. -Composed transforms are built using the :class:`Compose` class: - -.. code-block:: - :caption: Transformed environment - - >>> base_env = GymEnv("Pendulum-v1", from_pixels=True, device="cuda:0") - >>> transform = Compose(ToTensorImage(in_keys=["pixels"]), Resize(64, 64, in_keys=["pixels"])) - >>> env = TransformedEnv(base_env, transform) - -Transforms are usually subclasses of :class:`~torchrl.envs.transforms.Transform`, although any -``Callable[[TensorDictBase], TensorDictBase]``. - -By default, the transformed environment will inherit the device of the -:obj:`base_env` that is passed to it. The transforms will then be executed on that device. -It is now apparent that this can bring a significant speedup depending on the kind of -operations that is to be computed. - -A great advantage of environment wrappers is that one can consult the environment up to that wrapper. -The same can be achieved with TorchRL transformed environments: the ``parent`` attribute will -return a new :class:`TransformedEnv` with all the transforms up to the transform of interest. -Re-using the example above: - -.. code-block:: - :caption: Transform parent - - >>> resize_parent = env.transform[-1].parent # returns the same as TransformedEnv(base_env, transform[:-1]) - - -Transformed environment can be used with vectorized environments. -Since each transform uses a ``"in_keys"``/``"out_keys"`` set of keyword argument, it is -also easy to root the transform graph to each component of the observation data (e.g. -pixels or states etc). - -Forward and inverse transforms -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Transforms also have an :meth:`~torchrl.envs.Transform.inv` method that is called before the action is applied in reverse -order over the composed transform chain. This allows applying transforms to data in the environment before the action is -taken in the environment. The keys to be included in this inverse transform are passed through the `"in_keys_inv"` -keyword argument, and the out-keys default to these values in most cases: - -.. code-block:: - :caption: Inverse transform - - >>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step - -The following paragraphs detail how one can think about what is to be considered `in_` or `out_` features. - -Understanding Transform Keys -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In transforms, `in_keys` and `out_keys` define the interaction between the base environment and the outside world -(e.g., your policy): - -- `in_keys` refers to the base environment's perspective (inner = `base_env` of the - :class:`~torchrl.envs.TransformedEnv`). -- `out_keys` refers to the outside world (outer = `policy`, `agent`, etc.). - -For example, with `in_keys=["obs"]` and `out_keys=["obs_standardized"]`, the policy will "see" a standardized -observation, while the base environment outputs a regular observation. - -Similarly, for inverse keys: - -- `in_keys_inv` refers to entries as seen by the base environment. -- `out_keys_inv` refers to entries as seen or produced by the policy. - -The following figure illustrates this concept for the :class:`~torchrl.envs.RenameTransform` class: the input -`TensorDict` of the `step` function must include the `out_keys_inv` as they are part of the outside world. The -transform changes these names to match the names of the inner, base environment using the `in_keys_inv`. -The inverse process is executed with the output tensordict, where the `in_keys` are mapped to the corresponding -`out_keys`. - -.. figure:: /_static/img/rename_transform.png - - Rename transform logic - -.. note:: During a call to `inv`, the transforms are executed in reversed order (compared to the forward / step mode). - -Transforming Tensors and Specs -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -When transforming actual tensors (coming from the policy), the process is schematically represented as: - - >>> for t in reversed(self.transform): - ... td = t.inv(td) - -This starts with the outermost transform to the innermost transform, ensuring the action value exposed to the policy -is properly transformed. - -For transforming the action spec, the process should go from innermost to outermost (similar to observation specs): - - >>> def transform_action_spec(self, action_spec): - ... for t in self.transform: - ... action_spec = t.transform_action_spec(action_spec) - ... return action_spec - -A pseudocode for a single transform_action_spec could be: - - >>> def transform_action_spec(self, action_spec): - ... return spec_from_random_values(self._apply_transform(action_spec.rand())) - -This approach ensures that the "outside" spec is inferred from the "inside" spec. Note that we did not call -`_inv_apply_transform` but `_apply_transform` on purpose! - -Exposing Specs to the Outside World -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -`TransformedEnv` will expose the specs corresponding to the `out_keys_inv` for actions and states. -For example, with :class:`~torchrl.envs.ActionDiscretizer`, the environment's action (e.g., `"action"`) is a float-valued -tensor that should not be generated when using :meth:`~torchrl.envs.EnvBase.rand_action` with the transformed -environment. Instead, `"action_discrete"` should be generated, and its continuous counterpart obtained from the -transform. Therefore, the user should see the `"action_discrete"` entry being exposed, but not `"action"`. - -Designing your own Transform -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To create a basic, custom transform, you need to subclass the `Transform` class and implement the -:meth:`~torchrl.envs._apply_transform` method. Here's an example of a simple transform that adds 1 to the observation -tensor: - - >>> class AddOneToObs(Transform): - ... """A transform that adds 1 to the observation tensor.""" - ... - ... def __init__(self): - ... super().__init__(in_keys=["observation"], out_keys=["observation"]) - ... - ... def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor: - ... return obs + 1 - - -Tips for subclassing `Transform` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -There are various ways of subclassing a transform. The things to take into considerations are: - -- Is the transform identical for each tensor / item being transformed? Use - :meth:`~torchrl.envs.Transform._apply_transform` and :meth:`~torchrl.envs.Transform._inv_apply_transform`. -- The transform needs access to the input data to env.step as well as output? Rewrite - :meth:`~torchrl.envs.Transform._step`. - Otherwise, rewrite :meth:`~torchrl.envs.Transform._call` (or :meth:`~torchrl.envs.Transform._inv_call`). -- Is the transform to be used within a replay buffer? Overwrite :meth:`~torchrl.envs.Transform.forward`, - :meth:`~torchrl.envs.Transform.inv`, :meth:`~torchrl.envs.Transform._apply_transform` or - :meth:`~torchrl.envs.Transform._inv_apply_transform`. -- Within a transform, you can access (and make calls to) the parent environment using - :attr:`~torchrl.envs.Transform.parent` (the base env + all transforms till this one) or - :meth:`~torchrl.envs.Transform.container` (The object that encapsulates the transform). -- Don't forget to edits the specs if needed: top level: :meth:`~torchrl.envs.Transform.transform_output_spec`, - :meth:`~torchrl.envs.Transform.transform_input_spec`. - Leaf level: :meth:`~torchrl.envs.Transform.transform_observation_spec`, - :meth:`~torchrl.envs.Transform.transform_action_spec`, :meth:`~torchrl.envs.Transform.transform_state_spec`, - :meth:`~torchrl.envs.Transform.transform_reward_spec` and - :meth:`~torchrl.envs.Transform.transform_reward_spec`. - -For practical examples, see the methods listed above. - -You can use a transform in an environment by passing it to the TransformedEnv constructor: - - >>> env = TransformedEnv(GymEnv("Pendulum-v1"), AddOneToObs()) - -You can compose multiple transforms together using the Compose class: - - >>> transform = Compose(AddOneToObs(), RewardSum()) - >>> env = TransformedEnv(GymEnv("Pendulum-v1"), transform) - -Inverse Transforms -^^^^^^^^^^^^^^^^^^ - -Some transforms have an inverse transform that can be used to undo the transformation. For example, the AddOneToAction -transform has an inverse transform that subtracts 1 from the action tensor: - - >>> class AddOneToAction(Transform): - ... """A transform that adds 1 to the action tensor.""" - ... def __init__(self): - ... super().__init__(in_keys=[], out_keys=[], in_keys_inv=["action"], out_keys_inv=["action"]) - ... def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor: - ... return action + 1 - -Using a Transform with a Replay Buffer -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -You can use a transform with a replay buffer by passing it to the ReplayBuffer constructor: - -Cloning transforms -~~~~~~~~~~~~~~~~~~ - -Because transforms appended to an environment are "registered" to this environment -through the ``transform.parent`` property, when manipulating transforms we should keep -in mind that the parent may come and go following what is being done with the transform. -Here are some examples: if we get a single transform from a :class:`Compose` object, -this transform will keep its parent: - - >>> third_transform = env.transform[2] - >>> assert third_transform.parent is not None - -This means that using this transform for another environment is prohibited, as -the other environment would replace the parent and this may lead to unexpected -behviours. Fortunately, the :class:`Transform` class comes with a :func:`clone` -method that will erase the parent while keeping the identity of all the -registered buffers: - - >>> TransformedEnv(base_env, third_transform) # raises an Exception as third_transform already has a parent - >>> TransformedEnv(base_env, third_transform.clone()) # works - -On a single process or if the buffers are placed in shared memory, this will -result in all the clone transforms to keep the same behavior even if the -buffers are changed in place (which is what will happen with the :class:`CatFrames` -transform, for instance). In distributed settings, this may not hold and one -should be careful about the expected behavior of the cloned transforms in this -context. -Finally, notice that indexing multiple transforms from a :class:`Compose` transform -may also result in loss of parenthood for these transforms: the reason is that -indexing a :class:`Compose` transform results in another :class:`Compose` transform -that does not have a parent environment. Hence, we have to clone the sub-transforms -to be able to create this other composition: - - >>> env = TransformedEnv(base_env, Compose(transform1, transform2, transform3)) - >>> last_two = env.transform[-2:] - >>> assert isinstance(last_two, Compose) - >>> assert last_two.parent is None - >>> assert last_two[0] is not transform2 - >>> assert isinstance(last_two[0], type(transform2)) # and the buffers will match - >>> assert last_two[1] is not transform3 - >>> assert isinstance(last_two[1], type(transform3)) # and the buffers will match - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - Transform - TransformedEnv - ActionDiscretizer - ActionMask - AutoResetEnv - AutoResetTransform - BatchSizeTransform - BinarizeReward - BurnInTransform - CatFrames - CatTensors - CenterCrop - ClipTransform - Compose - ConditionalPolicySwitch - ConditionalSkip - Crop - DTypeCastTransform - DeviceCastTransform - DiscreteActionProjection - DoubleToFloat - EndOfLifeTransform - ExcludeTransform - FiniteTensorDictCheck - FlattenObservation - FrameSkipTransform - GrayScale - Hash - InitTracker - KLRewardTransform - LineariseRewards - ModuleTransform - MultiAction - NoopResetEnv - ObservationNorm - ObservationTransform - PermuteTransform - PinMemoryTransform - R3MTransform - RandomCropTensorDict - RemoveEmptySpecs - RenameTransform - Resize - Reward2GoTransform - RewardClipping - RewardScaling - RewardSum - SelectTransform - SignTransform - SqueezeTransform - Stack - StepCounter - TargetReturn - TensorDictPrimer - TimeMaxPool - Timer - Tokenizer - ToTensorImage - TrajCounter - UnaryTransform - UnsqueezeTransform - VC1Transform - VIPRewardTransform - VIPTransform - VecGymEnvTransform - VecNorm - VecNormV2 - gSDENoise - -Environments with masked actions --------------------------------- - -In some environments with discrete actions, the actions available to the agent might change throughout execution. -In such cases the environments will output an action mask (under the ``"action_mask"`` key by default). -This mask needs to be used to filter out unavailable actions for that step. - -If you are using a custom policy you can pass this mask to your probability distribution like so: - -.. code-block:: - :caption: Categorical policy with action mask - - >>> from tensordict.nn import TensorDictModule, ProbabilisticTensorDictModule, TensorDictSequential - >>> import torch.nn as nn - >>> from torchrl.modules import MaskedCategorical - >>> module = TensorDictModule( - >>> nn.Linear(in_feats, out_feats), - >>> in_keys=["observation"], - >>> out_keys=["logits"], - >>> ) - >>> dist = ProbabilisticTensorDictModule( - >>> in_keys={"logits": "logits", "mask": "action_mask"}, - >>> out_keys=["action"], - >>> distribution_class=MaskedCategorical, - >>> ) - >>> actor = TensorDictSequential(module, dist) - -If you want to use a default policy, you will need to wrap your environment in the :class:`~torchrl.envs.transforms.ActionMask` -transform. This transform can take care of updating the action mask in the action spec in order for the default policy -to always know what the latest available actions are. You can do this like so: - -.. code-block:: - :caption: How to use the action mask transform - - >>> from tensordict.nn import TensorDictModule, ProbabilisticTensorDictModule, TensorDictSequential - >>> import torch.nn as nn - >>> from torchrl.envs.transforms import TransformedEnv, ActionMask - >>> env = TransformedEnv( - >>> your_base_env - >>> ActionMask(action_key="action", mask_key="action_mask"), - >>> ) - -.. note:: - In case you are using a parallel environment it is important to add the transform to the parallel environment itself - and not to its sub-environments. - - - -Recorders ---------- - -.. _Environment-Recorders: - -Recording data during environment rollout execution is crucial to keep an eye on the algorithm performance as well as -reporting results after training. - -TorchRL offers several tools to interact with the environment output: first and foremost, a ``callback`` callable -can be passed to the :meth:`~torchrl.envs.EnvBase.rollout` method. This function will be called upon the collected -tensordict at each iteration of the rollout (if some iterations have to be skipped, an internal variable should be added -to keep track of the call count within ``callback``). - -To save collected tensordicts on disk, the :class:`~torchrl.record.TensorDictRecorder` can be used. - -Recording videos -~~~~~~~~~~~~~~~~ - -Several backends offer the possibility of recording rendered images from the environment. -If the pixels are already part of the environment output (e.g. Atari or other game simulators), a -:class:`~torchrl.record.VideoRecorder` can be appended to the environment. This environment transform takes as input -a logger capable of recording videos (e.g. :class:`~torchrl.record.loggers.CSVLogger`, :class:`~torchrl.record.loggers.WandbLogger` -or :class:`~torchrl.record.loggers.TensorBoardLogger`) as well as a tag indicating where the video should be saved. -For instance, to save mp4 videos on disk, one can use :class:`~torchrl.record.loggers.CSVLogger` with a `video_format="mp4"` -argument. - -The :class:`~torchrl.record.VideoRecorder` transform can handle batched images and automatically detects numpy or PyTorch -formatted images (WHC or CWH). - - >>> logger = CSVLogger("dummy-exp", video_format="mp4") - >>> env = GymEnv("ALE/Pong-v5") - >>> env = env.append_transform(VideoRecorder(logger, tag="rendered", in_keys=["pixels"])) - >>> env.rollout(10) - >>> env.transform.dump() # Save the video and clear cache - -Note that the cache of the transform will keep on growing until dump is called. It is the user responsibility to -take care of calling `dump` when needed to avoid OOM issues. - -In some cases, creating a testing environment where images can be collected is tedious or expensive, or simply impossible -(some libraries only allow one environment instance per workspace). -In these cases, assuming that a `render` method is available in the environment, the :class:`~torchrl.record.PixelRenderTransform` -can be used to call `render` on the parent environment and save the images in the rollout data stream. -This class works over single and batched environments alike: - - >>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator - >>> from torchrl.record.loggers import CSVLogger - >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder - >>> - >>> def make_env(): - >>> env = GymEnv("CartPole-v1", render_mode="rgb_array") - >>> # Uncomment this line to execute per-env - >>> # env = env.append_transform(PixelRenderTransform()) - >>> return env - >>> - >>> if __name__ == "__main__": - ... logger = CSVLogger("dummy", video_format="mp4") - ... - ... env = ParallelEnv(16, EnvCreator(make_env)) - ... env.start() - ... # Comment this line to execute per-env - ... env = env.append_transform(PixelRenderTransform()) - ... - ... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) - ... env.rollout(3) - ... - ... check_env_specs(env) - ... - ... r = env.rollout(30) - ... env.transform.dump() - ... env.close() - - -.. currentmodule:: torchrl.record - -Recorders are transforms that register data as they come in, for logging purposes. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - TensorDictRecorder - VideoRecorder - PixelRenderTransform - - -Helpers -------- -.. currentmodule:: torchrl.envs - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - RandomPolicy - check_env_specs - exploration_type - get_available_libraries - make_composite_from_td - set_exploration_type - step_mdp - terminated_or_truncated - -Domain-specific ---------------- -.. currentmodule:: torchrl.envs - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - ModelBasedEnvBase - model_based.dreamer.DreamerEnv - model_based.dreamer.DreamerDecoder - - -Libraries ---------- - -.. currentmodule:: torchrl.envs - -TorchRL's mission is to make the training of control and decision algorithm as -easy as it gets, irrespective of the simulator being used (if any). -Multiple wrappers are available for DMControl, Habitat, Jumanji and, naturally, -for Gym. - -This last library has a special status in the RL community as being the mostly -used framework for coding simulators. Its successful API has been foundational -and inspired many other frameworks, among which TorchRL. -However, Gym has gone through multiple design changes and it is sometimes hard -to accommodate these as an external adoption library: users usually have their -"preferred" version of the library. Moreover, gym is now being maintained -by another group under the "gymnasium" name, which does not facilitate code -compatibility. In practice, we must consider that users may have a version of -gym *and* gymnasium installed in the same virtual environment, and we must -allow both to work concomittantly. -Fortunately, TorchRL provides a solution for this problem: a special decorator -:class:`~.gym.set_gym_backend` allows to control which library will be used -in the relevant functions: - - >>> from torchrl.envs.libs.gym import GymEnv, set_gym_backend, gym_backend - >>> import gymnasium, gym - >>> with set_gym_backend(gymnasium): - ... print(gym_backend()) - ... env1 = GymEnv("Pendulum-v1") - - >>> with set_gym_backend(gym): - ... print(gym_backend()) - ... env2 = GymEnv("Pendulum-v1") - - >>> print(env1._env.env.env) - - >>> print(env2._env.env.env) - - -We can see that the two libraries modify the value returned by :func:`~torchrl.envs.gym.gym_backend()` -which can be further used to indicate which library needs to be used for -the current computation. :class:`~.gym.set_gym_backend` is also a decorator: -we can use it to tell to a specific function what gym backend needs to be used -during its execution. -The :func:`torchrl.envs.libs.gym.gym_backend` function allows you to gather -the current gym backend or any of its modules: - - >>> import mo_gymnasium - >>> with set_gym_backend("gym"): - ... wrappers = gym_backend('wrappers') - ... print(wrappers) - - >>> with set_gym_backend("gymnasium"): - ... wrappers = gym_backend('wrappers') - ... print(wrappers) - - -Another tool that comes in handy with gym and other external dependencies is -the :class:`torchrl._utils.implement_for` class. Decorating a function -with ``@implement_for`` will tell torchrl that, depending on the version -indicated, a specific behavior is to be expected. This allows us to easily -support multiple versions of gym without requiring any effort from the user side. -For example, considering that our virtual environment has the v0.26.2 installed, -the following function will return ``1`` when queried: - - >>> from torchrl._utils import implement_for - >>> @implement_for("gym", None, "0.26.0") - ... def fun(): - ... return 0 - >>> @implement_for("gym", "0.26.0", None) - ... def fun(): - ... return 1 - >>> fun() - 1 - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - BraxEnv - BraxWrapper - DMControlEnv - DMControlWrapper - GymEnv - GymWrapper - HabitatEnv - IsaacGymEnv - IsaacGymWrapper - IsaacLabWrapper - JumanjiEnv - JumanjiWrapper - MeltingpotEnv - MeltingpotWrapper - MOGymEnv - MOGymWrapper - MultiThreadedEnv - MultiThreadedEnvWrapper - OpenMLEnv - OpenSpielWrapper - OpenSpielEnv - PettingZooEnv - PettingZooWrapper - RoboHiveEnv - SMACv2Env - SMACv2Wrapper - UnityMLAgentsEnv - UnityMLAgentsWrapper - VmasEnv - VmasWrapper - gym_backend - set_gym_backend - register_gym_spec_conversion +.. code-block:: python + + from torchrl.envs import GymEnv, ParallelEnv, TransformedEnv + from torchrl.envs.transforms import RewardSum, StepCounter + + # Create a single environment + env = GymEnv("Pendulum-v1") + + # Add transforms + env = TransformedEnv(env, RewardSum()) + + # Create parallel environments + def make_env(): + return TransformedEnv( + GymEnv("Pendulum-v1"), + StepCounter(max_steps=200) + ) + + parallel_env = ParallelEnv(4, make_env) + + # Run a rollout + rollout = parallel_env.rollout(100) + +Documentation Sections +---------------------- + +.. toctree:: + :maxdepth: 2 + + envs_api + envs_vectorized + envs_transforms + envs_multiagent + envs_libraries + envs_recorders diff --git a/docs/source/reference/envs_api.rst b/docs/source/reference/envs_api.rst new file mode 100644 index 00000000000..03cb747ece0 --- /dev/null +++ b/docs/source/reference/envs_api.rst @@ -0,0 +1,208 @@ +.. currentmodule:: torchrl.envs + +.. _Environment-API: + +Environment API +=============== + +TorchRL offers an API to handle environments of different backends, such as gym, +dm-control, dm-lab, model-based environments as well as custom environments. +The goal is to be able to swap environments in an experiment with little or no effort, +even if these environments are simulated using different libraries. +TorchRL offers some out-of-the-box environment wrappers under :mod:`torchrl.envs.libs`, +which we hope can be easily imitated for other libraries. +The parent class :class:`~torchrl.envs.EnvBase` is a :class:`torch.nn.Module` subclass that implements +some typical environment methods using :class:`tensordict.TensorDict` as a data organiser. This allows this +class to be generic and to handle an arbitrary number of input and outputs, as well as +nested or batched data structures. + +Each env will have the following attributes: + +- :attr:`env.batch_size`: a :class:`torch.Size` representing the number of envs + batched together. +- :attr:`env.device`: the device where the input and output tensordict are expected to live. + The environment device does not mean that the actual step operations will be computed on device + (this is the responsibility of the backend, with which TorchRL can do little). The device of + an environment just represents the device where the data is to be expected when input to the + environment or retrieved from it. TorchRL takes care of mapping the data to the desired device. + This is especially useful for transforms (see below). For parametric environments (e.g. + model-based environments), the device does represent the hardware that will be used to + compute the operations. +- :attr:`env.observation_spec`: a :class:`~torchrl.data.Composite` object + containing all the observation key-spec pairs. +- :attr:`env.state_spec`: a :class:`~torchrl.data.Composite` object + containing all the input key-spec pairs (except action). For most stateful + environments, this container will be empty. +- :attr:`env.action_spec`: a :class:`~torchrl.data.TensorSpec` object + representing the action spec. +- :attr:`env.reward_spec`: a :class:`~torchrl.data.TensorSpec` object representing + the reward spec. +- :attr:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing + the done-flag spec. See the section on trajectory termination below. +- :attr:`env.input_spec`: a :class:`~torchrl.data.Composite` object containing + all the input keys (``"full_action_spec"`` and ``"full_state_spec"``). +- :attr:`env.output_spec`: a :class:`~torchrl.data.Composite` object containing + all the output keys (``"full_observation_spec"``, ``"full_reward_spec"`` and ``"full_done_spec"``). + +If the environment carries non-tensor data, a :class:`~torchrl.data.NonTensor` +instance can be used. + +Env specs: locks and batch size +------------------------------- + +.. _Environment-lock: + +Environment specs are locked by default (through a ``spec_locked`` arg passed to the env constructor). +Locking specs means that any modification of the spec (or its children if it is a :class:`~torchrl.data.Composite` +instance) will require to unlock it. This can be done via the :meth:`~torchrl.envs.EnvBase.set_spec_lock_`. +The reason specs are locked by default is that it makes it easy to cache values such as action or reset keys and the +likes. +Unlocking an env should only be done if it expected that the specs will be modified often (which, in principle, should +be avoided). +Modifications of the specs such as `env.observation_spec = new_spec` are allowed: under the hood, TorchRL will erase +the cache, unlock the specs, make the modification and relock the specs if the env was previously locked. + +Importantly, the environment spec shapes should contain the batch size, e.g. +an environment with :attr:`env.batch_size` ``== torch.Size([4])`` should have +an :attr:`env.action_spec` with shape :class:`torch.Size` ``([4, action_size])``. +This is helpful when preallocation tensors, checking shape consistency etc. + +Env methods +----------- + +With these, the following methods are implemented: + +- :meth:`env.reset`: a reset method that may (but not necessarily requires to) take + a :class:`tensordict.TensorDict` input. It return the first tensordict of a rollout, usually + containing a ``"done"`` state and a set of observations. If not present, + a ``"reward"`` key will be instantiated with 0s and the appropriate shape. +- :meth:`env.step`: a step method that takes a :class:`tensordict.TensorDict` input + containing an input action as well as other inputs (for model-based or stateless + environments, for instance). +- :meth:`env.step_and_maybe_reset`: executes a step, and (partially) resets the + environments if it needs to. It returns the updated input with a ``"next"`` + key containing the data of the next step, as well as a tensordict containing + the input data for the next step (ie, reset or result or + :func:`~torchrl.envs.utils.step_mdp`) + This is done by reading the ``done_keys`` and + assigning a ``"_reset"`` signal to each done state. This method allows + to code non-stopping rollout functions with little effort: + + >>> data_ = env.reset() + >>> result = [] + >>> for i in range(N): + ... data, data_ = env.step_and_maybe_reset(data_) + ... result.append(data) + ... + >>> result = torch.stack(result) + +- :meth:`env.set_seed`: a seeding method that will return the next seed + to be used in a multi-env setting. This next seed is deterministically computed + from the preceding one, such that one can seed multiple environments with a different + seed without risking to overlap seeds in consecutive experiments, while still + having reproducible results. +- :meth:`env.rollout`: executes a rollout in the environment for + a maximum number of steps (``max_steps=N``) and using a policy (``policy=model``). + The policy should be coded using a :class:`tensordict.nn.TensorDictModule` + (or any other :class:`tensordict.TensorDict`-compatible module). + The resulting :class:`tensordict.TensorDict` instance will be marked with + a trailing ``"time"`` named dimension that can be used by other modules + to treat this batched dimension as it should. + +The following figure summarizes how a rollout is executed in torchrl. + +.. figure:: /_static/img/rollout.gif + + TorchRL rollouts using TensorDict. + +In brief, a TensorDict is created by the :meth:`~.EnvBase.reset` method, +then populated with an action by the policy before being passed to the +:meth:`~.EnvBase.step` method which writes the observations, done flag(s) and +reward under the ``"next"`` entry. The result of this call is stored for +delivery and the ``"next"`` entry is gathered by the :func:`~.utils.step_mdp` +function. + +.. note:: + In general, all TorchRL environment have a ``"done"`` and ``"terminated"`` + entry in their output tensordict. If they are not present by design, + the :class:`~.EnvBase` metaclass will ensure that every done or terminated + is flanked with its dual. + In TorchRL, ``"done"`` strictly refers to the union of all the end-of-trajectory + signals and should be interpreted as "the last step of a trajectory" or + equivalently "a signal indicating the need to reset". + If the environment provides it (eg, Gymnasium), the truncation entry is also + written in the :meth:`EnvBase.step` output under a ``"truncated"`` entry. + If the environment carries a single value, it will interpreted as a ``"terminated"`` + signal by default. + By default, TorchRL's collectors and rollout methods will be looking for the ``"done"`` + entry to assess if the environment should be reset. + +.. note:: + + The `torchrl.collectors.utils.split_trajectories` function can be used to + slice adjacent trajectories. It relies on a ``"traj_ids"`` entry in the + input tensordict, or to the junction of ``"done"`` and ``"truncated"`` key + if the ``"traj_ids"`` is missing. + + +.. note:: + + In some contexts, it can be useful to mark the first step of a trajectory. + TorchRL provides such functionality through the :class:`~torchrl.envs.InitTracker` + transform. + + +Our environment :ref:`tutorial ` +provides more information on how to design a custom environment from scratch. + +Base classes +------------ + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + EnvBase + GymLikeEnv + EnvMetaData + +Custom native TorchRL environments +---------------------------------- + +TorchRL offers a series of custom built-in environments. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + ChessEnv + PendulumEnv + TicTacToeEnv + LLMHashingEnv + +Domain-specific +--------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + ModelBasedEnvBase + model_based.dreamer.DreamerEnv + model_based.dreamer.DreamerDecoder + +Helpers +------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + RandomPolicy + check_env_specs + exploration_type + get_available_libraries + make_composite_from_td + set_exploration_type + step_mdp + terminated_or_truncated diff --git a/docs/source/reference/envs_libraries.rst b/docs/source/reference/envs_libraries.rst new file mode 100644 index 00000000000..b4b1b78b547 --- /dev/null +++ b/docs/source/reference/envs_libraries.rst @@ -0,0 +1,277 @@ +.. currentmodule:: torchrl.envs + +Library Wrappers +================ + +TorchRL's mission is to make the training of control and decision algorithm as +easy as it gets, irrespective of the simulator being used (if any). +Multiple wrappers are available for DMControl, Habitat, Jumanji and, naturally, +for Gym. + +This last library has a special status in the RL community as being the mostly +used framework for coding simulators. Its successful API has been foundational +and inspired many other frameworks, among which TorchRL. +However, Gym has gone through multiple design changes and it is sometimes hard +to accommodate these as an external adoption library: users usually have their +"preferred" version of the library. Moreover, gym is now being maintained +by another group under the "gymnasium" name, which does not facilitate code +compatibility. In practice, we must consider that users may have a version of +gym *and* gymnasium installed in the same virtual environment, and we must +allow both to work concomittantly. +Fortunately, TorchRL provides a solution for this problem: a special decorator +:class:`~.gym.set_gym_backend` allows to control which library will be used +in the relevant functions: + + >>> from torchrl.envs.libs.gym import GymEnv, set_gym_backend, gym_backend + >>> import gymnasium, gym + >>> with set_gym_backend(gymnasium): + ... print(gym_backend()) + ... env1 = GymEnv("Pendulum-v1") + + >>> with set_gym_backend(gym): + ... print(gym_backend()) + ... env2 = GymEnv("Pendulum-v1") + + >>> print(env1._env.env.env) + + >>> print(env2._env.env.env) + + +We can see that the two libraries modify the value returned by :func:`~torchrl.envs.gym.gym_backend()` +which can be further used to indicate which library needs to be used for +the current computation. :class:`~.gym.set_gym_backend` is also a decorator: +we can use it to tell to a specific function what gym backend needs to be used +during its execution. +The :func:`torchrl.envs.libs.gym.gym_backend` function allows you to gather +the current gym backend or any of its modules: + + >>> import mo_gymnasium + >>> with set_gym_backend("gym"): + ... wrappers = gym_backend('wrappers') + ... print(wrappers) + + >>> with set_gym_backend("gymnasium"): + ... wrappers = gym_backend('wrappers') + ... print(wrappers) + + +Another tool that comes in handy with gym and other external dependencies is +the :class:`torchrl._utils.implement_for` class. Decorating a function +with ``@implement_for`` will tell torchrl that, depending on the version +indicated, a specific behavior is to be expected. This allows us to easily +support multiple versions of gym without requiring any effort from the user side. +For example, considering that our virtual environment has the v0.26.2 installed, +the following function will return ``1`` when queried: + + >>> from torchrl._utils import implement_for + >>> @implement_for("gym", None, "0.26.0") + ... def fun(): + ... return 0 + >>> @implement_for("gym", "0.26.0", None) + ... def fun(): + ... return 1 + >>> fun() + 1 + +Available wrappers +------------------ + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + BraxEnv + BraxWrapper + DMControlEnv + DMControlWrapper + GymEnv + GymWrapper + HabitatEnv + IsaacGymEnv + IsaacGymWrapper + IsaacLabWrapper + JumanjiEnv + JumanjiWrapper + MeltingpotEnv + MeltingpotWrapper + MOGymEnv + MOGymWrapper + MultiThreadedEnv + MultiThreadedEnvWrapper + OpenMLEnv + OpenSpielWrapper + OpenSpielEnv + PettingZooEnv + PettingZooWrapper + RoboHiveEnv + SMACv2Env + SMACv2Wrapper + UnityMLAgentsEnv + UnityMLAgentsWrapper + VmasEnv + VmasWrapper + gym_backend + set_gym_backend + register_gym_spec_conversion + +Auto-resetting Environments +--------------------------- + +.. _autoresetting_envs: + +Auto-resetting environments are environments where calls to :meth:`~torchrl.envs.EnvBase.reset` are not expected when +the environment reaches a ``"done"`` state during a rollout, as the reset happens automatically. +Usually, in such cases the observations delivered with the done and reward (which effectively result from performing the +action in the environment) are actually the first observations of a new episode, and not the last observations of the +current episode. + +To handle these cases, torchrl provides a :class:`~torchrl.envs.AutoResetTransform` that will copy the observations +that result from the call to `step` to the next `reset` and skip the calls to `reset` during rollouts (in both +:meth:`~torchrl.envs.EnvBase.rollout` and :class:`~torchrl.collectors.SyncDataCollector` iterations). +This transform class also provides a fine-grained control over the behavior to be adopted for the invalid observations, +which can be masked with `"nan"` or any other values, or not masked at all. + +To tell torchrl that an environment is auto-resetting, it is sufficient to provide an ``auto_reset`` argument +during construction. If provided, an ``auto_reset_replace`` argument can also control whether the values of the last +observation of an episode should be replaced with some placeholder or not. + + >>> from torchrl.envs import GymEnv + >>> from torchrl.envs import set_gym_backend + >>> import torch + >>> torch.manual_seed(0) + >>> + >>> class AutoResettingGymEnv(GymEnv): + ... def _step(self, tensordict): + ... tensordict = super()._step(tensordict) + ... if tensordict["done"].any(): + ... td_reset = super().reset() + ... tensordict.update(td_reset.exclude(*self.done_keys)) + ... return tensordict + ... + ... def _reset(self, tensordict=None): + ... if tensordict is not None and "_reset" in tensordict: + ... return tensordict.copy() + ... return super()._reset(tensordict) + >>> + >>> with set_gym_backend("gym"): + ... env = AutoResettingGymEnv("CartPole-v1", auto_reset=True, auto_reset_replace=True) + ... env.set_seed(0) + ... r = env.rollout(30, break_when_any_done=False) + >>> print(r["next", "done"].squeeze()) + tensor([False, False, False, False, False, False, False, False, False, False, + False, False, False, True, False, False, False, False, False, False, + False, False, False, False, False, True, False, False, False, False]) + +Dynamic Specs +------------- + +.. _dynamic_envs: + +Running environments in parallel is usually done via the creation of memory buffers used to pass information from one +process to another. In some cases, it may be impossible to forecast whether an environment will or will not have +consistent inputs or outputs during a rollout, as their shape may be variable. We refer to this as dynamic specs. + +TorchRL is capable of handling dynamic specs, but the batched environments and collectors will need to be made +aware of this feature. Note that, in practice, this is detected automatically. + +To indicate that a tensor will have a variable size along a dimension, one can set the size value as ``-1`` for the +desired dimensions. Because the data cannot be stacked contiguously, calls to ``env.rollout`` need to be made with +the ``return_contiguous=False`` argument. +Here is a working example: + + >>> from torchrl.envs import EnvBase + >>> from torchrl.data import Unbounded, Composite, Bounded, Binary + >>> import torch + >>> from tensordict import TensorDict, TensorDictBase + >>> + >>> class EnvWithDynamicSpec(EnvBase): + ... def __init__(self, max_count=5): + ... super().__init__(batch_size=()) + ... self.observation_spec = Composite( + ... observation=Unbounded(shape=(3, -1, 2)), + ... ) + ... self.action_spec = Bounded(low=-1, high=1, shape=(2,)) + ... self.full_done_spec = Composite( + ... done=Binary(1, shape=(1,), dtype=torch.bool), + ... terminated=Binary(1, shape=(1,), dtype=torch.bool), + ... truncated=Binary(1, shape=(1,), dtype=torch.bool), + ... ) + ... self.reward_spec = Unbounded((1,), dtype=torch.float) + ... self.count = 0 + ... self.max_count = max_count + ... + ... def _reset(self, tensordict=None): + ... self.count = 0 + ... data = TensorDict( + ... { + ... "observation": torch.full( + ... (3, self.count + 1, 2), + ... self.count, + ... dtype=self.observation_spec["observation"].dtype, + ... ) + ... } + ... ) + ... data.update(self.done_spec.zero()) + ... return data + ... + ... def _step( + ... self, + ... tensordict: TensorDictBase, + ... ) -> TensorDictBase: + ... self.count += 1 + ... done = self.count >= self.max_count + ... observation = TensorDict( + ... { + ... "observation": torch.full( + ... (3, self.count + 1, 2), + ... self.count, + ... dtype=self.observation_spec["observation"].dtype, + ... ) + ... } + ... ) + ... done = self.full_done_spec.zero() | done + ... reward = self.full_reward_spec.zero() + ... return observation.update(done).update(reward) + ... + ... def _set_seed(self, seed: Optional[int]) -> None: + ... self.manual_seed = seed + ... return seed + >>> env = EnvWithDynamicSpec() + >>> print(env.rollout(5, return_contiguous=False)) + LazyStackedTensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: LazyStackedTensorDict( + fields={ + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([5, 3, -1, 2]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([5]), + device=None, + is_shared=False, + stack_dim=0), + observation: Tensor(shape=torch.Size([5, 3, -1, 2]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([5]), + device=None, + is_shared=False, + stack_dim=0) + +.. warning:: + The absence of memory buffers in :class:`~torchrl.envs.ParallelEnv` and in + data collectors can impact performance of these classes dramatically. Any + such usage should be carefully benchmarked against a plain execution on a + single process, as serializing and deserializing large numbers of tensors + can be very expensive. + +Currently, :func:`~torchrl.envs.utils.check_env_specs` will pass for dynamic specs where a shape varies along some +dimensions, but not when a key is present during a step and absent during others, or when the number of dimensions +varies. diff --git a/docs/source/reference/envs_multiagent.rst b/docs/source/reference/envs_multiagent.rst new file mode 100644 index 00000000000..13f0a7cb9ca --- /dev/null +++ b/docs/source/reference/envs_multiagent.rst @@ -0,0 +1,138 @@ +.. currentmodule:: torchrl.envs + +.. _MARL-environment-API: + +Multi-agent Environments +======================== + +TorchRL supports multi-agent learning out-of-the-box. +*The same classes used in a single-agent learning pipeline can be seamlessly used in multi-agent contexts, +without any modification or dedicated multi-agent infrastructure.* + +In this view, environments play a core role for multi-agent. In multi-agent environments, +many decision-making agents act in a shared world. +Agents can observe different things, act in different ways and also be rewarded differently. +Therefore, many paradigms exist to model multi-agent environments (DecPODPs, Markov Games). +Some of the main differences between these paradigms include: + +- **observation** can be per-agent and also have some shared components +- **reward** can be per-agent or shared +- **done** (and ``"truncated"`` or ``"terminated"``) can be per-agent or shared. + +TorchRL accommodates all these possible paradigms thanks to its :class:`tensordict.TensorDict` data carrier. +In particular, in multi-agent environments, per-agent keys will be carried in a nested "agents" TensorDict. +This TensorDict will have the additional agent dimension and thus group data that is different for each agent. +The shared keys, on the other hand, will be kept in the first level, as in single-agent cases. + +Let's look at an example to understand this better. For this example we are going to use +`VMAS `_, a multi-robot task simulator also +based on PyTorch, which runs parallel batched simulation on device. + +We can create a VMAS environment and look at what the output from a random step looks like: + +.. code-block:: + :caption: Example of multi-agent step tensordict + + >>> from torchrl.envs.libs.vmas import VmasEnv + >>> env = VmasEnv("balance", num_envs=3, n_agents=5) + >>> td = env.rand_step() + >>> td + TensorDict( + fields={ + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 5, 2]))}, + batch_size=torch.Size([3, 5])), + next: TensorDict( + fields={ + agents: TensorDict( + fields={ + info: TensorDict( + fields={ + ground_rew: Tensor(shape=torch.Size([3, 5, 1])), + pos_rew: Tensor(shape=torch.Size([3, 5, 1]))}, + batch_size=torch.Size([3, 5])), + observation: Tensor(shape=torch.Size([3, 5, 16])), + reward: Tensor(shape=torch.Size([3, 5, 1]))}, + batch_size=torch.Size([3, 5])), + done: Tensor(shape=torch.Size([3, 1]))}, + batch_size=torch.Size([3]))}, + batch_size=torch.Size([3])) + +We can observe that *keys that are shared by all agents*, such as **done** are present in the root tensordict with +batch size `(num_envs,)`, which represents the number of environments simulated. + +On the other hand, *keys that are different between agents*, such as **action**, **reward**, **observation**, +and **info** are present in the nested "agents" tensordict with batch size `(num_envs, n_agents)`, +which represents the additional agent dimension. + +Multi-agent tensor specs will follow the same style as in tensordicts. +Specs relating to values that vary between agents will need to be nested in the "agents" entry. + +Here is an example of how specs can be created in a multi-agent environment where +only the done flag is shared across agents (as in VMAS): + +.. code-block:: + :caption: Example of multi-agent spec creation + + >>> action_specs = [] + >>> observation_specs = [] + >>> reward_specs = [] + >>> info_specs = [] + >>> for i in range(env.n_agents): + ... action_specs.append(agent_i_action_spec) + ... reward_specs.append(agent_i_reward_spec) + ... observation_specs.append(agent_i_observation_spec) + >>> env.action_spec = Composite( + ... { + ... "agents": Composite( + ... {"action": torch.stack(action_specs)}, shape=(env.n_agents,) + ... ) + ... } + ...) + >>> env.reward_spec = Composite( + ... { + ... "agents": Composite( + ... {"reward": torch.stack(reward_specs)}, shape=(env.n_agents,) + ... ) + ... } + ...) + >>> env.observation_spec = Composite( + ... { + ... "agents": Composite( + ... {"observation": torch.stack(observation_specs)}, shape=(env.n_agents,) + ... ) + ... } + ...) + >>> env.done_spec = Categorical( + ... n=2, + ... shape=torch.Size((1,)), + ... dtype=torch.bool, + ... ) + +As you can see, it is very simple! Per-agent keys will have the nested composite spec and shared keys will follow +single agent standards. + +.. note:: + Since reward, done and action keys may have the additional "agent" prefix (e.g., `("agents","action")`), + the default keys used in the arguments of other TorchRL components (e.g. "action") will not match exactly. + Therefore, TorchRL provides the `env.action_key`, `env.reward_key`, and `env.done_key` attributes, + which will automatically point to the right key to use. Make sure you pass these attributes to the various + components in TorchRL to inform them of the right key (e.g., the `loss.set_keys()` function). + +.. note:: + TorchRL abstracts these nested specs away for ease of use. + This means that accessing `env.reward_spec` will always return the leaf + spec if the accessed spec is Composite. Therefore, if in the example above + we run `env.reward_spec` after env creation, we would get the same output as `torch.stack(reward_specs)}`. + To get the full composite spec with the "agents" key, you can run + `env.output_spec["full_reward_spec"]`. The same is valid for action and done specs. + Note that `env.reward_spec == env.output_spec["full_reward_spec"][env.reward_key]`. + + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + MarlGroupMapType + check_marl_grouping diff --git a/docs/source/reference/envs_recorders.rst b/docs/source/reference/envs_recorders.rst new file mode 100644 index 00000000000..143a8f19be2 --- /dev/null +++ b/docs/source/reference/envs_recorders.rst @@ -0,0 +1,83 @@ +.. currentmodule:: torchrl.record + +.. _Environment-Recorders: + +Recorders +========= + +Recording data during environment rollout execution is crucial to keep an eye on the algorithm performance as well as +reporting results after training. + +TorchRL offers several tools to interact with the environment output: first and foremost, a ``callback`` callable +can be passed to the :meth:`~torchrl.envs.EnvBase.rollout` method. This function will be called upon the collected +tensordict at each iteration of the rollout (if some iterations have to be skipped, an internal variable should be added +to keep track of the call count within ``callback``). + +To save collected tensordicts on disk, the :class:`~torchrl.record.TensorDictRecorder` can be used. + +Recording videos +---------------- + +Several backends offer the possibility of recording rendered images from the environment. +If the pixels are already part of the environment output (e.g. Atari or other game simulators), a +:class:`~torchrl.record.VideoRecorder` can be appended to the environment. This environment transform takes as input +a logger capable of recording videos (e.g. :class:`~torchrl.record.loggers.CSVLogger`, :class:`~torchrl.record.loggers.WandbLogger` +or :class:`~torchrl.record.loggers.TensorBoardLogger`) as well as a tag indicating where the video should be saved. +For instance, to save mp4 videos on disk, one can use :class:`~torchrl.record.loggers.CSVLogger` with a `video_format="mp4"` +argument. + +The :class:`~torchrl.record.VideoRecorder` transform can handle batched images and automatically detects numpy or PyTorch +formatted images (WHC or CWH). + + >>> logger = CSVLogger("dummy-exp", video_format="mp4") + >>> env = GymEnv("ALE/Pong-v5") + >>> env = env.append_transform(VideoRecorder(logger, tag="rendered", in_keys=["pixels"])) + >>> env.rollout(10) + >>> env.transform.dump() # Save the video and clear cache + +Note that the cache of the transform will keep on growing until dump is called. It is the user responsibility to +take care of calling `dump` when needed to avoid OOM issues. + +In some cases, creating a testing environment where images can be collected is tedious or expensive, or simply impossible +(some libraries only allow one environment instance per workspace). +In these cases, assuming that a `render` method is available in the environment, the :class:`~torchrl.record.PixelRenderTransform` +can be used to call `render` on the parent environment and save the images in the rollout data stream. +This class works over single and batched environments alike: + + >>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator + >>> from torchrl.record.loggers import CSVLogger + >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder + >>> + >>> def make_env(): + >>> env = GymEnv("CartPole-v1", render_mode="rgb_array") + >>> # Uncomment this line to execute per-env + >>> # env = env.append_transform(PixelRenderTransform()) + >>> return env + >>> + >>> if __name__ == "__main__": + ... logger = CSVLogger("dummy", video_format="mp4") + ... + ... env = ParallelEnv(16, EnvCreator(make_env)) + ... env.start() + ... # Comment this line to execute per-env + ... env = env.append_transform(PixelRenderTransform()) + ... + ... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) + ... env.rollout(3) + ... + ... check_env_specs(env) + ... + ... r = env.rollout(30) + ... env.transform.dump() + ... env.close() + + +Recorders are transforms that register data as they come in, for logging purposes. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + TensorDictRecorder + VideoRecorder + PixelRenderTransform diff --git a/docs/source/reference/envs_transforms.rst b/docs/source/reference/envs_transforms.rst new file mode 100644 index 00000000000..d68d3e115f4 --- /dev/null +++ b/docs/source/reference/envs_transforms.rst @@ -0,0 +1,359 @@ +.. currentmodule:: torchrl.envs.transforms + +.. _transforms: + +Transforms +========== + +In most cases, the raw output of an environment must be treated before being passed to another object (such as a +policy or a value operator). To do this, TorchRL provides a set of transforms that aim at reproducing the transform +logic of `torch.distributions.Transform` and `torchvision.transforms`. +Our environment :ref:`tutorial ` +provides more information on how to design a custom transform. + +Transformed environments are build using the :class:`TransformedEnv` primitive. +Composed transforms are built using the :class:`Compose` class: + +.. code-block:: + :caption: Transformed environment + + >>> base_env = GymEnv("Pendulum-v1", from_pixels=True, device="cuda:0") + >>> transform = Compose(ToTensorImage(in_keys=["pixels"]), Resize(64, 64, in_keys=["pixels"])) + >>> env = TransformedEnv(base_env, transform) + +Transforms are usually subclasses of :class:`~torchrl.envs.transforms.Transform`, although any +``Callable[[TensorDictBase], TensorDictBase]``. + +By default, the transformed environment will inherit the device of the +``base_env`` that is passed to it. The transforms will then be executed on that device. +It is now apparent that this can bring a significant speedup depending on the kind of +operations that is to be computed. + +A great advantage of environment wrappers is that one can consult the environment up to that wrapper. +The same can be achieved with TorchRL transformed environments: the ``parent`` attribute will +return a new :class:`TransformedEnv` with all the transforms up to the transform of interest. +Re-using the example above: + +.. code-block:: + :caption: Transform parent + + >>> resize_parent = env.transform[-1].parent # returns the same as TransformedEnv(base_env, transform[:-1]) + + +Transformed environment can be used with vectorized environments. +Since each transform uses a ``"in_keys"``/``"out_keys"`` set of keyword argument, it is +also easy to root the transform graph to each component of the observation data (e.g. +pixels or states etc). + +Forward and inverse transforms +------------------------------ + +Transforms also have an :meth:`~torchrl.envs.Transform.inv` method that is called before the action is applied in reverse +order over the composed transform chain. This allows applying transforms to data in the environment before the action is +taken in the environment. The keys to be included in this inverse transform are passed through the `"in_keys_inv"` +keyword argument, and the out-keys default to these values in most cases: + +.. code-block:: + :caption: Inverse transform + + >>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step + +The following paragraphs detail how one can think about what is to be considered `in_` or `out_` features. + +Understanding Transform Keys +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In transforms, `in_keys` and `out_keys` define the interaction between the base environment and the outside world +(e.g., your policy): + +- `in_keys` refers to the base environment's perspective (inner = `base_env` of the + :class:`~torchrl.envs.TransformedEnv`). +- `out_keys` refers to the outside world (outer = `policy`, `agent`, etc.). + +For example, with `in_keys=["obs"]` and `out_keys=["obs_standardized"]`, the policy will "see" a standardized +observation, while the base environment outputs a regular observation. + +Similarly, for inverse keys: + +- `in_keys_inv` refers to entries as seen by the base environment. +- `out_keys_inv` refers to entries as seen or produced by the policy. + +The following figure illustrates this concept for the :class:`~torchrl.envs.RenameTransform` class: the input +`TensorDict` of the `step` function must include the `out_keys_inv` as they are part of the outside world. The +transform changes these names to match the names of the inner, base environment using the `in_keys_inv`. +The inverse process is executed with the output tensordict, where the `in_keys` are mapped to the corresponding +`out_keys`. + +.. figure:: /_static/img/rename_transform.png + + Rename transform logic + +.. note:: During a call to `inv`, the transforms are executed in reversed order (compared to the forward / step mode). + +Transforming Tensors and Specs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When transforming actual tensors (coming from the policy), the process is schematically represented as: + + >>> for t in reversed(self.transform): + ... td = t.inv(td) + +This starts with the outermost transform to the innermost transform, ensuring the action value exposed to the policy +is properly transformed. + +For transforming the action spec, the process should go from innermost to outermost (similar to observation specs): + + >>> def transform_action_spec(self, action_spec): + ... for t in self.transform: + ... action_spec = t.transform_action_spec(action_spec) + ... return action_spec + +A pseudocode for a single transform_action_spec could be: + + >>> def transform_action_spec(self, action_spec): + ... return spec_from_random_values(self._apply_transform(action_spec.rand())) + +This approach ensures that the "outside" spec is inferred from the "inside" spec. Note that we did not call +`_inv_apply_transform` but `_apply_transform` on purpose! + +Exposing Specs to the Outside World +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +`TransformedEnv` will expose the specs corresponding to the `out_keys_inv` for actions and states. +For example, with :class:`~torchrl.envs.ActionDiscretizer`, the environment's action (e.g., `"action"`) is a float-valued +tensor that should not be generated when using :meth:`~torchrl.envs.EnvBase.rand_action` with the transformed +environment. Instead, `"action_discrete"` should be generated, and its continuous counterpart obtained from the +transform. Therefore, the user should see the `"action_discrete"` entry being exposed, but not `"action"`. + +Designing your own Transform +---------------------------- + +To create a basic, custom transform, you need to subclass the `Transform` class and implement the +:meth:`~torchrl.envs._apply_transform` method. Here's an example of a simple transform that adds 1 to the observation +tensor: + + >>> class AddOneToObs(Transform): + ... """A transform that adds 1 to the observation tensor.""" + ... + ... def __init__(self): + ... super().__init__(in_keys=["observation"], out_keys=["observation"]) + ... + ... def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor: + ... return obs + 1 + + +Tips for subclassing `Transform` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +There are various ways of subclassing a transform. The things to take into considerations are: + +- Is the transform identical for each tensor / item being transformed? Use + :meth:`~torchrl.envs.Transform._apply_transform` and :meth:`~torchrl.envs.Transform._inv_apply_transform`. +- The transform needs access to the input data to env.step as well as output? Rewrite + :meth:`~torchrl.envs.Transform._step`. + Otherwise, rewrite :meth:`~torchrl.envs.Transform._call` (or :meth:`~torchrl.envs.Transform._inv_call`). +- Is the transform to be used within a replay buffer? Overwrite :meth:`~torchrl.envs.Transform.forward`, + :meth:`~torchrl.envs.Transform.inv`, :meth:`~torchrl.envs.Transform._apply_transform` or + :meth:`~torchrl.envs.Transform._inv_apply_transform`. +- Within a transform, you can access (and make calls to) the parent environment using + :attr:`~torchrl.envs.Transform.parent` (the base env + all transforms till this one) or + :meth:`~torchrl.envs.Transform.container` (The object that encapsulates the transform). +- Don't forget to edits the specs if needed: top level: :meth:`~torchrl.envs.Transform.transform_output_spec`, + :meth:`~torchrl.envs.Transform.transform_input_spec`. + Leaf level: :meth:`~torchrl.envs.Transform.transform_observation_spec`, + :meth:`~torchrl.envs.Transform.transform_action_spec`, :meth:`~torchrl.envs.Transform.transform_state_spec`, + :meth:`~torchrl.envs.Transform.transform_reward_spec` and + :meth:`~torchrl.envs.Transform.transform_reward_spec`. + +For practical examples, see the methods listed above. + +You can use a transform in an environment by passing it to the TransformedEnv constructor: + + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), AddOneToObs()) + +You can compose multiple transforms together using the Compose class: + + >>> transform = Compose(AddOneToObs(), RewardSum()) + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), transform) + +Inverse Transforms +~~~~~~~~~~~~~~~~~~ + +Some transforms have an inverse transform that can be used to undo the transformation. For example, the AddOneToAction +transform has an inverse transform that subtracts 1 from the action tensor: + + >>> class AddOneToAction(Transform): + ... """A transform that adds 1 to the action tensor.""" + ... def __init__(self): + ... super().__init__(in_keys=[], out_keys=[], in_keys_inv=["action"], out_keys_inv=["action"]) + ... def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor: + ... return action + 1 + +Using a Transform with a Replay Buffer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can use a transform with a replay buffer by passing it to the ReplayBuffer constructor: + +Cloning transforms +~~~~~~~~~~~~~~~~~~ + +Because transforms appended to an environment are "registered" to this environment +through the ``transform.parent`` property, when manipulating transforms we should keep +in mind that the parent may come and go following what is being done with the transform. +Here are some examples: if we get a single transform from a :class:`Compose` object, +this transform will keep its parent: + + >>> third_transform = env.transform[2] + >>> assert third_transform.parent is not None + +This means that using this transform for another environment is prohibited, as +the other environment would replace the parent and this may lead to unexpected +behviours. Fortunately, the :class:`Transform` class comes with a :func:`clone` +method that will erase the parent while keeping the identity of all the +registered buffers: + + >>> TransformedEnv(base_env, third_transform) # raises an Exception as third_transform already has a parent + >>> TransformedEnv(base_env, third_transform.clone()) # works + +On a single process or if the buffers are placed in shared memory, this will +result in all the clone transforms to keep the same behavior even if the +buffers are changed in place (which is what will happen with the :class:`CatFrames` +transform, for instance). In distributed settings, this may not hold and one +should be careful about the expected behavior of the cloned transforms in this +context. +Finally, notice that indexing multiple transforms from a :class:`Compose` transform +may also result in loss of parenthood for these transforms: the reason is that +indexing a :class:`Compose` transform results in another :class:`Compose` transform +that does not have a parent environment. Hence, we have to clone the sub-transforms +to be able to create this other composition: + + >>> env = TransformedEnv(base_env, Compose(transform1, transform2, transform3)) + >>> last_two = env.transform[-2:] + >>> assert isinstance(last_two, Compose) + >>> assert last_two.parent is None + >>> assert last_two[0] is not transform2 + >>> assert isinstance(last_two[0], type(transform2)) # and the buffers will match + >>> assert last_two[1] is not transform3 + >>> assert isinstance(last_two[1], type(transform3)) # and the buffers will match + +Available Transforms +-------------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + Transform + TransformedEnv + ActionDiscretizer + ActionMask + AutoResetEnv + AutoResetTransform + BatchSizeTransform + BinarizeReward + BurnInTransform + CatFrames + CatTensors + CenterCrop + ClipTransform + Compose + ConditionalPolicySwitch + ConditionalSkip + Crop + DTypeCastTransform + DeviceCastTransform + DiscreteActionProjection + DoubleToFloat + EndOfLifeTransform + ExcludeTransform + FiniteTensorDictCheck + FlattenObservation + FrameSkipTransform + GrayScale + Hash + InitTracker + KLRewardTransform + LineariseRewards + ModuleTransform + MultiAction + NoopResetEnv + ObservationNorm + ObservationTransform + PermuteTransform + PinMemoryTransform + R3MTransform + RandomCropTensorDict + RemoveEmptySpecs + RenameTransform + Resize + Reward2GoTransform + RewardClipping + RewardScaling + RewardSum + SelectTransform + SignTransform + SqueezeTransform + Stack + StepCounter + TargetReturn + TensorDictPrimer + TimeMaxPool + Timer + Tokenizer + ToTensorImage + TrajCounter + UnaryTransform + UnsqueezeTransform + VC1Transform + VIPRewardTransform + VIPTransform + VecGymEnvTransform + VecNorm + VecNormV2 + gSDENoise + +Environments with masked actions +-------------------------------- + +In some environments with discrete actions, the actions available to the agent might change throughout execution. +In such cases the environments will output an action mask (under the ``"action_mask"`` key by default). +This mask needs to be used to filter out unavailable actions for that step. + +If you are using a custom policy you can pass this mask to your probability distribution like so: + +.. code-block:: + :caption: Categorical policy with action mask + + >>> from tensordict.nn import TensorDictModule, ProbabilisticTensorDictModule, TensorDictSequential + >>> import torch.nn as nn + >>> from torchrl.modules import MaskedCategorical + >>> module = TensorDictModule( + >>> nn.Linear(in_feats, out_feats), + >>> in_keys=["observation"], + >>> out_keys=["logits"], + >>> ) + >>> dist = ProbabilisticTensorDictModule( + >>> in_keys={"logits": "logits", "mask": "action_mask"}, + >>> out_keys=["action"], + >>> distribution_class=MaskedCategorical, + >>> ) + >>> actor = TensorDictSequential(module, dist) + +If you want to use a default policy, you will need to wrap your environment in the :class:`~torchrl.envs.transforms.ActionMask` +transform. This transform can take care of updating the action mask in the action spec in order for the default policy +to always know what the latest available actions are. You can do this like so: + +.. code-block:: + :caption: How to use the action mask transform + + >>> from tensordict.nn import TensorDictModule, ProbabilisticTensorDictModule, TensorDictSequential + >>> import torch.nn as nn + >>> from torchrl.envs.transforms import TransformedEnv, ActionMask + >>> env = TransformedEnv( + >>> your_base_env + >>> ActionMask(action_key="action", mask_key="action_mask"), + >>> ) + +.. note:: + In case you are using a parallel environment it is important to add the transform to the parallel environment itself + and not to its sub-environments. diff --git a/docs/source/reference/envs_vectorized.rst b/docs/source/reference/envs_vectorized.rst new file mode 100644 index 00000000000..29a7d67e7a2 --- /dev/null +++ b/docs/source/reference/envs_vectorized.rst @@ -0,0 +1,351 @@ +.. currentmodule:: torchrl.envs + +Vectorized and Parallel Environments +==================================== + +Vectorized (or better: parallel) environments is a common feature in Reinforcement Learning +where executing the environment step can be cpu-intensive. +Some libraries such as `gym3 `_ or `EnvPool `_ +offer interfaces to execute batches of environments simultaneously. +While they often offer a very competitive computational advantage, they do not +necessarily scale to the wide variety of environment libraries supported by TorchRL. +Therefore, TorchRL offers its own, generic :class:`ParallelEnv` class to run multiple +environments in parallel. +As this class inherits from :class:`SerialEnv`, it enjoys the exact same API as other environment. +Of course, a :class:`ParallelEnv` will have a batch size that corresponds to its environment count: + +.. note:: + Given the library's many optional dependencies (eg, Gym, Gymnasium, and many others) + warnings can quickly become quite annoying in multiprocessed / distributed settings. + By default, TorchRL filters out these warnings in sub-processes. If one still wishes to + see these warnings, they can be displayed by setting ``torchrl.filter_warnings_subprocess=False``. + +It is important that your environment specs match the input and output that it sends and receives, as +:class:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes. +Check the :func:`~torchrl.envs.utils.check_env_specs` method for a sanity check. + +.. code-block:: + :caption: Parallel environment + + >>> def make_env(): + ... return GymEnv("Pendulum-v1", from_pixels=True, g=9.81, device="cuda:0") + >>> check_env_specs(env) # this must pass for ParallelEnv to work + >>> env = ParallelEnv(4, make_env) + >>> print(env.batch_size) + torch.Size([4]) + +:class:`ParallelEnv` allows to retrieve the attributes from its contained environments: +one can simply call: + +.. code-block:: + :caption: Parallel environment attributes + + >>> a, b, c, d = env.g # gets the g-force of the various envs, which we set to 9.81 before + >>> print(a) + 9.81 + +.. note:: + + *A note on performance*: launching a :class:`~.ParallelEnv` can take quite some time + as it requires to launch as many python instances as there are processes. Due to + the time that it takes to run ``import torch`` (and other imports), starting the + parallel env can be a bottleneck. This is why, for instance, TorchRL tests are so slow. + Once the environment is launched, a great speedup should be observed. + +.. note:: + + *TorchRL requires precise specs*: Another thing to take in consideration is + that :class:`ParallelEnv` (as well as data collectors) + will create data buffers based on the environment specs to pass data from one process + to another. This means that a misspecified spec (input, observation or reward) will + cause a breakage at runtime as the data can't be written on the preallocated buffer. + In general, an environment should be tested using the :func:`~.utils.check_env_specs` + test function before being used in a :class:`ParallelEnv`. This function will raise + an assertion error whenever the preallocated buffer and the collected data mismatch. + +We also offer the :class:`~.SerialEnv` class that enjoys the exact same API but is executed +serially. This is mostly useful for testing purposes, when one wants to assess the +behavior of a :class:`~.ParallelEnv` without launching the subprocesses. + +In addition to :class:`~.ParallelEnv`, which offers process-based parallelism, we also provide a way to create +multithreaded environments with :class:`~.MultiThreadedEnv`. This class uses `EnvPool `_ +library underneath, which allows for higher performance, but at the same time restricts flexibility - one can only +create environments implemented in ``EnvPool``. This covers many popular RL environments types (Atari, Classic Control, +etc.), but one can not use an arbitrary TorchRL environment, as it is possible with :class:`~.ParallelEnv`. Run +`benchmarks/benchmark_batched_envs.py` to compare performance of different ways to parallelize batched environments. + +Vectorized environment classes +------------------------------ + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + SerialEnv + ParallelEnv + EnvCreator + +Partial steps and partial resets +-------------------------------- + +TorchRL allows environments to reset some but not all the environments, or run a step in one but not all environments. +If there is only one environment in the batch, then a partial reset / step is also allowed with the behavior detailed +below. + +Batching environments and locking the batch +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. _ref_batch_locked: + +Before detailing what partial resets and partial steps do, we must distinguish cases where an environment has +a batch size of its own (mostly stateful environments) or when the environment is just a mere module that, given an +input of arbitrary size, batches the operations over all elements (mostly stateless environments). + +This is controlled via the :attr:`~torchrl.envs.batch_locked` attribute: a batch-locked environment requires all input +tensordicts to have the same batch-size as the env's. Typical examples of these environments are +:class:`~torchrl.envs.GymEnv` and related. Batch-unlocked envs are by contrast allowed to work with any input size. +Notable examples are :class:`~torchrl.envs.BraxEnv` or :class:`~torchrl.envs.JumanjiEnv`. + +Executing partial steps in a batch-unlocked environment is straightforward: one just needs to mask the part of the +tensordict that does not need to be executed, pass the other part to `step` and merge the results with the previous +input. + +Batched environments (:class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv`) can also deal with +partial steps easily, they just pass the actions to the sub-environments that are required to be executed. + +In all other cases, TorchRL assumes that the environment handles the partial steps correctly. + +.. warning:: This means that custom environments may silently run the non-required steps as there is no way for torchrl + to control what happens within the `_step` method! + +Partial Steps +~~~~~~~~~~~~~ + +.. _ref_partial_steps: + +Partial steps are controlled via the temporary key `"_step"` which points to a boolean mask of the +size of the tensordict that holds it. The classes armed to deal with this are: + +- Batched environments: :class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv` will dispatch the + action to and only to the environments where `"_step"` is `True`; +- Batch-unlocked environments; +- Unbatched environments (i.e., environments without batch size). In these environments, the :meth:`~torchrl.envs.EnvBase.step` + method will first look for a `"_step"` entry and, if present, act accordingly. + If a :class:`~torchrl.envs.Transform` instance passes a `"_step"` entry to the tensordict, it is also captured by + :class:`~torchrl.envs.TransformedEnv`'s own `_step` method which will skip the `base_env.step` as well as any further + transformation. + +When dealing with partial steps, the strategy is always to use the step output and mask missing values with the previous +content of the input tensordict, if present, or a `0`-valued tensor if the tensor cannot be found. This means that +if the input tensordict does not contain all the previous observations, then the output tensordict will be 0-valued for +all the non-stepped elements. Within batched environments, data collectors and rollouts utils, this is an issue that +is not observed because these classes handle the passing of data properly. + +Partial steps are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_all_done` is `True`, +as the environments with a `True` done state will need to be skipped during calls to `_step`. + +The :class:`~torchrl.envs.ConditionalSkip` transform allows you to programmatically ask for (partial) step skips. + +Partial Resets +~~~~~~~~~~~~~~ + +.. _ref_partial_resets: + +Partial resets work pretty much like partial steps, but with the `"_reset"` entry. + +The same restrictions of partial steps apply to partial resets. + +Likewise, partial resets are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_any_done` is `True`, +as the environments with a `True` done state will need to be reset, but not others. + +See te following paragraph for a deep dive in partial resets within batched and vectorized environments. + +Partial resets in detail +~~~~~~~~~~~~~~~~~~~~~~~~ + +TorchRL uses a private ``"_reset"`` key to indicate to the environment which +component (sub-environments or agents) should be reset. +This allows to reset some but not all of the components. + +The ``"_reset"`` key has two distinct functionalities: + +1. During a call to :meth:`~.EnvBase._reset`, the ``"_reset"`` key may or may + not be present in the input tensordict. TorchRL's convention is that the + absence of the ``"_reset"`` key at a given ``"done"`` level indicates + a total reset of that level (unless a ``"_reset"`` key was found at a level + above, see details below). + If it is present, it is expected that those entries and only those components + where the ``"_reset"`` entry is ``True`` (along key and shape dimension) will be reset. + + The way an environment deals with the ``"_reset"`` keys in its :meth:`~.EnvBase._reset` + method is proper to its class. + Designing an environment that behaves according to ``"_reset"`` inputs is the + developer's responsibility, as TorchRL has no control over the inner logic + of :meth:`~.EnvBase._reset`. Nevertheless, the following point should be + kept in mind when designing that method. + +2. After a call to :meth:`~.EnvBase._reset`, the output will be masked with the + ``"_reset"`` entries and the output of the previous :meth:`~.EnvBase.step` + will be written wherever the ``"_reset"`` was ``False``. In practice, this + means that if a ``"_reset"`` modifies data that isn't exposed by it, this + modification will be lost. After this masking operation, the ``"_reset"`` + entries will be erased from the :meth:`~.EnvBase.reset` outputs. + +It must be pointed out that ``"_reset"`` is a private key, and it should only be +used when coding specific environment features that are internal facing. +In other words, this should NOT be used outside of the library, and developers +will keep the right to modify the logic of partial resets through ``"_reset"`` +setting without preliminary warranty, as long as they don't affect TorchRL +internal tests. + +Finally, the following assumptions are made and should be kept in mind when +designing reset functionalities: + +- Each ``"_reset"`` is paired with a ``"done"`` entry (+ ``"terminated"`` and, + possibly, ``"truncated"``). This means that the following structure is not + allowed: ``TensorDict({"done": done, "nested": {"_reset": reset}}, [])``, as + the ``"_reset"`` lives at a different nesting level than the ``"done"``. +- A reset at one level does not preclude the presence of a ``"_reset"`` at lower + levels, but it annihilates its effects. The reason is simply that + whether the ``"_reset"`` at the root level corresponds to an ``all()``, ``any()`` + or custom call to the nested ``"done"`` entries cannot be known in advance, + and it is explicitly assumed that the ``"_reset"`` at the root was placed + there to supersede the nested values (for an example, have a look at + :class:`~.PettingZooWrapper` implementation where each group has one or more + ``"done"`` entries associated which is aggregated at the root level with a + ``any`` or ``all`` logic depending on the task). +- When calling :meth:`env.reset(tensordict)` with a partial ``"_reset"`` entry + that will reset some but not all the done sub-environments, the input data + should contain the data of the sub-environments that are __not__ being reset. + The reason for this constrain lies in the fact that the output of the + ``env._reset(data)`` can only be predicted for the entries that are reset. + For the others, TorchRL cannot know in advance if they will be meaningful or + not. For instance, one could perfectly just pad the values of the non-reset + components, in which case the non-reset data will be meaningless and should + be discarded. + +Below, we give some examples of the expected effect that ``"_reset"`` keys will +have on an environment returning zeros after reset: + + >>> # single reset at the root + >>> data = TensorDict({"val": [1, 1], "_reset": [False, True]}, []) + >>> env.reset(data) + >>> print(data.get("val")) # only the second value is 0 + tensor([1, 0]) + >>> # nested resets + >>> data = TensorDict({ + ... ("agent0", "val"): [1, 1], ("agent0", "_reset"): [False, True], + ... ("agent1", "val"): [2, 2], ("agent1", "_reset"): [True, False], + ... }, []) + >>> env.reset(data) + >>> print(data.get(("agent0", "val"))) # only the second value is 0 + tensor([1, 0]) + >>> print(data.get(("agent1", "val"))) # only the first value is 0 + tensor([0, 2]) + >>> # nested resets are overridden by a "_reset" at the root + >>> data = TensorDict({ + ... "_reset": [True, True], + ... ("agent0", "val"): [1, 1], ("agent0", "_reset"): [False, True], + ... ("agent1", "val"): [2, 2], ("agent1", "_reset"): [True, False], + ... }, []) + >>> env.reset(data) + >>> print(data.get(("agent0", "val"))) # reset at the root overrides nested + tensor([0, 0]) + >>> print(data.get(("agent1", "val"))) # reset at the root overrides nested + tensor([0, 0]) + +.. code-block:: + :caption: Parallel environment reset + + >>> tensordict = TensorDict({"_reset": [[True], [False], [True], [True]]}, [4]) + >>> env.reset(tensordict) # eliminates the "_reset" entry + TensorDict( + fields={ + terminated: Tensor(torch.Size([4, 1]), dtype=torch.bool), + done: Tensor(torch.Size([4, 1]), dtype=torch.bool), + pixels: Tensor(torch.Size([4, 500, 500, 3]), dtype=torch.uint8), + truncated: Tensor(torch.Size([4, 1]), dtype=torch.bool), + batch_size=torch.Size([4]), + device=None, + is_shared=True) + +Async environments +------------------ + +Asynchronous environments allow for parallel execution of multiple environments, which can significantly speed up the +data collection process in reinforcement learning. + +The `AsyncEnvPool` class and its subclasses provide a flexible interface for managing these environments using different +backends, such as threading and multiprocessing. + +The `AsyncEnvPool` class serves as a base class for asynchronous environment pools, providing a common interface for +managing multiple environments concurrently. It supports different backends for parallel execution, such as threading +and multiprocessing, and provides methods for asynchronous stepping and resetting of environments. + +Contrary to :class:`~torchrl.envs.ParallelEnv`, :class:`~torchrl.envs.AsyncEnvPool` and its subclasses permit the +execution of a given set of sub-environments while another task performed, allowing for complex asynchronous jobs to be +run at the same time. For instance, it is possible to execute some environments while the policy is running based on +the output of others. + +This family of classes is particularly interesting when dealing with environments that have a high (and/or variable) +latency. + +.. note:: This class and its subclasses should work when nested in with :class:`~torchrl.envs.TransformedEnv` and + batched environments, but users won't currently be able to use the async features of the base environment when + it's nested in these classes. One should prefer nested transformed envs within an `AsyncEnvPool` instead. + If this is not possible, please raise an issue. + +Classes +~~~~~~~ + +- :class:`~torchrl.envs.AsyncEnvPool`: A base class for asynchronous environment pools. It determines the backend + implementation to use based on the provided arguments and manages the lifecycle of the environments. +- :class:`~torchrl.envs.ProcessorAsyncEnvPool`: An implementation of :class:`~torchrl.envs.AsyncEnvPool` using + multiprocessing for parallel execution of environments. This class manages a pool of environments, each running in + its own process, and provides methods for asynchronous stepping and resetting of environments using inter-process + communication. It is automatically instantiated when `"multiprocessing"` is passed as a backend during the + :class:`~torchrl.envs.AsyncEnvPool` instantiation. +- :class:`~torchrl.envs.ThreadingAsyncEnvPool`: An implementation of :class:`~torchrl.envs.AsyncEnvPool` using + threading for parallel execution of environments. This class manages a pool of environments, each running in its own + thread, and provides methods for asynchronous stepping and resetting of environments using a thread pool executor. + It is automatically instantiated when `"threading"` is passed as a backend during the + :class:`~torchrl.envs.AsyncEnvPool` instantiation. + +Example +~~~~~~~ + + >>> from functools import partial + >>> from torchrl.envs import AsyncEnvPool, GymEnv + >>> import torch + >>> # Choose backend + >>> backend = "threading" + >>> env = AsyncEnvPool( + >>> [partial(GymEnv, "Pendulum-v1"), partial(GymEnv, "CartPole-v1")], + >>> stack="lazy", + >>> backend=backend + >>> ) + >>> # Execute a synchronous reset + >>> reset = env.reset() + >>> print(reset) + >>> # Execute a synchronous step + >>> s = env.rand_step(reset) + >>> print(s) + >>> # Execute an asynchronous step in env 0 + >>> s0 = s[0] + >>> s0["action"] = torch.randn(1).clamp(-1, 1) + >>> s0["env_index"] = 0 + >>> env.async_step_send(s0) + >>> # Receive data + >>> s0_result = env.async_step_recv() + >>> print('result', s0_result) + >>> # Close env + >>> env.close() + + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + AsyncEnvPool + ProcessorAsyncEnvPool + ThreadingAsyncEnvPool diff --git a/docs/source/reference/llms.rst b/docs/source/reference/llms.rst index 7dd464fd7db..9f7f111c46d 100644 --- a/docs/source/reference/llms.rst +++ b/docs/source/reference/llms.rst @@ -6,770 +6,46 @@ LLM Interface .. _ref_llms: TorchRL provides a comprehensive framework for LLM post-training and fine-tuning. The LLM API is built around five core concepts that work -together to create a complete reinforcement learning pipeline for language models: +together to create a complete reinforcement learning pipeline for language models. -1. **Data Representation** (`Data Structures`_): The foundation for handling conversations, text parsing, and LLM - output classes. This includes the :class:`~torchrl.data.llm.History` class for managing conversation context and structured output classes for - tokens, log-probabilities, and text. +Key Components +-------------- -2. **LLM Wrapper API** (`Modules`_): Unified interfaces for different LLM backends, including :class:`~torchrl.modules.llm.TransformersWrapper` for - Hugging Face models, :class:`~torchrl.modules.llm.vLLMWrapper` for vLLM inference, and :class:`~torchrl.modules.llm.AsyncVLLM` for high-performance - distributed vLLM inference (recommended). These wrappers provide consistent input/output formats across different backends and an integrated - interface for loss computation, data storage, grading, weight synchronization, etc. +1. **Data Structures**: History class for conversation management, structured output classes +2. **LLM Wrappers**: Unified interfaces for Transformers, vLLM, and AsyncVLLM +3. **Environments**: ChatEnv, task-specific environments, and transforms +4. **Collectors**: LLMCollector and RayLLMCollector for data collection +5. **Objectives**: GRPOLoss, SFTLoss for training -3. **Environments** (`Environments`_): The orchestration layer that manages data loading, tool execution, reward computation, and formatting. This includes - :class:`~torchrl.envs.llm.ChatEnv` for conversation management, dataset environments, and various transforms for tool integration. - -4. **Objectives** (`Objectives`_): Specialized loss functions for LLM training, including :class:`~torchrl.objectives.llm.GRPOLoss` for Group Relative - Policy Optimization and :class:`~torchrl.objectives.llm.SFTLoss` for supervised fine-tuning. - -5. **Collectors** (`Collectors`_): Collectors are used to collect data from the environment and store it in a format that can be used for training. This includes - :class:`~torchrl.collectors.llm.LLMCollector` for collecting data from the environment and :class:`~torchrl.collectors.llm.RayLLMCollector` for collecting - data in distributed settings using Ray. - -These components work together to create a complete pipeline: environments load and format data, LLM wrappers handle inference, data structures maintain -conversation context, and objectives compute training losses. The modular design allows you to mix and match components based on your specific use case. - -A complete example of how to use the LLM API can be found in the `sota-implementations/grpo/` directory. The training orchestration involves three main components: - -- The Data Collector: holds a reference to the environment and the inference model or engine. It collects data, puts it in the buffer, and handles weight updates. -- The Replay Buffer: stores the collected data and executes any pre or post-processing steps. These may include: - - Advantage estimation with Monte-Carlo based method (using the :class:`~torchrl.objectives.llm.MCAdvantage` transform); - - Grading of the outputs; - - Logging etc. -- The trainer: handles the training loop, including the optimization step, serialization, logging and weight updates initialization. - -.. warning:: The LLM API is still under development and may change in the future. Feedback, issues and PRs are welcome! - -Data Structures ---------------- - -The data representation layer provides the foundation for handling conversations and LLM outputs in a structured way. - -History Class -~~~~~~~~~~~~~ - -The :class:`~torchrl.data.llm.History` class is a TensorClass version of the chat format usually found in transformers -(see `Hugging Face chat documentation `_). -It provides a comprehensive API for managing conversation data with features including: - -- **Text parsing and formatting**: Convert between text and structured conversation format using :meth:`~torchrl.data.llm.chat.History.from_text` - and :meth:`~torchrl.data.llm.chat.History.apply_chat_template` -- **Dynamic conversation building**: Append and extend conversations with :meth:`~torchrl.data.llm.chat.History.append` and - :meth:`~torchrl.data.llm.chat.History.extend` methods -- **Multi-model support**: Automatic template detection for various model families (Qwen, DialoGPT, Falcon, DeepSeek, etc.) -- **Assistant token masking**: Identify which tokens were generated by the assistant for reinforcement learning applications -- **Tool calling support**: Handle function calls and tool responses in conversations -- **Batch operations**: Efficient tensor operations for processing multiple conversations simultaneously. - -.. currentmodule:: torchrl.data.llm - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - History - ContentBase - -Supported Model Families -^^^^^^^^^^^^^^^^^^^^^^^^ - -We currently support the following model families for string to History parsing or assistant token masking: - -- **Qwen family** (e.g., `Qwen/Qwen2.5-0.5B`): Custom template with full tool calling support -- **DialoGPT family** (e.g., `microsoft/DialoGPT-medium`): Custom template for conversation format -- **Falcon family** (e.g., `tiiuae/falcon-7b-instruct`): Custom template for instruction format -- **DeepSeek family** (e.g., `deepseek-ai/deepseek-coder-6.7b-base`): Custom template with native format - -Other models are supported, but you will need to provide a custom template for them. -LLAMA, Mistral, OPT, GPT, MPT, BLOOM, Pythia, Phi, etc. will use the default `chatml_format` template. - -Usage -^^^^^ - -.. code-block:: python - - >>> from torchrl.data.llm.chat import History - >>> from transformers import AutoTokenizer - >>> - >>> # Create a conversation history - >>> history = History.from_chats([[ - ... {"role": "user", "content": "Hello"}, - ... {"role": "assistant", "content": "Hi there!"}, - ... {"role": "user", "content": "How are you?"}, - ... {"role": "assistant", "content": "I'm doing well, thanks!"} - ... ]]) - >>> - >>> # Load any supported tokenizer - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") - >>> - >>> # Apply chat template with assistant token masking - >>> result = history.apply_chat_template( - ... chat_template_name="qwen", - ... add_generation_prompt=False, - ... return_dict=True, - ... return_assistant_tokens_mask=True, - ... ) - >>> - >>> # The result contains an assistant_masks tensor - >>> assistant_masks = result["assistant_masks"] - >>> print(f"Assistant tokens: {assistant_masks.sum().item()}") - -Adding Custom Templates -^^^^^^^^^^^^^^^^^^^^^^^ - -You can add custom chat templates for new model families using the :func:`torchrl.data.llm.add_chat_template` function. - -.. autofunction:: torchrl.data.llm.add_chat_template - -Usage Examples -^^^^^^^^^^^^^^ - -Adding a Llama Template -""""""""""""""""""""""" - -.. code-block:: python - - >>> from torchrl.data.llm import add_chat_template, History - >>> from transformers import AutoTokenizer - >>> - >>> # Define the Llama chat template - >>> llama_template = ''' - ... {% for message in messages %} - ... {%- if message['role'] == 'user' %} - ... {{ '[INST] ' + message['content'] + ' [/INST]' }} - ... {%- elif message['role'] == 'assistant' %} - ... {% generation %}{{ message['content'] + '' }}{% endgeneration %} - ... {%- endif %} - ... {% endfor %} - ... {%- if add_generation_prompt %} - ... {% generation %}{{ ' ' }}{% endgeneration %} - ... {%- endif %} - ... ''' - >>> - >>> # Define the inverse parser for Llama format - >>> def parse_llama_text(text: str) -> History: - ... import re - ... pattern = r'\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?)' - ... matches = re.findall(pattern, text, re.DOTALL) - ... messages = [] - ... for user_content, assistant_content in matches: - ... messages.append(History(role="user", content=user_content.strip())) - ... messages.append(History(role="assistant", content=assistant_content.strip())) - ... return lazy_stack(messages) - >>> - >>> # Add the template with auto-detection - >>> add_chat_template( - ... template_name="llama", - ... template=llama_template, - ... inverse_parser=parse_llama_text, - ... model_family_keywords=["llama", "meta-llama"] - ... ) - >>> - >>> # Now you can use it with auto-detection - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - >>> history = History.from_chats([[ - ... {"role": "user", "content": "Hello"}, - ... {"role": "assistant", "content": "Hi there!"} - ... ]]) - >>> - >>> # Auto-detection will use the llama template - >>> result = history.apply_chat_template( - ... tokenizer=tokenizer, - ... add_generation_prompt=False, - ... return_dict=True, - ... return_assistant_tokens_mask=True, - ... ) - -Testing Your Custom Templates -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -When adding custom templates, you should test them to ensure they work correctly. Here are the recommended tests: - -Assistant Token Masking Test -"""""""""""""""""""""""""""" - -Test that your template supports assistant token masking: - -.. code-block:: python - - import pytest - from torchrl.data.llm.chat import History, add_chat_template - from transformers import AutoTokenizer - - def test_my_model_assistant_masking(): - """Test that your model supports assistant token masking.""" - # Add your template first - add_chat_template( - template_name="my_model", - template="your_template_here", - model_family_keywords=["my_model"] - ) - - tokenizer = AutoTokenizer.from_pretrained("your/model/name") - history = History.from_chats([[ - {'role': 'user', 'content': 'Hello'}, - {'role': 'assistant', 'content': 'Hi there!'} - ]]) - - result = history.apply_chat_template( - tokenizer=tokenizer, - chat_template_name="my_model", - add_generation_prompt=False, - return_dict=True, - return_assistant_tokens_mask=True, - ) - - # Verify assistant mask is present - assert 'assistant_masks' in result - assert result['assistant_masks'].shape[0] == 1, "Should have batch dimension of 1" - assert result['assistant_masks'].shape[1] > 0, "Should have sequence length > 0" - - # Verify some assistant tokens are masked - assistant_token_count = result['assistant_masks'].sum().item() - assert assistant_token_count > 0, "Should have assistant tokens masked" - print(f"✓ {assistant_token_count} assistant tokens masked") - -Template Equivalence Test -""""""""""""""""""""""""" - -Test that your custom template produces the same output as the model's default template (except for masking): - -.. code-block:: python - - def test_my_model_template_equivalence(): - """Test that your template matches the model's default template.""" - tokenizer = AutoTokenizer.from_pretrained("your/model/name") - history = History.from_chats([[ - {'role': 'user', 'content': 'Hello'}, - {'role': 'assistant', 'content': 'Hi there!'}, - {'role': 'user', 'content': 'How are you?'}, - {'role': 'assistant', 'content': 'I\'m good, thanks!'}, - ]]) - - # Get output with model's default template - try: - default_out = history.apply_chat_template( - tokenizer=tokenizer, - add_generation_prompt=False, - chat_template=tokenizer.chat_template, - tokenize=False, - ) - except Exception as e: - default_out = None - print(f"[WARN] Could not get default template: {e}") - - # Get output with your custom template - custom_out = history.apply_chat_template( - tokenizer=tokenizer, - add_generation_prompt=False, - chat_template_name="my_model", - tokenize=False, - ) - - if default_out is not None: - # Normalize whitespace for comparison - import re - def norm(s): - return re.sub(r"\s+", " ", s.strip()) - - assert norm(default_out) == norm(custom_out), ( - f"Custom template does not match default!\n" - f"Default: {default_out}\nCustom: {custom_out}" - ) - print("✓ Template equivalence verified") - else: - print("[INFO] Skipped equivalence check (no default template available)") - -Inverse Parsing Test -"""""""""""""""""""" - -If you provided an inverse parser, test that it works correctly: - -.. code-block:: python - - def test_my_model_inverse_parsing(): - """Test that your inverse parser works correctly.""" - history = History.from_chats([[ - {'role': 'user', 'content': 'Hello'}, - {'role': 'assistant', 'content': 'Hi there!'} - ]]) - - # Format using your template - formatted = history.apply_chat_template( - tokenizer=tokenizer, - chat_template_name="my_model", - add_generation_prompt=False, - tokenize=False, - ) - - # Parse back using your inverse parser - parsed = History.from_text(formatted, chat_template_name="my_model") - - # Verify the parsing worked - assert parsed.role == history.role - assert parsed.content == history.content - print("✓ Inverse parsing verified") - -LLM Wrapper API -~~~~~~~~~~~~~~~ - -The LLM wrapper API provides unified interfaces for different LLM backends, ensuring consistent input/output formats across training and inference pipelines. The main wrappers are :class:`~torchrl.modules.llm.TransformersWrapper` for Hugging Face models and :class:`~torchrl.modules.llm.vLLMWrapper` for vLLM inference. - -**Data Structure Classes** - -The wrappers use structured :class:`~tensordict.TensorClass` objects to represent different aspects of LLM data: - -- **:class:`~torchrl.modules.llm.policies.Text`**: Contains text data with `prompt`, `response`, and `full` fields -- **:class:`~torchrl.modules.llm.policies.ChatHistory`**: Contains :class:`~torchrl.data.llm.History` objects with `prompt`, `response`, and `full` fields -- **:class:`~torchrl.modules.llm.policies.Tokens`**: Contains tokenized data with `prompt`, `response`, and `full` fields -- **:class:`~torchrl.modules.llm.policies.LogProbs`**: Contains log probabilities with `prompt`, `response`, and `full` fields -- **:class:`~torchrl.modules.llm.policies.Masks`**: Contains attention and assistant masks - -**API Flow** - -The wrappers operate in two distinct modes: - -**Generation Mode (`generate=True`)**: -- **Input**: Reads from `prompt` fields (e.g., `history.prompt`, `text.prompt`, `tokens.prompt`) -- **Output**: Writes to both `response` and `full` fields - - `response`: Contains only the newly generated content - - `full`: Contains the complete sequence (prompt + response) - -**Log-Probability Mode (`generate=False`)**: -- **Input**: Reads from `full` fields (e.g., `history.full`, `text.full`, `tokens.full`) -- **Output**: Writes log probabilities to the corresponding `full` fields - -**LLM-Environment Interaction Loop** - -.. figure:: /_static/img/llm-env.png - :alt: LLM-Environment interaction loop - :align: center - :width: 80% - - LLM-Environment interaction: the LLM generates a response, the environment updates the conversation, and transforms can inject new messages or tools. - -In a typical RL or tool-augmented setting, the LLM and environment interact in a loop: - -1. **LLM Generation**: The LLM wrapper receives a `prompt` (the current conversation history), generates a `response`, and outputs a `full` field - containing the concatenation of the prompt and response. -2. **Environment Step**: The environment takes the `full` field and makes it the next `prompt` for the LLM. This ensures that the conversation - context grows with each turn. See :ref:`ref_env_llm_step` for more details. -3. **Transforms**: Before the next LLM step, transforms can modify the conversation—for example, by inserting a new user message, a tool call, - or a reward annotation. -4. **Repeat**: This process repeats for as many turns as needed, enabling multi-turn dialogue, tool use, and RL training. - -This design allows for flexible augmentation of the conversation at each step, supporting advanced RL and tool-use scenarios. - -A typical pseudocode loop: - -.. code-block:: python - - # Get the first prompt out of an initial query - obs = env.reset(TensorDict({"query": "Hello!"}, batch_size=env.batch_size, device=env.device)) - while not done: - # LLM generates a response given the current prompt - llm_output = llm(obs) - # Environment steps: creates a ("next", "history") field with the new prompt (from the previous `"full"` field) - obs = env.step(llm_output) - -**Integration with History** - -When using `input_mode="history"`, the wrapper integrates seamlessly with the :class:`~torchrl.data.llm.History` class: - -- **Input**: Takes a :class:`~torchrl.modules.llm.policies.ChatHistory` object containing a History in the `prompt` field -- **Generation**: Applies chat templates to convert History to tokens, generates response, then parses the full text back into a History object -- **Output**: Returns a ChatHistory with: - - `prompt`: Original conversation history - - `response`: New History object containing only the assistant's response - - `full`: Complete conversation history with the new response appended - -This design allows for natural conversation flow where each generation step extends the conversation history, making it ideal for multi-turn -dialogue systems. - - -Prompt vs. Response and padding -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. figure:: /_static/img/llm-data.svg - :alt: LLM output data format (Tokens, Masks, Padded vs. Sparse) - :align: center - :width: 80% - - Structure of LLM outputs: padded vs. sparse representations for Tokens, LogProbs, and Masks. - -The diagram above illustrates the structure of the main output classes used in TorchRL's LLM API: - -- **Tokens** (and by extension, **LogProbs**): - - *Padded* format: All sequences in a batch are padded to the same length (with a special pad token), making them suitable for tensor operations. The prompt and response are concatenated to form `tokens.full`, and masks indicate valid vs. padded positions. - - *Sparse* format: Each sequence retains its original length (no padding), represented as lists of tensors. This is more memory-efficient for variable-length data. -- **Masks**: Two main masks are shown: - - `mask.attention_mask_all` marks valid (non-pad) tokens. - - `mask.assistant_mask_all` marks which tokens were generated by the assistant (useful for RLHF and SFT training). -- **Text**: Not shown in detail, as it is simply the decoded string representation of the prompt, response, or full sequence. - -This format ensures that all LLM outputs (Tokens, LogProbs, Masks, Text) are consistent and easy to manipulate, regardless of whether you use padded or sparse batching. - -In general, we recommend working with unpadded data, as it is more memory-efficient and easier to manipulate. -For instance, when collecting multiple padded elements from the buffer, it may be hard to clearly understand how to re-pad them -to combine them in a cohesive batch. Working with unpadded data is more straightforward. - -Modules -------- - -The LLM wrapper API provides unified interfaces for different LLM backends, ensuring consistent input/output formats across training and inference pipelines. - -Wrappers -~~~~~~~~ - -The main goal of these primitives is to: - -- Unify the input/output data format across training and inference pipelines -- Unify the input/output data format across backends (to be able to use different backends across losses and collectors) -- Provide appropriate tooling to construct these objects in typical RL settings (resource allocation, async execution, weight update, etc.) - -.. currentmodule:: torchrl.modules.llm - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - LLMWrapperBase - TransformersWrapper - vLLMWrapper - RemoteTransformersWrapper - AsyncVLLM - ChatHistory - Text - LogProbs - Masks - Tokens - -Async vLLM Engine (Recommended) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -:class:`~torchrl.modules.llm.AsyncVLLM` is the recommended approach for high-performance vLLM inference in TorchRL. -It provides a distributed, async-capable inference service built on Ray that offers superior performance and resource utilization compared to synchronous vLLM engines. - -**Key Features:** - -- **Distributed Architecture**: Runs multiple vLLM engine replicas as Ray actors for horizontal scaling -- **Load Balancing**: Automatically distributes requests across available replicas -- **Native vLLM Batching**: Leverages vLLM's optimized batching for maximum throughput. - Every thread or actor in your code will be able to make requests to the vLLM engine(s), put the query in - the queue and let the engine handle the batching. -- **Resource Management**: Automatic GPU allocation and cleanup through Ray placement groups -- **Simple API**: Single-import convenience with :meth:`~torchrl.modules.llm.AsyncVLLM.from_pretrained` - -**Basic Usage:** +Quick Example +------------- .. code-block:: python - from torchrl.modules.llm import AsyncVLLM, vLLMWrapper - from vllm import SamplingParams - - # Create async vLLM service (recommended) - async_engine = AsyncVLLM.from_pretrained( - "Qwen/Qwen2.5-7B", - num_devices=2, # Use 2 GPUs per replica (tensor parallel) - num_replicas=2, # Create 2 replicas for higher throughput - max_model_len=4096 - ) - - # Use with vLLMWrapper for TorchRL integration - wrapper = vLLMWrapper(async_engine, input_mode="history", generate=True) - - # Direct generation (also supported) - sampling_params = SamplingParams(temperature=0.7, max_tokens=100) - result = async_engine.generate("Hello, world!", sampling_params) - - # Cleanup when done - async_engine.shutdown() - -These objects (AsyncVLLM and vLLMWrapper) can be shared across multiple collectors, environments, or workers efficiently. -They can be directly passed from one worker to another: under the hood, Ray will handle the handler sharing and remote execution. - -**Performance Benefits:** - -- **Higher Throughput**: Multiple replicas process requests concurrently -- **Better GPU Utilization**: Ray ensures optimal GPU allocation and co-location -- **Reduced Latency**: Native batching reduces per-request overhead -- **Fault Tolerance**: Ray provides automatic error recovery and resource management - -**Resource Sharing:** - -AsyncVLLM instances can be shared across multiple collectors, environments, or workers efficiently: - -.. code-block:: python - - from torchrl.modules.llm import AsyncVLLM, vLLMWrapper + from torchrl.modules.llm import vLLMWrapper, AsyncVLLM + from torchrl.envs.llm import ChatEnv from torchrl.collectors.llm import LLMCollector - # Create a shared AsyncVLLM service - shared_async_engine = AsyncVLLM.from_pretrained( - "Qwen/Qwen2.5-7B", - num_devices=2, - num_replicas=4, # High throughput for multiple consumers - max_model_len=4096 - ) - - # Multiple wrappers can use the same AsyncVLLM service - wrapper1 = vLLMWrapper(shared_async_engine, input_mode="history") - wrapper2 = vLLMWrapper(shared_async_engine, input_mode="text") - - # Multiple collectors can share the same service - collector1 = LLMCollector(env1, policy=wrapper1) - collector2 = LLMCollector(env2, policy=wrapper2) - - # The AsyncVLLM service automatically load-balances across replicas - # No additional coordination needed between consumers - -This approach is more efficient than creating separate vLLM instances for each consumer, as it: - -- **Reduces Memory Usage**: Single model loading shared across consumers -- **Automatic Load Balancing**: Requests are distributed across replicas -- **Better Resource Utilization**: GPUs are used more efficiently -- **Simplified Management**: Single service to monitor and manage - -.. note:: - **AsyncVLLM vs. Traditional Actor Sharing** - - Unlike traditional Ray actor sharing patterns where you manually create named actors and share references, - AsyncVLLM handles the distributed architecture internally. You simply create one AsyncVLLM service and - pass it to multiple consumers. The service automatically: - - - Creates and manages multiple Ray actors (replicas) internally - - Load-balances requests across replicas without manual coordination - - Handles actor lifecycle and resource cleanup - - This eliminates the need for manual actor name management and reference sharing that was required - with the previous `RemotevLLMWrapper` approach. - -Remote Wrappers -^^^^^^^^^^^^^^^ - -TorchRL provides remote wrapper classes that enable distributed execution of LLM wrappers using Ray. These wrappers provide a simplified interface that doesn't require explicit `remote()` and `get()` calls, making them easy to use in distributed settings. - -.. note:: - **For vLLM: Use AsyncVLLM Instead** - - For vLLM-based inference, we recommend using :class:`~torchrl.modules.llm.AsyncVLLM` directly rather than - remote wrappers. AsyncVLLM provides better performance, resource utilization, and built-in load balancing. - See the `Async vLLM Engine (Recommended)`_ section above for details. - - Remote wrappers are primarily intended for Transformers-based models or other use cases where AsyncVLLM - is not applicable. - -**Key Features:** - -- **Simplified Interface**: No need to call `remote()` and `get()` explicitly -- **Full API Compatibility**: Exposes all public methods from the base `LLMWrapperBase` class -- **Automatic Ray Management**: Handles Ray initialization and remote execution internally -- **Property Access**: All properties are accessible through the remote wrapper -- **Error Handling**: Proper error propagation from remote actors -- **Resource Management**: Context manager support for automatic cleanup - -**Model Parameter Requirements:** - -- **RemoteTransformersWrapper**: Only accepts string model names/paths. Transformers models are not serializable. - -**Supported Backends:** - -Currently, only Transformers-based models are supported through remote wrappers. For vLLM models, use :class:`~torchrl.modules.llm.AsyncVLLM` instead. - -**Usage Examples:** - -.. code-block:: python - - import ray - from torchrl.modules.llm.policies import RemoteTransformersWrapper - from torchrl.data.llm import History - from torchrl.modules.llm.policies import ChatHistory, Text - from tensordict import TensorDict - - # Initialize Ray (if not already done) - if not ray.is_initialized(): - ray.init() - - # Transformers wrapper (only string models supported) - # The remote wrappers implement context managers for proper resource cleanup: - with RemoteTransformersWrapper( - model="gpt2", - max_concurrency=16, - input_mode="text", - generate=True, - generate_kwargs={"max_new_tokens": 30} - ) as remote_transformers: - - text_input = TensorDict({"text": Text(prompt="Hello world")}, batch_size=(1,)) - result = remote_transformers(text_input) - print(result["text"].response) -**Performance Considerations:** - -- **Network Overhead**: Remote execution adds network communication overhead -- **Serialization**: Data is serialized when sent to remote actors -- **Memory**: Each remote actor maintains its own copy of the model -- **Concurrency**: Multiple remote wrappers can run concurrently -- **Max Concurrency**: Use the `max_concurrency` parameter to control the number of concurrent calls to each remote actor -- **Cleanup**: Always use context managers or call `cleanup_batching()` to prevent hanging due to batching locks - -Utils -^^^^^ - -.. currentmodule:: torchrl.modules.llm - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - make_async_vllm_engine - stateless_init_process_group_async - make_vllm_worker - stateless_init_process_group - -Collectors ----------- - -.. _Collectors: - -TorchRL offers specialized collector classes (:class:`~torchrl.collectors.llm.LLMCollector` and :class:`~torchrl.collectors.llm.RayLLMCollector`) -that are tailored for LLM use cases. We also provide weight synchronization schemes for vLLM inference engines. - -See :ref:`ref_collectors` for more details on the collector API. In brief, the idea of a collector is to isolate the inference part of the pipeline -in a dedicated class. -A collector usually takes as input a policy and an environment, and alternate between running one and the other. -In "classical" settings, the policy is similar to the policy being trained (with some optional extra-exploration). In the context of LLM fine-tuning, -the policy will usually be a specialized inference engine, such as a vLLM server. -Collectors are defined by the following parameters and features: - -- **Sync/Async**: Whether the collector should run in sync or async mode. - In sync mode, the collector will run the inference step in alternate with the optimization/training step. - In async mode, the collector will run the inference step in parallel with the optimization/training step. - A replay buffer can be passed to the collector, in such a way that the collector can directly write to it. - In other cases, the collector can be iterated over to collect data. -- **Steps**: A collector is built with a certain number of steps budget, as well as a number of steps to be - included in each batch yield during collection. -- **Weight Synchronization Schemes**: Weight sync schemes handle the synchronization of weights between the training model - and the inference engine. The new scheme-based approach provides flexible, high-performance weight updates for vLLM and - other inference backends. - -vLLM Weight Synchronization Schemes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -TorchRL provides two weight synchronization schemes for vLLM engines, offering different trade-offs between -performance and simplicity: - -**1. NCCL-Based Synchronization** (:class:`~torchrl.weight_update.llm.VLLMWeightSyncScheme`) - -Uses NCCL collectives for high-bandwidth GPU-to-GPU weight transfers. Best for: - -- High-frequency weight updates -- Large models where transfer speed is critical -- Setups with GPU interconnect (NVLink, InfiniBand) - -**2. Double-Buffer Synchronization** (:class:`~torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme`) - -Uses memory-mapped file storage for asynchronous weight transfers. Best for: - -- Simpler setup without NCCL coordination -- Distributed setups with shared filesystems (NFS) -- Cases where update frequency is lower - -**Usage Example with NCCL:** - -.. code-block:: python - - from torchrl.collectors.llm import RayLLMCollector - from torchrl.weight_update.llm import VLLMWeightSyncScheme - from torchrl.modules.llm import AsyncVLLM, vLLMWrapper - # Create vLLM engine - vllm_engine = AsyncVLLM.from_pretrained( - "Qwen/Qwen2.5-7B", - num_devices=2, - num_replicas=2, - ) - policy = vLLMWrapper(vllm_engine, input_mode="history") - - # Create NCCL weight sync scheme - weight_sync_scheme = VLLMWeightSyncScheme( - master_address="localhost", - master_port=29500, - gpus_per_replica=2, # tp_size × dp_size × pp_size - num_replicas=2, - strategy="state_dict" - ) - - # Create collector with weight sync scheme - collector = RayLLMCollector( - env=make_env, - policy=policy, - dialog_turns_per_batch=256, - total_dialog_turns=10000, - weight_sync_schemes={"policy": weight_sync_scheme}, - track_policy_version=True, - ) + engine = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-7B", num_replicas=2) + policy = vLLMWrapper(engine, input_mode="history") - # During training, get the sender and update weights - sender = collector._weight_senders["policy"] - sender.register_model(training_model) + # Create environment + env = ChatEnv(tokenizer=tokenizer) - # Initialize collective group (must be called before first update) - metadata = get_model_metadata(training_model) - sender.init_all_workers_group(metadata, vllm_engine=vllm_engine) - - # Update weights during training - for i, data in enumerate(collector): - # ... training step ... - if i % 10 == 0: - sender.update_weights() # Broadcasts via NCCL + # Create collector + collector = LLMCollector(env, policy, dialog_turns_per_batch=256) -**Usage Example with Double-Buffer:** +.. warning:: The LLM API is still under development and may change in the future. + Feedback, issues and PRs are welcome! -.. code-block:: python +Documentation Sections +---------------------- - from torchrl.collectors.llm import RayLLMCollector - from torchrl.weight_update.llm import VLLMDoubleBufferSyncScheme - from torchrl.modules.llm import AsyncVLLM, vLLMWrapper - - # Create vLLM engine - vllm_engine = AsyncVLLM.from_pretrained( - "Qwen/Qwen2.5-7B", - num_devices=2, - num_replicas=1, - ) - policy = vLLMWrapper(vllm_engine, input_mode="history") - - # Create double-buffer weight sync scheme - weight_sync_scheme = VLLMDoubleBufferSyncScheme( - remote_addr="/tmp/weights", # Or "/mnt/shared/weights" for NFS - num_threads=128, - strategy="state_dict" - ) - - # Create collector with weight sync scheme - collector = RayLLMCollector( - env=make_env, - policy=policy, - dialog_turns_per_batch=256, - total_dialog_turns=10000, - weight_sync_schemes={"policy": weight_sync_scheme}, - track_policy_version=True, - ) - - # During training, get the sender and receiver - sender = collector._weight_senders["policy"] - sender.register_model(training_model) - - # No initialization needed for double-buffer scheme! - - # Update weights during training - for i, data in enumerate(collector): - # ... training step ... - if i % 10 == 0: - sender.update_weights() # Writes to shared storage - # vLLM workers can poll and apply: receiver.poll_and_apply() +.. toctree:: + :maxdepth: 2 +<<<<<<< HEAD Policy Version Tracking ~~~~~~~~~~~~~~~~~~~~~~~ @@ -1195,3 +471,11 @@ SFT :template: rl_template.rst TopKRewardSelector +================== + llms_data + llms_modules + llms_envs + llms_transforms + llms_collectors + llms_objectives +>>>>>>> 571142f4e ([Doc] Huge doc refactoring) diff --git a/docs/source/reference/llms_collectors.rst b/docs/source/reference/llms_collectors.rst new file mode 100644 index 00000000000..314b0afb4a4 --- /dev/null +++ b/docs/source/reference/llms_collectors.rst @@ -0,0 +1,34 @@ +.. currentmodule:: torchrl.collectors.llm + +LLM Collectors +============== + +Specialized collector classes for LLM use cases. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + LLMCollector + RayLLMCollector + vLLMUpdater + vLLMUpdaterV2 + +Weight Synchronization Schemes +------------------------------ + +.. currentmodule:: torchrl.weight_update.llm + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + VLLMWeightSyncScheme + VLLMWeightSender + VLLMWeightReceiver + VLLMCollectiveTransport + VLLMDoubleBufferSyncScheme + VLLMDoubleBufferWeightSender + VLLMDoubleBufferWeightReceiver + VLLMDoubleBufferTransport + get_model_metadata diff --git a/docs/source/reference/llms_data.rst b/docs/source/reference/llms_data.rst new file mode 100644 index 00000000000..382caf68489 --- /dev/null +++ b/docs/source/reference/llms_data.rst @@ -0,0 +1,36 @@ +.. currentmodule:: torchrl.data.llm + +Data Structures +=============== + +The data representation layer provides the foundation for handling conversations and LLM outputs. + +History Class +------------- + +The :class:`~torchrl.data.llm.History` class is a TensorClass version of the chat format usually found in transformers. +It provides a comprehensive API for managing conversation data with features including: + +- **Text parsing and formatting**: Convert between text and structured conversation format +- **Dynamic conversation building**: Append and extend conversations +- **Multi-model support**: Automatic template detection for various model families +- **Assistant token masking**: Identify which tokens were generated by the assistant +- **Tool calling support**: Handle function calls and tool responses +- **Batch operations**: Efficient tensor operations for multiple conversations + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + History + ContentBase + add_chat_template + +TopK Reward Selector +-------------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + TopKRewardSelector diff --git a/docs/source/reference/llms_envs.rst b/docs/source/reference/llms_envs.rst new file mode 100644 index 00000000000..e04b42a6f7b --- /dev/null +++ b/docs/source/reference/llms_envs.rst @@ -0,0 +1,24 @@ +.. currentmodule:: torchrl.envs.llm + +LLM Environments +================ + +The environment layer orchestrates data loading, tool execution, reward computation, and formatting. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + ChatEnv + DatasetChatEnv + GSM8KEnv + make_gsm8k_env + GSM8KPrepareQuestion + GSM8KRewardParser + IFEvalEnv + IfEvalScorer + IFEvalScoreData + LLMEnv + LLMHashingEnv + make_mlgym + MLGymWrapper diff --git a/docs/source/reference/llms_modules.rst b/docs/source/reference/llms_modules.rst new file mode 100644 index 00000000000..53779611a0e --- /dev/null +++ b/docs/source/reference/llms_modules.rst @@ -0,0 +1,45 @@ +.. currentmodule:: torchrl.modules.llm + +LLM Wrappers +============ + +The LLM wrapper API provides unified interfaces for different LLM backends, ensuring consistent +input/output formats across training and inference pipelines. + +Wrappers +-------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + LLMWrapperBase + TransformersWrapper + vLLMWrapper + RemoteTransformersWrapper + AsyncVLLM + +Data Structure Classes +---------------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + ChatHistory + Text + LogProbs + Masks + Tokens + +Utilities +--------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + make_async_vllm_engine + stateless_init_process_group_async + make_vllm_worker + stateless_init_process_group diff --git a/docs/source/reference/llms_objectives.rst b/docs/source/reference/llms_objectives.rst new file mode 100644 index 00000000000..e7863a081b8 --- /dev/null +++ b/docs/source/reference/llms_objectives.rst @@ -0,0 +1,27 @@ +.. currentmodule:: torchrl.objectives.llm + +LLM Objectives +============== + +Specialized loss functions for LLM training. + +GRPO +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + GRPOLoss + GRPOLossOutput + MCAdvantage + +SFT +--- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + SFTLoss + SFTLossOutput diff --git a/docs/source/reference/llms_transforms.rst b/docs/source/reference/llms_transforms.rst new file mode 100644 index 00000000000..4b34ad7ec0b --- /dev/null +++ b/docs/source/reference/llms_transforms.rst @@ -0,0 +1,34 @@ +.. currentmodule:: torchrl.envs.llm.transforms + +LLM Transforms +============== + +Transforms for LLM environments, including tools and utilities. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + AddThinkingPrompt + BrowserTransform + DataLoadingPrimer + ExecuteToolsInOrder + JSONCallParser + KLComputation + KLRewardTransform + MCPToolTransform + PolicyVersion + PythonExecutorService + PythonInterpreter + RayDataLoadingPrimer + RetrieveKL + RetrieveLogProb + SimpleToolTransform + TemplateTransform + Tokenizer + ToolCall + ToolRegistry + ToolService + XMLBlockParser + as_nested_tensor + as_padded_tensor diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 28c97982b4d..2e1e932829e 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -5,467 +5,53 @@ torchrl.modules package .. _ref_modules: -TensorDict modules: Actors, exploration, value models and generative models ---------------------------------------------------------------------------- +TorchRL offers a comprehensive collection of RL-specific neural network modules built on top of +:class:`tensordict.nn.TensorDictModule`. These modules are designed to work seamlessly with +tensordict data structures, making it easy to build and compose RL models. -.. _tdmodules: +Key Features +------------ -TorchRL offers a series of module wrappers aimed at making it easy to build -RL models from the ground up. These wrappers are exclusively based on -:class:`tensordict.nn.TensorDictModule` and :class:`tensordict.nn.TensorDictSequential`. -They can loosely be split in three categories: -policies (actors), including exploration strategies, -value model and simulation models (in model-based contexts). +- **Spec-based construction**: Automatically configure output layers based on action specs +- **Probabilistic modules**: Built-in support for stochastic policies +- **Exploration strategies**: Modular exploration wrappers (ε-greedy, Ornstein-Uhlenbeck, etc.) +- **Value networks**: Q-value, distributional, and dueling architectures +- **Safe modules**: Automatic projection to satisfy action constraints +- **Model-based RL**: World model and dynamics modules -The main features are: - -- Integration of the specs in your model to ensure that the model output matches - what your environment expects as input; -- Probabilistic modules that can automatically sample from a chosen distribution - and/or return the distribution of interest; -- Custom containers for Q-Value learning, model-based agents and others. - -TensorDictModules and SafeModules -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -TorchRL :class:`~torchrl.modules.tensordict_module.SafeModule` allows you to -check the you model output matches what is to be expected for the environment. -This should be used whenever your model is to be recycled across multiple -environments for instance, and when you want to make sure that the outputs -(e.g. the action) always satisfies the bounds imposed by the environment. -Here is an example of how to use that feature with the -:class:`~torchrl.modules.tensordict_module.Actor` class: - - >>> env = GymEnv("Pendulum-v1") - >>> action_spec = env.action_spec - >>> model = nn.LazyLinear(action_spec.shape[-1]) - >>> policy = Actor(model, in_keys=["observation"], spec=action_spec, safe=True) - -The ``safe`` flag ensures that the output is always within the bounds of the -``action_spec`` domain: if the network output violates these bounds it will be -projected (in a L1-manner) into the desired domain. - -.. currentmodule:: torchrl.modules.tensordict_module - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - Actor - MultiStepActorWrapper - SafeModule - SafeSequential - TanhModule - -Exploration wrappers and modules -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To efficiently explore the environment, TorchRL proposes a series of modules -that will override the action sampled by the policy by a noisier version. -Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_type`: -if the exploration is set to ``ExplorationType.RANDOM``, the exploration is active. In all -other cases, the action written in the tensordict is simply the network output. - -.. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule` - uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch. - The :func:`~torchrl.envs.utils.set_exploration_type` context manager will have no effect on - this module. - -.. currentmodule:: torchrl.modules - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - AdditiveGaussianModule - ConsistentDropoutModule - EGreedyModule - OrnsteinUhlenbeckProcessModule - -Probabilistic actors -~~~~~~~~~~~~~~~~~~~~ - -Some algorithms such as PPO require a probabilistic policy to be implemented. -In TorchRL, these policies take the form of a model, followed by a distribution -constructor. - - .. note:: The choice of a probabilistic or regular actor class depends on the algorithm - that is being implemented. On-policy algorithms usually require a probabilistic - actor, off-policy usually have a deterministic actor with an extra exploration - strategy. There are, however, many exceptions to this rule. - -The model reads an input (typically some observation from the environment) -and outputs the parameters of a distribution, while the distribution constructor -reads these parameters and gets a random sample from the distribution and/or -provides a :class:`torch.distributions.Distribution` object. - - >>> from tensordict.nn import NormalParamExtractor, TensorDictSequential, TensorDictModule - >>> from torchrl.modules import SafeProbabilisticModule - >>> from torchrl.envs import GymEnv - >>> from torch.distributions import Normal - >>> from torch import nn - >>> - >>> env = GymEnv("Pendulum-v1") - >>> action_spec = env.action_spec - >>> model = nn.Sequential(nn.LazyLinear(action_spec.shape[-1] * 2), NormalParamExtractor()) - >>> # build the first module, which maps the observation on the mean and sd of the normal distribution - >>> model = TensorDictModule(model, in_keys=["observation"], out_keys=["loc", "scale"]) - >>> # build the distribution constructor - >>> prob_module = SafeProbabilisticModule( - ... in_keys=["loc", "scale"], - ... out_keys=["action"], - ... distribution_class=Normal, - ... return_log_prob=True, - ... spec=action_spec, - ... ) - >>> policy = TensorDictSequential(model, prob_module) - >>> # execute a rollout - >>> env.rollout(3, policy) - -To facilitate the construction of probabilistic policies, we provide a dedicated -:class:`~torchrl.modules.tensordict_module.ProbabilisticActor`: - - >>> from torchrl.modules import ProbabilisticActor - >>> policy = ProbabilisticActor( - ... model, - ... in_keys=["loc", "scale"], - ... out_keys=["action"], - ... distribution_class=Normal, - ... return_log_prob=True, - ... spec=action_spec, - ... ) - -which alleviates the need to specify a constructor and putting it with the -module in a sequence. - -Outputs of this policy will contain a ``"loc"`` and ``"scale"`` entries, an -``"action"`` sampled according to the normal distribution and the log-probability -of this action. - -.. currentmodule:: torchrl.modules.tensordict_module - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - ProbabilisticActor - SafeProbabilisticModule - SafeProbabilisticTensorDictSequential - -Q-Value actors -~~~~~~~~~~~~~~ - -Q-Value actors are a type of policy that selects actions based on the maximum value -(or "quality") of a state-action pair. This value can be represented as a table or a -function. For discrete action spaces with continuous states, it's common to use a non-linear -model like a neural network to represent this function. - -QValueActor -^^^^^^^^^^^ - -The :class:`~torchrl.modules.QValueActor` class takes in a module and an action -specification, and outputs the selected action and its corresponding value. - - >>> import torch - >>> from tensordict import TensorDict - >>> from torch import nn - >>> from torchrl.data import OneHot - >>> from torchrl.modules.tensordict_module.actors import QValueActor - >>> # Create a tensor dict with an observation - >>> td = TensorDict({'observation': torch.randn(5, 3)}, [5]) - >>> # Define the action space - >>> action_spec = OneHot(4) - >>> # Create a linear module to output action values - >>> module = nn.Linear(3, 4) - >>> # Create a QValueActor instance - >>> qvalue_actor = QValueActor(module=module, spec=action_spec) - >>> # Run the actor on the tensor dict - >>> qvalue_actor(td) - >>> print(td) - TensorDict( - fields={ - action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), - action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), - chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), - observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([5]), - device=None, - is_shared=False) - -This will output a tensor dict with the selected action and its corresponding value. - -Distributional Q-Learning -^^^^^^^^^^^^^^^^^^^^^^^^^ - -Distributional Q-learning is a variant of Q-learning that represents the value function as a -probability distribution over possible values, rather than a single scalar value. -This allows the agent to learn about the uncertainty in the environment and make more informed -decisions. -In TorchRL, distributional Q-learning is implemented using the :class:`~torchrl.modules.DistributionalQValueActor` -class. This class takes in a module, an action specification, and a support vector, and outputs the selected -action and its corresponding value distribution. - - - >>> import torch - >>> from tensordict import TensorDict - >>> from torch import nn - >>> from torchrl.data import OneHot - >>> from torchrl.modules import DistributionalQValueActor, MLP - >>> # Create a tensor dict with an observation - >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) - >>> # Define the action space - >>> action_spec = OneHot(4) - >>> # Define the number of bins for the value distribution - >>> nbins = 3 - >>> # Create an MLP module to output logits for the value distribution - >>> module = MLP(out_features=(nbins, 4), depth=2) - >>> # Create a DistributionalQValueActor instance - >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins)) - >>> # Run the actor on the tensor dict - >>> td = qvalue_actor(td) - >>> print(td) - TensorDict( - fields={ - action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), - action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), - observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([5]), - device=None, - is_shared=False) - -This will output a tensor dict with the selected action and its corresponding value distribution. - -.. currentmodule:: torchrl.modules.tensordict_module - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - QValueActor - QValueModule - DistributionalQValueActor - DistributionalQValueModule - -Value operators and joined models -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torchrl.modules.tensordict_module - -TorchRL provides a series of value operators that wrap value networks to -soften the interface with the rest of the library. -The basic building block is :class:`torchrl.modules.tensordict_module.ValueOperator`: -given an input state (and possibly action), it will automatically write a ``"state_value"`` -(or ``"state_action_value"``) in the tensordict, depending on what the input is. -As such, this class accounts for both value and quality networks. -Three classes are also proposed to group together a policy and a value network. -The :class:`~.ActorCriticOperator` is an joined actor-quality network with shared parameters: -it reads an observation, pass it through a -common backbone, writes a hidden state, feeds this hidden state to the policy, -then takes the hidden state and the action and provides the quality of the state-action -pair. -The :class:`~.ActorValueOperator` is a joined actor-value network with shared parameters: -it reads an observation, pass it through a -common backbone, writes a hidden state, feeds this hidden state to the policy -and value modules to output an action and a state value. -Finally, the :class:`~.ActorCriticWrapper` is a joined actor and value network -without shared parameters. It is mainly intended as a replacement for -:class:`~.ActorValueOperator` when a script needs to account for both options. - - >>> actor = make_actor() - >>> value = make_value() - >>> if shared_params: - ... common = make_common() - ... model = ActorValueOperator(common, actor, value) - ... else: - ... model = ActorValueOperator(actor, value) - >>> policy = model.get_policy_operator() # will work in both cases - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - ActorCriticOperator - ActorCriticWrapper - ActorValueOperator - ValueOperator - DecisionTransformerInferenceWrapper - -Domain-specific TensorDict modules -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. currentmodule:: torchrl.modules.tensordict_module - -These modules include dedicated solutions for MBRL or RLHF pipelines. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - LMHeadActorValueOperator - WorldModelWrapper - -Hooks ------ -.. currentmodule:: torchrl.modules - -The Q-value hooks are used by the :class:`~.QValueActor` and :class:`~.DistributionalQValueActor` -modules and those should be preferred in general as they are easier to create -and use. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - QValueHook - DistributionalQValueHook - -Models ------- -.. currentmodule:: torchrl.modules - -TorchRL provides a series of useful "regular" (ie non-tensordict) nn.Module -classes for RL usage. - -Regular modules -~~~~~~~~~~~~~~~ - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - BatchRenorm1d - ConsistentDropout - Conv3dNet - ConvNet - MLP - Squeeze2dLayer - SqueezeLayer - -Algorithm-specific modules -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -These networks implement sub-networks that have shown to be useful for specific -algorithms, such as DQN, DDPG or Dreamer. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - DTActor - DdpgCnnActor - DdpgCnnQNet - DdpgMlpActor - DdpgMlpQNet - DecisionTransformer - DistributionalDQNnet - DreamerActor - DuelingCnnDQNet - GRUCell - GRU - GRUModule - LSTMCell - LSTM - LSTMModule - ObsDecoder - ObsEncoder - OnlineDTActor - RSSMPosterior - RSSMPrior - set_recurrent_mode - recurrent_mode - -Multi-agent-specific modules -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -These networks implement models that can be used in multi-agent contexts. -They use :func:`~torch.vmap` to execute multiple networks all at once on the -network inputs. Because the parameters are batched, initialization may differ -from what is usually done with other PyTorch modules, see -:meth:`~torchrl.modules.MultiAgentNetBase.get_stateful_net` -for more information. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - MultiAgentNetBase - MultiAgentMLP - MultiAgentConvNet - QMixer - VDNMixer - - -Exploration ------------ -.. currentmodule:: torchrl.modules - -Noisy linear layers are a popular way of exploring the environment without -altering the actions, but by integrating the stochasticity in the weight -configuration. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - NoisyLinear - NoisyLazyLinear - reset_noise - - -Planners --------- -.. currentmodule:: torchrl.modules - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - CEMPlanner - MPCPlannerBase - MPPIPlanner - - -Distributions +Quick Example ------------- -.. currentmodule:: torchrl.modules - -Some distributions are typically used in RL scripts. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - Delta - IndependentNormal - TanhNormal - TruncatedNormal - TanhDelta - OneHotCategorical - LLMMaskedCategorical - MaskedCategorical - MaskedOneHotCategorical - Ordinal - OneHotOrdinal - -Utils ------ -.. currentmodule:: torchrl.modules.utils - -The module utils include functionals used to do some custom mappings as well as a tool to -build :class:`~torchrl.envs.TensorDictPrimer` instances from a given module. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - mappings - inv_softplus - biased_softplus - get_primers_from_module - -.. currentmodule:: torchrl.modules - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - VmapModule +.. code-block:: python + + from torchrl.modules import ProbabilisticActor, TanhNormal + from torchrl.envs import GymEnv + from tensordict.nn import TensorDictModule + import torch.nn as nn + + env = GymEnv("Pendulum-v1") + + # Create a probabilistic actor + actor = ProbabilisticActor( + module=TensorDictModule( + nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 2)), + in_keys=["observation"], + out_keys=["loc", "scale"], + ), + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + spec=env.action_spec, + ) + +Documentation Sections +---------------------- + +.. toctree:: + :maxdepth: 2 + + modules_actors + modules_exploration + modules_critics + modules_models + modules_distributions + modules_utils diff --git a/docs/source/reference/modules_actors.rst b/docs/source/reference/modules_actors.rst new file mode 100644 index 00000000000..afbf90ba702 --- /dev/null +++ b/docs/source/reference/modules_actors.rst @@ -0,0 +1,47 @@ +.. currentmodule:: torchrl.modules + +Actor Modules +============= + +Actor modules represent policies in RL. They map observations to actions, either deterministically +or stochastically. + +TensorDictModules and SafeModules +--------------------------------- + +.. currentmodule:: torchrl.modules.tensordict_module + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + Actor + MultiStepActorWrapper + SafeModule + SafeSequential + TanhModule + +Probabilistic actors +-------------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + ProbabilisticActor + SafeProbabilisticModule + SafeProbabilisticTensorDictSequential + +Q-Value actors +-------------- + +.. currentmodule:: torchrl.modules + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + QValueActor + DistributionalQValueActor + QValueModule + DistributionalQValueModule diff --git a/docs/source/reference/modules_critics.rst b/docs/source/reference/modules_critics.rst new file mode 100644 index 00000000000..ba2f7110bf5 --- /dev/null +++ b/docs/source/reference/modules_critics.rst @@ -0,0 +1,25 @@ +.. currentmodule:: torchrl.modules + +Value Networks and Critics +========================== + +Value networks estimate the value of states or state-action pairs. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + ValueOperator + DuelingCnnDQNet + DistributionalDQNnet + ConvNet + MLP + DdpgCnnActor + DdpgCnnQNet + DdpgMlpActor + DdpgMlpQNet + LSTMModule + GRUModule + OnlineDTActor + DTActor + DecisionTransformer diff --git a/docs/source/reference/modules_distributions.rst b/docs/source/reference/modules_distributions.rst new file mode 100644 index 00000000000..65ba204cee0 --- /dev/null +++ b/docs/source/reference/modules_distributions.rst @@ -0,0 +1,20 @@ +.. currentmodule:: torchrl.modules + +Distribution Classes +==================== + +Custom distribution classes for RL, extending PyTorch distributions. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + Delta + IndependentNormal + MaskedCategorical + NormalParamExtractor + OneHotCategorical + ReparamGradientStrategy + TanhDelta + TanhNormal + TruncatedNormal diff --git a/docs/source/reference/modules_exploration.rst b/docs/source/reference/modules_exploration.rst new file mode 100644 index 00000000000..01b25ef05cf --- /dev/null +++ b/docs/source/reference/modules_exploration.rst @@ -0,0 +1,15 @@ +.. currentmodule:: torchrl.modules + +Exploration Strategies +====================== + +Exploration modules add noise to actions to enable exploration during training. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + AdditiveGaussianModule + ConsistentDropoutModule + EGreedyModule + OrnsteinUhlenbeckProcessModule diff --git a/docs/source/reference/modules_models.rst b/docs/source/reference/modules_models.rst new file mode 100644 index 00000000000..be3e74ef0c7 --- /dev/null +++ b/docs/source/reference/modules_models.rst @@ -0,0 +1,18 @@ +.. currentmodule:: torchrl.modules + +World Models and Model-Based RL +=============================== + +Modules for model-based reinforcement learning, including world models and dynamics models. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + WorldModelWrapper + DreamerActor + ObsEncoder + ObsDecoder + RSSMPosterior + RSSMPrior + RSSMRollout diff --git a/docs/source/reference/modules_utils.rst b/docs/source/reference/modules_utils.rst new file mode 100644 index 00000000000..92ec06b2645 --- /dev/null +++ b/docs/source/reference/modules_utils.rst @@ -0,0 +1,22 @@ +.. currentmodule:: torchrl.modules + +Utilities and Helpers +===================== + +Utility modules and helper functions for building RL networks. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + ActorValueOperator + ActorCriticOperator + ActorCriticWrapper + +.. currentmodule:: torchrl.modules.models.utils + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + SquashDims diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index f2741809bd3..08c832a4cc3 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -5,413 +5,49 @@ torchrl.objectives package .. _ref_objectives: -TorchRL provides a series of losses to use in your training scripts. -The aim is to have losses that are easily reusable/swappable and that have -a simple signature. - -The main characteristics of TorchRL losses are: - -- they are stateful objects: they contain a copy of the trainable parameters - such that ``loss_module.parameters()`` gives whatever is needed to train the - algorithm. -- They follow the ``tensordict`` convention: the :meth:`torch.nn.Module.forward` - method will receive a tensordict as input that contains all the necessary - information to return a loss value. -- They output a :class:`tensordict.TensorDict` instance with the loss values - written under a ``"loss_"`` where ``smth`` is a string describing the - loss. Additional keys in the tensordict may be useful metrics to log during - training time. - -.. note:: - The reason we return independent losses is to let the user use a different - optimizer for different sets of parameters for instance. Summing the losses - can be simply done via - - >>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_")) - -.. note:: - Initializing parameters in losses can be done via a query to :meth:`~torchrl.objectives.LossModule.get_stateful_net` - which will return a stateful version of the network that can be initialized like any other module. - If the modification is done in-place, it will be downstreamed to any other module that uses the same parameter - set (within and outside of the loss): for instance, modifying the ``actor_network`` parameters from the loss - will also modify the actor in the collector. - If the parameters are modified out-of-place, :meth:`~torchrl.objectives.LossModule.from_stateful_net` can be - used to reset the parameters in the loss to the new value. - -torch.vmap and randomness -------------------------- - -TorchRL loss modules have plenty of calls to :func:`~torch.vmap` to amortize the cost of calling multiple similar models -in a loop, and instead vectorize these operations. `vmap` needs to be told explicitly what to do when random numbers -need to be generated within the call. To do this, a randomness mode need to be set and must be one of `"error"` (default, -errors when dealing with pseudo-random functions), `"same"` (replicates the results across the batch) or `"different"` -(each element of the batch is treated separately). -Relying on the default will typically result in an error such as this one: - - >>> RuntimeError: vmap: called random operation while in randomness error mode. - -Since the calls to `vmap` are buried down the loss modules, TorchRL -provides an interface to set that vmap mode from the outside through `loss.vmap_randomness = str_value`, see -:meth:`~torchrl.objectives.LossModule.vmap_randomness` for more information. - -``LossModule.vmap_randomness`` defaults to `"error"` if no random module is detected, and to `"different"` in -other cases. By default, only a limited number of modules are listed as random, but the list can be extended -using the :func:`~torchrl.objectives.common.add_random_module` function. - -Training value functions ------------------------- - -TorchRL provides a range of **value estimators** such as TD(0), TD(1), TD(:math:`\lambda`) -and GAE. -In a nutshell, a value estimator is a function of data (mostly -rewards and done states) and a state value (ie. the value -returned by a function that is fit to estimate state-values). -To learn more about value estimators, check the introduction to RL from `Sutton -and Barto `_, -in particular the chapters about value iteration and TD learning. -It gives a somewhat biased estimation of the discounted return following a state -or a state-action pair based on data and proxy maps. These estimators are -used in two contexts: - -- To train the value network to learn the "true" state value (or state-action - value) map, one needs a target value to fit it to. The better (less bias, - less variance) the estimator, the better the value network will be, which in - turn can speed up the policy training significantly. Typically, the value - network loss will look like: - - >>> value = value_network(states) - >>> target_value = value_estimator(rewards, done, value_network(next_state)) - >>> value_net_loss = (value - target_value).pow(2).mean() - -- Computing an "advantage" signal for policy-optimization. The advantage is - the delta between the value estimate (from the estimator, ie from "real" data) - and the output of the value network (ie the proxy to this value). A positive - advantage can be seen as a signal that the policy actually performed better - than expected, thereby signaling that there is room for improvement if that - trajectory is to be taken as example. Conversely, a negative advantage signifies - that the policy underperformed compared to what was to be expected. - -Thins are not always as easy as in the example above and the formula to compute -the value estimator or the advantage may be slightly more intricate than this. -To help users flexibly use one or another value estimator, we provide a simple -API to change it on-the-fly. Here is an example with DQN, but all modules will -follow a similar structure: - - >>> from torchrl.objectives import DQNLoss, ValueEstimators - >>> loss_module = DQNLoss(actor) - >>> kwargs = {"gamma": 0.9, "lmbda": 0.9} - >>> loss_module.make_value_estimator(ValueEstimators.TDLambda, **kwargs) - -The :class:`~torchrl.objectives.ValueEstimators` class enumerates the value -estimators to choose from. This makes it easy for the users to rely on -auto-completion to make their choice. - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - LossModule - add_random_module - -DQN ---- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - DQNLoss - DistributionalDQNLoss - -DDPG ----- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - DDPGLoss - -SAC ---- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - SACLoss - DiscreteSACLoss - -REDQ ----- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - REDQLoss - -CrossQ ------- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - CrossQLoss - -IQL ---- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - IQLLoss - DiscreteIQLLoss - -CQL ---- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - CQLLoss - DiscreteCQLLoss - -GAIL ----- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - GAILLoss - -DT --- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - DTLoss - OnlineDTLoss - -TD3 ---- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - TD3Loss - -TD3+BC ------- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - TD3BCLoss - -PPO ---- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - PPOLoss - ClipPPOLoss - KLPENPPOLoss - -Using PPO with multi-head action policies -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. note:: The main tools to consider when building multi-head policies are: :class:`~tensordict.nn.CompositeDistribution`, - :class:`~tensordict.nn.ProbabilisticTensorDictModule` and :class:`~tensordict.nn.ProbabilisticTensorDictSequential`. - When dealing with these, it is recommended to call `tensordict.nn.set_composite_lp_aggregate(False).set()` at the - beginning of the script to instruct :class:`~tensordict.nn.CompositeDistribution` that log-probabilities should not - be aggregated but rather written as leaves in the tensordict. - -In some cases, we have a single advantage value but more than one action undertaken. Each action has its own -log-probability, and shape. For instance, it can be that the action space is structured as follows: - - >>> action_td = TensorDict( - ... agents=TensorDict( - ... action0=Tensor(batch, n_agents, f0), - ... action1=Tensor(batch, n_agents, f1, f2), - ... batch_size=torch.Size((batch, n_agents)) - ... ), - ... batch_size=torch.Size((batch,)) - ... ) - -where `f0`, `f1` and `f2` are some arbitrary integers. - -Note that, in TorchRL, the root tensordict has the shape of the environment (if the environment is batch-locked, otherwise it -has the shape of the number of batched environments being run). If the tensordict is sampled from the buffer, it will -also have the shape of the replay buffer `batch_size`. The `n_agent` dimension, although common to each action, does not -in general appear in the root tensordict's batch-size (although it appears in the sub-tensordict containing the -agent-specific data according to the :ref:`MARL API `). - -There is a legitimate reason why this is the case: the number of agent may condition some but not all the specs of the -environment. For example, some environments have a shared done state among all agents. A more complete tensordict -would in this case look like - - >>> action_td = TensorDict( - ... agents=TensorDict( - ... action0=Tensor(batch, n_agents, f0), - ... action1=Tensor(batch, n_agents, f1, f2), - ... observation=Tensor(batch, n_agents, f3), - ... batch_size=torch.Size((batch, n_agents)) - ... ), - ... done=Tensor(batch, 1), - ... [...] # etc - ... batch_size=torch.Size((batch,)) - ... ) - -Notice that `done` states and `reward` are usually flanked by a rightmost singleton dimension. See this :ref:`part of the doc ` -to learn more about this restriction. - -The log-probability of our actions given their respective distributions may look like anything like - - >>> action_td = TensorDict( - ... agents=TensorDict( - ... action0_log_prob=Tensor(batch, n_agents), - ... action1_log_prob=Tensor(batch, n_agents, f1), - ... batch_size=torch.Size((batch, n_agents)) - ... ), - ... batch_size=torch.Size((batch,)) - ... ) - -or - - >>> action_td = TensorDict( - ... agents=TensorDict( - ... action0_log_prob=Tensor(batch, n_agents), - ... action1_log_prob=Tensor(batch, n_agents), - ... batch_size=torch.Size((batch, n_agents)) - ... ), - ... batch_size=torch.Size((batch,)) - ... ) - -ie, the number of dimensions of distributions log-probabilities generally varies from the sample's dimensionality to -anything inferior to that, e.g. if the distribution is multivariate -- :class:`~torch.distributions.Dirichlet` for -instance -- or an :class:`~torch.distributions.Independent` instance. -The dimension of the tensordict, on the contrary, still matches the env's / replay-buffer's batch-size. - -During a call to the PPO loss, the loss module will schematically execute the following set of operations: - - >>> def ppo(tensordict): - ... prev_log_prob = tensordict.select(*log_prob_keys) - ... action = tensordict.select(*action_keys) - ... new_log_prob = dist.log_prob(action) - ... log_weight = new_log_prob - prev_log_prob - ... advantage = tensordict.get("advantage") # computed by GAE earlier - ... # attempt to map shape - ... log_weight.batch_size = advantage.batch_size[:-1] - ... log_weight = sum(log_weight.sum(dim="feature").values(True, True)) # get a single tensor of log_weights - ... return minimum(log_weight.exp() * advantage, log_weight.exp().clamp(1-eps, 1+eps) * advantage) - -To appreciate what a PPO pipeline looks like with multihead policies, an example can be found in the library's -`example directory `__. - - -A2C ---- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - A2CLoss - -Reinforce ---------- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - ReinforceLoss - -Dreamer -------- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - DreamerActorLoss - DreamerModelLoss - DreamerValueLoss - -Multi-agent objectives +TorchRL provides a comprehensive collection of loss modules for reinforcement learning algorithms. +These losses are designed to be stateful, reusable, and follow the tensordict convention. + +Key Features +------------ + +- **Stateful objects**: Contain trainable parameters accessible via ``loss_module.parameters()`` +- **TensorDict convention**: Input and output use TensorDict format +- **Structured output**: Loss values returned with ``"loss_"`` keys +- **Value estimators**: Support for TD(0), TD(λ), GAE, and more +- **Vmap support**: Efficient batched operations with customizable randomness modes + +Quick Example +------------- + +.. code-block:: python + + from torchrl.objectives import DDPGLoss + from torchrl.modules import Actor, ValueOperator + + # Create loss module + loss = DDPGLoss( + actor_network=actor, + value_network=value, + gamma=0.99, + ) + + # Compute loss + td = collector.rollout() + loss_vals = loss(td) + + # Get total loss + total_loss = sum(v for k, v in loss_vals.items() if k.startswith("loss_")) + +Documentation Sections ---------------------- -.. currentmodule:: torchrl.objectives.multiagent - -These objectives are specific to multi-agent algorithms. - -QMixer -~~~~~~ - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - QMixerLoss - - -Returns -------- - -.. _ref_returns: - -.. currentmodule:: torchrl.objectives.value - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - - ValueEstimatorBase - TD0Estimator - TD1Estimator - TDLambdaEstimator - GAE - functional.td0_return_estimate - functional.td0_advantage_estimate - functional.td1_return_estimate - functional.vec_td1_return_estimate - functional.td1_advantage_estimate - functional.vec_td1_advantage_estimate - functional.td_lambda_return_estimate - functional.vec_td_lambda_return_estimate - functional.td_lambda_advantage_estimate - functional.vec_td_lambda_advantage_estimate - functional.generalized_advantage_estimate - functional.vec_generalized_advantage_estimate - functional.reward2go - - -Utils ------ - -.. currentmodule:: torchrl.objectives - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst +.. toctree:: + :maxdepth: 2 - HardUpdate - SoftUpdate - ValueEstimators - default_value_kwargs - distance_loss - group_optimizers - hold_out_net - hold_out_params - next_state_value + objectives_common + objectives_value + objectives_policy + objectives_actorcritic + objectives_offline + objectives_other diff --git a/docs/source/reference/objectives_actorcritic.rst b/docs/source/reference/objectives_actorcritic.rst new file mode 100644 index 00000000000..7ba0690f185 --- /dev/null +++ b/docs/source/reference/objectives_actorcritic.rst @@ -0,0 +1,17 @@ +.. currentmodule:: torchrl.objectives + +Actor-Critic Methods +==================== + +Loss modules for actor-critic algorithms. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + DDPGLoss + SACLoss + DiscreteSACLoss + TD3Loss + REDQLoss + CrossQLoss diff --git a/docs/source/reference/objectives_common.rst b/docs/source/reference/objectives_common.rst new file mode 100644 index 00000000000..9ed8197715a --- /dev/null +++ b/docs/source/reference/objectives_common.rst @@ -0,0 +1,36 @@ +.. currentmodule:: torchrl.objectives + +Common Components +================= + +Base classes and common utilities for all loss modules. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + LossModule + add_random_module + +Value Estimators +---------------- + +.. currentmodule:: torchrl.objectives.value + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + ValueEstimatorBase + TD0Estimator + TD1Estimator + TDLambdaEstimator + GAE + +.. currentmodule:: torchrl.objectives + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + ValueEstimators diff --git a/docs/source/reference/objectives_offline.rst b/docs/source/reference/objectives_offline.rst new file mode 100644 index 00000000000..44a160cc4d2 --- /dev/null +++ b/docs/source/reference/objectives_offline.rst @@ -0,0 +1,16 @@ +.. currentmodule:: torchrl.objectives + +Offline RL Methods +================== + +Loss modules for offline reinforcement learning. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + CQLLoss + DiscreteCQLLoss + IQLLoss + DiscreteIQLLoss + TD3BCLoss diff --git a/docs/source/reference/objectives_other.rst b/docs/source/reference/objectives_other.rst new file mode 100644 index 00000000000..018268ed7f6 --- /dev/null +++ b/docs/source/reference/objectives_other.rst @@ -0,0 +1,17 @@ +.. currentmodule:: torchrl.objectives + +Other Loss Modules +================== + +Additional loss modules for specialized algorithms. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + GAILLoss + DTLoss + OnlineDTLoss + DreamerActorLoss + DreamerModelLoss + DreamerValueLoss diff --git a/docs/source/reference/objectives_policy.rst b/docs/source/reference/objectives_policy.rst new file mode 100644 index 00000000000..5a50bbd9c3a --- /dev/null +++ b/docs/source/reference/objectives_policy.rst @@ -0,0 +1,16 @@ +.. currentmodule:: torchrl.objectives + +Policy Gradient Methods +======================= + +Loss modules for policy gradient algorithms. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + PPOLoss + ClipPPOLoss + KLPENPPOLoss + A2CLoss + ReinforceLoss diff --git a/docs/source/reference/objectives_value.rst b/docs/source/reference/objectives_value.rst new file mode 100644 index 00000000000..63f9a37c394 --- /dev/null +++ b/docs/source/reference/objectives_value.rst @@ -0,0 +1,17 @@ +.. currentmodule:: torchrl.objectives + +Value-Based Methods +=================== + +Loss modules for value-based RL algorithms. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + DQNLoss + DistributionalDQNLoss + IQLLoss + DiscreteIQLLoss + CQLLoss + DiscreteCQLLoss diff --git a/docs/source/reference/services.rst b/docs/source/reference/services.rst index 72d26fd3c54..37b4e752267 100644 --- a/docs/source/reference/services.rst +++ b/docs/source/reference/services.rst @@ -337,7 +337,7 @@ You can easily switch between using a shared service or local processes: ) Conditional Usage Pattern -~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~ You can decide at runtime whether to use a distributed service based on your configuration: @@ -509,7 +509,6 @@ Service Registry :template: rl_template.rst get_services - reset_services ServiceBase RayService diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index c47436d11a8..adb98e78445 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -7,417 +7,45 @@ torchrl.trainers package The trainer package provides utilities to write re-usable training scripts. The core idea is to use a trainer that implements a nested loop, where the outer loop runs the data collection steps and the inner -loop the optimization steps. We believe this fits multiple RL training schemes, such as -on-policy, off-policy, model-based and model-free solutions, offline RL and others. -More particular cases, such as meta-RL algorithms may have training schemes that differ substantially. +loop the optimization steps. -The ``trainer.train()`` method can be sketched as follows: +Key Features +------------ -.. code-block:: - :caption: Trainer loops +- **Modular hook system**: Customize training at 10 different points in the loop +- **Checkpointing support**: Save and restore training state with torch or torchsnapshot +- **Algorithm trainers**: High-level trainers for PPO, SAC with Hydra configuration +- **Builder helpers**: Utilities for constructing collectors, losses, and replay buffers - >>> for batch in collector: - ... batch = self._process_batch_hook(batch) # "batch_process" - ... self._pre_steps_log_hook(batch) # "pre_steps_log" - ... self._pre_optim_hook() # "pre_optim_steps" - ... for j in range(self.optim_steps_per_batch): - ... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch" - ... losses = self.loss_module(sub_batch) - ... self._post_loss_hook(sub_batch) # "post_loss" - ... self.optimizer.step() - ... self.optimizer.zero_grad() - ... self._post_optim_hook() # "post_optim" - ... self._post_optim_log(sub_batch) # "post_optim_log" - ... self._post_steps_hook() # "post_steps" - ... self._post_steps_log_hook(batch) # "post_steps_log" - - There are 10 hooks that can be used in a trainer loop: - - >>> for batch in collector: - ... batch = self._process_batch_hook(batch) # "batch_process" - ... self._pre_steps_log_hook(batch) # "pre_steps_log" - ... self._pre_optim_hook() # "pre_optim_steps" - ... for j in range(self.optim_steps_per_batch): - ... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch" - ... losses = self.loss_module(sub_batch) - ... self._post_loss_hook(sub_batch) # "post_loss" - ... self.optimizer.step() - ... self.optimizer.zero_grad() - ... self._post_optim_hook() # "post_optim" - ... self._post_optim_log(sub_batch) # "post_optim_log" - ... self._post_steps_hook() # "post_steps" - ... self._post_steps_log_hook(batch) # "post_steps_log" - - There are 10 hooks that can be used in a trainer loop: - - >>> for batch in collector: - ... batch = self._process_batch_hook(batch) # "batch_process" - ... self._pre_steps_log_hook(batch) # "pre_steps_log" - ... self._pre_optim_hook() # "pre_optim_steps" - ... for j in range(self.optim_steps_per_batch): - ... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch" - ... losses = self.loss_module(sub_batch) - ... self._post_loss_hook(sub_batch) # "post_loss" - ... self.optimizer.step() - ... self.optimizer.zero_grad() - ... self._post_optim_hook() # "post_optim" - ... self._post_optim_log(sub_batch) # "post_optim_log" - ... self._post_steps_hook() # "post_steps" - ... self._post_steps_log_hook(batch) # "post_steps_log" - -There are 10 hooks that can be used in a trainer loop: ``"batch_process"``, ``"pre_optim_steps"``, -``"process_optim_batch"``, ``"post_loss"``, ``"post_steps"``, ``"post_optim"``, ``"pre_steps_log"``, -``"post_steps_log"``, ``"post_optim_log"`` and ``"optimizer"``. They are indicated in the comments where they are applied. -Hooks can be split into 3 categories: **data processing** (``"batch_process"`` and ``"process_optim_batch"``), -**logging** (``"pre_steps_log"``, ``"post_optim_log"`` and ``"post_steps_log"``) and **operations** hook -(``"pre_optim_steps"``, ``"post_loss"``, ``"post_optim"`` and ``"post_steps"``). - -- **Data processing** hooks update a tensordict of data. Hooks ``__call__`` method should accept - a ``TensorDict`` object as input and update it given some strategy. - Examples of such hooks include Replay Buffer extension (``ReplayBufferTrainer.extend``), data normalization (including normalization - constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such. - -- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger - some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward - logger (``LogScalar``) and such. Hooks should return a dictionary (or a None value) containing the - data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value - should be displayed on the progression bar printed on the training log. - -- **Operation** hooks are hooks that execute specific operations over the models, data collectors, - target network updates and such. For instance, syncing the weights of the collectors using ``UpdateWeights`` - or update the priority of the replay buffer using ``ReplayBufferTrainer.update_priority`` are examples - of operation hooks. They are data-independent (they do not require a ``TensorDict`` - input), they are just supposed to be executed once at every iteration (or every N iterations). - -The hooks provided by TorchRL usually inherit from a common abstract class ``TrainerHookBase``, -and all implement three base methods: a ``state_dict`` and ``load_state_dict`` method for -checkpointing and a ``register`` method that registers the hook at the default value in the -trainer. This method takes a trainer and a module name as input. For instance, the following logging -hook is executed every 10 calls to ``"post_optim_log"``: - -.. code-block:: - - >>> class LoggingHook(TrainerHookBase): - ... def __init__(self): - ... self.counter = 0 - ... - ... def register(self, trainer, name): - ... trainer.register_module(self, "logging_hook") - ... trainer.register_op("post_optim_log", self) - ... - ... def save_dict(self): - ... return {"counter": self.counter} - ... - ... def load_state_dict(self, state_dict): - ... self.counter = state_dict["counter"] - ... - ... def __call__(self, batch): - ... if self.counter % 10 == 0: - ... self.counter += 1 - ... out = {"some_value": batch["some_value"].item(), "log_pbar": False} - ... else: - ... out = None - ... self.counter += 1 - ... return out - -Checkpointing +Quick Example ------------- -The trainer class and hooks support checkpointing, which can be achieved either -using the `torchsnapshot `_ backend or -the regular torch backend. This can be controlled via the global variable ``CKPT_BACKEND``: - -.. code-block:: - - $ CKPT_BACKEND=torchsnapshot python script.py - -``CKPT_BACKEND`` defaults to ``torch``. The advantage of torchsnapshot over pytorch -is that it is a more flexible API, which supports distributed checkpointing and -also allows users to load tensors from a file stored on disk to a tensor with a -physical storage (which pytorch currently does not support). This allows, for instance, -to load tensors from and to a replay buffer that would otherwise not fit in memory. - -When building a trainer, one can provide a path where the checkpoints are to -be written. With the ``torchsnapshot`` backend, a directory path is expected, -whereas the ``torch`` backend expects a file path (typically a ``.pt`` file). - -.. code-block:: - - >>> filepath = "path/to/dir/or/file" - >>> trainer = Trainer( - ... collector=collector, - ... total_frames=total_frames, - ... frame_skip=frame_skip, - ... loss_module=loss_module, - ... optimizer=optimizer, - ... save_trainer_file=filepath, - ... ) - >>> select_keys = SelectKeys(["action", "observation"]) - >>> select_keys.register(trainer) - >>> # to save to a path - >>> trainer.save_trainer(True) - >>> # to load from a path - >>> trainer.load_from_file(filepath) - -The ``Trainer.train()`` method can be used to execute the above loop with all of -its hooks, although using the :obj:`Trainer` class for its checkpointing capability -only is also a perfectly valid use. - - -Trainer and hooks ------------------ - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - BatchSubSampler - ClearCudaCache - CountFramesLog - LogScalar - OptimizerHook - LogValidationReward - ReplayBufferTrainer - RewardNormalizer - SelectKeys - Trainer - TrainerHookBase - UpdateWeights - TargetNetUpdaterHook - UTDRHook - - -Algorithm-specific trainers (Experimental) ------------------------------------------- - -.. warning:: - The following trainers are experimental/prototype features. The API may change in future versions. - Please report any issues or feedback to help improve these implementations. - -TorchRL provides high-level, algorithm-specific trainers that combine the modular components -into complete training solutions with sensible defaults and comprehensive configuration options. - -.. currentmodule:: torchrl.trainers.algorithms - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - PPOTrainer - SACTrainer - -Algorithm Trainers -~~~~~~~~~~~~~~~~~~ - -TorchRL provides high-level algorithm trainers that offer complete training solutions with minimal code. -These trainers feature comprehensive configuration systems built on Hydra, enabling both simple usage -and sophisticated customization. - -**Currently Available:** - -- :class:`~torchrl.trainers.algorithms.PPOTrainer` - Proximal Policy Optimization -- :class:`~torchrl.trainers.algorithms.SACTrainer` - Soft Actor-Critic - -**Key Features:** - -- **Complete pipeline**: Environment setup, data collection, and optimization -- **Hydra configuration**: Extensive dataclass-based configuration system -- **Built-in logging**: Rewards, actions, and algorithm-specific metrics -- **Modular design**: Built on existing TorchRL components -- **Minimal code**: Complete SOTA implementations in ~20 lines! - -.. warning:: - Algorithm trainers are experimental features. The API may change in future versions. - We welcome feedback and contributions to help improve these implementations! - -Quick Start Examples -^^^^^^^^^^^^^^^^^^^^ - -**PPO Training:** - -.. code-block:: bash - - # Train PPO on Pendulum-v1 with default settings - python sota-implementations/ppo_trainer/train.py - -**SAC Training:** - -.. code-block:: bash - - # Train SAC on a continuous control task - python sota-implementations/sac_trainer/train.py - -**Custom Configuration:** - -.. code-block:: bash - - # Override parameters for any algorithm - python sota-implementations/ppo_trainer/train.py \ - trainer.total_frames=2000000 \ - training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \ - networks.policy_network.num_cells=[256,256] \ - optimizer.lr=0.0003 - -**Environment Switching:** - -.. code-block:: bash - - # Switch environment and logger for any trainer - python sota-implementations/sac_trainer/train.py \ - training_env.create_env_fn.base_env.env_name=Walker2d-v4 \ - logger=tensorboard \ - logger.exp_name=sac_walker2d - -**View Configuration Options:** - -.. code-block:: bash - - # See all available options for any trainer - python sota-implementations/ppo_trainer/train.py --help - python sota-implementations/sac_trainer/train.py --help - -Universal Configuration System -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -All algorithm trainers share a unified configuration architecture organized into logical groups: - -- **Environment**: ``training_env.create_env_fn.base_env.env_name``, ``training_env.num_workers`` -- **Networks**: ``networks.policy_network.num_cells``, ``networks.value_network.num_cells`` -- **Training**: ``trainer.total_frames``, ``trainer.clip_norm``, ``optimizer.lr`` -- **Data**: ``collector.frames_per_batch``, ``replay_buffer.batch_size``, ``replay_buffer.storage.max_size`` -- **Logging**: ``logger.exp_name``, ``logger.project``, ``trainer.log_interval`` - -**Working Example:** - -All trainer implementations follow the same simple pattern: - .. code-block:: python - import hydra - from torchrl.trainers.algorithms.configs import * - - @hydra.main(config_path="config", config_name="config", version_base="1.1") - def main(cfg): - trainer = hydra.utils.instantiate(cfg.trainer) - trainer.train() - - if __name__ == "__main__": - main() - -*Complete algorithm training with full configurability in ~20 lines!* - -Configuration Classes -^^^^^^^^^^^^^^^^^^^^^ - -The trainer system uses a hierarchical configuration system with shared components. - -.. note:: - The configuration system requires Python 3.10+ due to its use of modern type annotation syntax. - -**Algorithm-Specific Trainers:** - -- **PPO**: :class:`~torchrl.trainers.algorithms.configs.trainers.PPOTrainerConfig` -- **SAC**: :class:`~torchrl.trainers.algorithms.configs.trainers.SACTrainerConfig` - -**Shared Configuration Components:** - -- **Environment**: :class:`~torchrl.trainers.algorithms.configs.envs_libs.GymEnvConfig`, :class:`~torchrl.trainers.algorithms.configs.envs.BatchedEnvConfig` -- **Networks**: :class:`~torchrl.trainers.algorithms.configs.modules.MLPConfig`, :class:`~torchrl.trainers.algorithms.configs.modules.TanhNormalModelConfig` -- **Data**: :class:`~torchrl.trainers.algorithms.configs.data.TensorDictReplayBufferConfig`, :class:`~torchrl.trainers.algorithms.configs.collectors.MultiaSyncDataCollectorConfig` -- **Objectives**: :class:`~torchrl.trainers.algorithms.configs.objectives.PPOLossConfig`, :class:`~torchrl.trainers.algorithms.configs.objectives.SACLossConfig` -- **Optimizers**: :class:`~torchrl.trainers.algorithms.configs.utils.AdamConfig`, :class:`~torchrl.trainers.algorithms.configs.utils.AdamWConfig` -- **Logging**: :class:`~torchrl.trainers.algorithms.configs.logging.WandbLoggerConfig`, :class:`~torchrl.trainers.algorithms.configs.logging.TensorboardLoggerConfig` - -Algorithm-Specific Features -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -**PPOTrainer:** - -- On-policy learning with advantage estimation -- Policy clipping and value function optimization -- Configurable number of epochs per batch -- Built-in GAE (Generalized Advantage Estimation) - -**SACTrainer:** - -- Off-policy learning with replay buffer -- Entropy-regularized policy optimization -- Target network soft updates -- Continuous action space optimization - -**Future Development:** - -The trainer system is actively expanding. Upcoming features include: - -- Additional algorithms: TD3, DQN, A2C, DDPG, and more -- Enhanced distributed training support -- Advanced configuration validation and error reporting -- Integration with more TorchRL ecosystem components - -See the complete `configuration system documentation `_ for all available options and examples. - - -Builders --------- - -.. currentmodule:: torchrl.trainers.helpers - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - make_collector_offpolicy - make_collector_onpolicy - make_dqn_loss - make_replay_buffer - make_target_updater - make_trainer - parallel_env_constructor - sync_async_collector - sync_sync_collector - transformed_env_constructor - -Utils ------ - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - correct_for_frame_skip - get_stats_random_rollout - -Loggers -------- - -.. _ref_loggers: - -.. currentmodule:: torchrl.record.loggers - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - Logger - csv.CSVLogger - mlflow.MLFlowLogger - tensorboard.TensorboardLogger - wandb.WandbLogger - get_logger - generate_exp_name - - -Recording utils ---------------- - -Recording utils are detailed :ref:`here `. - -.. currentmodule:: torchrl.record - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - VideoRecorder - TensorDictRecorder - PixelRenderTransform + from torchrl.trainers import Trainer + from torchrl.trainers import UpdateWeights, LogScalar + + # Create trainer + trainer = Trainer( + collector=collector, + total_frames=1000000, + loss_module=loss, + optimizer=optimizer, + ) + + # Register hooks + UpdateWeights(collector, 10).register(trainer) + LogScalar("reward").register(trainer) + + # Train + trainer.train() + +Documentation Sections +---------------------- + +.. toctree:: + :maxdepth: 2 + + trainers_basics + trainers_loggers + trainers_hooks diff --git a/docs/source/reference/trainers_basics.rst b/docs/source/reference/trainers_basics.rst new file mode 100644 index 00000000000..217ad92b2ec --- /dev/null +++ b/docs/source/reference/trainers_basics.rst @@ -0,0 +1,58 @@ +.. currentmodule:: torchrl.trainers + +Trainer Basics +============== + +Core trainer classes and builder utilities. + +Trainer and hooks +----------------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + Trainer + TrainerHookBase + +Algorithm-specific trainers +--------------------------- + +.. currentmodule:: torchrl.trainers.algorithms + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + PPOTrainer + SACTrainer + +Builders +-------- + +.. currentmodule:: torchrl.trainers.helpers + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + make_collector_offpolicy + make_collector_onpolicy + make_dqn_loss + make_replay_buffer + make_target_updater + make_trainer + parallel_env_constructor + sync_async_collector + sync_sync_collector + transformed_env_constructor + +Utils +----- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + correct_for_frame_skip + get_stats_random_rollout diff --git a/docs/source/reference/trainers_hooks.rst b/docs/source/reference/trainers_hooks.rst new file mode 100644 index 00000000000..2a6ed8ba7e8 --- /dev/null +++ b/docs/source/reference/trainers_hooks.rst @@ -0,0 +1,23 @@ +.. currentmodule:: torchrl.trainers + +Training Hooks +============== + +Hooks for customizing the training loop at various points. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + BatchSubSampler + ClearCudaCache + CountFramesLog + LogScalar + OptimizerHook + LogValidationReward + ReplayBufferTrainer + RewardNormalizer + SelectKeys + UpdateWeights + TargetNetUpdaterHook + UTDRHook diff --git a/docs/source/reference/trainers_loggers.rst b/docs/source/reference/trainers_loggers.rst new file mode 100644 index 00000000000..d1bf25c48ba --- /dev/null +++ b/docs/source/reference/trainers_loggers.rst @@ -0,0 +1,33 @@ +.. currentmodule:: torchrl.record.loggers + +.. _ref_loggers: + +Loggers +======= + +Logger classes for experiment tracking and visualization. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + Logger + csv.CSVLogger + mlflow.MLFlowLogger + tensorboard.TensorboardLogger + wandb.WandbLogger + get_logger + generate_exp_name + +Recording utils +--------------- + +.. currentmodule:: torchrl.record + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + VideoRecorder + TensorDictRecorder + PixelRenderTransform diff --git a/scripts/check-sphinx-section-underline b/scripts/check-sphinx-section-underline new file mode 100755 index 00000000000..9d9de2524c9 --- /dev/null +++ b/scripts/check-sphinx-section-underline @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +"""Check that Sphinx/ReST section underlines match title lengths.""" +import re +import sys +from pathlib import Path + +# Pattern to match Sphinx section titles followed by underlines +# Supports: = - ` : . ' " ~ ^ _ * + # < > +SECTION_PATTERN = re.compile(r"^([^\n]+)\n([~=\-^`':\"#\*_\+.<>]+)\n", re.MULTILINE) + + +def fix_file(path): + """Fix underline length mismatches in a file.""" + try: + text = Path(path).read_text(encoding="utf-8") + except Exception as e: + print(f"Warning: Could not read {path}: {e}") + return False + + original_text = text + fixed_count = 0 + + def replace_underline(match): + nonlocal fixed_count + title, underline = match.groups() + title_stripped = title.strip() + underline_stripped = underline.strip() + + # Skip if title is empty or looks like it might be a code block or other content + if not title_stripped or title_stripped.startswith(".."): + return match.group(0) + + if len(title_stripped) != len(underline_stripped): + # Get the underline character and create correct length underline + underline_char = underline_stripped[0] + correct_underline = underline_char * len(title_stripped) + fixed_count += 1 + return f"{title}\n{correct_underline}\n" + + return match.group(0) + + text = SECTION_PATTERN.sub(replace_underline, text) + + if text != original_text: + Path(path).write_text(text, encoding="utf-8") + return fixed_count + + return 0 + + +def check_file(path): + """Check a single file for underline length mismatches.""" + try: + text = Path(path).read_text(encoding="utf-8") + except Exception as e: + print(f"Warning: Could not read {path}: {e}") + return [] + + errors = [] + + for match in SECTION_PATTERN.finditer(text): + title, underline = match.groups() + title_stripped = title.strip() + underline_stripped = underline.strip() + + # Skip if title is empty or looks like it might be a code block or other content + if not title_stripped or title_stripped.startswith(".."): + continue + + if len(title_stripped) != len(underline_stripped): + # Calculate line number + line_num = text.count("\n", 0, match.start()) + 1 + errors.append( + f"{path}:{line_num}: " + f"title '{title_stripped}' length {len(title_stripped)}, " + f"underline length {len(underline_stripped)}" + ) + + return errors + + +def main(argv): + """Main entry point for the hook.""" + # Check for --fix flag + fix_mode = "--fix" in argv + if fix_mode: + argv = [arg for arg in argv if arg != "--fix"] + + if len(argv) < 2: + print("✅ Sphinx section underline check: no files to check.") + sys.exit(0) + + if fix_mode: + total_fixed = 0 + for path in argv[1:]: + fixed_count = fix_file(path) + if fixed_count: + print(f"✏️ Fixed {fixed_count} section(s) in {path}") + total_fixed += fixed_count + + if total_fixed: + print(f"\n✅ Fixed {total_fixed} section underline(s) total.") + sys.exit(0) + else: + print("✅ Sphinx section underline check: no fixes needed.") + sys.exit(0) + else: + all_errors = [] + for path in argv[1:]: + all_errors.extend(check_file(path)) + + if all_errors: + print("❌ Sphinx section underline length errors:\n") + for e in all_errors: + print(" ", e) + print("\nFix underline lengths to match title text.") + print("Or run with --fix flag to automatically fix them:") + print( + f" python3 scripts/check-sphinx-section-underline --fix {' '.join(argv[1:])}" + ) + sys.exit(1) + else: + print("✅ Sphinx section underline check passed.") + sys.exit(0) + + +if __name__ == "__main__": + main(sys.argv) diff --git a/setup.cfg b/setup.cfg index 2c7173763fb..525b66f6394 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,6 +25,7 @@ per-file-ignores = torchrl/objectives/td3.py: TOR101 torchrl/objectives/value/advantages.py: TOR101 tutorials/*/**.py: T001, T201 + scripts/*: T001, T201 examples/*.py: T001, T201 packaging/verify_nightly_version.py: T001, T201 test/opengl_rendering.py: T001, T201 diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 34b35da9446..42d13108a0f 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -222,7 +222,7 @@ def _send_buffer_to_workers( # Wait for acknowledgments from all workers for pipe in self._pipes: if not pipe.poll(timeout): - raise TimeoutError(f"Timeout waiting for acknowledgment from worker") + raise TimeoutError("Timeout waiting for acknowledgment from worker") _, msg = pipe.recv() if msg != "registered": raise RuntimeError(f"Expected 'registered' acknowledgment, got '{msg}'")