diff --git a/docs/source/index.rst b/docs/source/index.rst index c7550d79e80..480848ce6fa 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -40,7 +40,7 @@ TorchRL aims at a high modularity and good runtime performance. tutorials/coding_dqn .. toctree:: - :maxdepth: 2 + :maxdepth: 3 :caption: References: reference/index diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index b1eef7305bd..14a9fac8b59 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -60,7 +60,7 @@ Besides those compute parameters, users may choose to configure the following pa Data collectors --------------- - +.. currentmodule:: torchrl.collectors.collectors .. autosummary:: :toctree: generated/ @@ -68,6 +68,7 @@ Data collectors MultiSyncDataCollector MultiaSyncDataCollector + RandomPolicy SyncDataCollector aSyncDataCollector diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 330f96c725a..d709abf6cf3 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -64,6 +64,9 @@ TensorSpec The `TensorSpec` parent class and subclasses define the basic properties of observations and actions in TorchRL, such as shape, device, dtype and domain. +It is important that your environment specs match the input and output that it sends and receives, as +:obj:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes. +Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check. .. autosummary:: diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 4fa10ffc581..33a9c79d158 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -60,6 +60,8 @@ With these, the following methods are implemented: EnvBase GymLikeEnv + EnvMetaData + Specs Vectorized envs --------------- @@ -75,11 +77,16 @@ environments in parallel. As this class inherits from :obj:`EnvBase`, it enjoys the exact same API as other environment. Of course, a :obj:`ParallelEnv` will have a batch size that corresponds to its environment count: +It is important that your environment specs match the input and output that it sends and receives, as +:obj:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes. +Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check. + .. code-block:: :caption: Parallel environment >>> def make_env(): ... return GymEnv("Pendulum-v1", from_pixels=True, g=9.81, device="cuda:0") + >>> check_env_specs(env) # this must pass for ParallelEnv to work >>> env = ParallelEnv(4, make_env) >>> print(env.batch_size) torch.Size([4]) @@ -135,6 +142,7 @@ behaviour of a :obj:`ParallelEnv` without launching the subprocesses. SerialEnv ParallelEnv + EnvCreator Transforms @@ -224,6 +232,21 @@ in the environment. The keys to be included in this inverse transform are passed VIPTransform VIPRewardTransform +Recorders +--------- + +.. currentmodule:: torchrl.record + +Recorders are transforms that register data as they come in, for logging purposes. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + TensorDictRecorder + VideoRecorder + + Helpers ------- .. currentmodule:: torchrl.envs.utils @@ -258,10 +281,12 @@ Libraries :toctree: generated/ :template: rl_template_fun.rst - gym.GymEnv - gym.GymWrapper + brax.BraxEnv + brax.BraxWrapper dm_control.DMControlEnv dm_control.DMControlWrapper + gym.GymEnv + gym.GymWrapper + habitat.HabitatEnv jumanji.JumanjiEnv jumanji.JumanjiWrapper - habitat.HabitatEnv diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 3de40fb1e20..face9fdf0c3 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -5,24 +5,28 @@ torchrl.modules package TensorDict modules ------------------ - +.. currentmodule:: torchrl.modules.tensordict_module .. autosummary:: :toctree: generated/ :template: rl_template_noinherit.rst - SafeModule - SafeProbabilisticModule - SafeSequential - SafeProbabilisticSequential Actor - ProbabilisticActor - ValueOperator - QValueActor - DistributionalQValueActor - ActorValueOperator ActorCriticOperator ActorCriticWrapper + ActorValueOperator + DistributionalQValueActor + ProbabilisticActor + QValueActor + ValueOperator + SafeModule + AdditiveGaussianWrapper + EGreedyWrapper + OrnsteinUhlenbeckProcessWrapper + SafeProbabilisticModule + SafeProbabilisticSequential + SafeSequential + WorldModelWrapper tensordict_module.common.is_tensordict_compatible tensordict_module.common.ensure_tensordict_compatible @@ -84,6 +88,7 @@ Planners CEMPlanner MPCPlannerBase + MPPIPlanner Distributions @@ -93,6 +98,8 @@ Distributions :template: rl_template_noinherit.rst Delta + IndependentNormal + NormalParamWrapper TanhNormal TruncatedNormal TanhDelta diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index b7dcb436d31..20f5de4b923 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -60,6 +60,26 @@ A2C A2CLoss +Reinforce +--------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + ReinforceLoss + +Dreamer +------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + DreamerActorLoss + DreamerModelLoss + DreamerValueLoss + Returns ------- diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index d6e411dc2f1..596009157bf 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -135,8 +135,8 @@ Trainer and hooks :toctree: generated/ :template: rl_template.rst - Trainer BatchSubSampler + ClearCudaCache CountFramesLog LogReward OptimizerHook @@ -144,8 +144,9 @@ Trainer and hooks ReplayBuffer RewardNormalizer SelectKeys + Trainer + TrainerHookBase UpdateWeights - ClearCudaCache Builders @@ -157,27 +158,27 @@ Builders :toctree: generated/ :template: rl_template_fun.rst - make_trainer - sync_sync_collector - sync_async_collector + make_a2c_loss + make_a2c_model make_collector_offpolicy make_collector_onpolicy - transformed_env_constructor - parallel_env_constructor - make_sac_loss - make_a2c_loss - make_dqn_loss + make_ddpg_actor make_ddpg_loss - make_target_updater - make_ppo_loss - make_redq_loss make_dqn_actor - make_ddpg_actor + make_dqn_loss + make_ppo_loss make_ppo_model - make_a2c_model - make_sac_model + make_redq_loss make_redq_model make_replay_buffer + make_sac_loss + make_sac_model + make_target_updater + make_trainer + parallel_env_constructor + sync_async_collector + sync_sync_collector + transformed_env_constructor Utils ----- @@ -188,3 +189,18 @@ Utils correct_for_frame_skip get_stats_random_rollout + +Loggers +------- + +.. currentmodule:: torchrl.trainers.loggers + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + Logger + CSVLogger + MLFlowLogger + TensorboardLogger + WandbLogger diff --git a/torchrl/modules/planners/__init__.py b/torchrl/modules/planners/__init__.py index ab6c72595f0..56c0e48bc65 100644 --- a/torchrl/modules/planners/__init__.py +++ b/torchrl/modules/planners/__init__.py @@ -5,3 +5,4 @@ from .cem import CEMPlanner from .common import MPCPlannerBase +from .mppi import MPPIPlanner diff --git a/torchrl/trainers/loggers/__init__.py b/torchrl/trainers/loggers/__init__.py index 6db613dad47..87558181125 100644 --- a/torchrl/trainers/loggers/__init__.py +++ b/torchrl/trainers/loggers/__init__.py @@ -5,5 +5,6 @@ from .common import Logger from .csv import CSVLogger +from .mlflow import MLFlowLogger from .tensorboard import TensorboardLogger from .wandb import WandbLogger