Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ TorchRL aims at a high modularity and good runtime performance.
tutorials/coding_dqn

.. toctree::
:maxdepth: 2
:maxdepth: 3
:caption: References:

reference/index
Expand Down
3 changes: 2 additions & 1 deletion docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,15 @@ Besides those compute parameters, users may choose to configure the following pa

Data collectors
---------------

.. currentmodule:: torchrl.collectors.collectors

.. autosummary::
:toctree: generated/
:template: rl_template.rst

MultiSyncDataCollector
MultiaSyncDataCollector
RandomPolicy
SyncDataCollector
aSyncDataCollector

Expand Down
3 changes: 3 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ TensorSpec

The `TensorSpec` parent class and subclasses define the basic properties of observations and actions in TorchRL, such
as shape, device, dtype and domain.
It is important that your environment specs match the input and output that it sends and receives, as
:obj:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes.
Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.


.. autosummary::
Expand Down
31 changes: 28 additions & 3 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ With these, the following methods are implemented:

EnvBase
GymLikeEnv
EnvMetaData
Specs

Vectorized envs
---------------
Expand All @@ -75,11 +77,16 @@ environments in parallel.
As this class inherits from :obj:`EnvBase`, it enjoys the exact same API as other environment.
Of course, a :obj:`ParallelEnv` will have a batch size that corresponds to its environment count:

It is important that your environment specs match the input and output that it sends and receives, as
:obj:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes.
Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.

.. code-block::
:caption: Parallel environment

>>> def make_env():
... return GymEnv("Pendulum-v1", from_pixels=True, g=9.81, device="cuda:0")
>>> check_env_specs(env) # this must pass for ParallelEnv to work
>>> env = ParallelEnv(4, make_env)
>>> print(env.batch_size)
torch.Size([4])
Expand Down Expand Up @@ -135,6 +142,7 @@ behaviour of a :obj:`ParallelEnv` without launching the subprocesses.

SerialEnv
ParallelEnv
EnvCreator


Transforms
Expand Down Expand Up @@ -224,6 +232,21 @@ in the environment. The keys to be included in this inverse transform are passed
VIPTransform
VIPRewardTransform

Recorders
---------

.. currentmodule:: torchrl.record

Recorders are transforms that register data as they come in, for logging purposes.

.. autosummary::
:toctree: generated/
:template: rl_template_fun.rst

TensorDictRecorder
VideoRecorder


Helpers
-------
.. currentmodule:: torchrl.envs.utils
Expand Down Expand Up @@ -258,10 +281,12 @@ Libraries
:toctree: generated/
:template: rl_template_fun.rst

gym.GymEnv
gym.GymWrapper
brax.BraxEnv
brax.BraxWrapper
dm_control.DMControlEnv
dm_control.DMControlWrapper
gym.GymEnv
gym.GymWrapper
habitat.HabitatEnv
jumanji.JumanjiEnv
jumanji.JumanjiWrapper
habitat.HabitatEnv
27 changes: 17 additions & 10 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,28 @@ torchrl.modules package

TensorDict modules
------------------

.. currentmodule:: torchrl.modules.tensordict_module

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

SafeModule
SafeProbabilisticModule
SafeSequential
SafeProbabilisticSequential
Actor
ProbabilisticActor
ValueOperator
QValueActor
DistributionalQValueActor
ActorValueOperator
ActorCriticOperator
ActorCriticWrapper
ActorValueOperator
DistributionalQValueActor
ProbabilisticActor
QValueActor
ValueOperator
SafeModule
AdditiveGaussianWrapper
EGreedyWrapper
OrnsteinUhlenbeckProcessWrapper
SafeProbabilisticModule
SafeProbabilisticSequential
SafeSequential
WorldModelWrapper
tensordict_module.common.is_tensordict_compatible
tensordict_module.common.ensure_tensordict_compatible

Expand Down Expand Up @@ -84,6 +88,7 @@ Planners

CEMPlanner
MPCPlannerBase
MPPIPlanner


Distributions
Expand All @@ -93,6 +98,8 @@ Distributions
:template: rl_template_noinherit.rst

Delta
IndependentNormal
NormalParamWrapper
TanhNormal
TruncatedNormal
TanhDelta
Expand Down
20 changes: 20 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@ A2C

A2CLoss

Reinforce
---------

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

ReinforceLoss

Dreamer
-------

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

DreamerActorLoss
DreamerModelLoss
DreamerValueLoss


Returns
-------
Expand Down
48 changes: 32 additions & 16 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,18 @@ Trainer and hooks
:toctree: generated/
:template: rl_template.rst

Trainer
BatchSubSampler
ClearCudaCache
CountFramesLog
LogReward
OptimizerHook
Recorder
ReplayBuffer
RewardNormalizer
SelectKeys
Trainer
TrainerHookBase
UpdateWeights
ClearCudaCache


Builders
Expand All @@ -157,27 +158,27 @@ Builders
:toctree: generated/
:template: rl_template_fun.rst

make_trainer
sync_sync_collector
sync_async_collector
make_a2c_loss
make_a2c_model
make_collector_offpolicy
make_collector_onpolicy
transformed_env_constructor
parallel_env_constructor
make_sac_loss
make_a2c_loss
make_dqn_loss
make_ddpg_actor
make_ddpg_loss
make_target_updater
make_ppo_loss
make_redq_loss
make_dqn_actor
make_ddpg_actor
make_dqn_loss
make_ppo_loss
make_ppo_model
make_a2c_model
make_sac_model
make_redq_loss
make_redq_model
make_replay_buffer
make_sac_loss
make_sac_model
make_target_updater
make_trainer
parallel_env_constructor
sync_async_collector
sync_sync_collector
transformed_env_constructor

Utils
-----
Expand All @@ -188,3 +189,18 @@ Utils

correct_for_frame_skip
get_stats_random_rollout

Loggers
-------

.. currentmodule:: torchrl.trainers.loggers

.. autosummary::
:toctree: generated/
:template: rl_template_fun.rst

Logger
CSVLogger
MLFlowLogger
TensorboardLogger
WandbLogger
1 change: 1 addition & 0 deletions torchrl/modules/planners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

from .cem import CEMPlanner
from .common import MPCPlannerBase
from .mppi import MPPIPlanner
1 change: 1 addition & 0 deletions torchrl/trainers/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@

from .common import Logger
from .csv import CSVLogger
from .mlflow import MLFlowLogger
from .tensorboard import TensorboardLogger
from .wandb import WandbLogger