diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index bcf6b5066c3..5480858cbcb 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -108,37 +108,44 @@ jobs: REF_TYPE=${{ github.ref_type }} REF_NAME=${{ github.ref_name }} - # TODO: adopt this behaviour - # if [[ "${REF_TYPE}" == branch ]]; then - # TARGET_FOLDER="${REF_NAME}" - # elif [[ "${REF_TYPE}" == tag ]]; then - # case "${REF_NAME}" in - # *-rc*) - # echo "Aborting upload since this is an RC tag: ${REF_NAME}" - # exit 0 - # ;; - # *) - # # Strip the leading "v" as well as the trailing patch version. For example: - # # 'v0.15.2' -> '0.15' - # TARGET_FOLDER=$(echo "${REF_NAME}" | sed 's/v\([0-9]\+\)\.\([0-9]\+\)\.[0-9]\+/\1.\2/') - # ;; - # esac - # fi - TARGET_FOLDER="./" + if [[ "${REF_TYPE}" == branch ]]; then + if [[ "${REF_NAME}" == main ]]; then + TARGET_FOLDER="${REF_NAME}" + # Bebug: + # else + # TARGET_FOLDER="release-doc" + fi + elif [[ "${REF_TYPE}" == tag ]]; then + case "${REF_NAME}" in + *-rc*) + echo "Aborting upload since this is an RC tag: ${REF_NAME}" + exit 0 + ;; + *) + # Strip the leading "v" as well as the trailing patch version. For example: + # 'v0.15.2' -> '0.15' + TARGET_FOLDER=$(echo "${REF_NAME}" | sed 's/v\([0-9]\+\)\.\([0-9]\+\)\.[0-9]\+/\1.\2/') + ;; + esac + fi + echo "Target Folder: ${TARGET_FOLDER}" - # mkdir -p "${TARGET_FOLDER}" - # rm -rf "${TARGET_FOLDER}"/* + mkdir -p "${TARGET_FOLDER}" + rm -rf "${TARGET_FOLDER}"/* + echo $(ls "${RUNNER_ARTIFACT_DIR}") rsync -a "${RUNNER_ARTIFACT_DIR}"/ "${TARGET_FOLDER}" git add "${TARGET_FOLDER}" || true - # if [[ "${TARGET_FOLDER}" == main ]]; then - # mkdir -p _static - # rm -rf _static/* - # cp -r "${TARGET_FOLDER}"/_static/* _static - # git add _static || true - # fi + # Debug + # if [[ "${TARGET_FOLDER}" == "main" ]] || [[ "${TARGET_FOLDER}" == "release-doc" ]]; then + if [[ "${TARGET_FOLDER}" == "main" ]] ; then + mkdir -p _static + rm -rf _static/* + cp -r "${TARGET_FOLDER}"/_static/* _static + git add _static || true + fi git config user.name 'pytorchbot' git config user.email 'soumith+bot@pytorch.org' diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index 8ad4ecef5ef..74bd058b8f0 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -82,7 +82,7 @@ In each of these cases, the last dimension (``T`` for ``time``) is adapted such that the batch size equals the ``frames_per_batch`` argument passed to the collector. -.. warning:: :class:`~torchrl.collectors.collectors.MultiSyncDataCollector` should not be +.. warning:: :class:`~torchrl.collectors.MultiSyncDataCollector` should not be used with ``cat_results=0``, as the data will be stacked along the batch dimension with batched environment, or the time dimension for single environments, which can introduce some confusion when swapping one with the other. @@ -91,12 +91,12 @@ collector. better interchangeability between configurations, collector classes and other components. -Whereas :class:`~torchrl.collectors.collectors.MultiSyncDataCollector` +Whereas :class:`~torchrl.collectors.MultiSyncDataCollector` has a dimension corresponding to the number of sub-collectors being run (``B``), -:class:`~torchrl.collectors.collectors.MultiaSyncDataCollector` doesn't. This -is easily understood when considering that :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector` +:class:`~torchrl.collectors.MultiaSyncDataCollector` doesn't. This +is easily understood when considering that :class:`~torchrl.collectors.MultiaSyncDataCollector` delivers batches of data on a first-come, first-serve basis, whereas -:class:`~torchrl.collectors.collectors.MultiSyncDataCollector` gathers data from +:class:`~torchrl.collectors.MultiSyncDataCollector` gathers data from each sub-collector before delivering it. Collectors and replay buffers interoperability @@ -168,7 +168,7 @@ batches written in the buffer won't come from the same source (thereby interrupt Single node data collectors --------------------------- -.. currentmodule:: torchrl.collectors.collectors +.. currentmodule:: torchrl.collectors .. autosummary:: :toctree: generated/ @@ -178,7 +178,6 @@ Single node data collectors SyncDataCollector MultiSyncDataCollector MultiaSyncDataCollector - RandomPolicy aSyncDataCollector diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 5c39c5a1349..5d4b6d0b7b5 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -849,14 +849,15 @@ Helpers :toctree: generated/ :template: rl_template_fun.rst - step_mdp - get_available_libraries - set_exploration_mode #deprecated - set_exploration_type + RandomPolicy + check_env_specs exploration_mode #deprecated exploration_type - check_env_specs + get_available_libraries make_composite_from_td + set_exploration_mode #deprecated + set_exploration_type + step_mdp terminated_or_truncated Domain-specific diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index 821902b2ee2..e253ad7067e 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -195,8 +195,6 @@ Builders make_collector_offpolicy make_collector_onpolicy make_dqn_loss - make_redq_loss - make_redq_model make_replay_buffer make_target_updater make_trainer diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index bc512a585b7..e2c0b97c2fc 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -31,6 +31,8 @@ WriterEnsemble, ) from .rlhf import ( + AdaptiveKLController, + ConstantKLController, create_infinite_iterator, get_dataloader, PairwiseDataset, diff --git a/torchrl/data/rlhf/__init__.py b/torchrl/data/rlhf/__init__.py index 93e232bf0ac..f0db092f2d1 100644 --- a/torchrl/data/rlhf/__init__.py +++ b/torchrl/data/rlhf/__init__.py @@ -11,4 +11,4 @@ ) from .prompt import PromptData, PromptTensorDictTokenizer from .reward import PairwiseDataset, RewardData -from .utils import RolloutFromModel +from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b3026da35ca..602f66bbd97 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1568,6 +1568,7 @@ def _shutdown_workers(self) -> None: if self._verbose: torchrl_logger.info(f"closing {i}") channel.send(("close", None)) + for i in range(self.num_workers): self._events[i].wait(self._timeout) self._events[i].clear() diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 5e7866864cd..3c51acead4d 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -67,8 +67,8 @@ def _jumanji_to_torchrl_spec_transform( dtype = numpy_to_torch_dtype_dict[spec.dtype] return BoundedTensorSpec( shape=shape, - low=np.asarray(spec.minimum), - high=np.asarray(spec.maximum), + low=np.asarray(spec.low), + high=np.asarray(spec.high), dtype=dtype, device=device, ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index db742f31181..c6583349948 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3120,8 +3120,11 @@ def unfold_done(done, N): reset_vals = reset_vals[1:] reps.extend([reset_vals[0]] * int(j)) j_ = j - reps = torch.stack(reps) - data = torch.masked_scatter(data, done_mask_expand, reps.reshape(-1)) + if reps: + reps = torch.stack(reps) + data = torch.masked_scatter( + data, done_mask_expand, reps.reshape(-1) + ) if first_val is not None: # Aggregate reset along last dim diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 17166453cba..95b196e4b2e 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -339,7 +339,7 @@ def _loss_value( batch_size=self.target_actor_network_params.batch_size, device=self.target_actor_network_params.device, ) - with target_params.to_module(self.value_estimator): + with target_params.to_module(self.actor_critic): target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) # Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function` diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index dbce0c29804..46d0b992f69 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -104,17 +104,20 @@ try: multiprocessing.set_start_method("spawn" if is_sphinx else "fork") + mp_context = "fork" except RuntimeError: + # If we can't set the method globally we can still run the parallel env with "fork" + # This will fail on windows! Use "spawn" and put the script within `if __name__ == "__main__"` + mp_context = "fork" pass - # sphinx_gallery_end_ignore import os import uuid import torch from torch import nn -from torchrl.collectors import MultiaSyncDataCollector +from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector from torchrl.data import LazyMemmapStorage, MultiStep, TensorDictReplayBuffer from torchrl.envs import ( EnvCreator, @@ -217,20 +220,26 @@ def is_notebook() -> bool: def make_env( parallel=False, obs_norm_sd=None, + num_workers=1, ): if obs_norm_sd is None: obs_norm_sd = {"standard_normal": True} if parallel: + + def maker(): + return GymEnv( + "CartPole-v1", + from_pixels=True, + pixels_only=True, + device=device, + ) + base_env = ParallelEnv( num_workers, - EnvCreator( - lambda: GymEnv( - "CartPole-v1", - from_pixels=True, - pixels_only=True, - device=device, - ) - ), + EnvCreator(maker), + # Don't create a sub-process if we have only one worker + serial_for_single=True, + mp_start_method=mp_context, ) else: base_env = GymEnv( @@ -279,6 +288,7 @@ def get_norm_stats(): # ``C=4`` (because of :class:`~torchrl.envs.CatFrames`). print("state dict of the observation norm:", obs_norm_sd) test_env.close() + del test_env return obs_norm_sd @@ -426,8 +436,15 @@ def get_collector( total_frames, device, ): - cls = MultiaSyncDataCollector - env_arg = [make_env(parallel=True, obs_norm_sd=stats)] * num_collectors + # We can't use nested child processes with mp_start_method="fork" + if is_fork: + cls = SyncDataCollector + env_arg = make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers) + else: + cls = MultiaSyncDataCollector + env_arg = [ + make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers) + ] * num_collectors data_collector = cls( env_arg, policy=actor_explore, diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index a2b2b12b562..38e53535336 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -307,7 +307,7 @@ # either by passing a string or an action-spec. This allows us to use # Categorical (sometimes called "sparse") encoding or the one-hot version of it. # -qval = QValueModule(action_space=env.action_spec) +qval = QValueModule(action_space=None, spec=env.action_spec) ###################################################################### # .. note:: @@ -412,7 +412,7 @@ # utd = 16 -pbar = tqdm.tqdm(total=1_000_000) +pbar = tqdm.tqdm(total=collector.total_frames) longest = 0 traj_lens = [] diff --git a/tutorials/sphinx-tutorials/getting-started-1.py b/tutorials/sphinx-tutorials/getting-started-1.py index 27df0fafb6e..b9b5901e758 100644 --- a/tutorials/sphinx-tutorials/getting-started-1.py +++ b/tutorials/sphinx-tutorials/getting-started-1.py @@ -273,9 +273,7 @@ policy = TensorDictSequential( value_net, # writes action values in our tensordict - QValueModule( - action_space=env.action_spec - ), # Reads the "action_value" entry by default + QValueModule(spec=env.action_spec), # Reads the "action_value" entry by default ) ################################### diff --git a/tutorials/sphinx-tutorials/getting-started-5.py b/tutorials/sphinx-tutorials/getting-started-5.py index 039e15fa035..7b664c34511 100644 --- a/tutorials/sphinx-tutorials/getting-started-5.py +++ b/tutorials/sphinx-tutorials/getting-started-5.py @@ -66,7 +66,7 @@ value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64]) value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"]) -policy = Seq(value_net, QValueModule(env.action_spec)) +policy = Seq(value_net, QValueModule(spec=env.action_spec)) exploration_module = EGreedyModule( env.action_spec, annealing_num_steps=100_000, eps_init=0.5 ) diff --git a/tutorials/sphinx-tutorials/multi_task.py b/tutorials/sphinx-tutorials/multi_task.py index 68cb995a1a3..4475fb6a9e4 100644 --- a/tutorials/sphinx-tutorials/multi_task.py +++ b/tutorials/sphinx-tutorials/multi_task.py @@ -12,6 +12,8 @@ # sphinx_gallery_start_ignore import warnings +from tensordict import LazyStackedTensorDict + warnings.filterwarnings("ignore") from torch import multiprocessing @@ -31,7 +33,6 @@ # sphinx_gallery_end_ignore -import torch from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn @@ -77,9 +78,9 @@ tdreset1 = env1.reset() tdreset2 = env2.reset() -# In TorchRL, stacking is done in a lazy manner: the original tensordicts +# With LazyStackedTensorDict, stacking is done in a lazy manner: the original tensordicts # can still be recovered by indexing the main tensordict -tdreset = torch.stack([tdreset1, tdreset2], 0) +tdreset = LazyStackedTensorDict.lazy_stack([tdreset1, tdreset2], 0) assert tdreset[0] is tdreset1 ############################################################################### diff --git a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py index e04f9a6e3aa..6de4eb2e5a0 100644 --- a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py +++ b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py @@ -4,13 +4,6 @@ =========================================================================== **Author**: `Matteo Bettini `_ -.. note:: - - If you are interested in Multi-Agent Reinforcement Learning (MARL) in - TorchRL, check out - `BenchMARL `__: a benchmarking library where you - can train and compare MARL sota-implementations, tasks, and models using TorchRL! - This tutorial demonstrates how to use PyTorch and TorchRL to solve a Competitive Multi-Agent Reinforcement Learning (MARL) problem. @@ -141,6 +134,12 @@ from tqdm import tqdm +# Check if we're building the doc, in which case disable video rendering +try: + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + ###################################################################### # Define Hyperparameters # ---------------------- @@ -879,7 +878,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase: # logger `video_logger`. Note that this code may require some external dependencies such as torchvision. # -if use_vmas: +if use_vmas and not is_sphinx: # Replace tmpdir with any desired path where the video should be saved with tempfile.TemporaryDirectory() as tmpdir: video_logger = CSVLogger("vmas_logs", tmpdir, video_format="mp4") diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index 0cdb5b70e22..29b9dbf9f9f 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -4,13 +4,6 @@ =============================================================== **Author**: `Matteo Bettini `_ -.. note:: - - If you are interested in Multi-Agent Reinforcement Learning (MARL) in - TorchRL, check out - `BenchMARL `__: a benchmarking library where you - can train and compare MARL sota-implementations, tasks, and models using TorchRL! - This tutorial demonstrates how to use PyTorch and :py:mod:`torchrl` to solve a Multi-Agent Reinforcement Learning (MARL) problem. diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 3c0bad89e70..4070707de08 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -779,6 +779,7 @@ def assert0(x): ), Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64), GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), + UnsqueezeTransform(-4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ) rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(size), transform=t, batch_size=16) @@ -787,7 +788,7 @@ def assert0(x): ###################################################################### -# Let us sample one element from the buffer. The shape of the transformed +# Let us sample one batch from the buffer. The shape of the transformed # pixel keys should have a length of 4 along the 4th dimension starting from # the end: # diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index ccd036c546a..b8b5e448b41 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -438,6 +438,7 @@ base_env = ParallelEnv( 4, lambda: GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False), + mp_start_method="fork", # This will break on Windows machines! Remove and decorate with if __name__ == "__main__" ) env = TransformedEnv( base_env, Compose(StepCounter(), ToTensorImage()) @@ -730,13 +731,17 @@ def exec_sequence(params, data): from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector -from torchrl.envs import EnvCreator, ParallelEnv +from torchrl.envs import EnvCreator, SerialEnv from torchrl.envs.libs.gym import GymEnv ############################################################################### # EnvCreator makes sure that we can send a lambda function from process to process +# We use a SerialEnv for simplicity, but for larger jobs a ParallelEnv would be better suited. -parallel_env = ParallelEnv(3, EnvCreator(lambda: GymEnv("Pendulum-v1"))) +parallel_env = SerialEnv( + 3, + EnvCreator(lambda: GymEnv("Pendulum-v1")), +) create_env_fn = [parallel_env, parallel_env] actor_module = nn.Linear(3, 1) diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index 4c792d44b80..f6a5518def7 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -8,6 +8,7 @@ .. _envs_tuto: """ + ############################################################################## # # Environments play a crucial role in RL settings, often somewhat similar to @@ -38,6 +39,8 @@ # sphinx_gallery_start_ignore import warnings +from tensordict.nn import TensorDictModule + warnings.filterwarnings("ignore") from torch import multiprocessing @@ -133,31 +136,29 @@ torch.manual_seed(0) # make sure that all torch code is also reproductible env.set_seed(0) -tensordict = env.reset() -print(tensordict) +reset_data = env.reset() +print("reset data", reset_data) ############################################################################### # We can now execute a step in the environment. Since we don't have a policy, # we can just generate a random action: -def policy(tensordict, env=env): - tensordict.set("action", env.action_spec.rand()) - return tensordict +policy = TensorDictModule(env.action_spec.rand, in_keys=[], out_keys=["action"]) -policy(tensordict) -tensordict_out = env.step(tensordict) +policy(reset_data) +tensordict_out = env.step(reset_data) ############################################################################### # By default, the tensordict returned by ``step`` is the same as the input... -assert tensordict_out is tensordict +assert tensordict_out is reset_data ############################################################################### # ... but with new keys -tensordict +tensordict_out ############################################################################### # What we just did (a random step using ``action_spec.rand()``) can also be @@ -175,13 +176,14 @@ def policy(tensordict, env=env): from torchrl.envs.utils import step_mdp -tensordict.set("some other key", torch.randn(1)) -tensordict_tprime = step_mdp(tensordict) +tensordict_out.set("some other key", torch.randn(1)) +tensordict_tprime = step_mdp(tensordict_out) print(tensordict_tprime) print( ( - tensordict_tprime.get("observation") == tensordict.get(("next", "observation")) + tensordict_tprime.get("observation") + == tensordict_out.get(("next", "observation")) ).all() ) @@ -548,7 +550,8 @@ def env_make(): ############################################################################### -parallel_env.start() +if parallel_env.is_closed: + parallel_env.start() foo_list = parallel_env.foo foo_list # needs to be instantiated, for instance using list