From c900f95f5683ebc7224eef27585a01425bca58f0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 13:45:27 +0100 Subject: [PATCH 01/26] init --- .github/workflows/docs.yml | 47 +++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index bcf6b5066c3..296aef263c9 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -93,7 +93,7 @@ jobs: upload: needs: build-docs if: github.repository == 'pytorch/rl' && github.event_name == 'push' && - ((github.ref_type == 'branch' && github.ref_name == 'main') || github.ref_type == 'tag') + ((github.ref_type == 'branch' && github.ref_name == 'main') || (github.ref_type == 'branch' && github.ref_name == 'release-doc') || github.ref_type == 'tag') permissions: contents: write uses: pytorch/test-infra/.github/workflows/linux_job.yml@main @@ -108,23 +108,22 @@ jobs: REF_TYPE=${{ github.ref_type }} REF_NAME=${{ github.ref_name }} - # TODO: adopt this behaviour - # if [[ "${REF_TYPE}" == branch ]]; then - # TARGET_FOLDER="${REF_NAME}" - # elif [[ "${REF_TYPE}" == tag ]]; then - # case "${REF_NAME}" in - # *-rc*) - # echo "Aborting upload since this is an RC tag: ${REF_NAME}" - # exit 0 - # ;; - # *) - # # Strip the leading "v" as well as the trailing patch version. For example: - # # 'v0.15.2' -> '0.15' - # TARGET_FOLDER=$(echo "${REF_NAME}" | sed 's/v\([0-9]\+\)\.\([0-9]\+\)\.[0-9]\+/\1.\2/') - # ;; - # esac - # fi - TARGET_FOLDER="./" + if [[ "${REF_TYPE}" == branch ]]; then + TARGET_FOLDER="${REF_NAME}" + elif [[ "${REF_TYPE}" == tag ]]; then + case "${REF_NAME}" in + *-rc*) + echo "Aborting upload since this is an RC tag: ${REF_NAME}" + exit 0 + ;; + *) + # Strip the leading "v" as well as the trailing patch version. For example: + # 'v0.15.2' -> '0.15' + TARGET_FOLDER=$(echo "${REF_NAME}" | sed 's/v\([0-9]\+\)\.\([0-9]\+\)\.[0-9]\+/\1.\2/') + ;; + esac + fi + echo "Target Folder: ${TARGET_FOLDER}" # mkdir -p "${TARGET_FOLDER}" @@ -133,12 +132,12 @@ jobs: rsync -a "${RUNNER_ARTIFACT_DIR}"/ "${TARGET_FOLDER}" git add "${TARGET_FOLDER}" || true - # if [[ "${TARGET_FOLDER}" == main ]]; then - # mkdir -p _static - # rm -rf _static/* - # cp -r "${TARGET_FOLDER}"/_static/* _static - # git add _static || true - # fi + if [[ "${TARGET_FOLDER}" == main ]]; then + mkdir -p _static + rm -rf _static/* + cp -r "${TARGET_FOLDER}"/_static/* _static + git add _static || true + fi git config user.name 'pytorchbot' git config user.email 'soumith+bot@pytorch.org' From 91097a4ca767b4e2fad367431b82bc426ed6c78b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 13:48:37 +0100 Subject: [PATCH 02/26] amend --- docs/source/conf.py | 2 +- docs/source/index.rst | 34 ---------------------------------- 2 files changed, 1 insertion(+), 35 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 060103b48b4..bec034dbd00 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -186,7 +186,7 @@ ) generate_knowledge_base_references("../../knowledge_base") -generate_tutorial_references("../../tutorials/sphinx-tutorials/", "tutorial") +generate_tutorial_references("../../tutorials/sphinx-tutorials-empty/", "tutorial") # generate_tutorial_references("../../tutorials/src/", "src") generate_tutorial_references("../../tutorials/media/", "media") diff --git a/docs/source/index.rst b/docs/source/index.rst index 44b7d406cd2..3646b857319 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -70,52 +70,18 @@ library. If you're in a hurry, you can start by :ref:`the last item of the series ` and navigate to the previous ones whenever you want to learn more! -.. toctree:: - :maxdepth: 1 - - tutorials/getting-started-0 - tutorials/getting-started-1 - tutorials/getting-started-2 - tutorials/getting-started-3 - tutorials/getting-started-4 - tutorials/getting-started-5 - Tutorials ========= Basics ------ -.. toctree:: - :maxdepth: 1 - - tutorials/coding_ppo - tutorials/pendulum - tutorials/torchrl_demo - Intermediate ------------ -.. toctree:: - :maxdepth: 1 - - tutorials/multiagent_ppo - tutorials/torchrl_envs - tutorials/pretrained_models - tutorials/dqn_with_rnn - tutorials/rb_tutorial - Advanced -------- -.. toctree:: - :maxdepth: 1 - - tutorials/multiagent_competitive_ddpg - tutorials/multi_task - tutorials/coding_ddpg - tutorials/coding_dqn - References ========== From 0a8c49e90027ee6238dc3a3ed528910da0d6e433 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 13:55:29 +0100 Subject: [PATCH 03/26] empty commit From 1a42ba6d095b406b77bba8bf0141fedf33fcd39f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 13:56:55 +0100 Subject: [PATCH 04/26] empty commit --- tutorials/sphinx-tutorials-empty/.gitkeep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tutorials/sphinx-tutorials-empty/.gitkeep diff --git a/tutorials/sphinx-tutorials-empty/.gitkeep b/tutorials/sphinx-tutorials-empty/.gitkeep new file mode 100644 index 00000000000..e69de29bb2d From 7ec6da85b068d856b5978369558a7fcd1198a20d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 13:57:22 +0100 Subject: [PATCH 05/26] empty commit From 3d230699252f00d9ed76b73a95d99b80d5ef1ff1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 14:05:04 +0100 Subject: [PATCH 06/26] amend --- tutorials/sphinx-tutorials-empty/README.rst | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 tutorials/sphinx-tutorials-empty/README.rst diff --git a/tutorials/sphinx-tutorials-empty/README.rst b/tutorials/sphinx-tutorials-empty/README.rst new file mode 100644 index 00000000000..7995a1fbb2e --- /dev/null +++ b/tutorials/sphinx-tutorials-empty/README.rst @@ -0,0 +1,4 @@ +README Tutos +============ + +Check the tutorials on torchrl documentation: https://pytorch.org/rl From 4aa5409feabbf9e0f2f9b0f3f3dccddffc2fb45a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 14:17:50 +0100 Subject: [PATCH 07/26] amend --- .github/workflows/docs.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 296aef263c9..3a7f0bf552a 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -92,8 +92,6 @@ jobs: upload: needs: build-docs - if: github.repository == 'pytorch/rl' && github.event_name == 'push' && - ((github.ref_type == 'branch' && github.ref_name == 'main') || (github.ref_type == 'branch' && github.ref_name == 'release-doc') || github.ref_type == 'tag') permissions: contents: write uses: pytorch/test-infra/.github/workflows/linux_job.yml@main From 07f4f8207531b3c98510d3cd367c5c99b6bf68a5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 14:29:47 +0100 Subject: [PATCH 08/26] amend --- .github/workflows/docs.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 3a7f0bf552a..20c9369a9e0 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -107,7 +107,11 @@ jobs: REF_NAME=${{ github.ref_name }} if [[ "${REF_TYPE}" == branch ]]; then - TARGET_FOLDER="${REF_NAME}" + if [[ "${REF_NAME}" == main ]]; then + TARGET_FOLDER="${REF_NAME}" + else + TARGET_FOLDER="release-doc" + fi elif [[ "${REF_TYPE}" == tag ]]; then case "${REF_NAME}" in *-rc*) From 8058f06b51ddfde1e2885c5847cb4c31fd757c6e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 14:32:50 +0100 Subject: [PATCH 09/26] amend --- .github/workflows/docs.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 20c9369a9e0..c2e589bd833 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -128,8 +128,9 @@ jobs: echo "Target Folder: ${TARGET_FOLDER}" - # mkdir -p "${TARGET_FOLDER}" - # rm -rf "${TARGET_FOLDER}"/* + mkdir -p "${TARGET_FOLDER}" + rm -rf "${TARGET_FOLDER}"/* + echo $(ls "${RUNNER_ARTIFACT_DIR}") rsync -a "${RUNNER_ARTIFACT_DIR}"/ "${TARGET_FOLDER}" git add "${TARGET_FOLDER}" || true From 77bb50f2f77cdf670aa7d48bbd776c9bfd991a0d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 14:59:05 +0100 Subject: [PATCH 10/26] amend --- .github/workflows/docs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index c2e589bd833..d48ca78a071 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -135,7 +135,7 @@ jobs: rsync -a "${RUNNER_ARTIFACT_DIR}"/ "${TARGET_FOLDER}" git add "${TARGET_FOLDER}" || true - if [[ "${TARGET_FOLDER}" == main ]]; then + if [[ "${TARGET_FOLDER}" == "main" ]] || [[ "${TARGET_FOLDER}" == "release-doc" ]]; then mkdir -p _static rm -rf _static/* cp -r "${TARGET_FOLDER}"/_static/* _static From ef9eefcf386179f775c4f5d0036e4be2a903ac90 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 16:33:31 +0100 Subject: [PATCH 11/26] amend --- docs/source/conf.py | 2 +- docs/source/index.rst | 34 ++++++++++++++++++ torchrl/envs/batched_envs.py | 1 + torchrl/envs/transforms/transforms.py | 7 ++-- tutorials/sphinx-tutorials-empty/.gitkeep | 0 tutorials/sphinx-tutorials-empty/README.rst | 4 --- tutorials/sphinx-tutorials/coding_ddpg.py | 2 +- tutorials/sphinx-tutorials/coding_dqn.py | 40 ++++++++++++++------- tutorials/sphinx-tutorials/dqn_with_rnn.py | 2 +- tutorials/sphinx-tutorials/multi_task.py | 7 ++-- tutorials/sphinx-tutorials/rb_tutorial.py | 3 +- tutorials/sphinx-tutorials/torchrl_demo.py | 8 +++-- tutorials/sphinx-tutorials/torchrl_envs.py | 31 +++++++++------- 13 files changed, 101 insertions(+), 40 deletions(-) delete mode 100644 tutorials/sphinx-tutorials-empty/.gitkeep delete mode 100644 tutorials/sphinx-tutorials-empty/README.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index bec034dbd00..060103b48b4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -186,7 +186,7 @@ ) generate_knowledge_base_references("../../knowledge_base") -generate_tutorial_references("../../tutorials/sphinx-tutorials-empty/", "tutorial") +generate_tutorial_references("../../tutorials/sphinx-tutorials/", "tutorial") # generate_tutorial_references("../../tutorials/src/", "src") generate_tutorial_references("../../tutorials/media/", "media") diff --git a/docs/source/index.rst b/docs/source/index.rst index 3646b857319..44b7d406cd2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -70,18 +70,52 @@ library. If you're in a hurry, you can start by :ref:`the last item of the series ` and navigate to the previous ones whenever you want to learn more! +.. toctree:: + :maxdepth: 1 + + tutorials/getting-started-0 + tutorials/getting-started-1 + tutorials/getting-started-2 + tutorials/getting-started-3 + tutorials/getting-started-4 + tutorials/getting-started-5 + Tutorials ========= Basics ------ +.. toctree:: + :maxdepth: 1 + + tutorials/coding_ppo + tutorials/pendulum + tutorials/torchrl_demo + Intermediate ------------ +.. toctree:: + :maxdepth: 1 + + tutorials/multiagent_ppo + tutorials/torchrl_envs + tutorials/pretrained_models + tutorials/dqn_with_rnn + tutorials/rb_tutorial + Advanced -------- +.. toctree:: + :maxdepth: 1 + + tutorials/multiagent_competitive_ddpg + tutorials/multi_task + tutorials/coding_ddpg + tutorials/coding_dqn + References ========== diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b3026da35ca..602f66bbd97 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1568,6 +1568,7 @@ def _shutdown_workers(self) -> None: if self._verbose: torchrl_logger.info(f"closing {i}") channel.send(("close", None)) + for i in range(self.num_workers): self._events[i].wait(self._timeout) self._events[i].clear() diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ccab829d480..083b8f953ba 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3130,8 +3130,11 @@ def unfold_done(done, N): reset_vals = reset_vals[1:] reps.extend([reset_vals[0]] * int(j)) j_ = j - reps = torch.stack(reps) - data = torch.masked_scatter(data, done_mask_expand, reps.reshape(-1)) + if reps: + reps = torch.stack(reps) + data = torch.masked_scatter( + data, done_mask_expand, reps.reshape(-1) + ) if first_val is not None: # Aggregate reset along last dim diff --git a/tutorials/sphinx-tutorials-empty/.gitkeep b/tutorials/sphinx-tutorials-empty/.gitkeep deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/tutorials/sphinx-tutorials-empty/README.rst b/tutorials/sphinx-tutorials-empty/README.rst deleted file mode 100644 index 7995a1fbb2e..00000000000 --- a/tutorials/sphinx-tutorials-empty/README.rst +++ /dev/null @@ -1,4 +0,0 @@ -README Tutos -============ - -Check the tutorials on torchrl documentation: https://pytorch.org/rl diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 17166453cba..95b196e4b2e 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -339,7 +339,7 @@ def _loss_value( batch_size=self.target_actor_network_params.batch_size, device=self.target_actor_network_params.device, ) - with target_params.to_module(self.value_estimator): + with target_params.to_module(self.actor_critic): target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) # Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function` diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index dbce0c29804..0977abfb1aa 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -105,16 +105,18 @@ try: multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: + # If we can't set the method globally we can still run the parallel env with "fork" + # This will fail on windows! Use "spawn" and put the script within `if __name__ == "__main__"` + mp_context = "fork" pass - # sphinx_gallery_end_ignore import os import uuid import torch from torch import nn -from torchrl.collectors import MultiaSyncDataCollector +from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector from torchrl.data import LazyMemmapStorage, MultiStep, TensorDictReplayBuffer from torchrl.envs import ( EnvCreator, @@ -217,20 +219,26 @@ def is_notebook() -> bool: def make_env( parallel=False, obs_norm_sd=None, + num_workers=1, ): if obs_norm_sd is None: obs_norm_sd = {"standard_normal": True} if parallel: + + def maker(): + return GymEnv( + "CartPole-v1", + from_pixels=True, + pixels_only=True, + device=device, + ) + base_env = ParallelEnv( num_workers, - EnvCreator( - lambda: GymEnv( - "CartPole-v1", - from_pixels=True, - pixels_only=True, - device=device, - ) - ), + EnvCreator(maker), + # Don't create a sub-process if we have only one worker + serial_for_single=True, + mp_start_method=mp_context, ) else: base_env = GymEnv( @@ -279,6 +287,7 @@ def get_norm_stats(): # ``C=4`` (because of :class:`~torchrl.envs.CatFrames`). print("state dict of the observation norm:", obs_norm_sd) test_env.close() + del test_env return obs_norm_sd @@ -426,8 +435,15 @@ def get_collector( total_frames, device, ): - cls = MultiaSyncDataCollector - env_arg = [make_env(parallel=True, obs_norm_sd=stats)] * num_collectors + # We can't use nested child processes with mp_start_method="fork" + if is_fork: + cls = SyncDataCollector + env_arg = make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers) + else: + cls = MultiaSyncDataCollector + env_arg = [ + make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers) + ] * num_collectors data_collector = cls( env_arg, policy=actor_explore, diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index a2b2b12b562..ce849210a2c 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -412,7 +412,7 @@ # utd = 16 -pbar = tqdm.tqdm(total=1_000_000) +pbar = tqdm.tqdm(total=collector.total_frames) longest = 0 traj_lens = [] diff --git a/tutorials/sphinx-tutorials/multi_task.py b/tutorials/sphinx-tutorials/multi_task.py index 68cb995a1a3..4475fb6a9e4 100644 --- a/tutorials/sphinx-tutorials/multi_task.py +++ b/tutorials/sphinx-tutorials/multi_task.py @@ -12,6 +12,8 @@ # sphinx_gallery_start_ignore import warnings +from tensordict import LazyStackedTensorDict + warnings.filterwarnings("ignore") from torch import multiprocessing @@ -31,7 +33,6 @@ # sphinx_gallery_end_ignore -import torch from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn @@ -77,9 +78,9 @@ tdreset1 = env1.reset() tdreset2 = env2.reset() -# In TorchRL, stacking is done in a lazy manner: the original tensordicts +# With LazyStackedTensorDict, stacking is done in a lazy manner: the original tensordicts # can still be recovered by indexing the main tensordict -tdreset = torch.stack([tdreset1, tdreset2], 0) +tdreset = LazyStackedTensorDict.lazy_stack([tdreset1, tdreset2], 0) assert tdreset[0] is tdreset1 ############################################################################### diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 3c0bad89e70..4070707de08 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -779,6 +779,7 @@ def assert0(x): ), Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64), GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), + UnsqueezeTransform(-4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ) rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(size), transform=t, batch_size=16) @@ -787,7 +788,7 @@ def assert0(x): ###################################################################### -# Let us sample one element from the buffer. The shape of the transformed +# Let us sample one batch from the buffer. The shape of the transformed # pixel keys should have a length of 4 along the 4th dimension starting from # the end: # diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 3738542e3a7..ae80d542d13 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -440,6 +440,7 @@ base_env = ParallelEnv( 4, lambda: GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False), + mp_start_method="fork", # This will break on Windows machines! Remove and decorate with if __name__ == "__main__" ) env = TransformedEnv( base_env, Compose(StepCounter(), ToTensorImage()) @@ -732,13 +733,16 @@ def exec_sequence(params, data): from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector -from torchrl.envs import EnvCreator, ParallelEnv +from torchrl.envs import EnvCreator, SerialEnv from torchrl.envs.libs.gym import GymEnv ############################################################################### # EnvCreator makes sure that we can send a lambda function from process to process +# We use a SerialEnv for simplicity, but for larger jobs a ParallelEnv would be better suited. -parallel_env = ParallelEnv(3, EnvCreator(lambda: GymEnv("Pendulum-v1"))) +parallel_env = SerialEnv( + 3, EnvCreator(lambda: GymEnv("Pendulum-v1")), mp_start_method="fork" +) create_env_fn = [parallel_env, parallel_env] actor_module = nn.Linear(3, 1) diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index 4c792d44b80..3ba7805c09a 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -8,6 +8,7 @@ .. _envs_tuto: """ +import functools ############################################################################## # # Environments play a crucial role in RL settings, often somewhat similar to @@ -38,6 +39,8 @@ # sphinx_gallery_start_ignore import warnings +from tensordict.nn import TensorDictModule + warnings.filterwarnings("ignore") from torch import multiprocessing @@ -133,31 +136,31 @@ torch.manual_seed(0) # make sure that all torch code is also reproductible env.set_seed(0) -tensordict = env.reset() -print(tensordict) +reset_data = env.reset() +print("reset data", reset_data) ############################################################################### # We can now execute a step in the environment. Since we don't have a policy, # we can just generate a random action: -def policy(tensordict, env=env): - tensordict.set("action", env.action_spec.rand()) - return tensordict +policy = TensorDictModule( + functools.partial(env.action_spec.rand(), env=env), in_keys=[], out_keys=["action"] +) -policy(tensordict) -tensordict_out = env.step(tensordict) +policy(reset_data) +tensordict_out = env.step(reset_data) ############################################################################### # By default, the tensordict returned by ``step`` is the same as the input... -assert tensordict_out is tensordict +assert tensordict_out is reset_data ############################################################################### # ... but with new keys -tensordict +tensordict_out ############################################################################### # What we just did (a random step using ``action_spec.rand()``) can also be @@ -175,13 +178,14 @@ def policy(tensordict, env=env): from torchrl.envs.utils import step_mdp -tensordict.set("some other key", torch.randn(1)) -tensordict_tprime = step_mdp(tensordict) +tensordict_out.set("some other key", torch.randn(1)) +tensordict_tprime = step_mdp(tensordict_out) print(tensordict_tprime) print( ( - tensordict_tprime.get("observation") == tensordict.get(("next", "observation")) + tensordict_tprime.get("observation") + == tensordict_out.get(("next", "observation")) ).all() ) @@ -548,7 +552,8 @@ def env_make(): ############################################################################### -parallel_env.start() +if parallel_env.is_closed: + parallel_env.start() foo_list = parallel_env.foo foo_list # needs to be instantiated, for instance using list From 7e1039e4a93156bb22331d3c31f53c5fc30938ac Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 17:21:50 +0100 Subject: [PATCH 12/26] init --- docs/source/reference/collectors.rst | 13 ++++++------- docs/source/reference/envs.rst | 11 ++++++----- torchrl/data/__init__.py | 2 ++ torchrl/data/rlhf/__init__.py | 2 +- tutorials/sphinx-tutorials/torchrl_envs.py | 1 + 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index 8ad4ecef5ef..74bd058b8f0 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -82,7 +82,7 @@ In each of these cases, the last dimension (``T`` for ``time``) is adapted such that the batch size equals the ``frames_per_batch`` argument passed to the collector. -.. warning:: :class:`~torchrl.collectors.collectors.MultiSyncDataCollector` should not be +.. warning:: :class:`~torchrl.collectors.MultiSyncDataCollector` should not be used with ``cat_results=0``, as the data will be stacked along the batch dimension with batched environment, or the time dimension for single environments, which can introduce some confusion when swapping one with the other. @@ -91,12 +91,12 @@ collector. better interchangeability between configurations, collector classes and other components. -Whereas :class:`~torchrl.collectors.collectors.MultiSyncDataCollector` +Whereas :class:`~torchrl.collectors.MultiSyncDataCollector` has a dimension corresponding to the number of sub-collectors being run (``B``), -:class:`~torchrl.collectors.collectors.MultiaSyncDataCollector` doesn't. This -is easily understood when considering that :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector` +:class:`~torchrl.collectors.MultiaSyncDataCollector` doesn't. This +is easily understood when considering that :class:`~torchrl.collectors.MultiaSyncDataCollector` delivers batches of data on a first-come, first-serve basis, whereas -:class:`~torchrl.collectors.collectors.MultiSyncDataCollector` gathers data from +:class:`~torchrl.collectors.MultiSyncDataCollector` gathers data from each sub-collector before delivering it. Collectors and replay buffers interoperability @@ -168,7 +168,7 @@ batches written in the buffer won't come from the same source (thereby interrupt Single node data collectors --------------------------- -.. currentmodule:: torchrl.collectors.collectors +.. currentmodule:: torchrl.collectors .. autosummary:: :toctree: generated/ @@ -178,7 +178,6 @@ Single node data collectors SyncDataCollector MultiSyncDataCollector MultiaSyncDataCollector - RandomPolicy aSyncDataCollector diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 5c39c5a1349..5d4b6d0b7b5 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -849,14 +849,15 @@ Helpers :toctree: generated/ :template: rl_template_fun.rst - step_mdp - get_available_libraries - set_exploration_mode #deprecated - set_exploration_type + RandomPolicy + check_env_specs exploration_mode #deprecated exploration_type - check_env_specs + get_available_libraries make_composite_from_td + set_exploration_mode #deprecated + set_exploration_type + step_mdp terminated_or_truncated Domain-specific diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index bc512a585b7..e2c0b97c2fc 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -31,6 +31,8 @@ WriterEnsemble, ) from .rlhf import ( + AdaptiveKLController, + ConstantKLController, create_infinite_iterator, get_dataloader, PairwiseDataset, diff --git a/torchrl/data/rlhf/__init__.py b/torchrl/data/rlhf/__init__.py index 93e232bf0ac..f0db092f2d1 100644 --- a/torchrl/data/rlhf/__init__.py +++ b/torchrl/data/rlhf/__init__.py @@ -11,4 +11,4 @@ ) from .prompt import PromptData, PromptTensorDictTokenizer from .reward import PairwiseDataset, RewardData -from .utils import RolloutFromModel +from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index 3ba7805c09a..a3a5f280cc4 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -9,6 +9,7 @@ """ import functools + ############################################################################## # # Environments play a crucial role in RL settings, often somewhat similar to From 9faae5a971a23467e6cfc6112b901444b548588b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 17:27:31 +0100 Subject: [PATCH 13/26] pyglet==1.2.4 --- docs/requirements.txt | 1 + tutorials/sphinx-tutorials/torchrl_demo.py | 2 +- tutorials/sphinx-tutorials/torchrl_envs.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index b3a99cf9e3d..6b7c1c1a513 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -26,3 +26,4 @@ memory_profiler pyrender pytest vmas +pyglet==1.2.4 diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index ae80d542d13..8e3610a0e2c 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -741,7 +741,7 @@ def exec_sequence(params, data): # We use a SerialEnv for simplicity, but for larger jobs a ParallelEnv would be better suited. parallel_env = SerialEnv( - 3, EnvCreator(lambda: GymEnv("Pendulum-v1")), mp_start_method="fork" + 3, EnvCreator(lambda: GymEnv("Pendulum-v1")), ) create_env_fn = [parallel_env, parallel_env] diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index a3a5f280cc4..1cf65516d5e 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -146,7 +146,7 @@ policy = TensorDictModule( - functools.partial(env.action_spec.rand(), env=env), in_keys=[], out_keys=["action"] + functools.partial(env.action_spec.rand, env=env), in_keys=[], out_keys=["action"] ) From 46f0032a9241098c24d956a06722cb47b2dda83d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 17:41:41 +0100 Subject: [PATCH 14/26] amend --- docs/requirements.txt | 1 - .../multiagent_competitive_ddpg.py | 15 +++++++-------- tutorials/sphinx-tutorials/multiagent_ppo.py | 7 ------- tutorials/sphinx-tutorials/torchrl_demo.py | 3 ++- 4 files changed, 9 insertions(+), 17 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 6b7c1c1a513..b3a99cf9e3d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -26,4 +26,3 @@ memory_profiler pyrender pytest vmas -pyglet==1.2.4 diff --git a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py index e04f9a6e3aa..6de4eb2e5a0 100644 --- a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py +++ b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py @@ -4,13 +4,6 @@ =========================================================================== **Author**: `Matteo Bettini `_ -.. note:: - - If you are interested in Multi-Agent Reinforcement Learning (MARL) in - TorchRL, check out - `BenchMARL `__: a benchmarking library where you - can train and compare MARL sota-implementations, tasks, and models using TorchRL! - This tutorial demonstrates how to use PyTorch and TorchRL to solve a Competitive Multi-Agent Reinforcement Learning (MARL) problem. @@ -141,6 +134,12 @@ from tqdm import tqdm +# Check if we're building the doc, in which case disable video rendering +try: + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + ###################################################################### # Define Hyperparameters # ---------------------- @@ -879,7 +878,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase: # logger `video_logger`. Note that this code may require some external dependencies such as torchvision. # -if use_vmas: +if use_vmas and not is_sphinx: # Replace tmpdir with any desired path where the video should be saved with tempfile.TemporaryDirectory() as tmpdir: video_logger = CSVLogger("vmas_logs", tmpdir, video_format="mp4") diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index 0cdb5b70e22..29b9dbf9f9f 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -4,13 +4,6 @@ =============================================================== **Author**: `Matteo Bettini `_ -.. note:: - - If you are interested in Multi-Agent Reinforcement Learning (MARL) in - TorchRL, check out - `BenchMARL `__: a benchmarking library where you - can train and compare MARL sota-implementations, tasks, and models using TorchRL! - This tutorial demonstrates how to use PyTorch and :py:mod:`torchrl` to solve a Multi-Agent Reinforcement Learning (MARL) problem. diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 8e3610a0e2c..e83eb974af8 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -741,7 +741,8 @@ def exec_sequence(params, data): # We use a SerialEnv for simplicity, but for larger jobs a ParallelEnv would be better suited. parallel_env = SerialEnv( - 3, EnvCreator(lambda: GymEnv("Pendulum-v1")), + 3, + EnvCreator(lambda: GymEnv("Pendulum-v1")), ) create_env_fn = [parallel_env, parallel_env] From ec68709d1978e3f1c7dcfaa5590feb718c70df4e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 18:03:44 +0100 Subject: [PATCH 15/26] init --- docs/source/reference/modules.rst | 1 - torchrl/data/replay_buffers/storages.py | 18 +- torchrl/data/tensor_specs.py | 20 -- torchrl/envs/common.py | 5 +- torchrl/envs/gym_like.py | 10 - torchrl/envs/transforms/transforms.py | 12 +- torchrl/modules/models/models.py | 157 ------------ torchrl/modules/tensordict_module/actors.py | 18 +- torchrl/modules/tensordict_module/common.py | 16 +- torchrl/objectives/a2c.py | 46 ---- torchrl/objectives/common.py | 6 - torchrl/objectives/dqn.py | 24 +- torchrl/objectives/multiagent/qmixer.py | 12 +- torchrl/objectives/ppo.py | 46 ---- torchrl/objectives/reinforce.py | 46 ---- torchrl/trainers/helpers/losses.py | 43 ---- torchrl/trainers/helpers/models.py | 260 +------------------- torchrl/trainers/trainers.py | 12 +- 18 files changed, 14 insertions(+), 738 deletions(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index c42376e4948..c12bba985d6 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -315,7 +315,6 @@ Regular modules MLP ConvNet Conv3dNet - LSTMNet SqueezeLayer Squeeze2dLayer diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index a1ada2eb72e..527152104a9 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1093,11 +1093,7 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: if self.device == "auto": self.device = data.device if self.device.type != "cpu": - warnings.warn( - "Support for Memmap device other than CPU will be deprecated in v0.4.0. " - "Using a 'cuda' device may be suboptimal.", - category=DeprecationWarning, - ) + raise RuntimeError("Support for Memmap device other than CPU is deprecated") def max_size_along_dim0(data_shape): if self.ndim > 1: @@ -1128,17 +1124,7 @@ def max_size_along_dim0(data_shape): def get(self, index: Union[int, Sequence[int], slice]) -> Any: result = super().get(index) - - # to be deprecated in v0.4 - def map_device(tensor): - if tensor.device != self.device: - return tensor.to(self.device, non_blocking=False) - return tensor - - if is_tensor_collection(result): - return map_device(result) - else: - return tree_map(map_device, result) + return result class StorageEnsemble(Storage): diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 105c13214e0..3228afdfd8b 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -382,22 +382,6 @@ def high(self, value): self.device = value.device self._high = value.cpu() - @property - def minimum(self): - warnings.warn( - f"{type(self)}.minimum is going to be deprecated in favour of {type(self)}.low in v0.4.0", - category=DeprecationWarning, - ) - return self._low.to(self.device) - - @property - def maximum(self): - warnings.warn( - f"{type(self)}.maximum is going to be deprecated in favour of {type(self)}.high in v0.4.0", - category=DeprecationWarning, - ) - return self._high.to(self.device) - @low.setter def low(self, value): self.device = value.device @@ -1596,10 +1580,6 @@ class BoundedTensorSpec(TensorSpec): """ # SPEC_HANDLED_FUNCTIONS = {} - DEPRECATED_KWARGS = ( - "The `minimum` and `maximum` keyword arguments are now " - "deprecated in favour of `low` and `high` in v0.4.0." - ) CONFLICTING_KWARGS = ( "The keyword arguments {} and {} conflict. Only one of these can be passed." ) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 8712c74340a..52f42445be6 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -9,7 +9,7 @@ import functools import warnings from copy import deepcopy -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple import numpy as np import torch @@ -202,7 +202,6 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): on that device and it is expected that all inputs and outputs will live on that device. Defaults to ``None``. - dtype (deprecated): dtype of the observations. Will be deprecated in v0.4. batch_size (torch.Size or equivalent, optional): batch-size of the environment. Corresponds to the leading dimension of all the input and output tensordicts the environment reads and writes. Defaults to an empty batch-size. @@ -341,7 +340,6 @@ def __init__( self, *, device: DEVICE_TYPING = None, - dtype: Optional[Union[torch.dtype, np.dtype]] = None, batch_size: Optional[torch.Size] = None, run_type_checks: bool = False, allow_done_after_reset: bool = False, @@ -365,7 +363,6 @@ def __init__( ) super().__init__() - self.dtype = dtype_map.get(dtype, dtype) if "is_closed" not in self.__dir__(): self.is_closed = True if batch_size is not None: diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index b27c1f795a2..9cbec79211d 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -529,13 +529,3 @@ def __repr__(self) -> str: @property def info_dict_reader(self): return self._info_dict_reader - - @info_dict_reader.setter - def info_dict_reader(self, value: callable): - warnings.warn( - f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. " - f"This method will append a reader to the list of existing readers (if any). " - f"Setting info_dict_reader directly will be deprecated in v0.4.0.", - category=DeprecationWarning, - ) - self._info_dict_reader.append(value) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ccab829d480..acec170c7d0 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -236,10 +236,6 @@ def out_keys_inv(self, value): value = [unravel_key(val) for val in value] self._out_keys_inv = value - def reset(self, tensordict): - warnings.warn("Transform.reset public method will be derpecated in v0.4.0.") - return self._reset(tensordict, tensordict_reset=tensordict) - def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: @@ -2836,13 +2832,7 @@ def __init__( if padding not in self.ACCEPTED_PADDING: raise ValueError(f"padding must be one of {self.ACCEPTED_PADDING}") if padding == "zeros": - warnings.warn( - "Padding option 'zeros' will be deprecated in v0.4.0. " - "Please use 'constant' padding with padding_value 0 instead.", - category=DeprecationWarning, - ) - padding = "constant" - padding_value = 0 + raise RuntimeError("Padding option 'zeros' will is deprecated") self.padding = padding self.padding_value = padding_value for in_key in self.in_keys: diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 8e6fc75e12e..23c229c6524 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -6,7 +6,6 @@ import dataclasses -import warnings from copy import deepcopy from numbers import Number from typing import Callable, Dict, List, Sequence, Tuple, Type, Union @@ -1481,162 +1480,6 @@ def forward(self, observation: torch.Tensor, action: torch.Tensor) -> torch.Tens return value -class LSTMNet(nn.Module): - """An embedder for an LSTM preceded by an MLP. - - The forward method returns the hidden states of the current state - (input hidden states) and the output, as - the environment returns the 'observation' and 'next_observation'. - - Because the LSTM kernel only returns the last hidden state, hidden states - are padded with zeros such that they have the right size to be stored in a - TensorDict of size [batch x time_steps]. - - If a 2D tensor is provided as input, it is assumed that it is a batch of data - with only one time step. This means that we explicitely assume that users will - unsqueeze inputs of a single batch with multiple time steps. - - Args: - out_features (int): number of output features. - lstm_kwargs (dict): the keyword arguments for the - :class:`~torch.nn.LSTM` layer. - mlp_kwargs (dict): the keyword arguments for the - :class:`~torchrl.modules.MLP` layer. - device (torch.device, optional): the device where the module should - be instantiated. - - Keyword Args: - lstm_backend (str, optional): one of ``"torchrl"`` or ``"torch"`` that - indeicates where the LSTM class is to be retrieved. The ``"torchrl"`` - backend (:class:`~torchrl.modules.LSTM`) is slower but works with - :func:`~torch.vmap` and should work with :func:`~torch.compile`. - Defaults to ``"torch"``. - - Examples: - >>> batch = 7 - >>> time_steps = 6 - >>> in_features = 4 - >>> out_features = 10 - >>> hidden_size = 5 - >>> net = LSTMNet( - ... out_features, - ... {"input_size": hidden_size, "hidden_size": hidden_size}, - ... {"out_features": hidden_size}, - ... ) - >>> # test single step vs multi-step - >>> x = torch.randn(batch, time_steps, in_features) # >3 dims = multi-step - >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x) - >>> x = torch.randn(batch, in_features) # 2 dims = single step - >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x) - - """ - - def __init__( - self, - out_features: int, - lstm_kwargs: Dict, - mlp_kwargs: Dict, - device: DEVICE_TYPING | None = None, - *, - lstm_backend: str | None = None, - ) -> None: - warnings.warn( - "LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed in v0.4.0.", - category=DeprecationWarning, - ) - super().__init__() - lstm_kwargs.update({"batch_first": True}) - self.mlp = MLP(device=device, **mlp_kwargs) - if lstm_backend is None: - lstm_backend = "torch" - self.lstm_backend = lstm_backend - if self.lstm_backend == "torch": - LSTM = nn.LSTM - else: - from torchrl.modules.tensordict_module.rnn import LSTM - self.lstm = LSTM(device=device, **lstm_kwargs) - self.linear = nn.LazyLinear(out_features, device=device) - - def _lstm( - self, - input: torch.Tensor, - hidden0_in: torch.Tensor | None = None, - hidden1_in: torch.Tensor | None = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - squeeze0 = False - squeeze1 = False - if input.ndimension() == 1: - squeeze0 = True - input = input.unsqueeze(0).contiguous() - - if input.ndimension() == 2: - squeeze1 = True - input = input.unsqueeze(1).contiguous() - batch, steps = input.shape[:2] - - if hidden1_in is None and hidden0_in is None: - shape = (batch, steps) if not squeeze1 else (batch,) - hidden0_in, hidden1_in = [ - torch.zeros( - *shape, - self.lstm.num_layers, - self.lstm.hidden_size, - device=input.device, - dtype=input.dtype, - ) - for _ in range(2) - ] - elif hidden1_in is None or hidden0_in is None: - raise RuntimeError( - f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}" - ) - elif squeeze0: - hidden0_in = hidden0_in.unsqueeze(0) - hidden1_in = hidden1_in.unsqueeze(0) - - # we only need the first hidden state - if not squeeze1: - _hidden0_in = hidden0_in[:, 0] - _hidden1_in = hidden1_in[:, 0] - else: - _hidden0_in = hidden0_in - _hidden1_in = hidden1_in - hidden = ( - _hidden0_in.transpose(-3, -2).contiguous(), - _hidden1_in.transpose(-3, -2).contiguous(), - ) - - y0, hidden = self.lstm(input, hidden) - # dim 0 in hidden is num_layers, but that will conflict with tensordict - hidden = tuple(_h.transpose(0, 1) for _h in hidden) - y = self.linear(y0) - - out = [y, hidden0_in, hidden1_in, *hidden] - if squeeze1: - # squeezes time - out[0] = out[0].squeeze(1) - if not squeeze1: - # we pad the hidden states with zero to make tensordict happy - for i in range(3, 5): - out[i] = torch.stack( - [torch.zeros_like(out[i]) for _ in range(input.shape[1] - 1)] - + [out[i]], - 1, - ) - if squeeze0: - out = [_out.squeeze(0) for _out in out] - return tuple(out) - - def forward( - self, - input: torch.Tensor, - hidden0_in: torch.Tensor | None = None, - hidden1_in: torch.Tensor | None = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - input = self.mlp(input) - return self._lstm(input, hidden0_in, hidden1_in) - - class OnlineDTActor(nn.Module): """Online Decision Transformer Actor class. diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 490d1fcb5ad..870f68b7bf3 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -452,11 +452,7 @@ def __init__( safe: bool = False, ): if isinstance(action_space, TensorSpec): - warnings.warn( - "Using specs in action_space will be deprecated in v0.4.0," - " please use the 'spec' argument if you want to provide an action spec", - category=DeprecationWarning, - ) + raise TypeError("Using specs in action_space is deprecated") action_space, spec = _process_action_space_spec(action_space, spec) self.action_space = action_space self.var_nums = var_nums @@ -929,11 +925,7 @@ def __init__( out_keys: Optional[Sequence[NestedKey]] = None, ): if isinstance(action_space, TensorSpec): - warnings.warn( - "Using specs in action_space will be deprecated in v0.4.0," - " please use the 'spec' argument if you want to provide an action spec", - category=DeprecationWarning, - ) + raise RuntimeError("Using specs in action_space is deprecated") action_space, _ = _process_action_space_spec(action_space, None) self.qvalue_model = DistributionalQValueModule( action_space=action_space, @@ -1196,11 +1188,7 @@ def __init__( make_log_softmax: bool = True, ): if isinstance(action_space, TensorSpec): - warnings.warn( - "Using specs in action_space will be deprecated in v0.4.0," - " please use the 'spec' argument if you want to provide an action spec", - category=DeprecationWarning, - ) + raise RuntimeError("Using specs in action_space is deprecated") action_space, spec = _process_action_space_spec(action_space, spec) self.action_space = action_space self.action_value_key = action_value_key diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 8dd621c98b2..7ac5d9873e5 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -9,7 +9,6 @@ import inspect import re import warnings -from numbers import Number from typing import Iterable, List, Optional, Type, Union import torch @@ -503,19 +502,8 @@ class DistributionalDQNnet(TensorDictModuleBase): "instead." ) - def __init__(self, *, in_keys=None, out_keys=None, DQNet: nn.Module = None): + def __init__(self, *, in_keys=None, out_keys=None): super().__init__() - if DQNet is not None: - warnings.warn( - f"Passing a network to {type(self)} is going to be deprecated in v0.4.0.", - category=DeprecationWarning, - ) - if not ( - not isinstance(DQNet.out_features, Number) - and len(DQNet.out_features) > 1 - ): - raise RuntimeError(self._wrong_out_feature_dims_error) - self.dqn = DQNet if in_keys is None: in_keys = ["action_value"] if out_keys is None: @@ -527,8 +515,6 @@ def __init__(self, *, in_keys=None, out_keys=None, DQNet: nn.Module = None): def forward(self, tensordict): for in_key, out_key in zip(self.in_keys, self.out_keys): q_values = tensordict.get(in_key) - if self.dqn is not None: - q_values = self.dqn(q_values) if q_values.ndimension() < 2: raise RuntimeError( self._wrong_out_feature_dims_error.format(q_values.shape) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 6487ad0597a..dd5c162f8b0 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib -import warnings from copy import deepcopy from dataclasses import dataclass from typing import Tuple @@ -328,51 +327,6 @@ def __init__( def functional(self): return self._functional - @property - def actor(self): - warnings.warn( - f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.actor_network - - @property - def critic(self): - warnings.warn( - f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.critic_network - - @property - def actor_params(self): - warnings.warn( - f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.actor_network_params - - @property - def critic_params(self): - warnings.warn( - f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.critic_network_params - - @property - def target_critic_params(self): - warnings.warn( - f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.target_critic_network_params - @property def in_keys(self): keys = [ diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 6b6fd391560..cfe8b793454 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -217,12 +217,6 @@ def convert_to_functional( will carry gradients as expected. """ - if kwargs.pop("funs_to_decorate", None) is not None: - warnings.warn( - "funs_to_decorate is without effect with the new objective API. This " - "warning will be replaced by an error in v0.4.0.", - category=DeprecationWarning, - ) if kwargs: raise TypeError(f"Unrecognised keyword arguments {list(kwargs.keys())}") # To make it robust to device casting, we must register list of diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 505f07f55c0..3e219c9b72e 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -173,23 +173,13 @@ def __init__( value_network: Union[QValueActor, nn.Module], *, loss_function: Optional[str] = "l2", - delay_value: bool = None, + delay_value: bool = True, double_dqn: bool = False, gamma: float = None, action_space: Union[str, TensorSpec] = None, priority_key: str = None, reduction: str = None, ) -> None: - if delay_value is None: - warnings.warn( - f"You did not provide a delay_value argument for {type(self)}. " - "Currently (v0.3) the default for delay_value is `False` but as of " - "v0.4 it will be `True`. Make sure to adapt your code depending " - "on your preferred configuration. " - "To remove this warning, indicate the value of delay_value in your " - "script." - ) - delay_value = False if reduction is None: reduction = "mean" super().__init__() @@ -449,20 +439,10 @@ def __init__( value_network: Union[DistributionalQValueActor, nn.Module], *, gamma: float, - delay_value: bool = None, + delay_value: bool = True, priority_key: str = None, reduction: str = None, ): - if delay_value is None: - warnings.warn( - f"You did not provide a delay_value argument for {type(self)}. " - "Currently (v0.3) the default for delay_value is `False` but as of " - "v0.4 it will be `True`. Make sure to adapt your code depending " - "on your preferred configuration. " - "To remove this warning, indicate the value of delay_value in your " - "script." - ) - delay_value = False if reduction is None: reduction = "mean" super().__init__() diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index fcfcba49ca1..f3994abd1b2 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -189,21 +189,11 @@ def __init__( mixer_network: Union[TensorDictModule, nn.Module], *, loss_function: Optional[str] = "l2", - delay_value: bool = None, + delay_value: bool = True, gamma: float = None, action_space: Union[str, TensorSpec] = None, priority_key: str = None, ) -> None: - if delay_value is None: - warnings.warn( - f"You did not provide a delay_value argument for {type(self)}. " - "Currently (v0.3) the default for delay_value is `False` but as of " - "v0.4 it will be `True`. Make sure to adapt your code depending " - "on your preferred configuration. " - "To remove this warning, indicate the value of delay_value in your " - "script." - ) - delay_value = False super().__init__() self._in_keys = None self._set_deprecated_ctor_keys(priority=priority_key) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index a26b90462c6..7264d5d6cbe 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -7,7 +7,6 @@ import contextlib import math -import warnings from copy import deepcopy from dataclasses import dataclass from typing import Tuple @@ -383,51 +382,6 @@ def __init__( def functional(self): return self._functional - @property - def actor(self): - warnings.warn( - f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.actor_network - - @property - def critic(self): - warnings.warn( - f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.critic_network - - @property - def actor_params(self): - warnings.warn( - f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.actor_network_params - - @property - def critic_params(self): - warnings.warn( - f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.critic_network_params - - @property - def target_critic_params(self): - warnings.warn( - f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.target_critic_network_params - def _set_in_keys(self): keys = [ self.tensor_keys.action, diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 96f15e8ab69..aa931b97c13 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib -import warnings from copy import deepcopy from dataclasses import dataclass @@ -317,51 +316,6 @@ def __init__( def functional(self): return self._functional - @property - def actor(self): - warnings.warn( - f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.actor_network - - @property - def critic(self): - warnings.warn( - f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.critic_network - - @property - def actor_params(self): - warnings.warn( - f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.actor_network_params - - @property - def critic_params(self): - warnings.warn( - f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.critic_network_params - - @property - def target_critic_params(self): - warnings.warn( - f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.target_critic_network_params - def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: self._value_estimator.set_keys( diff --git a/torchrl/trainers/helpers/losses.py b/torchrl/trainers/helpers/losses.py index a949bea6718..152d7e2891f 100644 --- a/torchrl/trainers/helpers/losses.py +++ b/torchrl/trainers/helpers/losses.py @@ -3,14 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import warnings from dataclasses import dataclass from typing import Any, Optional, Tuple -from torchrl.modules import ActorCriticOperator, ActorValueOperator from torchrl.objectives import DistributionalDQNLoss, DQNLoss, HardUpdate, SoftUpdate from torchrl.objectives.common import LossModule -from torchrl.objectives.deprecated import REDQLoss_deprecated from torchrl.objectives.utils import TargetNetUpdater @@ -38,46 +35,6 @@ def make_target_updater( return target_net_updater -def make_redq_loss( - model, cfg -) -> Tuple[REDQLoss_deprecated, Optional[TargetNetUpdater]]: - """Builds the REDQ loss module.""" - warnings.warn( - "This helper function will be deprecated in v0.4. Consider using the local helper in the REDQ example.", - category=DeprecationWarning, - ) - loss_kwargs = {} - if hasattr(cfg, "distributional") and cfg.distributional: - raise NotImplementedError - else: - loss_kwargs.update({"loss_function": cfg.loss_function}) - loss_kwargs.update({"delay_qvalue": cfg.loss == "double"}) - loss_class = REDQLoss_deprecated - if isinstance(model, ActorValueOperator): - actor_model = model.get_policy_operator() - qvalue_model = model.get_value_operator() - elif isinstance(model, ActorCriticOperator): - raise RuntimeError( - "Although REDQ Q-value depends upon selected actions, using the" - "ActorCriticOperator will lead to resampling of the actions when" - "computing the Q-value loss, which we don't want. Please use the" - "ActorValueOperator instead." - ) - else: - actor_model, qvalue_model = model - - loss_module = loss_class( - actor_network=actor_model, - qvalue_network=qvalue_model, - num_qvalue_nets=cfg.num_q_values, - gSDE=cfg.gSDE, - **loss_kwargs, - ) - loss_module.make_value_estimator(gamma=cfg.gamma) - target_net_updater = make_target_updater(cfg, loss_module) - return loss_module, target_net_updater - - def make_dqn_loss(model, cfg) -> Tuple[DQNLoss, Optional[TargetNetUpdater]]: """Builds the DQN loss module.""" loss_kwargs = {} diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 05f566674f2..0a3cea40b36 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -3,16 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import itertools -import warnings from dataclasses import dataclass -from typing import Optional, Sequence import torch - from tensordict import set_lazy_legacy from tensordict.nn import InteractionType -from torch import distributions as d, nn - +from torch import nn from torchrl.data.tensor_specs import ( CompositeSpec, DiscreteTensorSpec, @@ -25,7 +21,6 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( NoisyLinear, - NormalParamWrapper, SafeModule, SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, @@ -37,8 +32,6 @@ TanhDelta, TanhNormal, ) -from torchrl.modules.distributions.continuous import SafeTanhTransform -from torchrl.modules.models.exploration import LazygSDEModule from torchrl.modules.models.model_based import ( DreamerActor, ObsDecoder, @@ -47,19 +40,12 @@ RSSMPrior, RSSMRollout, ) -from torchrl.modules.models.models import ( - DdpgCnnActor, - DdpgCnnQNet, - DuelingCnnDQNet, - DuelingMlpDQNet, - MLP, -) +from torchrl.modules.models.models import DuelingCnnDQNet, DuelingMlpDQNet, MLP from torchrl.modules.tensordict_module import ( Actor, DistributionalQValueActor, QValueActor, ) -from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator from torchrl.modules.tensordict_module.world_models import WorldModelWrapper from torchrl.trainers.helpers import transformed_env_constructor @@ -210,248 +196,6 @@ def make_dqn_actor( return model -def make_redq_model( - proof_environment: EnvBase, - cfg: "DictConfig", # noqa: F821 - device: DEVICE_TYPING = "cpu", - in_keys: Optional[Sequence[str]] = None, - actor_net_kwargs=None, - qvalue_net_kwargs=None, - observation_key=None, - **kwargs, -) -> nn.ModuleList: - """Actor and Q-value model constructor helper function for REDQ. - - Follows default parameters proposed in REDQ original paper: https://openreview.net/pdf?id=AY8zfZm0tDd. - Other configurations can easily be implemented by modifying this function at will. - A single instance of the Q-value model is returned. It will be multiplicated by the loss function. - - Args: - proof_environment (EnvBase): a dummy environment to retrieve the observation and action spec - cfg (DictConfig): contains arguments of the REDQ script - device (torch.device, optional): device on which the model must be cast. Default is "cpu". - in_keys (iterable of strings, optional): observation key to be read by the actor, usually one of - `'observation_vector'` or `'pixels'`. If none is provided, one of these two keys is chosen - based on the `cfg.from_pixels` argument. - actor_net_kwargs (dict, optional): kwargs of the actor MLP. - qvalue_net_kwargs (dict, optional): kwargs of the qvalue MLP. - - Returns: - A nn.ModuleList containing the actor, qvalue operator(s) and the value operator. - - Examples: - >>> from torchrl.trainers.helpers.envs import parser_env_args - >>> from torchrl.trainers.helpers.models import make_redq_model, parser_model_args_continuous - >>> from torchrl.envs.libs.gym import GymEnv - >>> from torchrl.envs.transforms import CatTensors, TransformedEnv, DoubleToFloat, Compose - >>> import hydra - >>> from hydra.core.config_store import ConfigStore - >>> import dataclasses - >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v4"), Compose(DoubleToFloat(["observation"]), - ... CatTensors(["observation"], "observation_vector"))) - >>> device = torch.device("cpu") - >>> config_fields = [(config_field.name, config_field.type, config_field) for config_cls in - ... (RedqModelConfig, EnvConfig) - ... for config_field in dataclasses.fields(config_cls)] - >>> Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) - >>> cs = ConfigStore.instance() - >>> cs.store(name="config", node=Config) - >>> with initialize(config_path=None): - >>> cfg = compose(config_name="config") - >>> model = make_redq_model( - ... proof_environment, - ... device=device, - ... cfg=cfg, - ... ) - >>> actor, qvalue = model - >>> td = proof_environment.reset() - >>> print(actor(td)) - TensorDict( - fields={ - done: Tensor(torch.Size([1]), dtype=torch.bool), - observation_vector: Tensor(torch.Size([17]), dtype=torch.float32), - loc: Tensor(torch.Size([6]), dtype=torch.float32), - scale: Tensor(torch.Size([6]), dtype=torch.float32), - action: Tensor(torch.Size([6]), dtype=torch.float32), - sample_log_prob: Tensor(torch.Size([1]), dtype=torch.float32)}, - batch_size=torch.Size([]), - device=cpu, - is_shared=False) - >>> print(qvalue(td.clone())) - TensorDict( - fields={ - done: Tensor(torch.Size([1]), dtype=torch.bool), - observation_vector: Tensor(torch.Size([17]), dtype=torch.float32), - loc: Tensor(torch.Size([6]), dtype=torch.float32), - scale: Tensor(torch.Size([6]), dtype=torch.float32), - action: Tensor(torch.Size([6]), dtype=torch.float32), - sample_log_prob: Tensor(torch.Size([1]), dtype=torch.float32), - state_action_value: Tensor(torch.Size([1]), dtype=torch.float32)}, - batch_size=torch.Size([]), - device=cpu, - is_shared=False) - - """ - warnings.warn( - "This helper function will be deprecated in v0.4. Consider using the local helper in the REDQ example.", - category=DeprecationWarning, - ) - tanh_loc = cfg.tanh_loc - default_policy_scale = cfg.default_policy_scale - gSDE = cfg.gSDE - - action_spec = proof_environment.action_spec - # obs_spec = proof_environment.observation_spec - # if observation_key is not None: - # obs_spec = obs_spec[observation_key] - # else: - # obs_spec_values = list(obs_spec.values()) - # if len(obs_spec_values) > 1: - # raise RuntimeError( - # "There is more than one observation in the spec, REDQ helper " - # "cannot infer automatically which to pick. " - # "Please indicate which key to read via the `observation_key` " - # "keyword in this helper." - # ) - # else: - # obs_spec = obs_spec_values[0] - - if actor_net_kwargs is None: - actor_net_kwargs = {} - if qvalue_net_kwargs is None: - qvalue_net_kwargs = {} - - linear_layer_class = torch.nn.Linear if not cfg.noisy else NoisyLinear - - out_features_actor = (2 - gSDE) * action_spec.shape[-1] - if cfg.from_pixels: - if in_keys is None: - in_keys_actor = ["pixels"] - else: - in_keys_actor = in_keys - actor_net_kwargs_default = { - "mlp_net_kwargs": { - "layer_class": linear_layer_class, - "activation_class": ACTIVATIONS[cfg.activation], - }, - "conv_net_kwargs": {"activation_class": ACTIVATIONS[cfg.activation]}, - } - actor_net_kwargs_default.update(actor_net_kwargs) - actor_net = DdpgCnnActor(out_features_actor, **actor_net_kwargs_default) - gSDE_state_key = "hidden" - out_keys_actor = ["param", "hidden"] - - value_net_default_kwargs = { - "mlp_net_kwargs": { - "layer_class": linear_layer_class, - "activation_class": ACTIVATIONS[cfg.activation], - }, - "conv_net_kwargs": {"activation_class": ACTIVATIONS[cfg.activation]}, - } - value_net_default_kwargs.update(qvalue_net_kwargs) - - in_keys_qvalue = ["pixels", "action"] - qvalue_net = DdpgCnnQNet(**value_net_default_kwargs) - else: - if in_keys is None: - in_keys_actor = ["observation_vector"] - else: - in_keys_actor = in_keys - - actor_net_kwargs_default = { - "num_cells": [cfg.actor_cells, cfg.actor_cells], - "out_features": out_features_actor, - "activation_class": ACTIVATIONS[cfg.activation], - } - actor_net_kwargs_default.update(actor_net_kwargs) - actor_net = MLP(**actor_net_kwargs_default) - out_keys_actor = ["param"] - gSDE_state_key = in_keys_actor[0] - - qvalue_net_kwargs_default = { - "num_cells": [cfg.qvalue_cells, cfg.qvalue_cells], - "out_features": 1, - "activation_class": ACTIVATIONS[cfg.activation], - } - qvalue_net_kwargs_default.update(qvalue_net_kwargs) - qvalue_net = MLP( - **qvalue_net_kwargs_default, - ) - in_keys_qvalue = in_keys_actor + ["action"] - - dist_class = TanhNormal - dist_kwargs = { - "min": action_spec.space.low, - "max": action_spec.space.high, - "tanh_loc": tanh_loc, - } - - if not gSDE: - actor_net = NormalParamWrapper( - actor_net, - scale_mapping=f"biased_softplus_{default_policy_scale}", - scale_lb=cfg.scale_lb, - ) - actor_module = SafeModule( - actor_net, - in_keys=in_keys_actor, - out_keys=["loc", "scale"] + out_keys_actor[1:], - ) - - else: - actor_module = SafeModule( - actor_net, - in_keys=in_keys_actor, - out_keys=["action"] + out_keys_actor[1:], # will be overwritten - ) - - if action_spec.domain == "continuous": - min = action_spec.space.low - max = action_spec.space.high - transform = SafeTanhTransform() - if (min != -1).any() or (max != 1).any(): - transform = d.ComposeTransform( - transform, - d.AffineTransform(loc=(max + min) / 2, scale=(max - min) / 2), - ) - else: - raise RuntimeError("cannot use gSDE with discrete actions") - - actor_module = SafeSequential( - actor_module, - SafeModule( - LazygSDEModule(transform=transform), - in_keys=["action", gSDE_state_key, "_eps_gSDE"], - out_keys=["loc", "scale", "action", "_eps_gSDE"], - ), - ) - - actor = ProbabilisticActor( - spec=action_spec, - in_keys=["loc", "scale"], - module=actor_module, - distribution_class=dist_class, - distribution_kwargs=dist_kwargs, - default_interaction_type=InteractionType.RANDOM, - return_log_prob=True, - ) - qvalue = ValueOperator( - in_keys=in_keys_qvalue, - module=qvalue_net, - ) - model = nn.ModuleList([actor, qvalue]).to(device) - - # init nets - with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - td = proof_environment.fake_tensordict() - td = td.unsqueeze(-1) - td = td.to(device) - for net in model: - net(td) - del td - return model - - @set_lazy_legacy(False) def make_dreamer( cfg: "DictConfig", # noqa: F821 diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 526b3c967e8..ccd9bb23bb3 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -662,23 +662,13 @@ def __init__( batch_size: Optional[int] = None, memmap: bool = False, device: DEVICE_TYPING = "cpu", - flatten_tensordicts: bool = None, + flatten_tensordicts: bool = False, max_dims: Optional[Sequence[int]] = None, ) -> None: self.replay_buffer = replay_buffer self.batch_size = batch_size self.memmap = memmap self.device = device - if flatten_tensordicts is None: - warnings.warn( - "flatten_tensordicts default value has now changed " - "to False for a faster execution. Make sure your " - "code is robust to this change. To silence this warning, " - "pass flatten_tensordicts= in your code. " - "This warning will be removed in v0.4.", - category=DeprecationWarning, - ) - flatten_tensordicts = True self.flatten_tensordicts = flatten_tensordicts self.max_dims = max_dims From c89c1c4730ecb2a7e666ecb781578a811e8c7f00 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 18:15:46 +0100 Subject: [PATCH 16/26] amend --- .../distributed_replay_buffer.py | 6 +- examples/memmap/memmap_speed_distributed.py | 4 +- test/_utils_internal.py | 157 ++++++++++++++++++ test/mocking_classes.py | 1 - test/test_actors.py | 16 +- test/test_collector.py | 3 +- torchrl/data/replay_buffers/storages.py | 14 +- torchrl/modules/__init__.py | 1 - torchrl/modules/models/__init__.py | 1 - torchrl/objectives/value/functional.py | 7 +- 10 files changed, 177 insertions(+), 33 deletions(-) diff --git a/examples/distributed/replay_buffers/distributed_replay_buffer.py b/examples/distributed/replay_buffers/distributed_replay_buffer.py index 0cb9aaaffbd..c7504fbf8ee 100644 --- a/examples/distributed/replay_buffers/distributed_replay_buffer.py +++ b/examples/distributed/replay_buffers/distributed_replay_buffer.py @@ -149,8 +149,10 @@ def _create_and_launch_data_collectors(self) -> None: class ReplayBufferNode(RemoteTensorDictReplayBuffer): - """Experience replay buffer node that is capable of accepting remote connections. Being a `RemoteTensorDictReplayBuffer` means all of it's public methods are remotely invokable using `torch.rpc`. - Using a LazyMemmapStorage is highly advised in distributed settings with shared storage due to the lower serialisation cost of MemmapTensors as well as the ability to specify file storage locations which can improve ability to recover from node failures. + """Experience replay buffer node that is capable of accepting remote connections. Being a `RemoteTensorDictReplayBuffer` + means all of it's public methods are remotely invokable using `torch.rpc`. + Using a LazyMemmapStorage is highly advised in distributed settings with shared storage due to the lower serialisation + cost of MemoryMappedTensors as well as the ability to specify file storage locations which can improve ability to recover from node failures. Args: capacity (int): the maximum number of elements that can be stored in the replay buffer. diff --git a/examples/memmap/memmap_speed_distributed.py b/examples/memmap/memmap_speed_distributed.py index 61c100e0e4a..ec324e7cc55 100644 --- a/examples/memmap/memmap_speed_distributed.py +++ b/examples/memmap/memmap_speed_distributed.py @@ -9,7 +9,7 @@ import configargparse import torch import torch.distributed.rpc as rpc -from tensordict import MemmapTensor +from tensordict import MemoryMappedTensor parser = configargparse.ArgumentParser() parser.add_argument("--rank", default=-1, type=int) @@ -59,7 +59,7 @@ def op_on_tensor(idx): # create tensor tensor = torch.zeros(10000, 10000) if tensortype == "memmap": - tensor = MemmapTensor(tensor) + tensor = MemoryMappedTensor.from_tensor(tensor) elif tensortype == "tensor": pass else: diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 6c267768044..b56108d6d91 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -18,6 +18,7 @@ import torch.cuda from tensordict import tensorclass, TensorDict +from torch import nn from torchrl._utils import implement_for, logger as torchrl_logger, seed_generator from torchrl.data.utils import CloudpickleWrapper @@ -498,3 +499,159 @@ def new_func(*args, **kwargs): return func(*args, **kwargs) return CloudpickleWrapper(new_func) + + +class LSTMNet(nn.Module): + """An embedder for an LSTM preceded by an MLP. + + The forward method returns the hidden states of the current state + (input hidden states) and the output, as + the environment returns the 'observation' and 'next_observation'. + + Because the LSTM kernel only returns the last hidden state, hidden states + are padded with zeros such that they have the right size to be stored in a + TensorDict of size [batch x time_steps]. + + If a 2D tensor is provided as input, it is assumed that it is a batch of data + with only one time step. This means that we explicitely assume that users will + unsqueeze inputs of a single batch with multiple time steps. + + Args: + out_features (int): number of output features. + lstm_kwargs (dict): the keyword arguments for the + :class:`~torch.nn.LSTM` layer. + mlp_kwargs (dict): the keyword arguments for the + :class:`~torchrl.modules.MLP` layer. + device (torch.device, optional): the device where the module should + be instantiated. + + Keyword Args: + lstm_backend (str, optional): one of ``"torchrl"`` or ``"torch"`` that + indeicates where the LSTM class is to be retrieved. The ``"torchrl"`` + backend (:class:`~torchrl.modules.LSTM`) is slower but works with + :func:`~torch.vmap` and should work with :func:`~torch.compile`. + Defaults to ``"torch"``. + + Examples: + >>> batch = 7 + >>> time_steps = 6 + >>> in_features = 4 + >>> out_features = 10 + >>> hidden_size = 5 + >>> net = LSTMNet( + ... out_features, + ... {"input_size": hidden_size, "hidden_size": hidden_size}, + ... {"out_features": hidden_size}, + ... ) + >>> # test single step vs multi-step + >>> x = torch.randn(batch, time_steps, in_features) # >3 dims = multi-step + >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x) + >>> x = torch.randn(batch, in_features) # 2 dims = single step + >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x) + + """ + + def __init__( + self, + out_features: int, + lstm_kwargs, + mlp_kwargs, + device=None, + *, + lstm_backend: str | None = None, + ) -> None: + warnings.warn( + "LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed in v0.4.0.", + category=DeprecationWarning, + ) + super().__init__() + lstm_kwargs.update({"batch_first": True}) + self.mlp = MLP(device=device, **mlp_kwargs) + if lstm_backend is None: + lstm_backend = "torch" + self.lstm_backend = lstm_backend + if self.lstm_backend == "torch": + LSTM = nn.LSTM + else: + from torchrl.modules.tensordict_module.rnn import LSTM + self.lstm = LSTM(device=device, **lstm_kwargs) + self.linear = nn.LazyLinear(out_features, device=device) + + def _lstm( + self, + input: torch.Tensor, + hidden0_in: torch.Tensor | None = None, + hidden1_in: torch.Tensor | None = None, + ): + squeeze0 = False + squeeze1 = False + if input.ndimension() == 1: + squeeze0 = True + input = input.unsqueeze(0).contiguous() + + if input.ndimension() == 2: + squeeze1 = True + input = input.unsqueeze(1).contiguous() + batch, steps = input.shape[:2] + + if hidden1_in is None and hidden0_in is None: + shape = (batch, steps) if not squeeze1 else (batch,) + hidden0_in, hidden1_in = [ + torch.zeros( + *shape, + self.lstm.num_layers, + self.lstm.hidden_size, + device=input.device, + dtype=input.dtype, + ) + for _ in range(2) + ] + elif hidden1_in is None or hidden0_in is None: + raise RuntimeError( + f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}" + ) + elif squeeze0: + hidden0_in = hidden0_in.unsqueeze(0) + hidden1_in = hidden1_in.unsqueeze(0) + + # we only need the first hidden state + if not squeeze1: + _hidden0_in = hidden0_in[:, 0] + _hidden1_in = hidden1_in[:, 0] + else: + _hidden0_in = hidden0_in + _hidden1_in = hidden1_in + hidden = ( + _hidden0_in.transpose(-3, -2).contiguous(), + _hidden1_in.transpose(-3, -2).contiguous(), + ) + + y0, hidden = self.lstm(input, hidden) + # dim 0 in hidden is num_layers, but that will conflict with tensordict + hidden = tuple(_h.transpose(0, 1) for _h in hidden) + y = self.linear(y0) + + out = [y, hidden0_in, hidden1_in, *hidden] + if squeeze1: + # squeezes time + out[0] = out[0].squeeze(1) + if not squeeze1: + # we pad the hidden states with zero to make tensordict happy + for i in range(3, 5): + out[i] = torch.stack( + [torch.zeros_like(out[i]) for _ in range(input.shape[1] - 1)] + + [out[i]], + 1, + ) + if squeeze0: + out = [_out.squeeze(0) for _out in out] + return tuple(out) + + def forward( + self, + input: torch.Tensor, + hidden0_in: torch.Tensor | None = None, + hidden1_in: torch.Tensor | None = None, + ): + input = self.mlp(input) + return self._lstm(input, hidden0_in, hidden1_in) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index ec9cec7fabd..ef383c51766 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -112,7 +112,6 @@ def __init__( ): super().__init__( device=kwargs.pop("device", "cpu"), - dtype=torch.get_default_dtype(), allow_done_after_reset=kwargs.pop("allow_done_after_reset", False), ) self.set_seed(seed) diff --git a/test/test_actors.py b/test/test_actors.py index ddefcea274c..560566286ae 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -63,8 +63,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions= out_keys=[("data", "action")], distribution_class=TanhDelta, distribution_kwargs={ - "min": action_spec.space.minimum, - "max": action_spec.space.maximum, + "min": action_spec.space.low, + "max": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, @@ -86,8 +86,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions= out_keys=[("data", "action")], distribution_class=TanhDelta, distribution_kwargs={ - "min": action_spec.space.minimum, - "max": action_spec.space.maximum, + "min": action_spec.space.low, + "max": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, @@ -130,8 +130,8 @@ def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions out_keys=[("data", "action")], distribution_class=TanhNormal, distribution_kwargs={ - "min": action_spec.space.minimum, - "max": action_spec.space.maximum, + "min": action_spec.space.low, + "max": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, @@ -153,8 +153,8 @@ def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions out_keys=[("data", "action")], distribution_class=TanhNormal, distribution_kwargs={ - "min": action_spec.space.minimum, - "max": action_spec.space.maximum, + "min": action_spec.space.low, + "max": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, diff --git a/test/test_collector.py b/test/test_collector.py index 230ff159c28..eee42a11fc2 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -20,6 +20,7 @@ generate_seeds, get_available_devices, get_default_devices, + LSTMNet, PENDULUM_VERSIONED, PONG_VERSIONED, retry, @@ -74,7 +75,7 @@ PARTIAL_MISSING_ERR, RandomPolicy, ) -from torchrl.modules import Actor, LSTMNet, OrnsteinUhlenbeckProcessWrapper, SafeModule +from torchrl.modules import Actor, OrnsteinUhlenbeckProcessWrapper, SafeModule # torch.set_default_dtype(torch.double) IS_WINDOWS = sys.platform == "win32" diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 527152104a9..425520f87d9 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -24,7 +24,7 @@ TensorDict, TensorDictBase, ) -from tensordict.memmap import MemmapTensor, MemoryMappedTensor +from tensordict.memmap import MemoryMappedTensor from tensordict.utils import _STRDTYPE2DTYPE from torch import multiprocessing as mp @@ -1287,7 +1287,7 @@ def __repr__(self): # Utils -def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor: +def _mem_map_tensor_as_tensor(mem_map_tensor) -> torch.Tensor: if _CKPT_BACKEND == "torchsnapshot" and not _has_ts: raise ImportError( "the checkpointing backend is set to torchsnapshot but the library is not installed. Consider installing the library or switch to another backend. " @@ -1296,16 +1296,6 @@ def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor: if isinstance(mem_map_tensor, torch.Tensor): # This will account for MemoryMappedTensors return mem_map_tensor - if _CKPT_BACKEND == "torchsnapshot": - # TorchSnapshot doesn't know how to stream MemmapTensor, so we view MemmapTensor - # as a Tensor for saving and loading purposes. This doesn't incur any copy. - return tensor_from_memoryview( - dtype=mem_map_tensor.dtype, - shape=list(mem_map_tensor.shape), - mv=memoryview(mem_map_tensor._memmap_array), - ) - elif _CKPT_BACKEND == "torch": - return mem_map_tensor._tensor def _collate_list_tensordict(x): diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index a987e701672..4a3c5e716e8 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -29,7 +29,6 @@ DreamerActor, DTActor, DuelingCnnDQNet, - LSTMNet, MLP, MultiAgentConvNet, MultiAgentMLP, diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 7b11cae9515..fb0cc0135b8 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -27,7 +27,6 @@ DTActor, DuelingCnnDQNet, DuelingMlpDQNet, - LSTMNet, MLP, OnlineDTActor, ) diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 082c0ae9e9a..2e3120cf185 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -12,7 +12,7 @@ import torch -from tensordict import MemmapTensor +from tensordict import MemoryMappedTensor __all__ = [ "generalized_advantage_estimate", @@ -59,10 +59,7 @@ def transposed_fun(*args, **kwargs): time_dim = kwargs.pop("time_dim", -2) def transpose_tensor(tensor): - if ( - not isinstance(tensor, (torch.Tensor, MemmapTensor)) - or tensor.numel() <= 1 - ): + if not isinstance(tensor, torch.Tensor) or tensor.numel() <= 1: return tensor, False if time_dim >= 0: timedim = time_dim - tensor.ndim From 27c0d857c9513cf1f70b33e3f1d87c31e93e36c1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 18:16:13 +0100 Subject: [PATCH 17/26] amend --- torchrl/envs/common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 52f42445be6..0715cabf29f 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2980,7 +2980,6 @@ class _EnvWrapper(EnvBase): def __init__( self, *args, - dtype: Optional[np.dtype] = None, device: DEVICE_TYPING = NO_DEFAULT, batch_size: Optional[torch.Size] = None, allow_done_after_reset: bool = False, @@ -2998,7 +2997,6 @@ def __init__( device = torch.device("cpu") super().__init__( device=device, - dtype=dtype, batch_size=batch_size, allow_done_after_reset=allow_done_after_reset, ) From b3f06c09ce813f938f79a651e44420c8a07625b1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 20:43:04 +0100 Subject: [PATCH 18/26] amend --- test/_utils_internal.py | 4 - test/mocking_classes.py | 1 - test/test_helpers.py | 119 ----------------------- test/test_modules.py | 135 --------------------------- torchrl/envs/model_based/common.py | 2 - torchrl/envs/model_based/dreamer.py | 3 +- torchrl/trainers/helpers/__init__.py | 4 +- 7 files changed, 3 insertions(+), 265 deletions(-) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index b56108d6d91..4e99843d16f 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -560,10 +560,6 @@ def __init__( *, lstm_backend: str | None = None, ) -> None: - warnings.warn( - "LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed in v0.4.0.", - category=DeprecationWarning, - ) super().__init__() lstm_kwargs.update({"batch_first": True}) self.mlp = MLP(device=device, **mlp_kwargs) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index ef383c51766..75769215ce5 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -925,7 +925,6 @@ def __init__( super().__init__( world_model, device=device, - dtype=dtype, batch_size=batch_size, ) self.observation_spec = CompositeSpec( diff --git a/test/test_helpers.py b/test/test_helpers.py index 46036346de5..e9d07767a49 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -50,8 +50,6 @@ DiscreteModelConfig, DreamerConfig, make_dqn_actor, - make_redq_model, - REDQModelConfig, ) TORCH_VERSION = version.parse(torch.__version__) @@ -162,123 +160,6 @@ def test_dqn_maker( proof_environment.close() -@pytest.mark.skipif(not _has_functorch, reason="functorch not installed") -@pytest.mark.skipif(not _has_hydra, reason="No hydra library found") -@pytest.mark.skipif(not _has_gym, reason="No gym library found") -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("from_pixels", [(), ("from_pixels=True", "catframes=4")]) -@pytest.mark.parametrize("gsde", [(), ("gSDE=True",)]) -@pytest.mark.parametrize("exploration", [ExplorationType.MODE, ExplorationType.RANDOM]) -def test_redq_make(device, from_pixels, gsde, exploration): - if not gsde and exploration != ExplorationType.RANDOM: - pytest.skip("no need to test this setting") - flags = list(from_pixels + gsde) - if gsde and from_pixels: - pytest.skip("gsde and from_pixels are incompatible") - - config_fields = [ - (config_field.name, config_field.type, config_field) - for config_cls in ( - EnvConfig, - REDQModelConfig, - ) - for config_field in dataclasses.fields(config_cls) - ] - - Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) - cs = ConfigStore.instance() - cs.store(name="config", node=Config) - with initialize(version_base="1.1", config_path=None): - cfg = compose(config_name="config", overrides=flags) - - env_maker = ( - ContinuousActionConvMockEnvNumpy - if from_pixels - else ContinuousActionVecMockEnv - ) - env_maker = transformed_env_constructor( - cfg, - use_env_creator=False, - custom_env_maker=env_maker, - stats={"loc": 0.0, "scale": 1.0}, - ) - proof_environment = env_maker() - - model = make_redq_model( - proof_environment, - device=device, - cfg=cfg, - ) - actor, qvalue = model - td = proof_environment.reset().to(device) - with set_exploration_type(exploration): - actor(td) - expected_keys = [ - "done", - "terminated", - "action", - "sample_log_prob", - "loc", - "scale", - "step_count", - "is_init", - ] - if len(gsde): - expected_keys += ["_eps_gSDE"] - if from_pixels: - expected_keys += [ - "hidden", - "pixels", - "pixels_orig", - ] - else: - expected_keys += ["observation_vector", "observation_orig"] - - try: - assert set(td.keys()) == set(expected_keys) - except AssertionError: - proof_environment.close() - raise - - if cfg.gSDE: - tsf_loc = actor.module[0].module[-1].module.transform(td.get("loc")) - if exploration == ExplorationType.RANDOM: - with pytest.raises(AssertionError): - torch.testing.assert_close(td.get("action"), tsf_loc) - else: - torch.testing.assert_close(td.get("action"), tsf_loc) - - qvalue(td) - expected_keys = [ - "done", - "terminated", - "action", - "sample_log_prob", - "state_action_value", - "loc", - "scale", - "step_count", - "is_init", - ] - if len(gsde): - expected_keys += ["_eps_gSDE"] - if from_pixels: - expected_keys += [ - "hidden", - "pixels", - "pixels_orig", - ] - else: - expected_keys += ["observation_vector", "observation_orig"] - try: - assert set(td.keys()) == set(expected_keys) - except AssertionError: - proof_environment.close() - raise - proof_environment.close() - del proof_environment - - @pytest.mark.parametrize("initial_seed", range(5)) def test_seed_generator(initial_seed): num_seeds = 100 diff --git a/test/test_modules.py b/test/test_modules.py index c9984e178c5..65c2d2613de 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -23,7 +23,6 @@ GRUCell, LSTM, LSTMCell, - LSTMNet, MultiAgentConvNet, MultiAgentMLP, OnlineDTActor, @@ -350,140 +349,6 @@ def test_noisy(layer_class, device, seed=0): torch.testing.assert_close(y1, y2) -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("out_features", [3, 4]) -@pytest.mark.parametrize("hidden_size", [8, 9]) -@pytest.mark.parametrize("num_layers", [1, 2]) -@pytest.mark.parametrize("has_precond_hidden", [True, False]) -def test_lstm_net( - device, - out_features, - hidden_size, - num_layers, - has_precond_hidden, - double_prec_fixture, -): - torch.manual_seed(0) - batch = 5 - time_steps = 6 - in_features = 7 - net = LSTMNet( - out_features, - { - "input_size": hidden_size, - "hidden_size": hidden_size, - "num_layers": num_layers, - }, - {"out_features": hidden_size}, - device=device, - ) - # test single step vs multi-step - x = torch.randn(batch, time_steps, in_features, device=device) - x_unbind = x.unbind(1) - tds_loop = [] - if has_precond_hidden: - hidden0_out0, hidden1_out0 = torch.randn( - 2, batch, time_steps, num_layers, hidden_size, device=device - ) - hidden0_out0[:, 1:] = 0.0 - hidden1_out0[:, 1:] = 0.0 - hidden0_out = hidden0_out0[:, 0] - hidden1_out = hidden1_out0[:, 0] - else: - hidden0_out, hidden1_out = None, None - hidden0_out0, hidden1_out0 = None, None - - for _x in x_unbind: - y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net( - _x, hidden0_out, hidden1_out - ) - td = TensorDict( - { - "y": y, - "hidden0_in": hidden0_in, - "hidden1_in": hidden1_in, - "hidden0_out": hidden0_out, - "hidden1_out": hidden1_out, - }, - [batch], - ) - tds_loop.append(td) - tds_loop = torch.stack(tds_loop, 1) - - y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net( - x, hidden0_out0, hidden1_out0 - ) - tds_vec = TensorDict( - { - "y": y, - "hidden0_in": hidden0_in, - "hidden1_in": hidden1_in, - "hidden0_out": hidden0_out, - "hidden1_out": hidden1_out, - }, - [batch, time_steps], - ) - torch.testing.assert_close(tds_vec["y"], tds_loop["y"]) - torch.testing.assert_close( - tds_vec["hidden0_out"][:, -1], tds_loop["hidden0_out"][:, -1] - ) - torch.testing.assert_close( - tds_vec["hidden1_out"][:, -1], tds_loop["hidden1_out"][:, -1] - ) - - -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("out_features", [3, 5]) -@pytest.mark.parametrize("hidden_size", [3, 5]) -def test_lstm_net_nobatch(device, out_features, hidden_size): - time_steps = 6 - in_features = 4 - net = LSTMNet( - out_features, - {"input_size": hidden_size, "hidden_size": hidden_size}, - {"out_features": hidden_size}, - device=device, - ) - # test single step vs multi-step - x = torch.randn(time_steps, in_features, device=device) - x_unbind = x.unbind(0) - tds_loop = [] - hidden0_in, hidden1_in, hidden0_out, hidden1_out = [ - None, - ] * 4 - for _x in x_unbind: - y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net( - _x, hidden0_out, hidden1_out - ) - td = TensorDict( - { - "y": y, - "hidden0_in": hidden0_in, - "hidden1_in": hidden1_in, - "hidden0_out": hidden0_out, - "hidden1_out": hidden1_out, - }, - [], - ) - tds_loop.append(td) - tds_loop = torch.stack(tds_loop, 0) - - y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x.unsqueeze(0)) - tds_vec = TensorDict( - { - "y": y, - "hidden0_in": hidden0_in, - "hidden1_in": hidden1_in, - "hidden0_out": hidden0_out, - "hidden1_out": hidden1_out, - }, - [1, time_steps], - ).squeeze(0) - torch.testing.assert_close(tds_vec["y"], tds_loop["y"]) - torch.testing.assert_close(tds_vec["hidden0_out"][-1], tds_loop["hidden0_out"][-1]) - torch.testing.assert_close(tds_vec["hidden1_out"][-1], tds_loop["hidden1_out"][-1]) - - @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("batch_size", [3, 5]) class TestPlanner: diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index c1940f75a8f..19250d72d54 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -117,13 +117,11 @@ def __init__( params: Optional[List[torch.Tensor]] = None, buffers: Optional[List[torch.Tensor]] = None, device: DEVICE_TYPING = "cpu", - dtype: Optional[Union[torch.dtype, np.dtype]] = None, batch_size: Optional[torch.Size] = None, run_type_checks: bool = False, ): super(ModelBasedEnvBase, self).__init__( device=device, - dtype=dtype, batch_size=batch_size, run_type_checks=run_type_checks, ) diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index f44c4aa025c..1a9e1898780 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -27,11 +27,10 @@ def __init__( belief_shape: Tuple[int, ...], obs_decoder: TensorDictModule = None, device: DEVICE_TYPING = "cpu", - dtype: Optional[Union[torch.dtype, np.dtype]] = None, batch_size: Optional[torch.Size] = None, ): super(DreamerEnv, self).__init__( - world_model, device=device, dtype=dtype, batch_size=batch_size + world_model, device=device, batch_size=batch_size ) self.obs_decoder = obs_decoder self.prior_shape = prior_shape diff --git a/torchrl/trainers/helpers/__init__.py b/torchrl/trainers/helpers/__init__.py index 2f7e65a4069..b09becdc15a 100644 --- a/torchrl/trainers/helpers/__init__.py +++ b/torchrl/trainers/helpers/__init__.py @@ -16,7 +16,7 @@ transformed_env_constructor, ) from .logger import LoggerConfig -from .losses import make_dqn_loss, make_redq_loss, make_target_updater -from .models import make_dqn_actor, make_dreamer, make_redq_model +from .losses import make_dqn_loss, make_target_updater +from .models import make_dqn_actor, make_dreamer from .replay_buffer import make_replay_buffer from .trainers import make_trainer From 48f86df48b536ffe315307ceecd9d89c26e851c7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 20:55:22 +0100 Subject: [PATCH 19/26] amend --- test/_utils_internal.py | 2 ++ test/test_helpers.py | 4 +--- torchrl/data/replay_buffers/storages.py | 8 -------- torchrl/objectives/value/functional.py | 1 - 4 files changed, 3 insertions(+), 12 deletions(-) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 4e99843d16f..b67d20e4055 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -36,6 +36,8 @@ # Specified for test_utils.py __version__ = "0.3" +from torchrl.modules import MLP + def CARTPOLE_VERSIONED(): # load gym diff --git a/test/test_helpers.py b/test/test_helpers.py index e9d07767a49..f468eddf6ed 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -7,11 +7,11 @@ import dataclasses import pathlib import sys - from time import sleep import pytest import torch + from _utils_internal import generate_seeds, get_default_devices from torchrl._utils import timeit @@ -38,8 +38,6 @@ FlattenObservation, TransformedEnv, ) -from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules.tensordict_module.common import _has_functorch from torchrl.trainers.helpers import transformed_env_constructor from torchrl.trainers.helpers.envs import ( EnvConfig, diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 425520f87d9..8e9e37bab85 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -27,19 +27,11 @@ from tensordict.memmap import MemoryMappedTensor from tensordict.utils import _STRDTYPE2DTYPE from torch import multiprocessing as mp - from torch.utils._pytree import LeafSpec, tree_flatten, tree_map, tree_unflatten from torchrl._utils import _CKPT_BACKEND, implement_for, logger as torchrl_logger from torchrl.data.replay_buffers.utils import _is_int, INT_CLASSES -try: - from torchsnapshot.serialization import tensor_from_memoryview - - _has_ts = True -except ImportError: - _has_ts = False - SINGLE_TENSOR_BUFFER_NAME = os.environ.get( "SINGLE_TENSOR_BUFFER_NAME", "_-single-tensor-_" ) diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 2e3120cf185..d3ad8d93ca4 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -12,7 +12,6 @@ import torch -from tensordict import MemoryMappedTensor __all__ = [ "generalized_advantage_estimate", From 167275b18cb2ef609b240c20abaa8dfa47d4fde6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 20:58:09 +0100 Subject: [PATCH 20/26] amend --- torchrl/data/replay_buffers/storages.py | 5 ----- torchrl/envs/model_based/common.py | 3 +-- torchrl/envs/model_based/dreamer.py | 3 +-- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 8e9e37bab85..a036b39103e 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1280,11 +1280,6 @@ def __repr__(self): # Utils def _mem_map_tensor_as_tensor(mem_map_tensor) -> torch.Tensor: - if _CKPT_BACKEND == "torchsnapshot" and not _has_ts: - raise ImportError( - "the checkpointing backend is set to torchsnapshot but the library is not installed. Consider installing the library or switch to another backend. " - f"Supported backends are {_CKPT_BACKEND.backends}" - ) if isinstance(mem_map_tensor, torch.Tensor): # This will account for MemoryMappedTensors return mem_map_tensor diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index 19250d72d54..f6b3f97cd4a 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -5,9 +5,8 @@ import abc import warnings -from typing import List, Optional, Union +from typing import List, Optional -import numpy as np import torch from tensordict import TensorDict from tensordict.nn import TensorDictModule diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index 1a9e1898780..5609861c75f 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -3,9 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Tuple, Union +from typing import Optional, Tuple -import numpy as np import torch from tensordict import TensorDict from tensordict.nn import TensorDictModule From 49b015df5a5413a9344cf725048c8a4d42da6dde Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 21:30:07 +0100 Subject: [PATCH 21/26] amend --- test/test_shared.py | 6 ++--- test/test_transforms.py | 45 +++++++++++++++------------------ torchrl/envs/libs/dm_control.py | 4 +-- torchrl/envs/libs/jumanji.py | 4 +-- 4 files changed, 28 insertions(+), 31 deletions(-) diff --git a/test/test_shared.py b/test/test_shared.py index 912f230e8cf..bc2638269e6 100644 --- a/test/test_shared.py +++ b/test/test_shared.py @@ -64,7 +64,7 @@ def test_shared(self, indexing_method): batch_size=[], ).share_memory_() elif indexing_method == 1: - subtd = td.get_sub_tensordict(0) + subtd = td._get_sub_tensordict(0) elif indexing_method == 2: subtd = td[0] else: @@ -182,14 +182,14 @@ def test_memmap(idx, dtype, large_scale=False): torchrl_logger.info("\nTesting writing to TD") for i in range(2): t0 = time.time() - sub_td_sm = td_sm.get_sub_tensordict(idx) + sub_td_sm = td_sm._get_sub_tensordict(idx) sub_td_sm.update_(td_to_copy) if i == 1: torchrl_logger.info(f"sm td: {time.time() - t0:4.4f} sec") torch.testing.assert_close(sub_td_sm.get("a"), td_to_copy.get("a")) t0 = time.time() - sub_td_sm = td_memmap.get_sub_tensordict(idx) + sub_td_sm = td_memmap._get_sub_tensordict(idx) sub_td_sm.update_(td_to_copy) if i == 1: torchrl_logger.info(f"memmap td: {time.time() - t0:4.4f} sec") diff --git a/test/test_transforms.py b/test/test_transforms.py index c9d2fb8c031..4396ea79c41 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -785,7 +785,7 @@ def test_transform_env_clone(self): @pytest.mark.parametrize("dim", [-1]) @pytest.mark.parametrize("N", [3, 4]) - @pytest.mark.parametrize("padding", ["zeros", "constant", "same"]) + @pytest.mark.parametrize("padding", ["constant", "same"]) def test_transform_model(self, dim, N, padding): # test equivalence between transforms within an env and within a rb key1 = "observation" @@ -838,7 +838,7 @@ def test_transform_model(self, dim, N, padding): @pytest.mark.parametrize("dim", [-1]) @pytest.mark.parametrize("N", [3, 4]) - @pytest.mark.parametrize("padding", ["same", "zeros", "constant"]) + @pytest.mark.parametrize("padding", ["same", "constant"]) @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) def test_transform_rb(self, dim, N, padding, rbclass): # test equivalence between transforms within an env and within a rb @@ -870,7 +870,7 @@ def test_transform_rb(self, dim, N, padding, rbclass): @pytest.mark.parametrize("dim", [-1]) @pytest.mark.parametrize("N", [3, 4]) - @pytest.mark.parametrize("padding", ["same", "zeros", "constant"]) + @pytest.mark.parametrize("padding", ["same", "constant"]) def test_transform_as_inverse(self, dim, N, padding): # test equivalence between transforms within an env and within a rb in_keys = ["observation", ("next", "observation")] @@ -987,7 +987,7 @@ def test_transform_no_env(self, device, d, batch_size, dim, N): assert v1 is not v2 @pytest.mark.skipif(not _has_gym, reason="gym required for this test") - @pytest.mark.parametrize("padding", ["zeros", "constant", "same"]) + @pytest.mark.parametrize("padding", ["constant", "same"]) @pytest.mark.parametrize("envtype", ["gym", "conv"]) def test_tranform_offline_against_online(self, padding, envtype): torch.manual_seed(0) @@ -1027,10 +1027,7 @@ def test_tranform_offline_against_online(self, padding, envtype): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("batch_size", [(), (1,), (1, 2)]) @pytest.mark.parametrize("d", range(2, 3)) - @pytest.mark.parametrize( - "dim", - [-3], - ) + @pytest.mark.parametrize("dim", [-3]) @pytest.mark.parametrize("N", [2, 4]) def test_transform_compose(self, device, d, batch_size, dim, N): key1 = "first key" @@ -4177,11 +4174,11 @@ def test_observationnorm( ) observation_spec = on.transform_observation_spec(observation_spec) if standard_normal: - assert (observation_spec.space.minimum == -loc / scale).all() - assert (observation_spec.space.maximum == (1 - loc) / scale).all() + assert (observation_spec.space.low == -loc / scale).all() + assert (observation_spec.space.high == (1 - loc) / scale).all() else: - assert (observation_spec.space.minimum == loc).all() - assert (observation_spec.space.maximum == scale + loc).all() + assert (observation_spec.space.low == loc).all() + assert (observation_spec.space.high == scale + loc).all() else: observation_spec = CompositeSpec( @@ -5097,9 +5094,9 @@ def test_keys_length_errors(self, in_keys, reset_keys, out_keys, batch=10): f"Could not match the env reset_keys {reset_keys} with the in_keys {in_keys}" ), ): - t.reset(td) + t._reset(td, td.empty()) else: - t.reset(td) + t._reset(td, td.empty()) class TestReward2Go(TransformBase): @@ -6149,8 +6146,8 @@ def test_transform_no_env(self, keys, batch, device): observation_spec ) assert observation_spec.shape == torch.Size([3, 16, 16]) - assert (observation_spec.space.minimum == 0).all() - assert (observation_spec.space.maximum == 1).all() + assert (observation_spec.space.low == 0).all() + assert (observation_spec.space.high == 1).all() else: observation_spec = CompositeSpec( { @@ -6198,8 +6195,8 @@ def test_transform_compose(self, keys, batch, device): observation_spec ) assert observation_spec.shape == torch.Size([3, 16, 16]) - assert (observation_spec.space.minimum == 0).all() - assert (observation_spec.space.maximum == 1).all() + assert (observation_spec.space.low == 0).all() + assert (observation_spec.space.high == 1).all() else: observation_spec = CompositeSpec( { @@ -8039,14 +8036,14 @@ def test_independent_reward_specs_from_shared_env(self): t1_reward_spec = t1.reward_spec t2_reward_spec = t2.reward_spec - assert t1_reward_spec.space.minimum == 0 - assert t1_reward_spec.space.maximum == 4 + assert t1_reward_spec.space.low == 0 + assert t1_reward_spec.space.high == 4 - assert t2_reward_spec.space.minimum == -2 - assert t2_reward_spec.space.maximum == 2 + assert t2_reward_spec.space.low == -2 + assert t2_reward_spec.space.high == 2 - assert base_env.reward_spec.space.minimum == -np.inf - assert base_env.reward_spec.space.maximum == np.inf + assert base_env.reward_spec.space.low == -np.inf + assert base_env.reward_spec.space.high == np.inf def test_allow_done_after_reset(self): base_env = ContinuousActionVecMockEnv(allow_done_after_reset=True) diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index 3e1aac917e0..96d392a76c2 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -80,8 +80,8 @@ def _dmcontrol_to_torchrl_spec_transform( shape = torch.Size([1]) return BoundedTensorSpec( shape=shape, - low=spec.minimum, - high=spec.maximum, + low=spec.low, + high=spec.high, dtype=dtype, device=device, ) diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 5e7866864cd..3c51acead4d 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -67,8 +67,8 @@ def _jumanji_to_torchrl_spec_transform( dtype = numpy_to_torch_dtype_dict[spec.dtype] return BoundedTensorSpec( shape=shape, - low=np.asarray(spec.minimum), - high=np.asarray(spec.maximum), + low=np.asarray(spec.low), + high=np.asarray(spec.high), dtype=dtype, device=device, ) From 1568b2897e5d2436083657c239fe4b394a9b738e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 21:30:54 +0100 Subject: [PATCH 22/26] amend --- torchrl/data/replay_buffers/storages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index a036b39103e..a57c0fc94f2 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -29,7 +29,7 @@ from torch import multiprocessing as mp from torch.utils._pytree import LeafSpec, tree_flatten, tree_map, tree_unflatten -from torchrl._utils import _CKPT_BACKEND, implement_for, logger as torchrl_logger +from torchrl._utils import implement_for, logger as torchrl_logger from torchrl.data.replay_buffers.utils import _is_int, INT_CLASSES SINGLE_TENSOR_BUFFER_NAME = os.environ.get( From e7988820bd155071aca5de9b14609a4b0b51a0e4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 21:41:26 +0100 Subject: [PATCH 23/26] amend --- tutorials/sphinx-tutorials/coding_dqn.py | 1 + tutorials/sphinx-tutorials/dqn_with_rnn.py | 2 +- tutorials/sphinx-tutorials/getting-started-1.py | 4 +--- tutorials/sphinx-tutorials/getting-started-5.py | 2 +- tutorials/sphinx-tutorials/torchrl_envs.py | 4 +--- 5 files changed, 5 insertions(+), 8 deletions(-) diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index 0977abfb1aa..46d0b992f69 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -104,6 +104,7 @@ try: multiprocessing.set_start_method("spawn" if is_sphinx else "fork") + mp_context = "fork" except RuntimeError: # If we can't set the method globally we can still run the parallel env with "fork" # This will fail on windows! Use "spawn" and put the script within `if __name__ == "__main__"` diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index ce849210a2c..38e53535336 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -307,7 +307,7 @@ # either by passing a string or an action-spec. This allows us to use # Categorical (sometimes called "sparse") encoding or the one-hot version of it. # -qval = QValueModule(action_space=env.action_spec) +qval = QValueModule(action_space=None, spec=env.action_spec) ###################################################################### # .. note:: diff --git a/tutorials/sphinx-tutorials/getting-started-1.py b/tutorials/sphinx-tutorials/getting-started-1.py index 27df0fafb6e..b9b5901e758 100644 --- a/tutorials/sphinx-tutorials/getting-started-1.py +++ b/tutorials/sphinx-tutorials/getting-started-1.py @@ -273,9 +273,7 @@ policy = TensorDictSequential( value_net, # writes action values in our tensordict - QValueModule( - action_space=env.action_spec - ), # Reads the "action_value" entry by default + QValueModule(spec=env.action_spec), # Reads the "action_value" entry by default ) ################################### diff --git a/tutorials/sphinx-tutorials/getting-started-5.py b/tutorials/sphinx-tutorials/getting-started-5.py index 039e15fa035..7b664c34511 100644 --- a/tutorials/sphinx-tutorials/getting-started-5.py +++ b/tutorials/sphinx-tutorials/getting-started-5.py @@ -66,7 +66,7 @@ value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64]) value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"]) -policy = Seq(value_net, QValueModule(env.action_spec)) +policy = Seq(value_net, QValueModule(spec=env.action_spec)) exploration_module = EGreedyModule( env.action_spec, annealing_num_steps=100_000, eps_init=0.5 ) diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index 1cf65516d5e..5555f788569 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -145,9 +145,7 @@ # we can just generate a random action: -policy = TensorDictModule( - functools.partial(env.action_spec.rand, env=env), in_keys=[], out_keys=["action"] -) +policy = TensorDictModule(env.action_spec.rand, in_keys=[], out_keys=["action"]) policy(reset_data) From b4e20b7d4e1bd120a0905cd1f74873f60a82d033 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 21:41:43 +0100 Subject: [PATCH 24/26] amend --- torchrl/envs/libs/dm_control.py | 4 ++-- torchrl/modules/tensordict_module/actors.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index 96d392a76c2..3e1aac917e0 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -80,8 +80,8 @@ def _dmcontrol_to_torchrl_spec_transform( shape = torch.Size([1]) return BoundedTensorSpec( shape=shape, - low=spec.low, - high=spec.high, + low=spec.minimum, + high=spec.maximum, dtype=dtype, device=device, ) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 870f68b7bf3..8561b026f3c 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -443,7 +443,7 @@ class QValueModule(TensorDictModuleBase): def __init__( self, - action_space: Optional[str], + action_space: Optional[str] = None, action_value_key: Optional[NestedKey] = None, action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None, From 22bfbdfd80b6ec0bc71696b9472295599b1d8e2c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 25 Apr 2024 11:40:50 +0100 Subject: [PATCH 25/26] amend --- docs/source/reference/trainers.rst | 2 -- tutorials/sphinx-tutorials/torchrl_envs.py | 1 - 2 files changed, 3 deletions(-) diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index 821902b2ee2..e253ad7067e 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -195,8 +195,6 @@ Builders make_collector_offpolicy make_collector_onpolicy make_dqn_loss - make_redq_loss - make_redq_model make_replay_buffer make_target_updater make_trainer diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index 5555f788569..f6a5518def7 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -8,7 +8,6 @@ .. _envs_tuto: """ -import functools ############################################################################## # From fae6da85a38278959a3b0a63d74cb6ec1ad8dffe Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 25 Apr 2024 13:06:41 +0100 Subject: [PATCH 26/26] amend --- .github/workflows/docs.yml | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index d48ca78a071..5480858cbcb 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -92,6 +92,8 @@ jobs: upload: needs: build-docs + if: github.repository == 'pytorch/rl' && github.event_name == 'push' && + ((github.ref_type == 'branch' && github.ref_name == 'main') || github.ref_type == 'tag') permissions: contents: write uses: pytorch/test-infra/.github/workflows/linux_job.yml@main @@ -109,8 +111,9 @@ jobs: if [[ "${REF_TYPE}" == branch ]]; then if [[ "${REF_NAME}" == main ]]; then TARGET_FOLDER="${REF_NAME}" - else - TARGET_FOLDER="release-doc" + # Bebug: + # else + # TARGET_FOLDER="release-doc" fi elif [[ "${REF_TYPE}" == tag ]]; then case "${REF_NAME}" in @@ -135,7 +138,9 @@ jobs: rsync -a "${RUNNER_ARTIFACT_DIR}"/ "${TARGET_FOLDER}" git add "${TARGET_FOLDER}" || true - if [[ "${TARGET_FOLDER}" == "main" ]] || [[ "${TARGET_FOLDER}" == "release-doc" ]]; then + # Debug + # if [[ "${TARGET_FOLDER}" == "main" ]] || [[ "${TARGET_FOLDER}" == "release-doc" ]]; then + if [[ "${TARGET_FOLDER}" == "main" ]] ; then mkdir -p _static rm -rf _static/* cp -r "${TARGET_FOLDER}"/_static/* _static