From 2b7945c0cb0e021d79844e5dffc867ec1bd383c0 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 5 Jan 2023 16:22:29 +0000 Subject: [PATCH 1/5] init --- docs/source/index.rst | 2 +- docs/source/reference/collectors.rst | 3 +- docs/source/reference/data.rst | 3 ++ docs/source/reference/envs.rst | 31 ++++++++++++++++-- docs/source/reference/modules.rst | 27 ++++++++++------ docs/source/reference/objectives.rst | 20 ++++++++++++ docs/source/reference/trainers.rst | 48 ++++++++++++++++++---------- torchrl/modules/planners/__init__.py | 1 + torchrl/trainers/loggers/__init__.py | 1 + 9 files changed, 105 insertions(+), 31 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index c7550d79e80..480848ce6fa 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -40,7 +40,7 @@ TorchRL aims at a high modularity and good runtime performance. tutorials/coding_dqn .. toctree:: - :maxdepth: 2 + :maxdepth: 3 :caption: References: reference/index diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index b1eef7305bd..14a9fac8b59 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -60,7 +60,7 @@ Besides those compute parameters, users may choose to configure the following pa Data collectors --------------- - +.. currentmodule:: torchrl.collectors.collectors .. autosummary:: :toctree: generated/ @@ -68,6 +68,7 @@ Data collectors MultiSyncDataCollector MultiaSyncDataCollector + RandomPolicy SyncDataCollector aSyncDataCollector diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 330f96c725a..d709abf6cf3 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -64,6 +64,9 @@ TensorSpec The `TensorSpec` parent class and subclasses define the basic properties of observations and actions in TorchRL, such as shape, device, dtype and domain. +It is important that your environment specs match the input and output that it sends and receives, as +:obj:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes. +Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check. .. autosummary:: diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 4fa10ffc581..33a9c79d158 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -60,6 +60,8 @@ With these, the following methods are implemented: EnvBase GymLikeEnv + EnvMetaData + Specs Vectorized envs --------------- @@ -75,11 +77,16 @@ environments in parallel. As this class inherits from :obj:`EnvBase`, it enjoys the exact same API as other environment. Of course, a :obj:`ParallelEnv` will have a batch size that corresponds to its environment count: +It is important that your environment specs match the input and output that it sends and receives, as +:obj:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes. +Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check. + .. code-block:: :caption: Parallel environment >>> def make_env(): ... return GymEnv("Pendulum-v1", from_pixels=True, g=9.81, device="cuda:0") + >>> check_env_specs(env) # this must pass for ParallelEnv to work >>> env = ParallelEnv(4, make_env) >>> print(env.batch_size) torch.Size([4]) @@ -135,6 +142,7 @@ behaviour of a :obj:`ParallelEnv` without launching the subprocesses. SerialEnv ParallelEnv + EnvCreator Transforms @@ -224,6 +232,21 @@ in the environment. The keys to be included in this inverse transform are passed VIPTransform VIPRewardTransform +Recorders +--------- + +.. currentmodule:: torchrl.record + +Recorders are transforms that register data as they come in, for logging purposes. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + TensorDictRecorder + VideoRecorder + + Helpers ------- .. currentmodule:: torchrl.envs.utils @@ -258,10 +281,12 @@ Libraries :toctree: generated/ :template: rl_template_fun.rst - gym.GymEnv - gym.GymWrapper + brax.BraxEnv + brax.BraxWrapper dm_control.DMControlEnv dm_control.DMControlWrapper + gym.GymEnv + gym.GymWrapper + habitat.HabitatEnv jumanji.JumanjiEnv jumanji.JumanjiWrapper - habitat.HabitatEnv diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 3de40fb1e20..face9fdf0c3 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -5,24 +5,28 @@ torchrl.modules package TensorDict modules ------------------ - +.. currentmodule:: torchrl.modules.tensordict_module .. autosummary:: :toctree: generated/ :template: rl_template_noinherit.rst - SafeModule - SafeProbabilisticModule - SafeSequential - SafeProbabilisticSequential Actor - ProbabilisticActor - ValueOperator - QValueActor - DistributionalQValueActor - ActorValueOperator ActorCriticOperator ActorCriticWrapper + ActorValueOperator + DistributionalQValueActor + ProbabilisticActor + QValueActor + ValueOperator + SafeModule + AdditiveGaussianWrapper + EGreedyWrapper + OrnsteinUhlenbeckProcessWrapper + SafeProbabilisticModule + SafeProbabilisticSequential + SafeSequential + WorldModelWrapper tensordict_module.common.is_tensordict_compatible tensordict_module.common.ensure_tensordict_compatible @@ -84,6 +88,7 @@ Planners CEMPlanner MPCPlannerBase + MPPIPlanner Distributions @@ -93,6 +98,8 @@ Distributions :template: rl_template_noinherit.rst Delta + IndependentNormal + NormalParamWrapper TanhNormal TruncatedNormal TanhDelta diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index b7dcb436d31..20f5de4b923 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -60,6 +60,26 @@ A2C A2CLoss +Reinforce +--------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + ReinforceLoss + +Dreamer +------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + DreamerActorLoss + DreamerModelLoss + DreamerValueLoss + Returns ------- diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index d6e411dc2f1..596009157bf 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -135,8 +135,8 @@ Trainer and hooks :toctree: generated/ :template: rl_template.rst - Trainer BatchSubSampler + ClearCudaCache CountFramesLog LogReward OptimizerHook @@ -144,8 +144,9 @@ Trainer and hooks ReplayBuffer RewardNormalizer SelectKeys + Trainer + TrainerHookBase UpdateWeights - ClearCudaCache Builders @@ -157,27 +158,27 @@ Builders :toctree: generated/ :template: rl_template_fun.rst - make_trainer - sync_sync_collector - sync_async_collector + make_a2c_loss + make_a2c_model make_collector_offpolicy make_collector_onpolicy - transformed_env_constructor - parallel_env_constructor - make_sac_loss - make_a2c_loss - make_dqn_loss + make_ddpg_actor make_ddpg_loss - make_target_updater - make_ppo_loss - make_redq_loss make_dqn_actor - make_ddpg_actor + make_dqn_loss + make_ppo_loss make_ppo_model - make_a2c_model - make_sac_model + make_redq_loss make_redq_model make_replay_buffer + make_sac_loss + make_sac_model + make_target_updater + make_trainer + parallel_env_constructor + sync_async_collector + sync_sync_collector + transformed_env_constructor Utils ----- @@ -188,3 +189,18 @@ Utils correct_for_frame_skip get_stats_random_rollout + +Loggers +------- + +.. currentmodule:: torchrl.trainers.loggers + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + Logger + CSVLogger + MLFlowLogger + TensorboardLogger + WandbLogger diff --git a/torchrl/modules/planners/__init__.py b/torchrl/modules/planners/__init__.py index ab6c72595f0..56c0e48bc65 100644 --- a/torchrl/modules/planners/__init__.py +++ b/torchrl/modules/planners/__init__.py @@ -5,3 +5,4 @@ from .cem import CEMPlanner from .common import MPCPlannerBase +from .mppi import MPPIPlanner diff --git a/torchrl/trainers/loggers/__init__.py b/torchrl/trainers/loggers/__init__.py index 6db613dad47..87558181125 100644 --- a/torchrl/trainers/loggers/__init__.py +++ b/torchrl/trainers/loggers/__init__.py @@ -5,5 +5,6 @@ from .common import Logger from .csv import CSVLogger +from .mlflow import MLFlowLogger from .tensorboard import TensorboardLogger from .wandb import WandbLogger From eb5bd2676f8a127195895dcd5adf24e24c04c40e Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 5 Jan 2023 17:10:34 +0000 Subject: [PATCH 2/5] missing headers --- benchmarks/storage/benchmark_sample_latency_over_rpc.py | 5 +++++ docs/source/reference/trainers.rst | 2 ++ test/opengl_rendering.py | 5 +++++ test/smoke_test.py | 6 ++++++ test/smoke_test_deps.py | 5 +++++ test/test_actors.py | 5 +++++ test/test_loggers.py | 5 +++++ test/test_rb_distributed.py | 5 +++++ torchrl/data/replay_buffers/rb_prototype.py | 5 +++++ torchrl/data/replay_buffers/samplers.py | 5 +++++ torchrl/data/replay_buffers/writers.py | 5 +++++ torchrl/envs/libs/brax.py | 5 +++++ torchrl/envs/libs/habitat.py | 5 +++++ torchrl/envs/libs/jax_utils.py | 5 +++++ torchrl/envs/libs/jumanji.py | 5 +++++ torchrl/modules/__init__.py | 2 +- torchrl/objectives/__init__.py | 1 + torchrl/objectives/deprecated.py | 5 +++++ torchrl/objectives/reinforce.py | 5 +++++ torchrl/trainers/__init__.py | 2 -- torchrl/trainers/loggers/__init__.py | 1 + torchrl/trainers/loggers/utils.py | 6 ++++++ 22 files changed, 92 insertions(+), 3 deletions(-) diff --git a/benchmarks/storage/benchmark_sample_latency_over_rpc.py b/benchmarks/storage/benchmark_sample_latency_over_rpc.py index d922095de5f..6bd1619419c 100644 --- a/benchmarks/storage/benchmark_sample_latency_over_rpc.py +++ b/benchmarks/storage/benchmark_sample_latency_over_rpc.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + """ Sample latency benchmarking (using RPC) ====================================== diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index 596009157bf..6be11a5497e 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -204,3 +204,5 @@ Loggers MLFlowLogger TensorboardLogger WandbLogger + get_logger + generate_exp_name diff --git a/test/opengl_rendering.py b/test/opengl_rendering.py index c5be1f5bad3..7533e298069 100644 --- a/test/opengl_rendering.py +++ b/test/opengl_rendering.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + """Headless GPU-accelerated OpenGL context creation on Google Colaboratory. Typical usage: # Optional PyOpenGL configuratiopn can be done here. diff --git a/test/smoke_test.py b/test/smoke_test.py index f0db69def86..313c786088c 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + def test_imports(): from torchrl.data import ( PrioritizedReplayBuffer, diff --git a/test/smoke_test_deps.py b/test/smoke_test_deps.py index 56463039bf4..71657cc9c69 100644 --- a/test/smoke_test_deps.py +++ b/test/smoke_test_deps.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import argparse import tempfile diff --git a/test/test_actors.py b/test/test_actors.py index 9fdf8bc3882..0d9de574e0d 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import pytest import torch from torchrl.modules.tensordict_module.actors import ( diff --git a/test/test_loggers.py b/test/test_loggers.py index 1f343a71dc7..ee02c0bbd94 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import argparse import os import os.path diff --git a/test/test_rb_distributed.py b/test/test_rb_distributed.py index 6b3b8482705..15e594fa78c 100644 --- a/test/test_rb_distributed.py +++ b/test/test_rb_distributed.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import os import time diff --git a/torchrl/data/replay_buffers/rb_prototype.py b/torchrl/data/replay_buffers/rb_prototype.py index 8534bba46b1..1817ae7d67a 100644 --- a/torchrl/data/replay_buffers/rb_prototype.py +++ b/torchrl/data/replay_buffers/rb_prototype.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import collections import threading from concurrent.futures import ThreadPoolExecutor diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 93dd50142a3..1dafaa31a98 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + from abc import ABC, abstractmethod from typing import Any, Tuple, Union diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index f058dd32f2d..75a6aa5d971 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + from abc import ABC, abstractmethod from typing import Any, Sequence diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index a1fbca2079b..529742a328a 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + from typing import Dict, Optional, Union import torch diff --git a/torchrl/envs/libs/habitat.py b/torchrl/envs/libs/habitat.py index 347e00fafd3..834836b6110 100644 --- a/torchrl/envs/libs/habitat.py +++ b/torchrl/envs/libs/habitat.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + from torchrl.envs.libs.gym import GymEnv try: diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index 1319b5cf77b..0819e2fe90c 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import dataclasses from typing import Union diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 95a213b96a3..99e99b76326 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + from typing import Dict, Optional, Union import numpy as np diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 61923f9dc7d..ab56a37d250 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -52,4 +52,4 @@ ValueOperator, WorldModelWrapper, ) -from .planners import CEMPlanner, MPCPlannerBase # usort:skip +from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 2cf6a1e8eef..eb173efb976 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -10,6 +10,7 @@ from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss from .ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss from .redq import REDQLoss +from .reinforce import ReinforceLoss from .sac import SACLoss from .utils import ( distance_loss, diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 69d09fc3fe0..9f493525b8b 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import math from numbers import Number from typing import Tuple, Union diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index f7330d7d52b..d7a37ac82b7 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + from typing import Optional import torch diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 035cdc60b27..507c914d1e7 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -18,5 +18,3 @@ Trainer, UpdateWeights, ) - -# from .loggers import * diff --git a/torchrl/trainers/loggers/__init__.py b/torchrl/trainers/loggers/__init__.py index 87558181125..fc3c27ace75 100644 --- a/torchrl/trainers/loggers/__init__.py +++ b/torchrl/trainers/loggers/__init__.py @@ -7,4 +7,5 @@ from .csv import CSVLogger from .mlflow import MLFlowLogger from .tensorboard import TensorboardLogger +from .utils import generate_exp_name, get_logger from .wandb import WandbLogger diff --git a/torchrl/trainers/loggers/utils.py b/torchrl/trainers/loggers/utils.py index d709a9150c1..476942c5479 100644 --- a/torchrl/trainers/loggers/utils.py +++ b/torchrl/trainers/loggers/utils.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + import os import pathlib import uuid From 94700ef31f97e599e44540b8d352ddcdeeddffa8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 5 Jan 2023 17:23:55 +0000 Subject: [PATCH 3/5] amend --- torchrl/objectives/reinforce.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index d7a37ac82b7..ee7959556c3 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -8,8 +8,8 @@ import torch from tensordict.tensordict import TensorDict, TensorDictBase -from torchrl.modules import SafeModule, SafeProbabilisticSequential -from torchrl.objectives import distance_loss +from torchrl.modules.tensordict_module import SafeModule, SafeProbabilisticSequential +from torchrl.objectives.utils import distance_loss from torchrl.objectives.common import LossModule From cfa29188e346ee13ad56098fafc29c4e63dedd8a Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 5 Jan 2023 17:38:08 +0000 Subject: [PATCH 4/5] amend --- docs/source/index.rst | 8 ++++++-- torchrl/objectives/reinforce.py | 2 +- tutorials/sphinx-tutorials/README.rst | 6 ++---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 480848ce6fa..98583aa643b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -27,9 +27,11 @@ for :doc:`cost functions `, :ref:`returns Date: Thu, 5 Jan 2023 17:47:14 +0000 Subject: [PATCH 5/5] amend --- docs/source/reference/modules.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index face9fdf0c3..e86046ab9e9 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -27,8 +27,8 @@ TensorDict modules SafeProbabilisticSequential SafeSequential WorldModelWrapper - tensordict_module.common.is_tensordict_compatible - tensordict_module.common.ensure_tensordict_compatible + common.is_tensordict_compatible + common.ensure_tensordict_compatible Hooks -----