From 7cc7c6141d8c5e98b4fa232fa911332dbcd7fa99 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 14 Nov 2022 17:56:27 +0100 Subject: [PATCH 01/33] a2c --- a2c/a2c.py | 222 +++++++++++++++++++++++++++ a2c/config.yaml | 76 +++++++++ torchrl/objectives/__init__.py | 1 + torchrl/objectives/a2c.py | 143 +++++++++++++++++ torchrl/trainers/helpers/__init__.py | 1 + torchrl/trainers/helpers/losses.py | 41 +++++ 6 files changed, 484 insertions(+) create mode 100644 a2c/a2c.py create mode 100644 a2c/config.yaml create mode 100644 torchrl/objectives/a2c.py diff --git a/a2c/a2c.py b/a2c/a2c.py new file mode 100644 index 00000000000..54aee3398c8 --- /dev/null +++ b/a2c/a2c.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses +import os +import pathlib +import uuid +from datetime import datetime + +import hydra +import torch.cuda +from hydra.core.config_store import ConfigStore +from torchrl.envs import ParallelEnv, EnvCreator +from torchrl.envs.transforms import RewardScaling, TransformedEnv +from torchrl.envs.utils import set_exploration_mode +from torchrl.objectives.value import GAE +from torchrl.record import VideoRecorder +from torchrl.trainers.helpers.collectors import ( + make_collector_onpolicy, + OnPolicyCollectorConfig, +) +from torchrl.trainers.helpers.envs import ( + correct_for_frame_skip, + get_stats_random_rollout, + parallel_env_constructor, + transformed_env_constructor, + EnvConfig, +) +from torchrl.trainers.helpers.logger import LoggerConfig +from torchrl.trainers.helpers.losses import make_a2c_loss, A2CLossConfig +from torchrl.trainers.helpers.models import ( + make_ppo_model as make_a2c_model, + PPOModelConfig as A2CModelConfig, +) +from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig + +config_fields = [ + (config_field.name, config_field.type, config_field) + for config_cls in ( + TrainerConfig, + OnPolicyCollectorConfig, + EnvConfig, + A2CLossConfig, + A2CModelConfig, + LoggerConfig, + ) + for config_field in dataclasses.fields(config_cls) +] + +Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) +cs = ConfigStore.instance() +cs.store(name="config", node=Config) + + +@hydra.main(version_base=None, config_path=".", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + + cfg = correct_for_frame_skip(cfg) + + if not isinstance(cfg.reward_scaling, float): + cfg.reward_scaling = 1.0 + + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda:0") + ) + + exp_name = "_".join( + [ + "A2C", + cfg.exp_name, + str(uuid.uuid4())[:8], + datetime.now().strftime("%y_%m_%d-%H_%M_%S"), + ] + ) + if cfg.logger == "tensorboard": + from torchrl.trainers.loggers.tensorboard import TensorboardLogger + + logger = TensorboardLogger(log_dir="a2c_logging", exp_name=exp_name) + elif cfg.logger == "csv": + from torchrl.trainers.loggers.csv import CSVLogger + + logger = CSVLogger(log_dir="a2c_logging", exp_name=exp_name) + elif cfg.logger == "wandb": + from torchrl.trainers.loggers.wandb import WandbLogger + + logger = WandbLogger(log_dir="a2c_logging", exp_name=exp_name) + elif cfg.logger == "mlflow": + from torchrl.trainers.loggers.mlflow import MLFlowLogger + + logger = MLFlowLogger( + tracking_uri=pathlib.Path(os.path.abspath("a2c_logging")).as_uri(), + exp_name=exp_name, + ) + video_tag = exp_name if cfg.record_video else "" + + stats = None + if not cfg.vecnorm and cfg.norm_stats: + proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)() + stats = get_stats_random_rollout( + cfg, + proof_env, + key="next_pixels" if cfg.from_pixels else "next_observation_vector", + ) + # make sure proof_env is closed + proof_env.close() + elif cfg.from_pixels: + stats = {"loc": 0.5, "scale": 0.5} + proof_env = transformed_env_constructor( + cfg=cfg, use_env_creator=False, stats=stats + )() + + model = make_a2c_model( + proof_env, + cfg=cfg, + device=device, + ) + actor_model = model.get_policy_operator() + + loss_module = make_a2c_loss(model, cfg) + if cfg.gSDE: + with torch.no_grad(), set_exploration_mode("random"): + # get dimensions to build the parallel env + proof_td = model(proof_env.reset().to(device)) + action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:] + del proof_td + else: + action_dim_gsde, state_dim_gsde = None, None + + proof_env.close() + create_env_fn = parallel_env_constructor( + cfg=cfg, + stats=stats, + action_dim_gsde=action_dim_gsde, + state_dim_gsde=state_dim_gsde, + ) + + collector = make_collector_onpolicy( + make_env=create_env_fn, + actor_model_explore=actor_model, + cfg=cfg, + # make_env_kwargs=[ + # {"device": device} if device >= 0 else {} + # for device in cfg.env_rendering_devices + # ], + ) + + recorder = transformed_env_constructor( + cfg, + video_tag=video_tag, + norm_obs_only=True, + stats=stats, + logger=logger, + use_env_creator=False, + )() + + # remove video recorder from recorder to have matching state_dict keys + if cfg.record_video: + recorder_rm = TransformedEnv(recorder.base_env) + for transform in recorder.transform: + if not isinstance(transform, VideoRecorder): + recorder_rm.append_transform(transform) + else: + recorder_rm = recorder + + if isinstance(create_env_fn, ParallelEnv): + recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"]) + create_env_fn.close() + elif isinstance(create_env_fn, EnvCreator): + recorder_rm.load_state_dict(create_env_fn().state_dict()) + else: + recorder_rm.load_state_dict(create_env_fn.state_dict()) + + # reset reward scaling + for t in recorder.transform: + if isinstance(t, RewardScaling): + t.scale.fill_(1.0) + t.loc.fill_(0.0) + + trainer = make_trainer( + collector, + loss_module, + recorder, + None, + actor_model, + None, + logger, + cfg, + ) + if cfg.loss == "kl": + trainer.register_op("pre_optim_steps", loss_module.reset) + + if not cfg.advantage_in_loss: + critic_model = model.get_value_operator() + advantage = GAE( + cfg.gamma, + cfg.lmbda, + value_network=critic_model, + average_rewards=True, + gradient_mode=False, + ) + trainer.register_op( + "process_optim_batch", + advantage, + ) + trainer._process_optim_batch_ops = [ + trainer._process_optim_batch_ops[-1], + *trainer._process_optim_batch_ops[:-1], + ] + + final_seed = collector.set_seed(cfg.seed) + print(f"init seed: {cfg.seed}, final seed: {final_seed}") + + trainer.train() + return (logger.log_dir, trainer._log_dict) + + +if __name__ == "__main__": + main() diff --git a/a2c/config.yaml b/a2c/config.yaml new file mode 100644 index 00000000000..daea44ddb64 --- /dev/null +++ b/a2c/config.yaml @@ -0,0 +1,76 @@ +# Environment +env_library: gym # env_library used for the simulated environment. +env_name: HalfCheetah-v4 name of the environment to be created. Default=Humanoid-v2 +env_task: run # task (if any) for the environment. +from_pixels: False # whether the environment output should be state vector(s) (default) or the pixels. +frame_skip: 1 # frame_skip for the environment. +reward_scaling: 1.0 # scale of the reward. +reward_loc: 0.0 # location of the reward. +init_env_steps: 1000 # number of random steps to compute normalizing constants +vecnorm: True # Normalizes the environment observation and reward outputs with the running statistics obtained across processes. +norm_rewards: False # If True, rewards will be normalized on the fly. +norm_stats: True # Deactivates the normalization based on random collection of data. +noops: 0 # number of random steps to do after reset. Default is 0 +catframes: 0 # Number of frames to concatenate through time. Default is 0 (do not use CatFrames). +center_crop: False # center crop size. +grayscale: True # Disables grayscale transform. +max_frames_per_traj: 1000 # Number of steps before a reset of the environment is called (if it has not been flagged as done before). +batch_transform: False # if True, the transforms will be applied to the parallel env, and not to each individual env.\ +image_size: 84 # if True and environment has discrete action space, then it is encoded as categorical values rather than one-hot. +categorical_action_encoding: False + +# Logger +logger: wandb # recorder type to be used. One of 'tensorboard', 'wandb' or 'csv' +record_video: False # whether a video of the task should be rendered during logging. +no_video: True # whether a video of the task should be rendered during logging. +exp_name: "" # experiment name. Used for logging directory. +record_interval: 1000 # number of batch collections in between two collections of validation rollouts. Default=1000. +record_frames: 1000 # number of steps in validation rollouts. " "Default=1000. +recorder_log_keys: ["reward"] # Keys to log in the recorder +offline_logging: True # If True, Wandb will do the logging offline + +# Collector +collector_devices: ["cpu"] # device on which the data collector should store the trajectories to be passed to this script. +pin_memory: False # if True, the data collector will call pin_memory before dispatching tensordicts onto the passing device +init_with_lag: False # if True, the first trajectory will be truncated earlier at a random step. +frames_per_batch: 128 # Number of steps executed in the environment per collection. +total_frames: 1_000_000 # total number of frames collected for training. Does account for frame_skip. +num_workers: 4 # Number of workers used for data collection. +env_per_collector: 4 # Number of environments per collector. If the env_per_collector is in the range: +# 1 < env_per_collector <= num_workers, then the collector runs +# ceil(num_workers/env_per_collector) in parallel and executes the policy steps synchronously +# for each of these parallel wrappers. If env_per_collector=num_workers, no parallel wrapper is created +seed: 42 # seed used for the environment, pytorch and numpy. +exploration_mode: "" # exploration mode of the data collector. +async_collection: False # Whether data collection should be done asynchrously. + +# Model +gSDE: True # if True, exploration is achieved using the gSDE technique. +tanh_loc: False # if True, uses a Tanh-Normal transform for the policy location of the form +# upscale * tanh(loc/upscale) (only available with TanhTransform and TruncatedGaussian distributions) +default_policy_scale: 1.0 # Default policy scale parameter +distribution: tanh_normal # if True, uses a Tanh-Normal-Tanh distribution for the policy +lstm: False # if True, uses an LSTM for the policy. +shared_mapping: False # if True, the first layers of the actor-critic are shared. + +# Objective +entropy_coef: 0.0 # Entropy factor for the A2C loss +critic_coef: 0.4 # Critic factor for the A2C loss +critic_loss_function: smooth_l1 # loss function for the value network. Either one of l1, l2 or smooth_l1 (default). +advantage_in_loss: False # if True, the advantage is computed on the sub-batch + +# Trainer +optim_steps_per_batch: 1 # Number of optimization steps in between two collection of data. +optimizer: "adam" # Optimizer to be used. +lr_scheduler: "cosine" # LR scheduler. +selected_keys: null # a list of strings that indicate the data that should be kept from the data collector. +batch_size: 128 # batch size of the TensorDict retrieved from the replay buffer. Default=256. +log_interval: 1 # logging interval, in terms of optimization steps. Default=10000. +lr: 5e-4 # Learning rate used for the optimizer. Default=3e-4. +weight_decay: 0.0 # Weight-decay to be used with the optimizer. Default=0.0. +clip_norm: 1000.0 # value at which the total gradient norm / single derivative should be clipped. Default=1000.0 +clip_grad_norm: False # if called, the gradient will be clipped based on its L2 norm. Otherwise, single gradient values will be clipped to the desired threshold. +normalize_rewards_online: False # Computes the running statistics of the rewards and normalizes them before they are passed to the loss module. +normalize_rewards_online_scale: 1.0 # Final scale of the normalized rewards. +normalize_rewards_online_decay: 0.9999 # Decay of the reward moving averaging +sub_traj_len: -1 # length of the trajectories that sub-samples must have in online settings. \ No newline at end of file diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 4f533862694..d3b99bb23e3 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -7,6 +7,7 @@ from .ddpg import DDPGLoss from .dqn import DQNLoss, DistributionalDQNLoss from .dreamer import DreamerValueLoss, DreamerActorLoss, DreamerModelLoss +from .a2c import A2CLoss from .ppo import PPOLoss, ClipPPOLoss, KLPENPPOLoss from .redq import REDQLoss from .sac import SACLoss diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py new file mode 100644 index 00000000000..fdd9d836db1 --- /dev/null +++ b/torchrl/objectives/a2c.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import warnings +from typing import Callable, Optional, Tuple + +import torch +from torch import distributions as d + +from tensordict.tensordict import TensorDictBase, TensorDict +from torchrl.modules import TensorDictModule +from torchrl.objectives.utils import distance_loss +from torchrl.modules.tensordict_module import ProbabilisticTensorDictModule +from torchrl.objectives.common import LossModule + + +class A2CLoss(LossModule): + """TorchRL implementation of the A2C loss. + + A2C (Advantage Actor Critic) is a model-free, online RL algorithm that uses parallel rollouts of n steps to + update the policy, relying on the REINFORCE estimator to compute the gradient. It also adds an entropy term to the + objective function to improve exploration. + + For more details regarding A2C, refer to: "Asynchronous Methods for Deep Reinforcment Learning", + https://arxiv.org/abs/1602.01783v2 + + Args: + actor (ProbabilisticTensorDictModule): policy operator. + critic (ValueOperator): value operator. + advantage_key (str): the input tensordict key where the advantage is expected to be written. + default: "advantage" + advantage_diff_key (str): the input tensordict key where advantage_diff is expected to be written. + default: "value_error" + entropy_coef (float): the weight of the entropy loss. + critic_coef (float): the weight of the critic loss. + gamma (scalar): a discount factor for return computation. + loss_function_type (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". + advantage_module (nn.Module): TensorDictModule used to compute tha advantage function. + """ + + def __init__( + self, + actor: ProbabilisticTensorDictModule, + critic: TensorDictModule, + advantage_key: str = "advantage", + advantage_diff_key: str = "value_error", + entropy_bonus: bool = True, + samples_mc_entropy: int = 1, + entropy_coef: float = 0.01, + critic_coef: float = 1.0, + gamma: float = 0.99, + loss_critic_type: str = "smooth_l1", + advantage_module: Callable[[TensorDictBase], TensorDictBase] = None, + ): + super().__init__() + self.convert_to_functional(actor, "actor") + self.convert_to_functional(critic, "critic", compare_against=self.actor_params) + self.advantage_key = advantage_key + self.advantage_diff_key = advantage_diff_key + self.samples_mc_entropy = samples_mc_entropy + self.entropy_bonus = entropy_bonus and entropy_coef + self.register_buffer( + "entropy_coef", torch.tensor(entropy_coef, device=self.device) + ) + self.register_buffer( + "critic_coef", torch.tensor(critic_coef, device=self.device) + ) + self.register_buffer("gamma", torch.tensor(gamma, device=self.device)) + self.loss_critic_type = loss_critic_type + self.advantage_module = advantage_module.to(self.device) + + def reset(self) -> None: + pass + + def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: + try: + entropy = dist.entropy() + except NotImplementedError: + x = dist.rsample((self.samples_mc_entropy,)) + entropy = -dist.log_prob(x) + return entropy.unsqueeze(-1) + + def _log_probs( + self, tensordict: TensorDictBase + ) -> Tuple[torch.Tensor, d.Distribution]: + # current log_prob of actions + action = tensordict.get("action") + if action.requires_grad: + raise RuntimeError("tensordict stored action requires grad.") + tensordict_clone = tensordict.select(*self.actor.in_keys).clone() + + dist, *_ = self.actor.get_dist( + tensordict_clone, params=self.actor_params, buffers=self.actor_buffers + ) + log_prob = dist.log_prob(action) + log_prob = log_prob.unsqueeze(-1) + return log_prob, dist + + def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: + if self.advantage_diff_key in tensordict.keys(): + advantage_diff = tensordict.get(self.advantage_diff_key) + if not advantage_diff.requires_grad: + raise RuntimeError( + "value_target retrieved from tensordict does not requires grad." + ) + loss_value = distance_loss( + advantage_diff, + torch.zeros_like(advantage_diff), + loss_function=self.loss_critic_type, + ) + else: + advantage = tensordict.get(self.advantage_key) + tensordict_select = tensordict.select(*self.critic.in_keys) + value = self.critic( + tensordict_select, + params=self.critic_params, + buffers=self.critic_buffers, + ).get("state_value") + value_target = advantage + value.detach() + loss_value = distance_loss( + value, value_target, loss_function=self.loss_critic_type + ) + return self.critic_coef * loss_value + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + tensordict = self.advantage_module(tensordict) + tensordict = tensordict.clone() + advantage = tensordict.get(self.advantage_key) + log_probs, dist = self._log_probs(tensordict) + loss = - (log_probs * advantage) + td_out = TensorDict({"loss_objective": loss.mean()}, []) + if self.entropy_bonus: + entropy = self.get_entropy_bonus(dist) + td_out.set("entropy", entropy.mean().detach()) # for logging + td_out.set("loss_entropy", -self.entropy_coef * entropy.mean()) + if self.critic_coef: + loss_critic = self.loss_critic(tensordict).mean() + td_out.set("loss_critic", loss_critic.mean()) + return td_out + diff --git a/torchrl/trainers/helpers/__init__.py b/torchrl/trainers/helpers/__init__.py index 3d680a2d515..e6a51e0e13d 100644 --- a/torchrl/trainers/helpers/__init__.py +++ b/torchrl/trainers/helpers/__init__.py @@ -21,6 +21,7 @@ make_dqn_loss, make_ddpg_loss, make_target_updater, + make_a2c_loss, make_ppo_loss, make_redq_loss, ) diff --git a/torchrl/trainers/helpers/losses.py b/torchrl/trainers/helpers/losses.py index 7af12ef45e0..2638bfbb932 100644 --- a/torchrl/trainers/helpers/losses.py +++ b/torchrl/trainers/helpers/losses.py @@ -172,6 +172,31 @@ def make_dqn_loss(model, cfg) -> Tuple[DQNLoss, Optional[TargetNetUpdater]]: return loss_module, target_net_updater +def make_a2c_loss(model, cfg) -> A2CLoss: + """Builds the A2C loss module.""" + actor_model = model.get_policy_operator() + critic_model = model.get_value_operator() + + advantage = TDEstimate( + gamma=cfg.gamma, + value_network=critic_model, + average_rewards=True, + gradient_mode=False, + ) + + kwargs = { + "actor": actor_model, + "critic": critic_model, + "loss_critic_type": cfg.critic_loss_function, + "entropy_coef": cfg.entropy_coef, + "advantage_module": advantage + } + + loss_module = A2CLoss(**kwargs) + + return loss_module + + def make_ppo_loss(model, cfg) -> PPOLoss: """Builds the PPO loss module.""" loss_dict = { @@ -228,6 +253,22 @@ class LossConfig: # Target entropy for the policy distribution. Default is None (auto calculated as the `target_entropy = -action_dim`) +@dataclass +class A2CLossConfig: + """A2C Loss config struct.""" + + gamma: float = 0.99 + # Decay factor for return computation. Default=0.99. + entropy_coef: float = 1e-3 + # Entropy factor for the A2C loss + critic_coef: float = 1.0 + # Critic factor for the A2C loss + critic_loss_function: str = "smooth_l1" + # loss function for the value network. Either one of l1, l2 or smooth_l1 (default). + advantage_in_loss: bool = False + # if True, the advantage is computed on the sub-batch. + + @dataclass class PPOLossConfig: """PPO Loss config struct.""" From 5eac813260c4b8781339cbf91ab07784c7402a0f Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 14 Nov 2022 17:58:13 +0100 Subject: [PATCH 02/33] a2c --- {a2c => examples/a2c}/a2c.py | 2 +- {a2c => examples/a2c}/config.yaml | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename {a2c => examples/a2c}/a2c.py (98%) rename {a2c => examples/a2c}/config.yaml (100%) diff --git a/a2c/a2c.py b/examples/a2c/a2c.py similarity index 98% rename from a2c/a2c.py rename to examples/a2c/a2c.py index 54aee3398c8..8facd3b8b25 100644 --- a/a2c/a2c.py +++ b/examples/a2c/a2c.py @@ -54,7 +54,7 @@ cs.store(name="config", node=Config) -@hydra.main(version_base=None, config_path=".", config_name="config") +@hydra.main(version_base=None, config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 cfg = correct_for_frame_skip(cfg) diff --git a/a2c/config.yaml b/examples/a2c/config.yaml similarity index 100% rename from a2c/config.yaml rename to examples/a2c/config.yaml From 880fed3f8075be48f1842e273445e6c464a513df Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 15 Nov 2022 09:41:00 +0100 Subject: [PATCH 03/33] a2c config --- examples/a2c/a2c.py | 24 +-- examples/a2c/config.yaml | 37 ++-- torchrl/objectives/a2c.py | 7 +- torchrl/trainers/helpers/losses.py | 3 +- torchrl/trainers/helpers/models.py | 308 +++++++++++++++++++++++++++++ 5 files changed, 336 insertions(+), 43 deletions(-) diff --git a/examples/a2c/a2c.py b/examples/a2c/a2c.py index 8facd3b8b25..c41ecdc7c3d 100644 --- a/examples/a2c/a2c.py +++ b/examples/a2c/a2c.py @@ -15,7 +15,7 @@ from torchrl.envs import ParallelEnv, EnvCreator from torchrl.envs.transforms import RewardScaling, TransformedEnv from torchrl.envs.utils import set_exploration_mode -from torchrl.objectives.value import GAE +from torchrl.objectives.value import TDEstimate from torchrl.record import VideoRecorder from torchrl.trainers.helpers.collectors import ( make_collector_onpolicy, @@ -157,23 +157,6 @@ def main(cfg: "DictConfig"): # noqa: F821 use_env_creator=False, )() - # remove video recorder from recorder to have matching state_dict keys - if cfg.record_video: - recorder_rm = TransformedEnv(recorder.base_env) - for transform in recorder.transform: - if not isinstance(transform, VideoRecorder): - recorder_rm.append_transform(transform) - else: - recorder_rm = recorder - - if isinstance(create_env_fn, ParallelEnv): - recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"]) - create_env_fn.close() - elif isinstance(create_env_fn, EnvCreator): - recorder_rm.load_state_dict(create_env_fn().state_dict()) - else: - recorder_rm.load_state_dict(create_env_fn.state_dict()) - # reset reward scaling for t in recorder.transform: if isinstance(t, RewardScaling): @@ -190,14 +173,11 @@ def main(cfg: "DictConfig"): # noqa: F821 logger, cfg, ) - if cfg.loss == "kl": - trainer.register_op("pre_optim_steps", loss_module.reset) if not cfg.advantage_in_loss: critic_model = model.get_value_operator() - advantage = GAE( + advantage = TDEstimate( cfg.gamma, - cfg.lmbda, value_network=critic_model, average_rewards=True, gradient_mode=False, diff --git a/examples/a2c/config.yaml b/examples/a2c/config.yaml index daea44ddb64..5632b4bc531 100644 --- a/examples/a2c/config.yaml +++ b/examples/a2c/config.yaml @@ -1,9 +1,9 @@ # Environment env_library: gym # env_library used for the simulated environment. -env_name: HalfCheetah-v4 name of the environment to be created. Default=Humanoid-v2 +env_name: HalfCheetah-v4 # name of the environment to be created. Default=Humanoid-v2 env_task: run # task (if any) for the environment. from_pixels: False # whether the environment output should be state vector(s) (default) or the pixels. -frame_skip: 1 # frame_skip for the environment. +frame_skip: 2 # frame_skip for the environment. reward_scaling: 1.0 # scale of the reward. reward_loc: 0.0 # location of the reward. init_env_steps: 1000 # number of random steps to compute normalizing constants @@ -23,20 +23,20 @@ categorical_action_encoding: False logger: wandb # recorder type to be used. One of 'tensorboard', 'wandb' or 'csv' record_video: False # whether a video of the task should be rendered during logging. no_video: True # whether a video of the task should be rendered during logging. -exp_name: "" # experiment name. Used for logging directory. +exp_name: A2C # experiment name. Used for logging directory. record_interval: 1000 # number of batch collections in between two collections of validation rollouts. Default=1000. record_frames: 1000 # number of steps in validation rollouts. " "Default=1000. recorder_log_keys: ["reward"] # Keys to log in the recorder offline_logging: True # If True, Wandb will do the logging offline # Collector -collector_devices: ["cpu"] # device on which the data collector should store the trajectories to be passed to this script. +collector_devices: [cpu] # device on which the data collector should store the trajectories to be passed to this script. pin_memory: False # if True, the data collector will call pin_memory before dispatching tensordicts onto the passing device init_with_lag: False # if True, the first trajectory will be truncated earlier at a random step. -frames_per_batch: 128 # Number of steps executed in the environment per collection. -total_frames: 1_000_000 # total number of frames collected for training. Does account for frame_skip. -num_workers: 4 # Number of workers used for data collection. -env_per_collector: 4 # Number of environments per collector. If the env_per_collector is in the range: +frames_per_batch: 256 # Number of steps executed in the environment per collection. +total_frames: 2_000_000 # total number of frames collected for training. Does account for frame_skip. +num_workers: 2 # Number of workers used for data collection. +env_per_collector: 2 # Number of environments per collector. If the env_per_collector is in the range: # 1 < env_per_collector <= num_workers, then the collector runs # ceil(num_workers/env_per_collector) in parallel and executes the policy steps synchronously # for each of these parallel wrappers. If env_per_collector=num_workers, no parallel wrapper is created @@ -45,7 +45,7 @@ exploration_mode: "" # exploration mode of the data collector. async_collection: False # Whether data collection should be done asynchrously. # Model -gSDE: True # if True, exploration is achieved using the gSDE technique. +gSDE: False # if True, exploration is achieved using the gSDE technique. tanh_loc: False # if True, uses a Tanh-Normal transform for the policy location of the form # upscale * tanh(loc/upscale) (only available with TanhTransform and TruncatedGaussian distributions) default_policy_scale: 1.0 # Default policy scale parameter @@ -54,7 +54,8 @@ lstm: False # if True, uses an LSTM for the policy. shared_mapping: False # if True, the first layers of the actor-critic are shared. # Objective -entropy_coef: 0.0 # Entropy factor for the A2C loss +gamma: 0.99 +entropy_coef: 1e-3 # Entropy factor for the A2C loss critic_coef: 0.4 # Critic factor for the A2C loss critic_loss_function: smooth_l1 # loss function for the value network. Either one of l1, l2 or smooth_l1 (default). advantage_in_loss: False # if True, the advantage is computed on the sub-batch @@ -62,15 +63,15 @@ advantage_in_loss: False # if True, the advantage is computed on the sub-batch # Trainer optim_steps_per_batch: 1 # Number of optimization steps in between two collection of data. optimizer: "adam" # Optimizer to be used. -lr_scheduler: "cosine" # LR scheduler. +lr_scheduler: "" # LR scheduler. selected_keys: null # a list of strings that indicate the data that should be kept from the data collector. -batch_size: 128 # batch size of the TensorDict retrieved from the replay buffer. Default=256. +batch_size: 256 # batch size of the TensorDict retrieved from the replay buffer. Default=256. log_interval: 1 # logging interval, in terms of optimization steps. Default=10000. -lr: 5e-4 # Learning rate used for the optimizer. Default=3e-4. +lr: 2e-4 # Learning rate used for the optimizer. Default=3e-4. weight_decay: 0.0 # Weight-decay to be used with the optimizer. Default=0.0. -clip_norm: 1000.0 # value at which the total gradient norm / single derivative should be clipped. Default=1000.0 -clip_grad_norm: False # if called, the gradient will be clipped based on its L2 norm. Otherwise, single gradient values will be clipped to the desired threshold. -normalize_rewards_online: False # Computes the running statistics of the rewards and normalizes them before they are passed to the loss module. +clip_norm: 0.5 # value at which the total gradient norm / single derivative should be clipped. Default=1000.0 +clip_grad_norm: True # if called, the gradient will be clipped based on its L2 norm. Otherwise, single gradient values will be clipped to the desired threshold. +normalize_rewards_online: True # Computes the running statistics of the rewards and normalizes them before they are passed to the loss module. normalize_rewards_online_scale: 1.0 # Final scale of the normalized rewards. -normalize_rewards_online_decay: 0.9999 # Decay of the reward moving averaging -sub_traj_len: -1 # length of the trajectories that sub-samples must have in online settings. \ No newline at end of file +normalize_rewards_online_decay: 0.0 # Decay of the reward moving averaging +sub_traj_len: 64 # length of the trajectories that sub-samples must have in online settings. \ No newline at end of file diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index fdd9d836db1..699b3786a92 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -53,7 +53,7 @@ def __init__( critic_coef: float = 1.0, gamma: float = 0.99, loss_critic_type: str = "smooth_l1", - advantage_module: Callable[[TensorDictBase], TensorDictBase] = None, + advantage_module: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, ): super().__init__() self.convert_to_functional(actor, "actor") @@ -126,7 +126,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: return self.critic_coef * loss_value def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - tensordict = self.advantage_module(tensordict) + if self.advantage_module is not None: + tensordict = self.advantage_module( + tensordict, + ) tensordict = tensordict.clone() advantage = tensordict.get(self.advantage_key) log_probs, dist = self._log_probs(tensordict) diff --git a/torchrl/trainers/helpers/losses.py b/torchrl/trainers/helpers/losses.py index 2638bfbb932..39181b4b158 100644 --- a/torchrl/trainers/helpers/losses.py +++ b/torchrl/trainers/helpers/losses.py @@ -14,6 +14,7 @@ DQNLoss, HardUpdate, KLPENPPOLoss, + A2CLoss, PPOLoss, SACLoss, SoftUpdate, @@ -24,7 +25,7 @@ # from torchrl.objectives.redq import REDQLoss from torchrl.objectives.utils import TargetNetUpdater -from torchrl.objectives.value.advantages import GAE +from torchrl.objectives.value.advantages import GAE, TDEstimate def make_target_updater( diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 3b9d7cd79bf..e8f829d15b5 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -414,6 +414,295 @@ def make_ddpg_actor( return module +def make_a2c_model( + proof_environment: EnvBase, + cfg: "DictConfig", # noqa: F821 + device: DEVICE_TYPING, + in_keys_actor: Optional[Sequence[str]] = None, + observation_key=None, + **kwargs, +) -> ActorValueOperator: + """Actor-value model constructor helper function. + + Currently constructs MLP networks with immutable default arguments as described in "Proximal Policy Optimization + Algorithms", https://arxiv.org/abs/1707.06347 + Other configurations can easily be implemented by modifying this function at will. + + Args: + proof_environment (EnvBase): a dummy environment to retrieve the observation and action spec + cfg (DictConfig): contains arguments of the PPO script + device (torch.device): device on which the model must be cast. + in_keys_actor (iterable of strings, optional): observation key to be read by the actor, usually one of + `'observation_vector'` or `'pixels'`. If none is provided, one of these two keys is chosen based on + the `cfg.from_pixels` argument. + + Returns: + A joined ActorCriticOperator. + + Examples: + >>> from torchrl.trainers.helpers.envs import parser_env_args + >>> from torchrl.trainers.helpers.models import make_ppo_model, parser_model_args_continuous + >>> from torchrl.envs.libs.gym import GymEnv + >>> from torchrl.envs.transforms import CatTensors, TransformedEnv, DoubleToFloat, Compose + >>> import hydra + >>> from hydra.core.config_store import ConfigStore + >>> import dataclasses + >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["next_observation"]), + ... CatTensors(["next_observation"], "next_observation_vector"))) + >>> device = torch.device("cpu") + >>> config_fields = [(config_field.name, config_field.type, config_field) for config_cls in + ... (PPOModelConfig, EnvConfig) + ... for config_field in dataclasses.fields(config_cls)] + >>> Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) + >>> cs = ConfigStore.instance() + >>> cs.store(name="config", node=Config) + >>> with initialize(config_path=None): + >>> cfg = compose(config_name="config") + >>> actor_value = make_ppo_model( + ... proof_environment, + ... device=device, + ... cfg=cfg, + ... ) + >>> actor = actor_value.get_policy_operator() + >>> value = actor_value.get_value_operator() + >>> td = proof_environment.reset() + >>> print(actor(td.clone())) + TensorDict( + fields={ + done: Tensor(torch.Size([1]), dtype=torch.bool), + observation_vector: Tensor(torch.Size([17]), dtype=torch.float32), + hidden: Tensor(torch.Size([300]), dtype=torch.float32), + loc: Tensor(torch.Size([6]), dtype=torch.float32), + scale: Tensor(torch.Size([6]), dtype=torch.float32), + action: Tensor(torch.Size([6]), dtype=torch.float32), + sample_log_prob: Tensor(torch.Size([1]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + >>> print(value(td.clone())) + TensorDict( + fields={ + done: Tensor(torch.Size([1]), dtype=torch.bool), + observation_vector: Tensor(torch.Size([17]), dtype=torch.float32), + hidden: Tensor(torch.Size([300]), dtype=torch.float32), + state_value: Tensor(torch.Size([1]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + + """ + # proof_environment.set_seed(cfg.seed) + specs = proof_environment.specs # TODO: use env.sepcs + action_spec = specs["action_spec"] + + if in_keys_actor is None and proof_environment.from_pixels: + in_keys_actor = ["pixels"] + in_keys_critic = ["pixels"] + elif in_keys_actor is None: + in_keys_actor = ["observation_vector"] + in_keys_critic = ["observation_vector"] + out_keys = ["action"] + + if action_spec.domain == "continuous": + out_features = (2 - cfg.gSDE) * action_spec.shape[-1] + if cfg.distribution == "tanh_normal": + policy_distribution_kwargs = { + "min": action_spec.space.minimum, + "max": action_spec.space.maximum, + "tanh_loc": cfg.tanh_loc, + } + policy_distribution_class = TanhNormal + elif cfg.distribution == "truncated_normal": + policy_distribution_kwargs = { + "min": action_spec.space.minimum, + "max": action_spec.space.maximum, + "tanh_loc": cfg.tanh_loc, + } + policy_distribution_class = TruncatedNormal + elif action_spec.domain == "discrete": + out_features = action_spec.shape[-1] + policy_distribution_kwargs = {} + policy_distribution_class = OneHotCategorical + else: + raise NotImplementedError( + f"actions with domain {action_spec.domain} are not supported" + ) + + if cfg.shared_mapping: + hidden_features = 300 + if proof_environment.from_pixels: + if in_keys_actor is None: + in_keys_actor = ["pixels"] + common_module = ConvNet( + bias_last_layer=True, + depth=None, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + else: + if cfg.lstm: + raise NotImplementedError( + "lstm not yet compatible with shared mapping for PPO" + ) + common_module = MLP( + num_cells=[ + 400, + ], + out_features=hidden_features, + activate_last_layer=True, + ) + common_operator = TensorDictModule( + spec=None, + module=common_module, + in_keys=in_keys_actor, + out_keys=["hidden"], + ) + + policy_net = MLP( + num_cells=[200], + out_features=out_features, + ) + if not cfg.gSDE: + actor_net = NormalParamWrapper( + policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" + ) + in_keys = ["hidden"] + actor_module = TensorDictModule( + actor_net, in_keys=in_keys, out_keys=["loc", "scale"] + ) + else: + in_keys = ["hidden"] + gSDE_state_key = "hidden" + actor_module = TensorDictModule( + policy_net, + in_keys=in_keys, + out_keys=["action"], # will be overwritten + ) + + if action_spec.domain == "continuous": + min = action_spec.space.minimum + max = action_spec.space.maximum + transform = SafeTanhTransform() + if (min != -1).any() or (max != 1).any(): + transform = d.ComposeTransform( + transform, + d.AffineTransform(loc=(max + min) / 2, scale=(max - min) / 2), + ) + else: + raise RuntimeError("cannot use gSDE with discrete actions") + + actor_module = TensorDictSequential( + actor_module, + TensorDictModule( + LazygSDEModule(transform=transform), + in_keys=["action", gSDE_state_key, "_eps_gSD"], + out_keys=["loc", "scale", "action", "_eps_gSDE"], + ), + ) + + policy_operator = ProbabilisticActor( + spec=CompositeSpec(action=action_spec), + module=actor_module, + dist_in_keys=["loc", "scale"], + default_interaction_mode="random", + distribution_class=policy_distribution_class, + distribution_kwargs=policy_distribution_kwargs, + return_log_prob=True, + ) + value_net = MLP( + num_cells=[200], + out_features=1, + ) + value_operator = ValueOperator(value_net, in_keys=["hidden"]) + actor_value = ActorValueOperator( + common_operator=common_operator, + policy_operator=policy_operator, + value_operator=value_operator, + ).to(device) + else: + if cfg.from_pixels: + raise RuntimeError( + "PPO learnt from pixels require the shared_mapping to be set to True." + ) + if cfg.lstm: + policy_net = LSTMNet( + out_features=out_features, + lstm_kwargs={"input_size": 256, "hidden_size": 256}, + mlp_kwargs={"num_cells": [256, 256], "out_features": 256}, + ) + in_keys_actor += ["hidden0", "hidden1"] + out_keys += ["hidden0", "hidden1", "next_hidden0", "next_hidden1"] + else: + policy_net = MLP( + num_cells=[64, 64], + out_features=out_features, + ) + + if not cfg.gSDE: + actor_net = NormalParamWrapper( + policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" + ) + actor_module = TensorDictModule( + actor_net, in_keys=in_keys_actor, out_keys=["loc", "scale"] + ) + else: + in_keys = in_keys_actor + gSDE_state_key = in_keys_actor[0] + actor_module = TensorDictModule( + policy_net, + in_keys=in_keys, + out_keys=["action"], # will be overwritten + ) + + if action_spec.domain == "continuous": + min = action_spec.space.minimum + max = action_spec.space.maximum + transform = SafeTanhTransform() + if (min != -1).any() or (max != 1).any(): + transform = d.ComposeTransform( + transform, + d.AffineTransform(loc=(max + min) / 2, scale=(max - min) / 2), + ) + else: + raise RuntimeError("cannot use gSDE with discrete actions") + + actor_module = TensorDictSequential( + actor_module, + TensorDictModule( + LazygSDEModule(transform=transform), + in_keys=["action", gSDE_state_key, "_eps_gSDE"], + out_keys=["loc", "scale", "action", "_eps_gSDE"], + ), + ) + + policy_po = ProbabilisticActor( + actor_module, + spec=action_spec, + dist_in_keys=["loc", "scale"], + distribution_class=policy_distribution_class, + distribution_kwargs=policy_distribution_kwargs, + return_log_prob=True, + default_interaction_mode="random", + ) + + value_net = MLP( + num_cells=[64, 64], + out_features=1, + ) + value_po = ValueOperator( + value_net, + in_keys=in_keys_critic, + ) + actor_value = ActorCriticWrapper(policy_po, value_po).to(device) + + with torch.no_grad(), set_exploration_mode("random"): + td = proof_environment.rollout(max_steps=1000) + td_device = td.to(device) + td_device = actor_value(td_device) # for init + return actor_value + + def make_ppo_model( proof_environment: EnvBase, cfg: "DictConfig", # noqa: F821 @@ -1539,6 +1828,25 @@ class PPOModelConfig: # if True, the first layers of the actor-critic are shared. +@dataclass +class PPOModelConfig: + """PPO model config struct.""" + + gSDE: bool = False + # if True, exploration is achieved using the gSDE technique. + tanh_loc: bool = False + # if True, uses a Tanh-Normal transform for the policy location of the form + # upscale * tanh(loc/upscale) (only available with TanhTransform and TruncatedGaussian distributions) + default_policy_scale: float = 1.0 + # Default policy scale parameter + distribution: str = "tanh_normal" + # if True, uses a Tanh-Normal-Tanh distribution for the policy + lstm: bool = False + # if True, uses an LSTM for the policy. + shared_mapping: bool = False + # if True, the first layers of the actor-critic are shared. + + @dataclass class SACModelConfig: """SAC model config struct.""" From 4c2436c982c0ea93c2b9674ec099aba02a94da37 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 15 Nov 2022 09:42:18 +0100 Subject: [PATCH 04/33] a2c config --- examples/a2c/a2c.py | 4 ++-- torchrl/trainers/helpers/models.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/a2c/a2c.py b/examples/a2c/a2c.py index c41ecdc7c3d..cf8d692cc91 100644 --- a/examples/a2c/a2c.py +++ b/examples/a2c/a2c.py @@ -31,8 +31,8 @@ from torchrl.trainers.helpers.logger import LoggerConfig from torchrl.trainers.helpers.losses import make_a2c_loss, A2CLossConfig from torchrl.trainers.helpers.models import ( - make_ppo_model as make_a2c_model, - PPOModelConfig as A2CModelConfig, + make_a2c_model, + A2CModelConfig, ) from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index e8f829d15b5..c62fcad3f40 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -547,7 +547,7 @@ def make_a2c_model( ) common_module = MLP( num_cells=[ - 400, + 64, ], out_features=hidden_features, activate_last_layer=True, @@ -560,7 +560,7 @@ def make_a2c_model( ) policy_net = MLP( - num_cells=[200], + num_cells=[64], out_features=out_features, ) if not cfg.gSDE: @@ -611,7 +611,7 @@ def make_a2c_model( return_log_prob=True, ) value_net = MLP( - num_cells=[200], + num_cells=[64], out_features=1, ) value_operator = ValueOperator(value_net, in_keys=["hidden"]) @@ -628,8 +628,8 @@ def make_a2c_model( if cfg.lstm: policy_net = LSTMNet( out_features=out_features, - lstm_kwargs={"input_size": 256, "hidden_size": 256}, - mlp_kwargs={"num_cells": [256, 256], "out_features": 256}, + lstm_kwargs={"input_size": 64, "hidden_size": 64}, + mlp_kwargs={"num_cells": [64, 64], "out_features": 64}, ) in_keys_actor += ["hidden0", "hidden1"] out_keys += ["hidden0", "hidden1", "next_hidden0", "next_hidden1"] From 2b7743217fab1ab96f1f002d932ad577fafd432d Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 15 Nov 2022 09:44:44 +0100 Subject: [PATCH 05/33] fix imports --- torchrl/trainers/helpers/__init__.py | 1 + torchrl/trainers/helpers/models.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/torchrl/trainers/helpers/__init__.py b/torchrl/trainers/helpers/__init__.py index e6a51e0e13d..13a6a554694 100644 --- a/torchrl/trainers/helpers/__init__.py +++ b/torchrl/trainers/helpers/__init__.py @@ -28,6 +28,7 @@ from .models import ( make_dqn_actor, make_ddpg_actor, + make_a2c_model, make_ppo_model, make_sac_model, make_redq_model, diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index c62fcad3f40..962bb0e3047 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -1828,6 +1828,25 @@ class PPOModelConfig: # if True, the first layers of the actor-critic are shared. +@dataclass +class A2CModelConfig: + """PPO model config struct.""" + + gSDE: bool = False + # if True, exploration is achieved using the gSDE technique. + tanh_loc: bool = False + # if True, uses a Tanh-Normal transform for the policy location of the form + # upscale * tanh(loc/upscale) (only available with TanhTransform and TruncatedGaussian distributions) + default_policy_scale: float = 1.0 + # Default policy scale parameter + distribution: str = "tanh_normal" + # if True, uses a Tanh-Normal-Tanh distribution for the policy + lstm: bool = False + # if True, uses an LSTM for the policy. + shared_mapping: bool = False + # if True, the first layers of the actor-critic are shared. + + @dataclass class PPOModelConfig: """PPO model config struct.""" From ea4c3a27b20df3afbdd40d204fd5fcf027162fb4 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 15 Nov 2022 11:54:44 +0100 Subject: [PATCH 06/33] latest --- examples/a2c/a2c.py | 1 + examples/a2c/config.yaml | 14 +++++++------- torchrl/objectives/a2c.py | 4 +++- torchrl/trainers/helpers/losses.py | 15 +++++++++------ 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/examples/a2c/a2c.py b/examples/a2c/a2c.py index cf8d692cc91..be7973099b8 100644 --- a/examples/a2c/a2c.py +++ b/examples/a2c/a2c.py @@ -182,6 +182,7 @@ def main(cfg: "DictConfig"): # noqa: F821 average_rewards=True, gradient_mode=False, ) + advantage = advantage.to(device) trainer.register_op( "process_optim_batch", advantage, diff --git a/examples/a2c/config.yaml b/examples/a2c/config.yaml index 5632b4bc531..eb039d77d36 100644 --- a/examples/a2c/config.yaml +++ b/examples/a2c/config.yaml @@ -7,9 +7,9 @@ frame_skip: 2 # frame_skip for the environment. reward_scaling: 1.0 # scale of the reward. reward_loc: 0.0 # location of the reward. init_env_steps: 1000 # number of random steps to compute normalizing constants -vecnorm: True # Normalizes the environment observation and reward outputs with the running statistics obtained across processes. +vecnorm: False # Normalizes the environment observation and reward outputs with the running statistics obtained across processes. norm_rewards: False # If True, rewards will be normalized on the fly. -norm_stats: True # Deactivates the normalization based on random collection of data. +norm_stats: False # Deactivates the normalization based on random collection of data. noops: 0 # number of random steps to do after reset. Default is 0 catframes: 0 # Number of frames to concatenate through time. Default is 0 (do not use CatFrames). center_crop: False # center crop size. @@ -24,7 +24,7 @@ logger: wandb # recorder type to be used. One of 'tensorboard', 'wandb' or 'csv' record_video: False # whether a video of the task should be rendered during logging. no_video: True # whether a video of the task should be rendered during logging. exp_name: A2C # experiment name. Used for logging directory. -record_interval: 1000 # number of batch collections in between two collections of validation rollouts. Default=1000. +record_interval: 100 # number of batch collections in between two collections of validation rollouts. Default=1000. record_frames: 1000 # number of steps in validation rollouts. " "Default=1000. recorder_log_keys: ["reward"] # Keys to log in the recorder offline_logging: True # If True, Wandb will do the logging offline @@ -55,13 +55,13 @@ shared_mapping: False # if True, the first layers of the actor-critic are shared # Objective gamma: 0.99 -entropy_coef: 1e-3 # Entropy factor for the A2C loss +entropy_coef: 0.1 # Entropy factor for the A2C loss critic_coef: 0.4 # Critic factor for the A2C loss critic_loss_function: smooth_l1 # loss function for the value network. Either one of l1, l2 or smooth_l1 (default). advantage_in_loss: False # if True, the advantage is computed on the sub-batch # Trainer -optim_steps_per_batch: 1 # Number of optimization steps in between two collection of data. +optim_steps_per_batch: 12 # Number of optimization steps in between two collection of data. optimizer: "adam" # Optimizer to be used. lr_scheduler: "" # LR scheduler. selected_keys: null # a list of strings that indicate the data that should be kept from the data collector. @@ -69,8 +69,8 @@ batch_size: 256 # batch size of the TensorDict retrieved from the replay buffer. log_interval: 1 # logging interval, in terms of optimization steps. Default=10000. lr: 2e-4 # Learning rate used for the optimizer. Default=3e-4. weight_decay: 0.0 # Weight-decay to be used with the optimizer. Default=0.0. -clip_norm: 0.5 # value at which the total gradient norm / single derivative should be clipped. Default=1000.0 -clip_grad_norm: True # if called, the gradient will be clipped based on its L2 norm. Otherwise, single gradient values will be clipped to the desired threshold. +clip_norm: 1000 # value at which the total gradient norm / single derivative should be clipped. Default=1000.0 +clip_grad_norm: False # if called, the gradient will be clipped based on its L2 norm. Otherwise, single gradient values will be clipped to the desired threshold. normalize_rewards_online: True # Computes the running statistics of the rewards and normalizes them before they are passed to the loss module. normalize_rewards_online_scale: 1.0 # Final scale of the normalized rewards. normalize_rewards_online_decay: 0.0 # Decay of the reward moving averaging diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 699b3786a92..58b09efcf65 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -70,7 +70,9 @@ def __init__( ) self.register_buffer("gamma", torch.tensor(gamma, device=self.device)) self.loss_critic_type = loss_critic_type - self.advantage_module = advantage_module.to(self.device) + self.advantage_module = advantage_module + if advantage_module: + self.advantage_module = advantage_module.to(self.device) def reset(self) -> None: pass diff --git a/torchrl/trainers/helpers/losses.py b/torchrl/trainers/helpers/losses.py index 39181b4b158..cc63bd69746 100644 --- a/torchrl/trainers/helpers/losses.py +++ b/torchrl/trainers/helpers/losses.py @@ -178,12 +178,15 @@ def make_a2c_loss(model, cfg) -> A2CLoss: actor_model = model.get_policy_operator() critic_model = model.get_value_operator() - advantage = TDEstimate( - gamma=cfg.gamma, - value_network=critic_model, - average_rewards=True, - gradient_mode=False, - ) + if cfg.advantage_in_loss: + advantage = TDEstimate( + gamma=cfg.gamma, + value_network=critic_model, + average_rewards=True, + gradient_mode=False, + ) + else: + advantage = None kwargs = { "actor": actor_model, From 5f4d2908e83f5fe11e500491e695a6a494ec844c Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 15 Nov 2022 15:31:46 +0100 Subject: [PATCH 07/33] simplified config --- examples/a2c/a2c.py | 4 --- examples/a2c/config.yaml | 53 +++++++--------------------------------- 2 files changed, 9 insertions(+), 48 deletions(-) diff --git a/examples/a2c/a2c.py b/examples/a2c/a2c.py index be7973099b8..2adc91fdd3b 100644 --- a/examples/a2c/a2c.py +++ b/examples/a2c/a2c.py @@ -187,10 +187,6 @@ def main(cfg: "DictConfig"): # noqa: F821 "process_optim_batch", advantage, ) - trainer._process_optim_batch_ops = [ - trainer._process_optim_batch_ops[-1], - *trainer._process_optim_batch_ops[:-1], - ] final_seed = collector.set_seed(cfg.seed) print(f"init seed: {cfg.seed}, final seed: {final_seed}") diff --git a/examples/a2c/config.yaml b/examples/a2c/config.yaml index eb039d77d36..08eeaf23c04 100644 --- a/examples/a2c/config.yaml +++ b/examples/a2c/config.yaml @@ -2,52 +2,21 @@ env_library: gym # env_library used for the simulated environment. env_name: HalfCheetah-v4 # name of the environment to be created. Default=Humanoid-v2 env_task: run # task (if any) for the environment. -from_pixels: False # whether the environment output should be state vector(s) (default) or the pixels. frame_skip: 2 # frame_skip for the environment. -reward_scaling: 1.0 # scale of the reward. -reward_loc: 0.0 # location of the reward. -init_env_steps: 1000 # number of random steps to compute normalizing constants -vecnorm: False # Normalizes the environment observation and reward outputs with the running statistics obtained across processes. -norm_rewards: False # If True, rewards will be normalized on the fly. -norm_stats: False # Deactivates the normalization based on random collection of data. -noops: 0 # number of random steps to do after reset. Default is 0 -catframes: 0 # Number of frames to concatenate through time. Default is 0 (do not use CatFrames). -center_crop: False # center crop size. -grayscale: True # Disables grayscale transform. -max_frames_per_traj: 1000 # Number of steps before a reset of the environment is called (if it has not been flagged as done before). -batch_transform: False # if True, the transforms will be applied to the parallel env, and not to each individual env.\ -image_size: 84 # if True and environment has discrete action space, then it is encoded as categorical values rather than one-hot. -categorical_action_encoding: False # Logger logger: wandb # recorder type to be used. One of 'tensorboard', 'wandb' or 'csv' record_video: False # whether a video of the task should be rendered during logging. -no_video: True # whether a video of the task should be rendered during logging. exp_name: A2C # experiment name. Used for logging directory. record_interval: 100 # number of batch collections in between two collections of validation rollouts. Default=1000. -record_frames: 1000 # number of steps in validation rollouts. " "Default=1000. -recorder_log_keys: ["reward"] # Keys to log in the recorder -offline_logging: True # If True, Wandb will do the logging offline # Collector -collector_devices: [cpu] # device on which the data collector should store the trajectories to be passed to this script. -pin_memory: False # if True, the data collector will call pin_memory before dispatching tensordicts onto the passing device -init_with_lag: False # if True, the first trajectory will be truncated earlier at a random step. -frames_per_batch: 256 # Number of steps executed in the environment per collection. +frames_per_batch: 64 # Number of steps executed in the environment per collection. total_frames: 2_000_000 # total number of frames collected for training. Does account for frame_skip. num_workers: 2 # Number of workers used for data collection. env_per_collector: 2 # Number of environments per collector. If the env_per_collector is in the range: -# 1 < env_per_collector <= num_workers, then the collector runs -# ceil(num_workers/env_per_collector) in parallel and executes the policy steps synchronously -# for each of these parallel wrappers. If env_per_collector=num_workers, no parallel wrapper is created -seed: 42 # seed used for the environment, pytorch and numpy. -exploration_mode: "" # exploration mode of the data collector. -async_collection: False # Whether data collection should be done asynchrously. # Model -gSDE: False # if True, exploration is achieved using the gSDE technique. -tanh_loc: False # if True, uses a Tanh-Normal transform for the policy location of the form -# upscale * tanh(loc/upscale) (only available with TanhTransform and TruncatedGaussian distributions) default_policy_scale: 1.0 # Default policy scale parameter distribution: tanh_normal # if True, uses a Tanh-Normal-Tanh distribution for the policy lstm: False # if True, uses an LSTM for the policy. @@ -55,23 +24,19 @@ shared_mapping: False # if True, the first layers of the actor-critic are shared # Objective gamma: 0.99 -entropy_coef: 0.1 # Entropy factor for the A2C loss -critic_coef: 0.4 # Critic factor for the A2C loss -critic_loss_function: smooth_l1 # loss function for the value network. Either one of l1, l2 or smooth_l1 (default). +entropy_coef: 0.01 # Entropy factor for the A2C loss +critic_coef: 0.25 # Critic factor for the A2C loss +critic_loss_function: l2 # loss function for the value network. Either one of l1, l2 or smooth_l1 (default). advantage_in_loss: False # if True, the advantage is computed on the sub-batch # Trainer -optim_steps_per_batch: 12 # Number of optimization steps in between two collection of data. -optimizer: "adam" # Optimizer to be used. +optim_steps_per_batch: 1 # Number of optimization steps in between two collection of data. +optimizer: adam # Optimizer to be used. lr_scheduler: "" # LR scheduler. -selected_keys: null # a list of strings that indicate the data that should be kept from the data collector. -batch_size: 256 # batch size of the TensorDict retrieved from the replay buffer. Default=256. +batch_size: 64 # batch size of the TensorDict retrieved from the replay buffer. Default=256. log_interval: 1 # logging interval, in terms of optimization steps. Default=10000. -lr: 2e-4 # Learning rate used for the optimizer. Default=3e-4. -weight_decay: 0.0 # Weight-decay to be used with the optimizer. Default=0.0. -clip_norm: 1000 # value at which the total gradient norm / single derivative should be clipped. Default=1000.0 -clip_grad_norm: False # if called, the gradient will be clipped based on its L2 norm. Otherwise, single gradient values will be clipped to the desired threshold. +lr: 0.0007 # Learning rate used for the optimizer. Default=3e-4. normalize_rewards_online: True # Computes the running statistics of the rewards and normalizes them before they are passed to the loss module. normalize_rewards_online_scale: 1.0 # Final scale of the normalized rewards. normalize_rewards_online_decay: 0.0 # Decay of the reward moving averaging -sub_traj_len: 64 # length of the trajectories that sub-samples must have in online settings. \ No newline at end of file +sub_traj_len: 64 # length of the trajectories that sub-samples must have in online settings. From e08dd5d5a90fe58504be8d03078619ae74a31569 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 15 Nov 2022 16:15:36 +0100 Subject: [PATCH 08/33] simplified config --- examples/a2c/a2c.py | 20 ++++++++------------ examples/a2c/config.yaml | 3 +-- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/examples/a2c/a2c.py b/examples/a2c/a2c.py index 2adc91fdd3b..0a12faa114e 100644 --- a/examples/a2c/a2c.py +++ b/examples/a2c/a2c.py @@ -142,10 +142,6 @@ def main(cfg: "DictConfig"): # noqa: F821 make_env=create_env_fn, actor_model_explore=actor_model, cfg=cfg, - # make_env_kwargs=[ - # {"device": device} if device >= 0 else {} - # for device in cfg.env_rendering_devices - # ], ) recorder = transformed_env_constructor( @@ -164,14 +160,14 @@ def main(cfg: "DictConfig"): # noqa: F821 t.loc.fill_(0.0) trainer = make_trainer( - collector, - loss_module, - recorder, - None, - actor_model, - None, - logger, - cfg, + collector=collector, + loss_module=loss_module, + recorder=recorder, + target_net_updater=None, + policy_exploration=actor_model, + replay_buffer=None, + logger=logger, + cfg=cfg, ) if not cfg.advantage_in_loss: diff --git a/examples/a2c/config.yaml b/examples/a2c/config.yaml index 08eeaf23c04..af6c6471b6e 100644 --- a/examples/a2c/config.yaml +++ b/examples/a2c/config.yaml @@ -1,7 +1,6 @@ # Environment env_library: gym # env_library used for the simulated environment. env_name: HalfCheetah-v4 # name of the environment to be created. Default=Humanoid-v2 -env_task: run # task (if any) for the environment. frame_skip: 2 # frame_skip for the environment. # Logger @@ -11,7 +10,7 @@ exp_name: A2C # experiment name. Used for logging directory. record_interval: 100 # number of batch collections in between two collections of validation rollouts. Default=1000. # Collector -frames_per_batch: 64 # Number of steps executed in the environment per collection. +frames_per_batch: 32 # Number of steps executed in the environment per collection. total_frames: 2_000_000 # total number of frames collected for training. Does account for frame_skip. num_workers: 2 # Number of workers used for data collection. env_per_collector: 2 # Number of environments per collector. If the env_per_collector is in the range: From 82db357bf17c61efc59567470748df03bc783c6e Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Tue, 15 Nov 2022 17:32:33 +0100 Subject: [PATCH 09/33] Update config.yaml --- examples/a2c/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/a2c/config.yaml b/examples/a2c/config.yaml index af6c6471b6e..b647dea1632 100644 --- a/examples/a2c/config.yaml +++ b/examples/a2c/config.yaml @@ -10,7 +10,7 @@ exp_name: A2C # experiment name. Used for logging directory. record_interval: 100 # number of batch collections in between two collections of validation rollouts. Default=1000. # Collector -frames_per_batch: 32 # Number of steps executed in the environment per collection. +frames_per_batch: 64 # Number of steps executed in the environment per collection. total_frames: 2_000_000 # total number of frames collected for training. Does account for frame_skip. num_workers: 2 # Number of workers used for data collection. env_per_collector: 2 # Number of environments per collector. If the env_per_collector is in the range: From afca8f4f4b4b05d37420d0d201308b2f2b9c75f8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 15 Nov 2022 13:35:40 +0000 Subject: [PATCH 10/33] [BugFix] Use GitHub for flake8 pre-commit hook (#679) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1afafcca1cc..5f385f466b2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,7 @@ repos: - black == 21.9b0 - usort == 0.6.4 - - repo: https://gitlab.com/pycqa/flake8 + - repo: https://github.com/pycqa/flake8 rev: 3.9.2 hooks: - id: flake8 From ea83339bc73a5316b7754c2cb5ec10d67aac1dcc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 15 Nov 2022 15:33:23 +0000 Subject: [PATCH 11/33] [BugFix] Update to strict select (#675) * init * strict=False * amend * amend --- test/test_rb.py | 44 +++++++++++++++------------ torchrl/collectors/collectors.py | 3 +- torchrl/envs/vec_env.py | 39 ++++++++++++++++++------ torchrl/modules/models/model_based.py | 2 +- 4 files changed, 57 insertions(+), 31 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index 7eaab102369..36915407450 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -125,16 +125,18 @@ def test_extend(self, rb_type, sampler, writer, storage, size): found_similar = False for b in rb._storage: if isinstance(b, TensorDictBase): - b = b.exclude("index").select(*set(d.keys()).intersection(b.keys())) - d = d.select(*set(d.keys()).intersection(b.keys())) + keys = set(d.keys()).intersection(b.keys()) + b = b.exclude("index").select(*keys, strict=False) + keys = set(d.keys()).intersection(b.keys()) + d = d.select(*keys, strict=False) value = b == d if isinstance(value, (torch.Tensor, TensorDictBase)): value = value.all() if value: - found_similar = True break - assert found_similar + else: + raise RuntimeError("did not find match") def test_sample(self, rb_type, sampler, writer, storage, size): torch.manual_seed(0) @@ -152,18 +154,18 @@ def test_sample(self, rb_type, sampler, writer, storage, size): for b in data: print(b, d) if isinstance(b, TensorDictBase): - b = b.exclude("index").select(*set(d.keys()).intersection(b.keys())) - d = d.select(*set(d.keys()).intersection(b.keys())) + keys = set(d.keys()).intersection(b.keys()) + b = b.exclude("index").select(*keys, strict=False) + keys = set(d.keys()).intersection(b.keys()) + d = d.select(*keys, strict=False) value = b == d if isinstance(value, (torch.Tensor, TensorDictBase)): value = value.all() if value: - found_similar = True break - if not found_similar: - d - assert found_similar, (d, data) + else: + raise RuntimeError("did not find match") def test_index(self, rb_type, sampler, writer, storage, size): torch.manual_seed(0) @@ -394,16 +396,18 @@ def test_extend(self, rbtype, storage, size, prefetch): found_similar = False for b in rb._storage: if isinstance(b, TensorDictBase): - b = b.exclude("index").select(*set(d.keys()).intersection(b.keys())) - d = d.select(*set(d.keys()).intersection(b.keys())) + keys = set(d.keys()).intersection(b.keys()) + b = b.exclude("index").select(*keys, strict=False) + keys = set(d.keys()).intersection(b.keys()) + d = d.select(*keys, strict=False) value = b == d if isinstance(value, (torch.Tensor, TensorDictBase)): value = value.all() if value: - found_similar = True break - assert found_similar + else: + raise RuntimeError("did not find match") def test_sample(self, rbtype, storage, size, prefetch): torch.manual_seed(0) @@ -418,18 +422,18 @@ def test_sample(self, rbtype, storage, size, prefetch): found_similar = False for b in data: if isinstance(b, TensorDictBase): - b = b.exclude("index").select(*set(d.keys()).intersection(b.keys())) - d = d.select(*set(d.keys()).intersection(b.keys())) + keys = set(d.keys()).intersection(b.keys()) + b = b.exclude("index").select(*keys, strict=False) + keys = set(d.keys()).intersection(b.keys()) + d = d.select(*keys, strict=False) value = b == d if isinstance(value, (torch.Tensor, TensorDictBase)): value = value.all() if value: - found_similar = True break - if not found_similar: - d - assert found_similar, (d, data) + else: + raise RuntimeError("did not find matching value") def test_index(self, rbtype, storage, size, prefetch): torch.manual_seed(0) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 0b20b6d3058..325670fbe67 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -536,7 +536,8 @@ def iterator(self) -> Iterator[TensorDictBase]: def _cast_to_policy(self, td: TensorDictBase) -> TensorDictBase: policy_device = self.device if hasattr(self.policy, "in_keys"): - td = td.select(*self.policy.in_keys) + # some keys may be absent -- TensorDictModule is resilient to missing keys + td = td.select(*self.policy.in_keys, strict=False) if self._td_policy is None: self._td_policy = td.to(policy_device) else: diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 2a9b5902f1d..dfec0d28239 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -421,13 +421,14 @@ def _create_td(self) -> None: ) if self._single_task: shared_tensordict_parent = shared_tensordict_parent.select( - *self.selected_keys + *self.selected_keys, + strict=False, ) self.shared_tensordict_parent = shared_tensordict_parent.to(self.device) else: shared_tensordict_parent = torch.stack( [ - tensordict.select(*selected_keys).to(self.device) + tensordict.select(*selected_keys, strict=False).to(self.device) for tensordict, selected_keys in zip( shared_tensordict_parent, self.selected_keys ) @@ -573,7 +574,10 @@ def _step( ) -> TensorDict: self._assert_tensordict_shape(tensordict) - tensordict_in = tensordict.select(*self.env_input_keys) + tensordict_in = tensordict.select( + *self.env_input_keys, + strict=False, + ) tensordict_out = [] for i in range(self.num_workers): _tensordict_out = self._envs[i].step(tensordict_in[i]) @@ -611,7 +615,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: keys = keys.union(_td.keys()) self.shared_tensordicts[i].update_(_td) - return self.shared_tensordict_parent.select(*keys).clone() + return self.shared_tensordict_parent.select( + *keys, + strict=False, + ).clone() def __getattr__(self, attr: str) -> Any: if attr in self.__dir__(): @@ -740,7 +747,12 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: def _step(self, tensordict: TensorDictBase) -> TensorDictBase: self._assert_tensordict_shape(tensordict) - self.shared_tensordict_parent.update_(tensordict.select(*self.env_input_keys)) + self.shared_tensordict_parent.update_( + tensordict.select( + *self.env_input_keys, + strict=False, + ) + ) for i in range(self.num_workers): self.parent_channels[i].send(("step", None)) @@ -756,7 +768,10 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: keys = keys.union(data) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps - return self.shared_tensordict_parent.select(*keys).clone() + return self.shared_tensordict_parent.select( + *keys, + strict=False, + ).clone() @_check_start def _shutdown_workers(self) -> None: @@ -829,7 +844,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: # there might be some delay between writing the shared tensordict # and reading the updated value on the main process sleep(0.01) - return self.shared_tensordict_parent.select(*keys).clone() + return self.shared_tensordict_parent.select( + *keys, + strict=False, + ).clone() def __reduce__(self): if not self.is_closed: @@ -979,7 +997,10 @@ def _run_worker_pipe_shared_mem( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - _td = tensordict.select(*env_input_keys) + _td = tensordict.select( + *env_input_keys, + strict=False, + ) if env.is_done and not allow_step_when_done: raise RuntimeError( f"calling step when env is done, just reset = {just_reset}" @@ -989,7 +1010,7 @@ def _run_worker_pipe_shared_mem( step_keys = set(_td.keys()) - set(env_input_keys) if pin_memory: _td.pin_memory() - tensordict.update_(_td.select(*step_keys)) + tensordict.update_(_td.select(*step_keys, strict=False)) if _td.get("done"): msg = "done" else: diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 3ce1faffbbd..18d12d7fa78 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -200,7 +200,7 @@ def forward(self, tensordict): tensordict_out.append(_tensordict) if t < time_steps - 1: _tensordict = step_mdp( - _tensordict.select(*self.out_keys), keep_other=False + _tensordict.select(*self.out_keys, strict=False), keep_other=False ) _tensordict = update_values[..., t + 1].update(_tensordict) From 0bc21da526a3ee31b4c65b8c844c6406309a35b4 Mon Sep 17 00:00:00 2001 From: Romain Julien Date: Tue, 15 Nov 2022 15:35:06 +0000 Subject: [PATCH 12/33] [Feature] Auto-compute stats for ObservationNorm (#669) * Add auto-compute stats feature for ObservationNorm * Fix issue in ObservNorm init function * Quick refactor of ObservationNorm init method * Minor refactoring and adding more tests for ObservationNorm * lint * docstring * docstring Co-authored-by: vmoens --- test/test_transforms.py | 67 +++++++++++++++++++ torchrl/envs/transforms/transforms.py | 95 ++++++++++++++++++++++++--- 2 files changed, 152 insertions(+), 10 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 8d9c031a8a6..2f6e6de7919 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -905,6 +905,73 @@ def test_observationnorm( assert (observation_spec[key].space.minimum == loc).all() assert (observation_spec[key].space.maximum == scale + loc).all() + @pytest.mark.parametrize( + "keys", [["next_observation"], ["next_observation", "next_pixel"]] + ) + @pytest.mark.parametrize("size", [1, 3]) + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("standard_normal", [True, False]) + def test_observationnorm_init_stats(self, keys, size, device, standard_normal): + base_env = ContinuousActionVecMockEnv( + observation_spec=CompositeSpec( + next_observation=NdBoundedTensorSpec( + minimum=1, maximum=1, shape=torch.Size([size]) + ), + next_observation_orig=NdBoundedTensorSpec( + minimum=1, maximum=1, shape=torch.Size([size]) + ), + ), + action_spec=NdBoundedTensorSpec( + minimum=1, maximum=1, shape=torch.Size((size,)) + ), + seed=0, + ) + base_env.out_key = "observation" + t_env = TransformedEnv( + base_env, + transform=ObservationNorm(in_keys=keys, standard_normal=standard_normal), + ) + if len(keys) > 1: + t_env.transform.init_stats(num_iter=11, key="next_observation") + else: + t_env.transform.init_stats(num_iter=11) + + if standard_normal: + torch.testing.assert_close(t_env.transform.loc, torch.Tensor([1.06] * size)) + torch.testing.assert_close( + t_env.transform.scale, torch.Tensor([0.03316621] * size) + ) + else: + torch.testing.assert_close( + t_env.transform.loc, torch.Tensor([31.960236] * size) + ) + torch.testing.assert_close( + t_env.transform.scale, torch.Tensor([30.151169] * size) + ) + + def test_observationnorm_stats_already_initialized_error(self): + transform = ObservationNorm(in_keys="next_observation", loc=0, scale=1) + + with pytest.raises(RuntimeError, match="Loc/Scale are already initialized"): + transform.init_stats(num_iter=11) + + def test_observationnorm_init_stats_multiple_keys_error(self): + transform = ObservationNorm(in_keys=["next_observation", "next_pixels"]) + + err_msg = "Transform has multiple in_keys but no specific key was passed as an argument" + with pytest.raises(RuntimeError, match=err_msg): + transform.init_stats(num_iter=11) + + def test_observationnorm_uninitialized_stats_error(self): + transform = ObservationNorm(in_keys=["next_observation", "next_pixels"]) + + err_msg = ( + "Loc/Scale have not been initialized. Either pass in values in the constructor " + "or call the init_stats method" + ) + with pytest.raises(RuntimeError, match=err_msg): + transform._apply_transform(torch.Tensor([1])) + def test_catframes_transform_observation_spec(self): N = 4 key1 = "first key" diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 0e23ef58e49..b1d0f66cd78 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -9,7 +9,7 @@ import multiprocessing as mp from copy import deepcopy, copy from textwrap import indent -from typing import Any, List, Optional, OrderedDict, Sequence, Union +from typing import Any, List, Optional, OrderedDict, Sequence, Union, Tuple from warnings import warn import torch @@ -96,6 +96,9 @@ def __init__( out_keys_inv: Optional[Sequence[str]] = None, ): super().__init__() + if isinstance(in_keys, str): + in_keys = [in_keys] + self.in_keys = in_keys if out_keys is None: out_keys = copy(self.in_keys) @@ -1255,10 +1258,21 @@ class ObservationNorm(ObservationTransform): >>> _ = transform(td) >>> print(torch.isclose(td.get('next_obs').mean(0), ... torch.zeros(3)).all()) - Tensor(True) + tensor(True) >>> print(torch.isclose(td.get('next_obs').std(0), ... torch.ones(3)).all()) - Tensor(True) + tensor(True) + + The normalisation stats can be automatically computed: + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> torch.manual_seed(0) + >>> env = GymEnv("Pendulum-v1") + >>> env = TransformedEnv(env, ObservationNorm(in_keys=["observation"])) + >>> env.set_seed(0) + >>> env.transform.init_stats(100) + >>> print(env.transform.loc, env.transform.scale) + tensor([-1.3752e+01, -6.5087e-03, 2.9294e-03], dtype=torch.float32) tensor([14.9636, 2.5608, 0.6408], dtype=torch.float32) """ @@ -1266,8 +1280,8 @@ class ObservationNorm(ObservationTransform): def __init__( self, - loc: Union[float, torch.Tensor], - scale: Union[float, torch.Tensor], + loc: Optional[float, torch.Tensor] = None, + scale: Optional[float, torch.Tensor] = None, in_keys: Optional[Sequence[str]] = None, # observation_spec_key: =None, standard_normal: bool = False, @@ -1279,18 +1293,79 @@ def __init__( "next_observation_state", ] super().__init__(in_keys=in_keys) - if not isinstance(loc, torch.Tensor): + self.standard_normal = standard_normal + self.eps = 1e-6 + + if loc is not None and not isinstance(loc, torch.Tensor): loc = torch.tensor(loc, dtype=torch.float) - if not isinstance(scale, torch.Tensor): + + if scale is not None and not isinstance(scale, torch.Tensor): scale = torch.tensor(scale, dtype=torch.float) + scale.clamp_min(self.eps) # self.observation_spec_key = observation_spec_key - self.standard_normal = standard_normal self.register_buffer("loc", loc) - eps = 1e-6 - self.register_buffer("scale", scale.clamp_min(eps)) + self.register_buffer("scale", scale) + + def init_stats( + self, + num_iter: int, + reduce_dim: Union[int, Tuple[int]] = 0, + key: Optional[str] = None, + ) -> None: + """Initializes the loc and scale stats of the parent environment. + + Normalization constant should ideally make the observation statistics approach + those of a standard Gaussian distribution. This method computes a location + and scale tensor that will empirically compute the mean and standard + deviation of a Gaussian distribution fitted on data generated randomly with + the parent environment for a given number of steps. + + Args: + num_iter (int): number of random iterations to run in the environment. + reduce_dim (int, optional): dimension to compute the mean and std over. + Defaults to 0. + key (str, optional): if provided, the summary statistics will be + retrieved from that key in the resulting tensordicts. + Otherwise, the first key in :obj:`ObservationNorm.in_keys` will be used. + + """ + if self.loc is not None or self.scale is not None: + raise RuntimeError( + f"Loc/Scale are already initialized: ({self.loc}, {self.scale})" + ) + + if len(self.in_keys) > 1 and key is None: + raise RuntimeError( + "Transform has multiple in_keys but no specific key was passed as an argument" + ) + key = self.in_keys[0] if key is None else key + + parent = self.parent + collected_frames = 0 + data = [] + while collected_frames < num_iter: + tensordict = parent.rollout(max_steps=num_iter) + collected_frames += tensordict.numel() + data.append(tensordict.get(key)) + + data = torch.cat(data, reduce_dim) + loc = data.mean(reduce_dim) + scale = data.std(reduce_dim) + + if not self.standard_normal: + loc = loc / scale + scale = 1 / scale + + self.register_buffer("loc", loc) + self.register_buffer("scale", scale.clamp_min(self.eps)) def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor: + if self.loc is None or self.scale is None: + raise RuntimeError( + "Loc/Scale have not been initialized. Either pass in values in the constructor " + "or call the init_stats method" + ) if self.standard_normal: loc = self.loc scale = self.scale From 8765ac9b74c471517eecc93fb5cb6fb5ce0ee7c9 Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Tue, 15 Nov 2022 16:57:55 +0100 Subject: [PATCH 13/33] [Doc] _make_collector helper function (#678) --- torchrl/trainers/helpers/collectors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index ca4fea8804e..f3e8c7a1ff1 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -175,7 +175,7 @@ def _make_collector( num_env_per_collector: Optional[int] = None, num_collectors: Optional[int] = None, **kwargs, -) -> _MultiDataCollector: +) -> _DataCollector: if env_kwargs is None: env_kwargs = dict() if isinstance(env_fns, list): From 14fbac900d301b45d8527e5f521ab867ac3d3b4f Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Tue, 15 Nov 2022 16:58:18 +0100 Subject: [PATCH 14/33] [Doc] BatchSubSampler class docstrings example (#677) --- torchrl/trainers/trainers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index cf32ff66f16..7c647a35bdd 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -889,7 +889,7 @@ class BatchSubSampler(TrainerHookBase): ... key1: torch.stack([torch.arange(0, 10), torch.arange(10, 20)], 0), ... key2: torch.stack([torch.arange(0, 10), torch.arange(10, 20)], 0), ... }, - ... [13, 10], + ... [2, 10], ... ) >>> trainer.register_op( ... "process_optim_batch", From bcdb0bcefcb81c95d51628f1cdace130ae0a3258 Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Tue, 15 Nov 2022 16:58:45 +0100 Subject: [PATCH 15/33] [BugFix] PPO objective crashes if advantage_module is None (#676) --- torchrl/objectives/ppo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index a956c4c1ebc..510a4ca2a79 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -80,7 +80,9 @@ def __init__( ) self.register_buffer("gamma", torch.tensor(gamma, device=self.device)) self.loss_critic_type = loss_critic_type - self.advantage_module = advantage_module.to(self.device) + self.advantage_module = advantage_module + if self.advantage_module is not None: + self.advantage_module = advantage_module.to(self.device) def reset(self) -> None: pass From fbb0e9f53dbf4b46cf8b9b8d34c717801b0ad99b Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 15 Nov 2022 16:00:35 +0000 Subject: [PATCH 16/33] Minor: lint --- torchrl/trainers/helpers/collectors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index f3e8c7a1ff1..1d72aab2643 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -10,7 +10,6 @@ from torchrl.collectors.collectors import ( _DataCollector, - _MultiDataCollector, SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector, From 26dcbcf139802731074280e36be03cfaa41ed2fb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 16 Nov 2022 11:20:48 +0000 Subject: [PATCH 17/33] [Refactor] Refactor 'next_' into nested tensordicts (#649) * init * [Feature] Nested composite spec (#654) * [Feature] Move `transform.forward` to `transform.step` (#660) * transform step function * amend * amend * amend * amend * amend * fixing key names * fixing key names * [Refactor] Transform next remove (#661) * Refactor "next_" into ("next", ) (#673) * amend * amend * bugfix * init * strict=False * strict=False * minor * amend * [BugFix] Use GitHub for flake8 pre-commit hook (#679) * amend * [BugFix] Update to strict select (#675) * init * strict=False * amend * amend * [Feature] Auto-compute stats for ObservationNorm (#669) * Add auto-compute stats feature for ObservationNorm * Fix issue in ObservNorm init function * Quick refactor of ObservationNorm init method * Minor refactoring and adding more tests for ObservationNorm * lint * docstring * docstring Co-authored-by: vmoens * amend * amend * lint * bf * bf * amend Co-authored-by: Romain Julien Co-authored-by: Romain Julien --- examples/dreamer/dreamer.py | 4 +- examples/dreamer/dreamer_utils.py | 12 +- test/_utils_internal.py | 32 +- test/mocking_classes.py | 83 ++--- test/test_collector.py | 33 +- test/test_cost.py | 84 +++-- test/test_env.py | 61 ++-- test/test_helpers.py | 34 +- test/test_libs.py | 5 +- test/test_modules.py | 63 ++-- test/test_postprocs.py | 6 +- test/test_tensor_spec.py | 107 +++++++ test/test_transforms.py | 297 +++++++++--------- torchrl/collectors/collectors.py | 39 +-- torchrl/collectors/utils.py | 6 +- torchrl/data/postprocs/postprocs.py | 33 +- torchrl/data/tensor_specs.py | 139 +++++++- torchrl/envs/common.py | 27 +- torchrl/envs/env_creator.py | 38 +-- torchrl/envs/gym_like.py | 7 +- torchrl/envs/libs/dm_control.py | 2 +- torchrl/envs/libs/gym.py | 8 +- torchrl/envs/model_based/common.py | 11 +- torchrl/envs/model_based/dreamer.py | 5 +- torchrl/envs/transforms/r3m.py | 13 +- torchrl/envs/transforms/transforms.py | 232 +++++++------- torchrl/envs/transforms/utils.py | 38 +-- torchrl/envs/transforms/vip.py | 23 +- torchrl/envs/utils.py | 23 +- torchrl/envs/vec_env.py | 14 +- torchrl/modules/models/model_based.py | 4 +- torchrl/modules/planners/cem.py | 2 +- torchrl/modules/tensordict_module/common.py | 15 +- torchrl/modules/tensordict_module/sequence.py | 2 +- torchrl/objectives/deprecated.py | 7 +- torchrl/objectives/dreamer.py | 13 +- torchrl/objectives/redq.py | 5 +- torchrl/objectives/value/advantages.py | 9 - torchrl/record/recorder.py | 4 +- torchrl/trainers/helpers/envs.py | 13 +- torchrl/trainers/helpers/models.py | 60 ++-- 41 files changed, 893 insertions(+), 720 deletions(-) diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index b6216f4d757..4e903731219 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -288,9 +288,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ): sampled_tensordict_save = ( sampled_tensordict.select( - "next_pixels", - "next_reco_pixels", - "state", + "next" "state", "belief", )[:4] .detach() diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index 3cdc7c13d30..0c69edc5c60 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -93,13 +93,13 @@ def make_env_transforms( if cfg.grayscale: env.append_transform(GrayScale()) env.append_transform(FlattenObservation()) - env.append_transform(CatFrames(N=cfg.catframes, in_keys=["next_pixels"])) + env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"])) if stats is None: obs_stats = {"loc": 0.0, "scale": 1.0} else: obs_stats = stats obs_stats["standard_normal"] = True - env.append_transform(ObservationNorm(**obs_stats, in_keys=["next_pixels"])) + env.append_transform(ObservationNorm(**obs_stats, in_keys=["pixels"])) if norm_rewards: reward_scaling = 1.0 reward_loc = 0.0 @@ -122,8 +122,8 @@ def make_env_transforms( ) default_dict = { - "next_state": NdUnboundedContinuousTensorSpec(cfg.state_dim), - "next_belief": NdUnboundedContinuousTensorSpec(cfg.rssm_hidden_dim), + "state": NdUnboundedContinuousTensorSpec(cfg.state_dim), + "belief": NdUnboundedContinuousTensorSpec(cfg.rssm_hidden_dim), } env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -309,7 +309,7 @@ def call_record( true_pixels = recover_pixels(world_model_td["next_pixels"], stats) - reco_pixels = recover_pixels(world_model_td["next_reco_pixels"], stats) + reco_pixels = recover_pixels(world_model_td["next", "reco_pixels"], stats) with autocast(dtype=torch.float16): world_model_td = world_model_td.select("state", "belief", "reward") world_model_td = model_based_env.rollout( @@ -319,7 +319,7 @@ def call_record( tensordict=world_model_td[:, 0], ) imagine_pxls = recover_pixels( - model_based_env.decode_obs(world_model_td)["next_reco_pixels"], + model_based_env.decode_obs(world_model_td)["next", "reco_pixels"], stats, ) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 0ecb446b918..421c61b08a4 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -13,7 +13,6 @@ import torch.cuda from tensordict.tensordict import TensorDictBase from torchrl._utils import seed_generator -from torchrl.data import CompositeSpec from torchrl.envs import EnvBase @@ -62,21 +61,20 @@ def _test_fake_tensordict(env: EnvBase): def _check_dtype(key, value, obs_spec, input_spec): - if key.startswith("next_"): - return - if isinstance(value, TensorDictBase): + if isinstance(value, TensorDictBase) and key == "next": for _key, _value in value.items(): - if isinstance(obs_spec, CompositeSpec) and "next_" + key in obs_spec.keys(): - _check_dtype(_key, _value, obs_spec["next_" + key], input_spec=None) - elif isinstance(input_spec, CompositeSpec) and key in input_spec.keys(): - _check_dtype(_key, _value, obs_spec=None, input_spec=input_spec[key]) - else: - raise KeyError(f"key '{_key}' is unknown.") + _check_dtype(_key, _value, obs_spec, input_spec=None) + elif isinstance(value, TensorDictBase) and key in obs_spec.keys(): + for _key, _value in value.items(): + _check_dtype(_key, _value, obs_spec=obs_spec[key], input_spec=None) + elif isinstance(value, TensorDictBase) and key in input_spec.keys(): + for _key, _value in value.items(): + _check_dtype(_key, _value, obs_spec=None, input_spec=input_spec[key]) else: - if obs_spec is not None and "next_" + key in obs_spec.keys(): + if obs_spec is not None and key in obs_spec.keys(): assert ( - obs_spec["next_" + key].dtype is value.dtype - ), f"{obs_spec['next_' + key].dtype} vs {value.dtype} for {key}" + obs_spec[key].dtype is value.dtype + ), f"{obs_spec[key].dtype} vs {value.dtype} for {key}" elif input_spec is not None and key in input_spec.keys(): assert ( input_spec[key].dtype is value.dtype @@ -112,3 +110,11 @@ def f_retry(*args, **kwargs): return f_retry # true decorator return deco_retry + + +@pytest.fixture +def dtype_fixture(): + dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.double) + yield dtype + torch.set_default_dtype(dtype) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index bffd83f1ec4..e03b857c9de 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -121,7 +121,7 @@ def __new__( action_spec = NdUnboundedContinuousTensorSpec((1,)) if observation_spec is None: observation_spec = CompositeSpec( - next_observation=NdUnboundedContinuousTensorSpec((1,)) + observation=NdUnboundedContinuousTensorSpec((1,)) ) if reward_spec is None: reward_spec = NdUnboundedContinuousTensorSpec((1,)) @@ -152,11 +152,9 @@ def _step(self, tensordict): ) done = self.counter >= self.max_val done = torch.tensor([done], dtype=torch.bool, device=self.device) - return TensorDict( - {"reward": n, "done": done, "next_observation": n.clone()}, [] - ) + return TensorDict({"reward": n, "done": done, "observation": n.clone()}, []) - def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: self.max_val = max(self.counter + 100, self.counter * 2) n = torch.tensor( @@ -164,7 +162,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: ) done = self.counter >= self.max_val done = torch.tensor([done], dtype=torch.bool, device=self.device) - return TensorDict({"done": done, "next_observation": n}, []) + return TensorDict({"done": done, "observation": n}, []) def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase: return self.step(tensordict) @@ -192,7 +190,7 @@ def __new__( ) if observation_spec is None: observation_spec = CompositeSpec( - next_observation=NdUnboundedContinuousTensorSpec((1,)) + observation=NdUnboundedContinuousTensorSpec((1,)) ) if reward_spec is None: reward_spec = NdUnboundedContinuousTensorSpec((1,)) @@ -226,7 +224,7 @@ def _step(self, tensordict): ) return TensorDict( - {"reward": n, "done": done, "next_observation": n}, + {"reward": n, "done": done, "observation": n}, tensordict.batch_size, device=self.device, ) @@ -247,7 +245,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: done = torch.full(batch_size, done, dtype=torch.bool, device=self.device) return TensorDict( - {"reward": n, "done": done, "next_observation": n}, + {"reward": n, "done": done, "observation": n}, batch_size, device=self.device, ) @@ -287,10 +285,8 @@ def __new__( if observation_spec is None: cls.out_key = "observation" observation_spec = CompositeSpec( - next_observation=NdUnboundedContinuousTensorSpec( - shape=torch.Size([size]) - ), - next_observation_orig=NdUnboundedContinuousTensorSpec( + observation=NdUnboundedContinuousTensorSpec(shape=torch.Size([size])), + observation_orig=NdUnboundedContinuousTensorSpec( shape=torch.Size([size]) ), ) @@ -308,7 +304,7 @@ def __new__( cls._out_key = "observation_orig" input_spec = CompositeSpec( **{ - cls._out_key: observation_spec["next_observation"], + cls._out_key: observation_spec["observation"], "action": action_spec, } ) @@ -325,15 +321,13 @@ def _get_in_obs(self, obs): def _get_out_obs(self, obs): return obs - def _reset(self, tensordict: TensorDictBase) -> TensorDictBase: + def _reset(self, tensordict: TensorDictBase = None) -> TensorDictBase: self.counter += 1 state = torch.zeros(self.size) + self.counter if tensordict is None: tensordict = TensorDict({}, self.batch_size, device=self.device) - tensordict = tensordict.select().set( - "next_" + self.out_key, self._get_out_obs(state) - ) - tensordict = tensordict.set("next_" + self._out_key, self._get_out_obs(state)) + tensordict = tensordict.select().set(self.out_key, self._get_out_obs(state)) + tensordict = tensordict.set(self._out_key, self._get_out_obs(state)) tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool)) return tensordict @@ -351,8 +345,8 @@ def _step( obs = self._get_in_obs(tensordict.get(self._out_key)) + a / self.maxstep tensordict = tensordict.select() # empty tensordict - tensordict.set("next_" + self.out_key, self._get_out_obs(obs)) - tensordict.set("next_" + self._out_key, self._get_out_obs(obs)) + tensordict.set(self.out_key, self._get_out_obs(obs)) + tensordict.set(self._out_key, self._get_out_obs(obs)) done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1)) reward = done.any(-1).unsqueeze(-1) @@ -379,10 +373,8 @@ def __new__( if observation_spec is None: cls.out_key = "observation" observation_spec = CompositeSpec( - next_observation=NdUnboundedContinuousTensorSpec( - shape=torch.Size([size]) - ), - next_observation_orig=NdUnboundedContinuousTensorSpec( + observation=NdUnboundedContinuousTensorSpec(shape=torch.Size([size])), + observation_orig=NdUnboundedContinuousTensorSpec( shape=torch.Size([size]) ), ) @@ -395,7 +387,7 @@ def __new__( cls._out_key = "observation_orig" input_spec = CompositeSpec( **{ - cls._out_key: observation_spec["next_observation"], + cls._out_key: observation_spec["observation"], "action": action_spec, } ) @@ -436,8 +428,8 @@ def _step( obs = self._obs_step(self._get_in_obs(tensordict.get(self._out_key)), a) tensordict = tensordict.select() # empty tensordict - tensordict.set("next_" + self.out_key, self._get_out_obs(obs)) - tensordict.set("next_" + self._out_key, self._get_out_obs(obs)) + tensordict.set(self.out_key, self._get_out_obs(obs)) + tensordict.set(self._out_key, self._get_out_obs(obs)) done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1)) reward = done.any(-1).unsqueeze(-1) @@ -483,10 +475,8 @@ def __new__( if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( - next_pixels=NdUnboundedContinuousTensorSpec( - shape=torch.Size([1, 7, 7]) - ), - next_pixels_orig=NdUnboundedContinuousTensorSpec( + pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7])), + pixels_orig=NdUnboundedContinuousTensorSpec( shape=torch.Size([1, 7, 7]) ), ) @@ -499,7 +489,7 @@ def __new__( cls._out_key = "pixels_orig" input_spec = CompositeSpec( **{ - cls._out_key: observation_spec["next_pixels_orig"], + cls._out_key: observation_spec["pixels_orig"], "action": action_spec, } ) @@ -537,10 +527,8 @@ def __new__( if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( - next_pixels=NdUnboundedContinuousTensorSpec( - shape=torch.Size([7, 7, 3]) - ), - next_pixels_orig=NdUnboundedContinuousTensorSpec( + pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])), + pixels_orig=NdUnboundedContinuousTensorSpec( shape=torch.Size([7, 7, 3]) ), ) @@ -555,7 +543,7 @@ def __new__( cls._out_key = "pixels_orig" input_spec = CompositeSpec( **{ - cls._out_key: observation_spec["next_pixels_orig"], + cls._out_key: observation_spec["pixels_orig"], "action": action_spec, } ) @@ -599,10 +587,8 @@ def __new__( if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( - next_pixels=NdUnboundedContinuousTensorSpec( - shape=torch.Size(pixel_shape) - ), - next_pixels_orig=NdUnboundedContinuousTensorSpec( + pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size(pixel_shape)), + pixels_orig=NdUnboundedContinuousTensorSpec( shape=torch.Size(pixel_shape) ), ) @@ -615,7 +601,7 @@ def __new__( if input_spec is None: cls._out_key = "pixels_orig" input_spec = CompositeSpec( - **{cls._out_key: observation_spec["next_pixels"], "action": action_spec} + **{cls._out_key: observation_spec["pixels"], "action": action_spec} ) return super().__new__( *args, @@ -650,10 +636,8 @@ def __new__( if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( - next_pixels=NdUnboundedContinuousTensorSpec( - shape=torch.Size([7, 7, 3]) - ), - next_pixels_orig=NdUnboundedContinuousTensorSpec( + pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])), + pixels_orig=NdUnboundedContinuousTensorSpec( shape=torch.Size([7, 7, 3]) ), ) @@ -714,7 +698,7 @@ def __init__( batch_size=batch_size, ) self.observation_spec = CompositeSpec( - next_hidden_observation=NdUnboundedContinuousTensorSpec((4,)) + hidden_observation=NdUnboundedContinuousTensorSpec((4,)) ) self.input_spec = CompositeSpec( hidden_observation=NdUnboundedContinuousTensorSpec((4,)), @@ -728,9 +712,6 @@ def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict: "hidden_observation": self.input_spec["hidden_observation"].rand( self.batch_size ), - "next_hidden_observation": self.observation_spec[ - "next_hidden_observation" - ].rand(self.batch_size), }, batch_size=self.batch_size, device=self.device, diff --git a/test/test_collector.py b/test/test_collector.py index 56e26d203c7..a0b122fd563 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -593,9 +593,8 @@ def env_fn(seed): @pytest.mark.parametrize("collector_class", [SyncDataCollector, aSyncDataCollector]) @pytest.mark.parametrize("env_name", ["conv", "vec"]) def test_traj_len_consistency(num_env, env_name, collector_class, seed=100): - """ - Tests that various frames_per_batch lead to the same results - """ + """Tests that various frames_per_batch lead to the same results.""" + if num_env == 1: def env_fn(seed): @@ -837,7 +836,7 @@ def make_env(): return ContinuousActionVecMockEnv() dummy_env = make_env() - obs_spec = dummy_env.observation_spec["next_observation"] + obs_spec = dummy_env.observation_spec["observation"] policy_module = nn.Linear(obs_spec.shape[-1], dummy_env.action_spec.shape[-1]) policy = Actor(policy_module, spec=dummy_env.action_spec) policy_explore = OrnsteinUhlenbeckProcessWrapper(policy) @@ -869,9 +868,9 @@ def make_env(): @pytest.mark.parametrize( "collector_class", [ - SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector, + SyncDataCollector, ], ) @pytest.mark.parametrize("init_random_frames", [0, 50]) @@ -894,7 +893,13 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe policy_kwargs = { "module": net, "in_keys": ["observation", "hidden1", "hidden2"], - "out_keys": ["action", "hidden1", "hidden2", "next_hidden1", "next_hidden2"], + "out_keys": [ + "action", + "hidden1", + "hidden2", + ("next", "hidden1"), + ("next", "hidden2"), + ], } if explicit_spec: hidden_spec = NdUnboundedContinuousTensorSpec((1, hidden_size)) @@ -902,8 +907,7 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe action=UnboundedContinuousTensorSpec(), hidden1=hidden_spec, hidden2=hidden_spec, - next_hidden1=hidden_spec, - next_hidden2=hidden_spec, + next=CompositeSpec(hidden1=hidden_spec, hidden2=hidden_spec), ) policy = TensorDictModule(**policy_kwargs) @@ -927,23 +931,24 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe collector = collector_class(**collector_kwargs) - keys = [ + keys = { "action", "done", "hidden1", "hidden2", "mask", - "next_hidden1", - "next_hidden2", - "next_observation", + ("next", "hidden1"), + ("next", "hidden2"), + ("next", "observation"), + "next", "observation", "reward", "step_count", "traj_ids", - ] + } b = next(iter(collector)) - assert set(b.keys()) == set(keys) + assert set(b.keys(True)) == keys collector.shutdown() del collector diff --git a/test/test_cost.py b/test/test_cost.py index 887f14cb1b6..af513a2d55e 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -21,7 +21,7 @@ import numpy as np import pytest import torch -from _utils_internal import get_available_devices +from _utils_internal import get_available_devices, dtype_fixture # noqa from mocking_classes import ContinuousActionConvMockEnv # from torchrl.data.postprocs.utils import expand_as_right @@ -94,14 +94,6 @@ from torchrl.objectives.value.utils import _custom_conv1d, _make_gammas_tensor -@pytest.fixture -def dtype_fixture(): - dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.DoubleTensor) - yield dtype - torch.set_default_dtype(dtype) - - class _check_td_steady: def __init__(self, td): self.td_clone = td.clone() @@ -223,7 +215,7 @@ def _create_mock_data_dqn( batch_size=(batch,), source={ "observation": obs, - "next_observation": next_obs, + "next": {"observation": next_obs}, "done": done, "reward": reward, "action": action, @@ -268,7 +260,7 @@ def _create_seq_mock_data_dqn( batch_size=(batch, T), source={ "observation": obs * mask.to(obs.dtype), - "next_observation": next_obs * mask.to(obs.dtype), + "next": {"observation": next_obs * mask.to(obs.dtype)}, "done": done, "mask": mask, "reward": reward * mask.to(obs.dtype), @@ -621,7 +613,7 @@ def _create_mock_data_ddpg( batch_size=(batch,), source={ "observation": obs, - "next_observation": next_obs, + "next": {"observation": next_obs}, "done": done, "reward": reward, "action": action, @@ -650,7 +642,7 @@ def _create_seq_mock_data_ddpg( batch_size=(batch, T), source={ "observation": obs * mask.to(obs.dtype), - "next_observation": next_obs * mask.to(obs.dtype), + "next": {"observation": next_obs * mask.to(obs.dtype)}, "done": done, "mask": mask, "reward": reward * mask.to(obs.dtype), @@ -843,7 +835,7 @@ def _create_mock_data_sac( batch_size=(batch,), source={ "observation": obs, - "next_observation": next_obs, + "next": {"observation": next_obs}, "done": done, "reward": reward, "action": action, @@ -872,7 +864,7 @@ def _create_seq_mock_data_sac( batch_size=(batch, T), source={ "observation": obs * mask.to(obs.dtype), - "next_observation": next_obs * mask.to(obs.dtype), + "next": {"observation": next_obs * mask.to(obs.dtype)}, "done": done, "mask": mask, "reward": reward * mask.to(obs.dtype), @@ -1192,7 +1184,7 @@ def _create_mock_data_redq( batch_size=(batch,), source={ "observation": obs, - "next_observation": next_obs, + "next": {"observation": next_obs}, "done": done, "reward": reward, "action": action, @@ -1221,7 +1213,7 @@ def _create_seq_mock_data_redq( batch_size=(batch, T), source={ "observation": obs * mask.to(obs.dtype), - "next_observation": next_obs * mask.to(obs.dtype), + "next": {"observation": next_obs * mask.to(obs.dtype)}, "done": done, "mask": mask, "reward": reward * mask.to(obs.dtype), @@ -1578,7 +1570,7 @@ def _create_mock_data_ppo( batch_size=(batch,), source={ "observation": obs, - "next_observation": next_obs, + "next": {"observation": next_obs}, "done": done, "reward": reward, "action": action, @@ -1610,7 +1602,7 @@ def _create_seq_mock_data_ppo( batch_size=(batch, T), source={ "observation": obs * mask.to(obs.dtype), - "next_observation": next_obs * mask.to(obs.dtype), + "next": {"observation": next_obs * mask.to(obs.dtype)}, "done": done, "mask": mask, "reward": reward * mask.to(obs.dtype), @@ -1822,7 +1814,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value): { "reward": torch.randn(batch, 1), "observation": torch.randn(batch, n_obs), - "next_observation": torch.randn(batch, n_obs), + "next": {"observation": torch.randn(batch, n_obs)}, "done": torch.zeros(batch, 1, dtype=torch.bool), "action": torch.randn(batch, n_act), }, @@ -1866,7 +1858,7 @@ def _create_world_model_data( "state": torch.zeros(batch_size, temporal_length, state_dim), "belief": torch.zeros(batch_size, temporal_length, rssm_hidden_dim), "pixels": torch.randn(batch_size, temporal_length, 3, 64, 64), - "next_pixels": torch.randn(batch_size, temporal_length, 3, 64, 64), + "next": {"pixels": torch.randn(batch_size, temporal_length, 3, 64, 64)}, "action": torch.randn(batch_size, temporal_length, 64), "reward": torch.randn(batch_size, temporal_length, 1), "done": torch.zeros(batch_size, temporal_length, dtype=torch.bool), @@ -1904,8 +1896,8 @@ def _create_value_data( def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200): mock_env = TransformedEnv(ContinuousActionConvMockEnv(pixel_shape=[3, 64, 64])) default_dict = { - "next_state": NdUnboundedContinuousTensorSpec(state_dim), - "next_belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": NdUnboundedContinuousTensorSpec(state_dim), + "belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim), } mock_env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -1928,19 +1920,19 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20 rssm_prior, in_keys=["state", "belief", "action"], out_keys=[ - "next_prior_mean", - "next_prior_std", + ("next", "prior_mean"), + ("next", "prior_std"), "_", - "next_belief", + ("next", "belief"), ], ), TensorDictModule( rssm_posterior, - in_keys=["next_belief", "next_encoded_latents"], + in_keys=[("next", "belief"), ("next", "encoded_latents")], out_keys=[ - "next_posterior_mean", - "next_posterior_std", - "next_state", + ("next", "posterior_mean"), + ("next", "posterior_std"), + ("next", "state"), ], ), ) @@ -1951,19 +1943,19 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20 world_modeler = TensorDictSequential( TensorDictModule( obs_encoder, - in_keys=["next_pixels"], - out_keys=["next_encoded_latents"], + in_keys=[("next", "pixels")], + out_keys=[("next", "encoded_latents")], ), rssm_rollout, TensorDictModule( obs_decoder, - in_keys=["next_state", "next_belief"], - out_keys=["next_reco_pixels"], + in_keys=[("next", "state"), ("next", "belief")], + out_keys=[("next", "reco_pixels")], ), ) reward_module = TensorDictModule( reward_module, - in_keys=["next_state", "next_belief"], + in_keys=[("next", "state"), ("next", "belief")], out_keys=["reward"], ) world_model = WorldModelWrapper(world_modeler, reward_module) @@ -1979,8 +1971,8 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20 def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200): mock_env = TransformedEnv(ContinuousActionConvMockEnv(pixel_shape=[3, 64, 64])) default_dict = { - "next_state": NdUnboundedContinuousTensorSpec(state_dim), - "next_belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": NdUnboundedContinuousTensorSpec(state_dim), + "belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim), } mock_env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -2002,14 +1994,14 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200): out_keys=[ "_", "_", - "next_state", - "next_belief", + "state", + "belief", ], ), ) reward_model = TensorDictModule( reward_module, - in_keys=["next_state", "next_belief"], + in_keys=["state", "belief"], out_keys=["reward"], ) model_based_env = DreamerEnv( @@ -2028,8 +2020,8 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200): def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200): mock_env = TransformedEnv(ContinuousActionConvMockEnv(pixel_shape=[3, 64, 64])) default_dict = { - "next_state": NdUnboundedContinuousTensorSpec(state_dim), - "next_belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": NdUnboundedContinuousTensorSpec(state_dim), + "belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim), } mock_env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -2555,12 +2547,11 @@ def test_tdlambda_tensor_gamma(device, gamma, lmbda, N, T): @pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99]) @pytest.mark.parametrize("N", [(3,), (7, 3)]) @pytest.mark.parametrize("T", [3, 5, 50]) -def test_vectdlambda_tensor_gamma(device, gamma, lmbda, N, T): +def test_vectdlambda_tensor_gamma(device, gamma, lmbda, N, T, dtype_fixture): # noqa """Tests td_lambda_advantage_estimate against vec_td_lambda_advantage_estimate with gamma being a tensor or a scalar """ - _ = dtype_fixture torch.manual_seed(0) @@ -2599,13 +2590,14 @@ def test_vectdlambda_tensor_gamma(device, gamma, lmbda, N, T): @pytest.mark.parametrize("N", [(3,), (7, 3)]) @pytest.mark.parametrize("T", [50, 3]) @pytest.mark.parametrize("rolling_gamma", [True, False, None]) -def test_vectdlambda_rand_gamma(device, lmbda, N, T, rolling_gamma): +def test_vectdlambda_rand_gamma( + device, lmbda, N, T, rolling_gamma, dtype_fixture # noqa +): """Tests td_lambda_advantage_estimate against vec_td_lambda_advantage_estimate with gamma being a random tensor """ torch.manual_seed(0) - _ = dtype_fixture done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) reward = torch.randn(*N, T, 1, device=device) diff --git a/test/test_env.py b/test/test_env.py index c20d2df1d86..aefeb4f36de 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -205,7 +205,7 @@ def test_rollout_predictability(device): ).all() assert ( torch.arange(first + 1, first + 101, device=device) - == td_out.get("next_observation").squeeze() + == td_out.get(("next", "observation")).squeeze() ).all() assert ( torch.arange(first + 1, first + 101, device=device) @@ -248,9 +248,7 @@ def create_env_fn(): return TransformedEnv( GymEnv(env_name, frame_skip=frame_skip, device=device), Compose( - ObservationNorm( - in_keys=["next_observation"], loc=0.5, scale=1.1 - ), + ObservationNorm(in_keys=["observation"], loc=0.5, scale=1.1), RewardClipping(0, 0.1), ), ) @@ -269,9 +267,7 @@ def t_out(): return ( Compose(*[ToTensorImage(), RewardClipping(0, 0.1)]) if not transformed_in - else Compose( - *[ObservationNorm(in_keys=["next_pixels"], loc=0, scale=1)] - ) + else Compose(*[ObservationNorm(in_keys=["pixels"], loc=0, scale=1)]) ) env0 = TransformedEnv( @@ -291,16 +287,12 @@ def t_out(): def t_out(): return ( Compose( - ObservationNorm( - in_keys=["next_observation"], loc=0.5, scale=1.1 - ), + ObservationNorm(in_keys=["observation"], loc=0.5, scale=1.1), RewardClipping(0, 0.1), ) if not transformed_in else Compose( - ObservationNorm( - in_keys=["next_observation"], loc=1.0, scale=1.0 - ) + ObservationNorm(in_keys=["observation"], loc=1.0, scale=1.0) ) ) @@ -329,7 +321,7 @@ def test_mb_rollout(self, device, seed=0): TensorDictModule( ActionObsMergeLinear(5, 4), in_keys=["hidden_observation", "action"], - out_keys=["next_hidden_observation"], + out_keys=["hidden_observation"], ), TensorDictModule( nn.Linear(4, 1), @@ -341,10 +333,11 @@ def test_mb_rollout(self, device, seed=0): world_model, device=device, batch_size=torch.Size([10]) ) rollout = mb_env.rollout(max_steps=100) - assert set(rollout.keys()) == set(mb_env.observation_spec.keys()).union( - set(mb_env.input_spec.keys()) - ).union({"reward", "done"}) - assert rollout["next_hidden_observation"].shape == (10, 100, 4) + expected_keys = {("next", key) for key in mb_env.observation_spec.keys()} + expected_keys = expected_keys.union(set(mb_env.input_spec.keys())) + expected_keys = expected_keys.union({"reward", "done", "next"}) + assert set(rollout.keys(True)) == expected_keys + assert rollout[("next", "hidden_observation")].shape == (10, 100, 4) @pytest.mark.parametrize("device", get_available_devices()) def test_mb_env_batch_lock(self, device, seed=0): @@ -354,7 +347,7 @@ def test_mb_env_batch_lock(self, device, seed=0): TensorDictModule( ActionObsMergeLinear(5, 4), in_keys=["hidden_observation", "action"], - out_keys=["next_hidden_observation"], + out_keys=["hidden_observation"], ), TensorDictModule( nn.Linear(4, 1), @@ -450,10 +443,10 @@ def env1_maker(): return TransformedEnv( DMControlEnv("humanoid", "stand"), Compose( - CatTensors(env1_obs_keys, "next_observation_stand", del_keys=False), - CatTensors(env1_obs_keys, "next_observation"), + CatTensors(env1_obs_keys, "observation_stand", del_keys=False), + CatTensors(env1_obs_keys, "observation"), DoubleToFloat( - in_keys=["next_observation_stand", "next_observation"], + in_keys=["observation_stand", "observation"], in_keys_inv=["action"], ), ), @@ -463,10 +456,10 @@ def env2_maker(): return TransformedEnv( DMControlEnv("humanoid", "walk"), Compose( - CatTensors(env2_obs_keys, "next_observation_walk", del_keys=False), - CatTensors(env2_obs_keys, "next_observation"), + CatTensors(env2_obs_keys, "observation_walk", del_keys=False), + CatTensors(env2_obs_keys, "observation"), DoubleToFloat( - in_keys=["next_observation_walk", "next_observation"], + in_keys=["observation_walk", "observation"], in_keys_inv=["action"], ), ), @@ -650,7 +643,7 @@ def test_parallel_env_seed( ).contiguous() key = "pixels" if "pixels" in td_serial.keys() else "observation" torch.testing.assert_close( - td_serial[:, 0].get("next_" + key), td_serial[:, 1].get(key) + td_serial[:, 0].get(("next", key)), td_serial[:, 1].get(key) ) out_seed_parallel = env_parallel.set_seed(0, static_seed=static_seed) @@ -664,7 +657,7 @@ def test_parallel_env_seed( max_steps=10, auto_reset=False, tensordict=td0_parallel ).contiguous() torch.testing.assert_close( - td_parallel[:, :-1].get("next_" + key), td_parallel[:, 1:].get(key) + td_parallel[:, :-1].get(("next", key)), td_parallel[:, 1:].get(key) ) assert_allclose_td(td0_serial, td0_parallel) @@ -938,10 +931,10 @@ def test_seed(): rollout2 = env2.rollout(max_steps=30) torch.testing.assert_close( - rollout1["observation"][1:], rollout1["next_observation"][:-1] + rollout1["observation"][1:], rollout1[("next", "observation")][:-1] ) torch.testing.assert_close( - rollout2["observation"][1:], rollout2["next_observation"][:-1] + rollout2["observation"][1:], rollout2[("next", "observation")][:-1] ) torch.testing.assert_close(rollout1["observation"], rollout2["observation"]) @@ -958,7 +951,7 @@ def test_steptensordict( tensordict = TensorDict( { "ledzep": torch.randn(4, 2), - "next_ledzep": torch.randn(4, 2), + "next": {"ledzep": torch.randn(4, 2)}, "reward": torch.randn(4, 1), "done": torch.zeros(4, 1, dtype=torch.bool), "beatles": torch.randn(4, 1), @@ -976,7 +969,7 @@ def test_steptensordict( next_tensordict=next_tensordict, ) assert "ledzep" in out.keys() - assert out["ledzep"] is tensordict["next_ledzep"] + assert out["ledzep"] is tensordict["next", "ledzep"] if keep_other: assert "beatles" in out.keys() assert out["beatles"] is tensordict["beatles"] @@ -1070,7 +1063,7 @@ def test_info_dict_reader(seed=0): tensordict = env.reset() tensordict = env.rand_step(tensordict) - assert env.observation_spec["x_position"].is_in(tensordict["x_position"]) + assert env.observation_spec["x_position"].is_in(tensordict[("next", "x_position")]) env2 = GymWrapper(gym.make("HalfCheetah-v4")) env2.set_info_dict_reader( @@ -1082,7 +1075,9 @@ def test_info_dict_reader(seed=0): tensordict2 = env2.reset() tensordict2 = env2.rand_step(tensordict2) - assert not env2.observation_spec["x_position"].is_in(tensordict2["x_position"]) + assert not env2.observation_spec["x_position"].is_in( + tensordict2[("next", "x_position")] + ) if __name__ == "__main__": diff --git a/test/test_helpers.py b/test/test_helpers.py index df005248ffa..bec3757fce0 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -631,36 +631,40 @@ def test_dreamer_make(device, tanh_loc, exploration, dreamer_constructor_fixture "action", "belief", "done", - "next_belief", - "next_encoded_latents", - "next_pixels", - "next_pixels_orig", - "next_posterior_mean", - "next_posterior_std", - "next_prior_mean", - "next_prior_std", - "next_state", + ("next", "belief"), + ("next", "encoded_latents"), + ("next", "pixels"), + ("next", "pixels_orig"), + ("next", "posterior_mean"), + ("next", "posterior_std"), + ("next", "prior_mean"), + ("next", "prior_std"), + ("next", "state"), "pixels", "pixels_orig", "reward", "state", - "next_reco_pixels", + ("next", "reco_pixels"), + "next", } - assert set(out.keys()) == expected_keys + assert set(out.keys(True)) == expected_keys simulated_data = model_based_env.rollout(3) expected_keys = { "action", "belief", "done", - "next_belief", - "next_state", - "pixels", + ("next", "belief"), + ("next", "state"), + ("next", "pixels"), + ("next", "pixels_orig"), "pixels_orig", + "pixels", "reward", "state", + "next", } - assert expected_keys == set(simulated_data.keys()) + assert expected_keys == set(simulated_data.keys(True)) simulated_action = actor_model(model_based_env.reset()) real_action = actor_model(proof_environment.reset()) diff --git a/test/test_libs.py b/test/test_libs.py index 0a176ce5a60..8667b158934 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -286,8 +286,11 @@ def test_td_creation_from_spec(env_lib, env_args, env_kwargs): ) env = env_lib(*env_args, **env_kwargs) td = env.rollout(max_steps=5) - td0 = td[0] + td0 = td[0].flatten_keys(".") fake_td = env.fake_tensordict() + + fake_td = fake_td.flatten_keys(".") + td = td.flatten_keys(".") assert set(fake_td.keys()) == set(td.keys()) for key in fake_td.keys(): assert fake_td.get(key).shape == td.get(key)[0].shape diff --git a/test/test_modules.py b/test/test_modules.py index eea003a01a0..57cbd4d99a4 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -616,19 +616,19 @@ def test_rssm_rollout( rssm_prior, in_keys=["state", "belief", "action"], out_keys=[ - "next_prior_mean", - "next_prior_std", + ("next", "prior_mean"), + ("next", "prior_std"), "_", - "next_belief", + ("next", "belief"), ], ), TensorDictModule( rssm_posterior, - in_keys=["next_belief", "next_encoded_latents"], + in_keys=[("next", "belief"), ("next", "encoded_latents")], out_keys=[ - "next_posterior_mean", - "next_posterior_std", - "next_state", + ("next", "posterior_mean"), + ("next", "posterior_std"), + ("next", "state"), ], ), ) @@ -641,9 +641,11 @@ def test_rssm_rollout( tensordict = TensorDict( { "state": state.clone(), - "next_belief": belief.clone(), "action": action.clone(), - "next_encoded_latents": obs_emb.clone(), + "next": { + "encoded_latents": obs_emb.clone(), + "belief": belief.clone(), + }, }, device=device, batch_size=torch.Size([*batch_size, temporal_size]), @@ -652,30 +654,38 @@ def test_rssm_rollout( _ = rssm_rollout(tensordict.clone()) torch.manual_seed(0) rollout = rssm_rollout(tensordict) - assert rollout["next_prior_mean"].shape == ( + assert rollout["next", "prior_mean"].shape == ( *batch_size, temporal_size, deter_size, ) - assert rollout["next_prior_std"].shape == ( + assert rollout["next", "prior_std"].shape == ( *batch_size, temporal_size, deter_size, ) - assert rollout["next_state"].shape == (*batch_size, temporal_size, deter_size) - assert rollout["next_belief"].shape == (*batch_size, temporal_size, stoch_size) - assert rollout["next_posterior_mean"].shape == ( + assert rollout["next", "state"].shape == ( *batch_size, temporal_size, deter_size, ) - assert rollout["next_posterior_std"].shape == ( + assert rollout["next", "belief"].shape == ( + *batch_size, + temporal_size, + stoch_size, + ) + assert rollout["next", "posterior_mean"].shape == ( *batch_size, temporal_size, deter_size, ) - assert torch.all(rollout["next_prior_std"] > 0) - assert torch.all(rollout["next_posterior_std"] > 0) + assert rollout["next", "posterior_std"].shape == ( + *batch_size, + temporal_size, + deter_size, + ) + assert torch.all(rollout["next", "prior_std"] > 0) + assert torch.all(rollout["next", "posterior_std"] > 0) state[..., 1:, :] = 0 belief[..., 1:, :] = 0 @@ -684,9 +694,8 @@ def test_rssm_rollout( tensordict_bis = TensorDict( { "state": state.clone(), - "next_belief": belief.clone(), "action": action.clone(), - "next_encoded_latents": obs_emb.clone(), + "next": {"encoded_latents": obs_emb.clone(), "belief": belief.clone()}, }, device=device, batch_size=torch.Size([*batch_size, temporal_size]), @@ -695,16 +704,18 @@ def test_rssm_rollout( rollout_bis = rssm_rollout(tensordict_bis) assert torch.allclose( - rollout["next_prior_mean"], rollout_bis["next_prior_mean"] - ), (rollout["next_prior_mean"] - rollout_bis["next_prior_mean"]).norm() - assert torch.allclose(rollout["next_prior_std"], rollout_bis["next_prior_std"]) - assert torch.allclose(rollout["next_state"], rollout_bis["next_state"]) - assert torch.allclose(rollout["next_belief"], rollout_bis["next_belief"]) + rollout["next", "prior_mean"], rollout_bis["next", "prior_mean"] + ), (rollout["next", "prior_mean"] - rollout_bis["next", "prior_mean"]).norm() + assert torch.allclose( + rollout["next", "prior_std"], rollout_bis["next", "prior_std"] + ) + assert torch.allclose(rollout["next", "state"], rollout_bis["next", "state"]) + assert torch.allclose(rollout["next", "belief"], rollout_bis["next", "belief"]) assert torch.allclose( - rollout["next_posterior_mean"], rollout_bis["next_posterior_mean"] + rollout["next", "posterior_mean"], rollout_bis["next", "posterior_mean"] ) assert torch.allclose( - rollout["next_posterior_std"], rollout_bis["next_posterior_std"] + rollout["next", "posterior_std"], rollout_bis["next", "posterior_std"] ) diff --git a/test/test_postprocs.py b/test/test_postprocs.py index 1cf01d7467f..d50b74bb08f 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -37,7 +37,7 @@ def test_multistep(n, key, device, T=11): tensordict = TensorDict( source={ key: total_obs[:, :T] * mask.to(torch.float), - "next_" + key: total_obs[:, 1:] * mask.to(torch.float), + "next": {key: total_obs[:, 1:] * mask.to(torch.float)}, "done": done, "reward": torch.randn(1, T, 1, device=device).expand(b, T, 1) * mask.to(torch.float), @@ -60,7 +60,7 @@ def test_multistep(n, key, device, T=11): # assert that done at last step is similar to unterminated traj assert (ms_tensordict.get("gamma")[4] == ms_tensordict.get("gamma")[0]).all() assert ( - ms_tensordict.get("next_" + key)[4] == ms_tensordict.get("next_" + key)[0] + ms_tensordict.get(("next", key))[4] == ms_tensordict.get(("next", key))[0] ).all() assert ( ms_tensordict.get("steps_to_next_obs")[4] @@ -69,7 +69,7 @@ def test_multistep(n, key, device, T=11): # check that next obs is properly replaced, or that it is terminated next_obs = ms_tensordict.get(key)[:, (1 + ms.n_steps_max) :] - true_next_obs = ms_tensordict.get("next_" + key)[:, : -(1 + ms.n_steps_max)] + true_next_obs = ms_tensordict.get(("next", key))[:, : -(1 + ms.n_steps_max)] terminated = ~ms_tensordict.get("nonterminal") assert ((next_obs == true_next_obs) | terminated[:, (1 + ms.n_steps_max) :]).all() diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 3d35bbf3f3d..3ca99d8ac73 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -20,6 +20,7 @@ NdUnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, + _keys_to_empty_composite_spec, ) @@ -378,6 +379,20 @@ def test_type_check(self, is_complete, device, dtype): def test_nested_composite_spec(self, is_complete, device, dtype): ts = self._composite_spec(is_complete, device, dtype) ts["nested_cp"] = self._composite_spec(is_complete, device, dtype) + assert set(ts.keys()) == { + "obs", + "act", + ("nested_cp", "obs"), + ("nested_cp", "act"), + } + assert len(ts.keys()) == len(ts.keys(yield_nesting_keys=True)) - 1 + assert set(ts.keys(yield_nesting_keys=True)) == { + "obs", + "act", + ("nested_cp", "obs"), + ("nested_cp", "act"), + "nested_cp", + } td = ts.rand() assert isinstance(td["nested_cp"], TensorDictBase) keys = list(td.keys()) @@ -385,6 +400,98 @@ def test_nested_composite_spec(self, is_complete, device, dtype): if key != "nested_cp": assert key in td["nested_cp"].keys() + def test_nested_composite_spec_index(self, is_complete, device, dtype): + ts = self._composite_spec(is_complete, device, dtype) + ts["nested_cp"] = self._composite_spec(is_complete, device, dtype) + ts["nested_cp"]["nested_cp"] = self._composite_spec(is_complete, device, dtype) + assert ts["nested_cp"]["nested_cp"] is ts["nested_cp", "nested_cp"] + assert ( + ts["nested_cp"]["nested_cp"]["obs"] is ts["nested_cp", "nested_cp", "obs"] + ) + + def test_nested_composite_spec_rand(self, is_complete, device, dtype): + ts = self._composite_spec(is_complete, device, dtype) + ts["nested_cp"] = self._composite_spec(is_complete, device, dtype) + ts["nested_cp"]["nested_cp"] = self._composite_spec(is_complete, device, dtype) + r = ts.rand() + assert (r["nested_cp", "nested_cp", "obs"] >= 0).all() + + def test_nested_composite_spec_zero(self, is_complete, device, dtype): + ts = self._composite_spec(is_complete, device, dtype) + ts["nested_cp"] = self._composite_spec(is_complete, device, dtype) + ts["nested_cp"]["nested_cp"] = self._composite_spec(is_complete, device, dtype) + r = ts.zero() + assert (r["nested_cp", "nested_cp", "obs"] == 0).all() + + def test_nested_composite_spec_setitem(self, is_complete, device, dtype): + ts = self._composite_spec(is_complete, device, dtype) + ts["nested_cp"] = self._composite_spec(is_complete, device, dtype) + ts["nested_cp"]["nested_cp"] = self._composite_spec(is_complete, device, dtype) + ts["nested_cp", "nested_cp", "obs"] = None + assert ( + ts["nested_cp"]["nested_cp"]["obs"] is ts["nested_cp", "nested_cp", "obs"] + ) + assert ts["nested_cp"]["nested_cp"]["obs"] is None + + def test_nested_composite_spec_update(self, is_complete, device, dtype): + ts = self._composite_spec(is_complete, device, dtype) + ts["nested_cp"] = self._composite_spec(is_complete, device, dtype) + td2 = CompositeSpec(new=None) + ts.update(td2) + assert set(ts.keys()) == { + "obs", + "act", + ("nested_cp", "obs"), + ("nested_cp", "act"), + "new", + } + + ts = self._composite_spec(is_complete, device, dtype) + ts["nested_cp"] = self._composite_spec(is_complete, device, dtype) + td2 = CompositeSpec(nested_cp=CompositeSpec(new=None).to(device)) + ts.update(td2) + assert set(ts.keys()) == { + "obs", + "act", + ("nested_cp", "obs"), + ("nested_cp", "act"), + ("nested_cp", "new"), + } + + ts = self._composite_spec(is_complete, device, dtype) + ts["nested_cp"] = self._composite_spec(is_complete, device, dtype) + td2 = CompositeSpec(nested_cp=CompositeSpec(act=None).to(device)) + ts.update(td2) + assert set(ts.keys()) == { + "obs", + "act", + ("nested_cp", "obs"), + ("nested_cp", "act"), + } + assert ts["nested_cp"]["act"] is None + + ts = self._composite_spec(is_complete, device, dtype) + ts["nested_cp"] = self._composite_spec(is_complete, device, dtype) + td2 = CompositeSpec(nested_cp=CompositeSpec(act=None).to(device)) + ts.update(td2) + td2 = CompositeSpec( + nested_cp=CompositeSpec(act=UnboundedContinuousTensorSpec(device)) + ) + ts.update(td2) + assert set(ts.keys()) == { + "obs", + "act", + ("nested_cp", "obs"), + ("nested_cp", "act"), + } + assert ts["nested_cp"]["act"] is not None + + +def test_keys_to_empty_composite_spec(): + keys = [("key1", "out"), ("key1", "in"), "key2", ("key1", "subkey1", "subkey2")] + composite = _keys_to_empty_composite_spec(keys) + assert set(composite.keys()) == set(keys) + class TestEquality: """Tests spec comparison.""" diff --git a/test/test_transforms.py b/test/test_transforms.py index 2f6e6de7919..050d1831fed 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -8,7 +8,7 @@ import numpy as np import pytest import torch -from _utils_internal import get_available_devices, retry +from _utils_internal import get_available_devices, retry, dtype_fixture # noqa from mocking_classes import ( ContinuousActionVecMockEnv, DiscreteActionConvMockEnvNumpy, @@ -90,9 +90,9 @@ def _test_vecnorm_subproc_auto( msg = queue_in.get(timeout=TIMEOUT) assert msg == "all_done" t = env.transform - obs_sum = t._td.get("next_observation_sum").clone() - obs_ssq = t._td.get("next_observation_ssq").clone() - obs_count = t._td.get("next_observation_count").clone() + obs_sum = t._td.get("observation_sum").clone() + obs_ssq = t._td.get("observation_ssq").clone() + obs_count = t._td.get("observation_count").clone() reward_sum = t._td.get("reward_sum").clone() reward_ssq = t._td.get("reward_ssq").clone() reward_count = t._td.get("reward_count").clone() @@ -145,9 +145,9 @@ def test_vecnorm_parallel_auto(self, nprc): td = make_env.state_dict()["_extra_state"]["td"] - obs_sum = td.get("next_observation_sum").clone() - obs_ssq = td.get("next_observation_ssq").clone() - obs_count = td.get("next_observation_count").clone() + obs_sum = td.get("observation_sum").clone() + obs_ssq = td.get("observation_ssq").clone() + obs_count = td.get("observation_count").clone() reward_sum = td.get("reward_sum").clone() reward_ssq = td.get("reward_ssq").clone() reward_count = td.get("reward_count").clone() @@ -254,9 +254,8 @@ def test_parallelenv_vecnorm(self): True, ], ) - def test_vecnorm(self, parallel, thr=0.2, N=200): + def test_vecnorm_rollout(self, parallel, thr=0.2, N=200): self.SEED += 1 - print(self.SEED) torch.manual_seed(self.SEED) if parallel is None: @@ -281,7 +280,7 @@ def test_vecnorm(self, parallel, thr=0.2, N=200): if td.get("done").any(): td = env_t.reset() tds = torch.stack(tds, 0) - obs = tds.get("next_observation") + obs = tds.get(("next", "observation")) obs = obs.view(-1, obs.shape[-1]) mean = obs.mean(0) assert (abs(mean) < thr).all() @@ -317,9 +316,7 @@ def test_added_transforms_are_in_eval_mode(): class TestTransformedEnv: def test_independent_obs_specs_from_shared_env(self): - obs_spec = CompositeSpec( - next_observation=BoundedTensorSpec(minimum=0, maximum=10) - ) + obs_spec = CompositeSpec(observation=BoundedTensorSpec(minimum=0, maximum=10)) base_env = ContinuousActionVecMockEnv(observation_spec=obs_spec) t1 = TransformedEnv(base_env, transform=ObservationNorm(loc=3, scale=2)) t2 = TransformedEnv(base_env, transform=ObservationNorm(loc=1, scale=6)) @@ -327,14 +324,14 @@ def test_independent_obs_specs_from_shared_env(self): t1_obs_spec = t1.observation_spec t2_obs_spec = t2.observation_spec - assert t1_obs_spec["next_observation"].space.minimum == 3 - assert t1_obs_spec["next_observation"].space.maximum == 23 + assert t1_obs_spec["observation"].space.minimum == 3 + assert t1_obs_spec["observation"].space.maximum == 23 - assert t2_obs_spec["next_observation"].space.minimum == 1 - assert t2_obs_spec["next_observation"].space.maximum == 61 + assert t2_obs_spec["observation"].space.minimum == 1 + assert t2_obs_spec["observation"].space.maximum == 61 - assert base_env.observation_spec["next_observation"].space.minimum == 0 - assert base_env.observation_spec["next_observation"].space.maximum == 10 + assert base_env.observation_spec["observation"].space.minimum == 0 + assert base_env.observation_spec["observation"].space.maximum == 10 def test_independent_reward_specs_from_shared_env(self): reward_spec = UnboundedContinuousTensorSpec() @@ -400,7 +397,7 @@ class TestTransforms: @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize( - "keys", [["next_observation", "some_other_key"], ["next_observation_pixels"]] + "keys", [["observation", "some_other_key"], ["observation_pixels"]] ) @pytest.mark.parametrize("device", get_available_devices()) def test_resize(self, interpolation, keys, nchannels, batch, device): @@ -427,7 +424,7 @@ def test_resize(self, interpolation, keys, nchannels, batch, device): assert observation_spec.shape == torch.Size([nchannels, 20, 21]) else: observation_spec = CompositeSpec( - **{key: NdBoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + {key: NdBoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = resize.transform_observation_spec(observation_spec) for key in keys: @@ -438,7 +435,7 @@ def test_resize(self, interpolation, keys, nchannels, batch, device): @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("h", [None, 21]) @pytest.mark.parametrize( - "keys", [["next_observation", "some_other_key"], ["next_observation_pixels"]] + "keys", [["observation", "some_other_key"], ["observation_pixels"]] ) @pytest.mark.parametrize("device", get_available_devices()) def test_centercrop(self, keys, h, nchannels, batch, device): @@ -467,7 +464,7 @@ def test_centercrop(self, keys, h, nchannels, batch, device): assert observation_spec.shape == torch.Size([nchannels, 20, h]) else: observation_spec = CompositeSpec( - **{key: NdBoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + {key: NdBoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = cc.transform_observation_spec(observation_spec) for key in keys: @@ -478,7 +475,7 @@ def test_centercrop(self, keys, h, nchannels, batch, device): @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @pytest.mark.parametrize( - "keys", [["next_observation", "some_other_key"], ["next_observation_pixels"]] + "keys", [["observation", "some_other_key"], ["observation_pixels"]] ) @pytest.mark.parametrize("device", get_available_devices()) def test_flatten(self, keys, size, nchannels, batch, device): @@ -507,7 +504,7 @@ def test_flatten(self, keys, size, nchannels, batch, device): assert observation_spec.shape[-3] == expected_size else: observation_spec = CompositeSpec( - **{ + { key: NdBoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) for key in keys } @@ -521,7 +518,7 @@ def test_flatten(self, keys, size, nchannels, batch, device): @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @pytest.mark.parametrize( - "keys", [["next_observation", "some_other_key"], ["next_observation_pixels"]] + "keys", [["observation", "some_other_key"], ["observation_pixels"]] ) @pytest.mark.parametrize("device", get_available_devices()) def test_unsqueeze(self, keys, size, nchannels, batch, device, unsqueeze_dim): @@ -560,7 +557,7 @@ def test_unsqueeze(self, keys, size, nchannels, batch, device, unsqueeze_dim): assert observation_spec.shape == expected_size else: observation_spec = CompositeSpec( - **{ + { key: NdBoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) for key in keys } @@ -574,11 +571,11 @@ def test_unsqueeze(self, keys, size, nchannels, batch, device, unsqueeze_dim): @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @pytest.mark.parametrize( - "keys", [["next_observation", "some_other_key"], ["next_observation_pixels"]] + "keys", [["observation", "some_other_key"], ["observation_pixels"]] ) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize( - "keys_inv", [[], ["action", "some_other_key"], ["next_observation_pixels"]] + "keys_inv", [[], ["action", "some_other_key"], ["observation_pixels"]] ) def test_unsqueeze_inv( self, keys, keys_inv, size, nchannels, batch, device, unsqueeze_dim @@ -612,11 +609,12 @@ def test_unsqueeze_inv( @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @pytest.mark.parametrize( - "keys", [["next_observation", "some_other_key"], ["next_observation_pixels"]] + "keys", + [[("next", "observation"), "some_other_key"], [("next", "observation_pixels")]], ) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize( - "keys_inv", [[], ["action", "some_other_key"], ["next_observation_pixels"]] + "keys_inv", [[], ["action", "some_other_key"], [("next", "observation_pixels")]] ) def test_squeeze(self, keys, keys_inv, size, nchannels, batch, device, squeeze_dim): torch.manual_seed(0) @@ -645,11 +643,11 @@ def test_squeeze(self, keys, keys_inv, size, nchannels, batch, device, squeeze_d @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @pytest.mark.parametrize( - "keys", [["next_observation", "some_other_key"], ["next_observation_pixels"]] + "keys", [["observation", "some_other_key"], ["observation_pixels"]] ) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize( - "keys_inv", [[], ["action", "some_other_key"], ["next_observation_pixels"]] + "keys_inv", [[], ["action", "some_other_key"], ["observation_pixels"]] ) def test_squeeze_inv( self, keys, keys_inv, size, nchannels, batch, device, squeeze_dim @@ -681,7 +679,8 @@ def test_squeeze_inv( @pytest.mark.skipif(not _has_tv, reason="no torchvision") @pytest.mark.parametrize( - "keys", [["next_observation", "some_other_key"], ["next_observation_pixels"]] + "keys", + [[("next", "observation"), "some_other_key"], [("next", "observation_pixels")]], ) @pytest.mark.parametrize("device", get_available_devices()) def test_grayscale(self, keys, device): @@ -706,7 +705,7 @@ def test_grayscale(self, keys, device): assert observation_spec.shape == torch.Size([1, 16, 16]) else: observation_spec = CompositeSpec( - **{key: NdBoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + {key: NdBoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = gs.transform_observation_spec(observation_spec) for key in keys: @@ -714,7 +713,8 @@ def test_grayscale(self, keys, device): @pytest.mark.parametrize("batch", [[], [1], [3, 2]]) @pytest.mark.parametrize( - "keys", [["next_observation", "some_other_key"], ["next_observation_pixels"]] + "keys", + [[("next", "observation"), "some_other_key"], [("next", "observation_pixels")]], ) @pytest.mark.parametrize("device", get_available_devices()) def test_totensorimage(self, keys, batch, device): @@ -747,7 +747,7 @@ def test_totensorimage(self, keys, batch, device): assert (observation_spec.space.maximum == 1).all() else: observation_spec = CompositeSpec( - **{key: NdBoundedTensorSpec(0, 255, (16, 16, 3)) for key in keys} + {key: NdBoundedTensorSpec(0, 255, (16, 16, 3)) for key in keys} ) observation_spec = totensorimage.transform_observation_spec( observation_spec @@ -759,7 +759,8 @@ def test_totensorimage(self, keys, batch, device): @pytest.mark.parametrize("batch", [[], [1], [3, 2]]) @pytest.mark.parametrize( - "keys", [["next_observation", "some_other_key"], ["next_observation_pixels"]] + "keys", + [["next_observation", "some_other_key"], [("next", "observation_pixels")]], ) @pytest.mark.parametrize("device", get_available_devices()) def test_compose(self, keys, batch, device, nchannels=1, N=4): @@ -788,10 +789,7 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4): assert observation_spec.shape == torch.Size([nchannels * N, 16, 16]) else: observation_spec = CompositeSpec( - **{ - key: NdBoundedTensorSpec(0, 255, (nchannels, 16, 16)) - for key in keys - } + {key: NdBoundedTensorSpec(0, 255, (nchannels, 16, 16)) for key in keys} ) observation_spec = compose.transform_observation_spec(observation_spec) for key in keys: @@ -838,7 +836,8 @@ def test_compose_inv(self, keys_inv_1, keys_inv_2, device): @pytest.mark.parametrize("batch", [[], [1], [3, 2]]) @pytest.mark.parametrize( - "keys", [["next_observation", "some_other_key"], ["next_observation_pixels"]] + "keys", + [["next_observation", "some_other_key"], [("next", "observation_pixels")]], ) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("nchannels", [1, 3]) @@ -889,7 +888,7 @@ def test_observationnorm( else: observation_spec = CompositeSpec( - **{ + { key: NdBoundedTensorSpec(0, 1, (nchannels, 16, 16), device=device) for key in keys } @@ -905,19 +904,17 @@ def test_observationnorm( assert (observation_spec[key].space.minimum == loc).all() assert (observation_spec[key].space.maximum == scale + loc).all() - @pytest.mark.parametrize( - "keys", [["next_observation"], ["next_observation", "next_pixel"]] - ) + @pytest.mark.parametrize("keys", [["observation"], ["observation", "next_pixel"]]) @pytest.mark.parametrize("size", [1, 3]) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("standard_normal", [True, False]) def test_observationnorm_init_stats(self, keys, size, device, standard_normal): base_env = ContinuousActionVecMockEnv( observation_spec=CompositeSpec( - next_observation=NdBoundedTensorSpec( + observation=NdBoundedTensorSpec( minimum=1, maximum=1, shape=torch.Size([size]) ), - next_observation_orig=NdBoundedTensorSpec( + observation_orig=NdBoundedTensorSpec( minimum=1, maximum=1, shape=torch.Size([size]) ), ), @@ -932,22 +929,18 @@ def test_observationnorm_init_stats(self, keys, size, device, standard_normal): transform=ObservationNorm(in_keys=keys, standard_normal=standard_normal), ) if len(keys) > 1: - t_env.transform.init_stats(num_iter=11, key="next_observation") + t_env.transform.init_stats(num_iter=11, key="observation") else: t_env.transform.init_stats(num_iter=11) - if standard_normal: - torch.testing.assert_close(t_env.transform.loc, torch.Tensor([1.06] * size)) - torch.testing.assert_close( - t_env.transform.scale, torch.Tensor([0.03316621] * size) - ) - else: - torch.testing.assert_close( - t_env.transform.loc, torch.Tensor([31.960236] * size) - ) - torch.testing.assert_close( - t_env.transform.scale, torch.Tensor([30.151169] * size) - ) + assert t_env.transform.loc.shape == t_env.observation_spec["observation"].shape + assert ( + t_env.transform.scale.shape == t_env.observation_spec["observation"].shape + ) + assert t_env.transform.loc.dtype == t_env.observation_spec["observation"].dtype + assert ( + t_env.transform.loc.device == t_env.observation_spec["observation"].device + ) def test_observationnorm_stats_already_initialized_error(self): transform = ObservationNorm(in_keys="next_observation", loc=0, scale=1) @@ -981,7 +974,7 @@ def test_catframes_transform_observation_spec(self): mins = [0, 0.5] maxes = [0.5, 1] observation_spec = CompositeSpec( - **{ + { key: NdBoundedTensorSpec( space_min, space_max, (1, 3, 3), dtype=torch.double ) @@ -991,7 +984,7 @@ def test_catframes_transform_observation_spec(self): result = cat_frames.transform_observation_spec(observation_spec) observation_spec = CompositeSpec( - **{ + { key: NdBoundedTensorSpec( space_min, space_max, (1, 3, 3), dtype=torch.double ) @@ -1058,15 +1051,15 @@ def test_finitetensordictcheck(self, device): ) ftd(td) td.set("inf", torch.zeros(1, 3).fill_(float("inf"))) - with pytest.raises(ValueError, match="Found non-finite elements"): + with pytest.raises(ValueError, match="Encountered a non-finite tensor"): ftd(td) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize( "keys", [ - ["next_observation", "some_other_key"], - ["next_observation_pixels"], + ["observation", "some_other_key"], + ["observation_pixels"], ["action"], ], ) @@ -1115,7 +1108,7 @@ def test_double2float(self, keys, keys_inv, device): else: observation_spec = CompositeSpec( - **{ + { key: NdBoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double) for key in keys } @@ -1128,8 +1121,8 @@ def test_double2float(self, keys, keys_inv, device): @pytest.mark.parametrize( "keys", [ - ["next_observation", "next_observation_other"], - ["next_observation_pixels"], + ["observation", "observation_other"], + ["observation_pixels"], ], ) def test_cattensors(self, keys, device): @@ -1165,7 +1158,7 @@ def test_cattensors(self, keys, device): assert observation_spec.shape == torch.Size([1, len(keys) * 4, 32]) else: observation_spec = CompositeSpec( - **{key: NdBoundedTensorSpec(0, 1, (1, 4, 32)) for key in keys} + {key: NdBoundedTensorSpec(0, 1, (1, 4, 32)) for key in keys} ) observation_spec = cattensors.transform_observation_spec(observation_spec) assert observation_spec["observation_out"].shape == torch.Size( @@ -1174,14 +1167,14 @@ def test_cattensors(self, keys, device): @pytest.mark.parametrize("append", [True, False]) def test_cattensors_empty(self, append): - ct = CatTensors(out_key="next_observation_out", dim=-1, del_keys=False) + ct = CatTensors(out_key="observation_out", dim=-1, del_keys=False) if append: mock_env = TransformedEnv(ContinuousActionVecMockEnv()) mock_env.append_transform(ct) else: mock_env = TransformedEnv(ContinuousActionVecMockEnv(), ct) tensordict = mock_env.rollout(3) - assert all(key in tensordict.keys() for key in ["next_observation_out"]) + assert all(key in tensordict.keys() for key in ["observation_out"]) # assert not any(key in tensordict.keys() for key in mock_env.base_env.observation_spec) @pytest.mark.parametrize("random", [True, False]) @@ -1474,8 +1467,8 @@ def test_insert(self): class TestR3M: @pytest.mark.parametrize("tensor_pixels_key", [None, ["funny_key"]]) def test_r3m_instantiation(self, model, tensor_pixels_key, device): - in_keys = ["next_pixels"] - out_keys = ["next_vec"] + in_keys = ["pixels"] + out_keys = ["vec"] r3m = R3MTransform( model, in_keys=in_keys, @@ -1492,21 +1485,25 @@ def test_r3m_instantiation(self, model, tensor_pixels_key, device): assert set(td.keys()) == exp_keys, set(td.keys()) - exp_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union({"next_vec", "next_pixels_orig", "action", "reward"}) - assert set(td.keys()) == exp_keys, set(td.keys()) - exp_keys + exp_keys = exp_keys.union( + {("next", "vec"), ("next", "pixels_orig"), "action", "reward", "next"} + ) + if tensor_pixels_key: + exp_keys.add(("next", tensor_pixels_key[0])) + assert set(td.keys(True)) == exp_keys, set(td.keys(True)) - exp_keys transformed_env.close() @pytest.mark.parametrize("stack_images", [True, False]) @pytest.mark.parametrize( "parallel", [ - False, True, + False, ], ) def test_r3m_mult_images(self, model, device, stack_images, parallel): - in_keys = ["next_pixels", "next_pixels2"] - out_keys = ["next_vec"] if stack_images else ["next_vec", "next_vec2"] + in_keys = ["pixels", "pixels2"] + out_keys = ["vec"] if stack_images else ["vec", "vec2"] r3m = R3MTransform( model, in_keys=in_keys, @@ -1517,7 +1514,7 @@ def test_r3m_mult_images(self, model, device, stack_images, parallel): def base_env_constructor(): return TransformedEnv( DiscreteActionConvMockEnvNumpy().to(device), - CatTensors(["next_pixels"], "next_pixels2", del_keys=False), + CatTensors(["pixels"], "pixels2", del_keys=False), ) assert base_env_constructor().device == device @@ -1547,15 +1544,17 @@ def base_env_constructor(): assert set(td.keys()) == exp_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union({"next_vec", "next_pixels_orig", "action", "reward"}) + exp_keys = exp_keys.union( + {("next", "vec"), ("next", "pixels_orig"), "action", "reward", "next"} + ) if not stack_images: - exp_keys = exp_keys.union({"next_vec2"}) - assert set(td.keys()) == exp_keys, set(td.keys()) - exp_keys + exp_keys.add(("next", "vec2")) + assert set(td.keys(True)) == exp_keys, set(td.keys()) - exp_keys transformed_env.close() def test_r3m_parallel(self, model, device): - in_keys = ["next_pixels"] - out_keys = ["next_vec"] + in_keys = ["pixels"] + out_keys = ["vec"] tensor_pixels_key = None r3m = R3MTransform( model, @@ -1571,22 +1570,24 @@ def test_r3m_parallel(self, model, device): exp_keys = {"vec", "done", "pixels_orig"} if tensor_pixels_key: exp_keys.add(tensor_pixels_key) - assert set(td.keys()) == exp_keys + assert set(td.keys(True)) == exp_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union({"next_vec", "next_pixels_orig", "action", "reward"}) - assert set(td.keys()) == exp_keys, set(td.keys()) - exp_keys + exp_keys = exp_keys.union( + {("next", "vec"), ("next", "pixels_orig"), "action", "reward", "next"} + ) + assert set(td.keys(True)) == exp_keys, set(td.keys()) - exp_keys transformed_env.close() del transformed_env @pytest.mark.parametrize("del_keys", [True, False]) @pytest.mark.parametrize( "in_keys", - [["next_pixels"], ["next_pixels_1", "next_pixels_2", "next_pixels_3"]], + [["pixels"], ["pixels_1", "pixels_2", "pixels_3"]], ) @pytest.mark.parametrize( "out_keys", - [["next_r3m_vec"], ["next_r3m_vec_1", "next_r3m_vec_2", "next_r3m_vec_3"]], + [["r3m_vec"], ["r3m_vec_1", "r3m_vec_2", "r3m_vec_3"]], ) def test_r3mnet_transform_observation_spec( self, in_keys, out_keys, del_keys, device, model @@ -1594,11 +1595,11 @@ def test_r3mnet_transform_observation_spec( r3m_net = _R3MNet(in_keys, out_keys, model, del_keys) observation_spec = CompositeSpec( - **{key: NdBoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys} + {key: NdBoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys} ) if del_keys: exp_ts = CompositeSpec( - **{ + { key: NdUnboundedContinuousTensorSpec(r3m_net.outdim, device) for key in out_keys } @@ -1618,7 +1619,7 @@ def test_r3mnet_transform_observation_spec( ts_dict[key] = observation_spec[key] for key in out_keys: ts_dict[key] = NdUnboundedContinuousTensorSpec(r3m_net.outdim, device) - exp_ts = CompositeSpec(**ts_dict) + exp_ts = CompositeSpec(ts_dict) observation_spec_out = r3m_net.transform_observation_spec(observation_spec) @@ -1629,8 +1630,8 @@ def test_r3mnet_transform_observation_spec( @pytest.mark.parametrize("tensor_pixels_key", [None, ["funny_key"]]) def test_r3m_spec_against_real(self, model, tensor_pixels_key, device): - in_keys = ["next_pixels"] - out_keys = ["next_vec"] + in_keys = ["pixels"] + out_keys = ["vec"] r3m = R3MTransform( model, in_keys=in_keys, @@ -1642,11 +1643,10 @@ def test_r3m_spec_against_real(self, model, tensor_pixels_key, device): expected_keys = ( list(transformed_env.input_spec.keys()) + list(transformed_env.observation_spec.keys()) - + [key.strip("next_") for key in transformed_env.observation_spec.keys()] - + ["reward"] - + ["done"] + + [("next", key) for key in transformed_env.observation_spec.keys()] + + ["reward", "done", "next"] ) - assert set(expected_keys) == set(transformed_env.rollout(3).keys()) + assert set(expected_keys) == set(transformed_env.rollout(3).keys(True)) @pytest.mark.skipif(not _has_tv, reason="torchvision not installed") @@ -1655,8 +1655,8 @@ def test_r3m_spec_against_real(self, model, tensor_pixels_key, device): class TestVIP: @pytest.mark.parametrize("tensor_pixels_key", [None, ["funny_key"]]) def test_vip_instantiation(self, model, tensor_pixels_key, device): - in_keys = ["next_pixels"] - out_keys = ["next_vec"] + in_keys = ["pixels"] + out_keys = ["vec"] vip = VIPTransform( model, in_keys=in_keys, @@ -1673,15 +1673,19 @@ def test_vip_instantiation(self, model, tensor_pixels_key, device): assert set(td.keys()) == exp_keys, set(td.keys()) - exp_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union({"next_vec", "next_pixels_orig", "action", "reward"}) - assert set(td.keys()) == exp_keys, set(td.keys()) - exp_keys + exp_keys = exp_keys.union( + {("next", "vec"), ("next", "pixels_orig"), "next", "action", "reward"} + ) + if tensor_pixels_key: + exp_keys.add(("next", tensor_pixels_key[0])) + assert set(td.keys(True)) == exp_keys, set(td.keys(True)) - exp_keys transformed_env.close() @pytest.mark.parametrize("stack_images", [True, False]) @pytest.mark.parametrize("parallel", [True, False]) def test_vip_mult_images(self, model, device, stack_images, parallel): - in_keys = ["next_pixels", "next_pixels2"] - out_keys = ["next_vec"] if stack_images else ["next_vec", "next_vec2"] + in_keys = ["pixels", "pixels2"] + out_keys = ["vec"] if stack_images else ["vec", "vec2"] vip = VIPTransform( model, in_keys=in_keys, @@ -1692,7 +1696,7 @@ def test_vip_mult_images(self, model, device, stack_images, parallel): def base_env_constructor(): return TransformedEnv( DiscreteActionConvMockEnvNumpy().to(device), - CatTensors(["next_pixels"], "next_pixels2", del_keys=False), + CatTensors(["pixels"], "pixels2", del_keys=False), ) assert base_env_constructor().device == device @@ -1722,15 +1726,17 @@ def base_env_constructor(): assert set(td.keys()) == exp_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union({"next_vec", "next_pixels_orig", "action", "reward"}) + exp_keys = exp_keys.union( + {("next", "vec"), ("next", "pixels_orig"), "next", "action", "reward"} + ) if not stack_images: - exp_keys = exp_keys.union({"next_vec2"}) - assert set(td.keys()) == exp_keys, set(td.keys()) - exp_keys + exp_keys.add(("next", "vec2")) + assert set(td.keys(True)) == exp_keys, set(td.keys(True)) - exp_keys transformed_env.close() def test_vip_parallel(self, model, device): - in_keys = ["next_pixels"] - out_keys = ["next_vec"] + in_keys = ["pixels"] + out_keys = ["vec"] tensor_pixels_key = None vip = VIPTransform( model, @@ -1749,14 +1755,17 @@ def test_vip_parallel(self, model, device): assert set(td.keys()) == exp_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union({"next_vec", "next_pixels_orig", "action", "reward"}) - assert set(td.keys()) == exp_keys, set(td.keys()) - exp_keys + exp_keys = exp_keys.union( + {("next", "vec"), ("next", "pixels_orig"), "next", "action", "reward"} + ) + assert set(td.keys(True)) == exp_keys, set(td.keys(True)) - exp_keys transformed_env.close() del transformed_env - def test_vip_parallel_reward(self, model, device): - in_keys = ["next_pixels"] - out_keys = ["next_vec"] + def test_vip_parallel_reward(self, model, device, dtype_fixture): # noqa + torch.manual_seed(1) + in_keys = ["pixels"] + out_keys = ["vec"] tensor_pixels_key = None vip = VIPRewardTransform( model, @@ -1791,42 +1800,51 @@ def test_vip_parallel_reward(self, model, device): assert set(td.keys()) == exp_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union({"next_vec", "next_pixels_orig", "action", "reward"}) - assert set(td.keys()) == exp_keys, td + exp_keys = exp_keys.union( + {("next", "vec"), ("next", "pixels_orig"), "next", "action", "reward"} + ) + assert set(td.keys(True)) == exp_keys, td + torch.manual_seed(1) tensordict_reset = TensorDict( {"goal_image": torch.randint(0, 255, (4, 7, 7, 3), dtype=torch.uint8)}, [4], device=device, ) td = transformed_env.rollout( - 3, auto_reset=False, tensordict=transformed_env.reset(tensordict_reset) + 5, auto_reset=False, tensordict=transformed_env.reset(tensordict_reset) ) - assert set(td.keys()) == exp_keys, td + assert set(td.keys(True)) == exp_keys, td # test that we do compute the reward we want - cur_embedding = td["next_vec"] + cur_embedding = td["next", "vec"] goal_embedding = td["goal_embedding"] last_embedding = td["vec"] - explicit_reward = -torch.norm(cur_embedding - goal_embedding, dim=-1) - ( - -torch.norm(last_embedding - goal_embedding, dim=-1) - ) - torch.testing.assert_close(explicit_reward, td["reward"].squeeze()) + # test that there is only one goal embedding goal = td["goal_embedding"] goal_expand = td["goal_embedding"][:, :1].expand_as(td["goal_embedding"]) torch.testing.assert_close(goal, goal_expand) + torch.testing.assert_close(cur_embedding[:, :-1], last_embedding[:, 1:]) + with pytest.raises(AssertionError): + torch.testing.assert_close(cur_embedding[:, 1:], last_embedding[:, :-1]) + + explicit_reward = -torch.norm(cur_embedding - goal_embedding, dim=-1) - ( + -torch.norm(last_embedding - goal_embedding, dim=-1) + ) + torch.testing.assert_close(explicit_reward, td["reward"].squeeze()) + transformed_env.close() del transformed_env @pytest.mark.parametrize("del_keys", [True, False]) @pytest.mark.parametrize( "in_keys", - [["next_pixels"], ["next_pixels_1", "next_pixels_2", "next_pixels_3"]], + [["pixels"], ["pixels_1", "pixels_2", "pixels_3"]], ) @pytest.mark.parametrize( "out_keys", - [["next_vip_vec"], ["next_vip_vec_1", "next_vip_vec_2", "next_vip_vec_3"]], + [["vip_vec"], ["vip_vec_1", "vip_vec_2", "vip_vec_3"]], ) def test_vipnet_transform_observation_spec( self, in_keys, out_keys, del_keys, device, model @@ -1834,11 +1852,11 @@ def test_vipnet_transform_observation_spec( vip_net = _VIPNet(in_keys, out_keys, model, del_keys) observation_spec = CompositeSpec( - **{key: NdBoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys} + {key: NdBoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys} ) if del_keys: exp_ts = CompositeSpec( - **{ + { key: NdUnboundedContinuousTensorSpec(vip_net.outdim, device) for key in out_keys } @@ -1858,7 +1876,7 @@ def test_vipnet_transform_observation_spec( ts_dict[key] = observation_spec[key] for key in out_keys: ts_dict[key] = NdUnboundedContinuousTensorSpec(vip_net.outdim, device) - exp_ts = CompositeSpec(**ts_dict) + exp_ts = CompositeSpec(ts_dict) observation_spec_out = vip_net.transform_observation_spec(observation_spec) @@ -1869,8 +1887,8 @@ def test_vipnet_transform_observation_spec( @pytest.mark.parametrize("tensor_pixels_key", [None, ["funny_key"]]) def test_vip_spec_against_real(self, model, tensor_pixels_key, device): - in_keys = ["next_pixels"] - out_keys = ["next_vec"] + in_keys = ["pixels"] + out_keys = ["vec"] vip = VIPTransform( model, in_keys=in_keys, @@ -1882,11 +1900,10 @@ def test_vip_spec_against_real(self, model, tensor_pixels_key, device): expected_keys = ( list(transformed_env.input_spec.keys()) + list(transformed_env.observation_spec.keys()) - + [key.strip("next_") for key in transformed_env.observation_spec.keys()] - + ["reward"] - + ["done"] + + [("next", key) for key in transformed_env.observation_spec.keys()] + + ["reward", "done", "next"] ) - assert set(expected_keys) == set(transformed_env.rollout(3).keys()) + assert set(expected_keys) == set(transformed_env.rollout(3).keys(True)) @pytest.mark.parametrize("device", get_available_devices()) @@ -1894,7 +1911,7 @@ def test_batch_locked_transformed(device): env = TransformedEnv( MockBatchedLockedEnv(device), Compose( - ObservationNorm(in_keys=["next_observation"], loc=0.5, scale=1.1), + ObservationNorm(in_keys=[("next", "observation")], loc=0.5, scale=1.1), RewardClipping(0, 0.1), ), ) @@ -1918,7 +1935,7 @@ def test_batch_unlocked_transformed(device): env = TransformedEnv( MockBatchedUnLockedEnv(device), Compose( - ObservationNorm(in_keys=["next_observation"], loc=0.5, scale=1.1), + ObservationNorm(in_keys=[("next", "observation")], loc=0.5, scale=1.1), RewardClipping(0, 0.1), ), ) @@ -1938,7 +1955,7 @@ def test_batch_unlocked_with_batch_size_transformed(device): env = TransformedEnv( MockBatchedUnLockedEnv(device, batch_size=torch.Size([2])), Compose( - ObservationNorm(in_keys=["next_observation"], loc=0.5, scale=1.1), + ObservationNorm(in_keys=[("next", "observation")], loc=0.5, scale=1.1), RewardClipping(0, 0.1), ), ) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 325670fbe67..31d6d021554 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -173,9 +173,7 @@ def _get_policy_and_device( ) sig = inspect.signature(policy.forward) next_observation = { - key[5:]: value - for key, value in observation_spec.rand().items() - if key.startswith("next_") + key: value for key, value in observation_spec.rand().items() } if set(sig.parameters) == set(next_observation): out_keys = ["action"] @@ -404,44 +402,29 @@ def __init__( ): # if policy spec is non-empty, all the values are not None and the keys # match the out_keys we assume the user has given all relevant information - self._tensordict_out = TensorDict( - { - **env.observation_spec.zero(env.batch_size), - "reward": env.reward_spec.zero(env.batch_size), - "done": torch.zeros( - env.batch_size, dtype=torch.bool, device=env.device - ), - **self.policy.spec.zero(env.batch_size), - }, - env.batch_size, - device=env.device, + self._tensordict_out = ( + env.fake_tensordict().expand(env.batch_size).to_tensordict() ) + self._tensordict_out.update(self.policy.spec.zero(env.batch_size)) + if env.device: + self._tensordict_out = self._tensordict_out.to(env.device) self._tensordict_out = ( self._tensordict_out.unsqueeze(-1) .expand(*env.batch_size, self.frames_per_batch) .to_tensordict() ) - self._tensordict_out = self._tensordict_out.update( - step_mdp(self._tensordict_out) - ) # add "observation" when there is "next_observation" else: # otherwise, we perform a small number of steps with the policy to # determine the relevant keys with which to pre-populate _tensordict_out. # See #505 for additional context. - self._tensordict_out = self.env.rollout( - 3, self.policy, auto_cast_to_device=True - ) - if env.batch_size: - self._tensordict_out = self._tensordict_out[..., :1] - else: - self._tensordict_out = self._tensordict_out[:1] + with torch.no_grad(): + self._tensordict_out = env.fake_tensordict() + self._tensordict_out = self.policy(self._tensordict_out).unsqueeze(-1) self._tensordict_out = ( self._tensordict_out.expand(*env.batch_size, self.frames_per_batch) .to_tensordict() .zero_() - .detach() ) - env.reset() # in addition to outputs of the policy, we add traj_ids and step_count to # _tensordict_out which will be collected during rollout @@ -583,8 +566,8 @@ def _reset_if_necessary(self) -> None: self._tensordict.set("reset_workers", done_or_terminated) else: self._tensordict.zero_() - self.env.reset(self._tensordict) + if self._tensordict.get("done").any(): raise RuntimeError( f"Got {sum(self._tensordict.get('done'))} done envs after reset." @@ -622,7 +605,7 @@ def rollout(self) -> TensorDictBase: td_cast = self._cast_to_policy(self._tensordict) td_cast = self.policy(td_cast) self._cast_to_env(td_cast, self._tensordict) - self.env.step(self._tensordict) + self._tensordict = self.env.step(self._tensordict) step_count = self._tensordict.get("step_count") step_count += 1 diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 64bd2b87235..9edbb81c0e5 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -33,6 +33,9 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: From there, builds a B x T x ... zero-padded tensordict with B batches on max duration T """ + # TODO: incorporate tensordict.split once it's implemented + sep = ".-|-." + rollout_tensordict = rollout_tensordict.flatten_keys(sep) traj_ids = rollout_tensordict.get("traj_ids") ndim = len(rollout_tensordict.batch_size) splits = traj_ids.view(-1) @@ -49,7 +52,7 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: ) if rollout_tensordict.ndimension() == 1: rollout_tensordict = rollout_tensordict.unsqueeze(0).to_tensordict() - return rollout_tensordict + return rollout_tensordict.unflatten_keys(sep) out_splits = { key: _d.contiguous().view(-1, *_d.shape[ndim:]).split(splits, 0) for key, _d in rollout_tensordict.items() @@ -71,6 +74,7 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: device=rollout_tensordict.device, batch_size=out_dict["mask"].shape[:-1], ) + td = td.unflatten_keys(sep) if (out_dict["done"].sum(1) > 1).any(): raise RuntimeError("Got more than one done per trajectory") return td diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index 0945cc2ec19..98d69dfb07b 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -144,9 +144,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict: TennsorDict instance with Batch x Time-steps x ... dimensions. The TensorDict must contain a "reward" and "done" key. All - keys that start with the "next_" prefix will be shifted by ( - at most) self.n_steps_max frames. The TensorDict will also - be updated with new key-value pairs: + keys that are contained within the "next" nested tensordict + will be shifted by (at most) :obj:`MultiStep.n_steps_max` frames. + The TensorDict will also be updated with new key-value pairs: - gamma: indicating the discount to be used for the next reward; @@ -189,26 +189,19 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: nonterminal = ~post_terminal[:, :T] steps_to_next_obs = _get_steps_to_next_obs(nonterminal, self.n_steps_max) - selected_td = tensordict.select( - *[ - key - for key in tensordict.keys() - if (key.startswith("next_") or key == "done") - ] - ) + selected_td = tensordict.select("next", "done") - for key, item in selected_td.items(): - tensordict.set_( - key, - _select_and_repeat( - item, - terminal, - post_terminal, - mask, - self.n_steps_max, - ), + def _select_and_repeat_local(item): + return _select_and_repeat( + item, + terminal, + post_terminal, + mask, + self.n_steps_max, ) + selected_td.apply_(_select_and_repeat_local) + tensordict.set("gamma", gamma_masked) tensordict.set("steps_to_next_obs", steps_to_next_obs) tensordict.set("nonterminal", nonterminal) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 66150b1fd38..e1c455fb569 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1052,6 +1052,9 @@ class CompositeSpec(TensorSpec): """A composition of TensorSpecs. Args: + *args: if an unnamed argument is passed, it must be a dictionary with keys + matching the expected keys to be found in the :obj:`CompositeSpec` object. + This is useful to build nested CompositeSpecs with tuple indices. **kwargs (key (str): value (TensorSpec)): dictionary of tensorspecs to be stored. Values can be None, in which case is_in will be assumed to be :obj:`True` for the corresponding tensors, and :obj:`project()` will have no @@ -1083,14 +1086,21 @@ class CompositeSpec(TensorSpec): >>> print("random td: ", composite_spec.rand([3,])) random td: TensorDict( fields={ - pixels: Tensor(torch.Size([3, 3, 32, 32]), \ -dtype=torch.float32), - observation_vector: Tensor(torch.Size([3, 33]), \ -dtype=torch.float32)}, + observation_vector: Tensor(torch.Size([3, 33]), dtype=torch.float32), + pixels: Tensor(torch.Size([3, 3, 32, 32]), dtype=torch.float32)}, batch_size=torch.Size([3]), - device=cpu, + device=None, is_shared=False) + + Examples: + >>> # we can build a nested composite spec using unnamed arguments + >>> print(CompositeSpec({("a", "b"): None, ("a", "c"): None})) + CompositeSpec( + a: CompositeSpec( + b: None, + c: None)) + """ domain: str = "composite" @@ -1100,7 +1110,7 @@ def __new__(cls, *args, **kwargs): cls._device = torch.device("cpu") return super().__new__(cls) - def __init__(self, **kwargs): + def __init__(self, *args, **kwargs): self._specs = kwargs if len(kwargs): _device = None @@ -1115,6 +1125,24 @@ def __init__(self, **kwargs): f"All devices of CompositeSpec must match." ) self._device = _device + if len(args): + if not len(kwargs): + self._device = None + if len(args) > 1: + raise RuntimeError( + "Got multiple arguments, when at most one is expected for CompositeSpec." + ) + argdict = args[0] + if not isinstance(argdict, dict): + raise RuntimeError( + f"Expected a dictionary of specs, but got an argument of type {type(argdict)}." + ) + for k, item in argdict.items(): + if item is None: + continue + if self._device is None: + self._device = item.device + self[k] = item @property def device(self) -> DEVICE_TYPING: @@ -1138,11 +1166,26 @@ def device(self, value: DEVICE_TYPING): self._device = value def __getitem__(self, item): + if isinstance(item, tuple) and len(item) > 1: + return self[item[0]][item[1:]] + elif isinstance(item, tuple): + return self[item[0]] + if item in {"shape", "device", "dtype", "space"}: raise AttributeError(f"CompositeSpec has no key {item}") return self._specs[item] def __setitem__(self, key, value): + if isinstance(key, tuple) and len(key) > 1: + if key[0] not in self.keys(True): + self[key[0]] = CompositeSpec() + self[key[0]][key[1:]] = value + return + elif isinstance(key, tuple): + self[key[0]] = value + return + elif not isinstance(key, str): + raise TypeError(f"Got key of type {type(key)} when a string was expected.") if key in {"shape", "device", "dtype", "space"}: raise AttributeError(f"CompositeSpec[{key}] cannot be set") if value is not None and value.device != self.device: @@ -1160,7 +1203,10 @@ def __delitem__(self, key: str) -> None: del self._specs[key] def encode(self, vals: Dict[str, Any]) -> Dict[str, torch.Tensor]: - out = {} + if isinstance(vals, TensorDict): + out = vals.select() # create and empty tensordict similar to vals + else: + out = TensorDict({}, [], _run_checks=False) for key, item in vals.items(): if item is None: raise RuntimeError( @@ -1214,15 +1260,23 @@ def rand(self, shape=None) -> TensorDictBase: shape = torch.Size([]) return TensorDict( { - key: value.rand(shape) - for key, value in self._specs.items() - if value is not None + key: self[key].rand(shape) + for key in self.keys(True) + if isinstance(key, str) and self[key] is not None }, batch_size=shape, ) - def keys(self) -> KeysView: - return self._specs.keys() + def keys(self, yield_nesting_keys: bool = False) -> KeysView: + """Keys of the CompositeSpec. + + Args: + yield_nesting_keys (bool, optional): if :obj:`True`, the values returned + will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` + will lead to the keys :obj:`["next", ("next", "obs")]`. Default is :obj:`False`, i.e. + only nested keys will be returned. + """ + return _CompositeSpecKeysView(self, _yield_nesting_keys=yield_nesting_keys) def items(self) -> ItemsView: return self._specs.items() @@ -1253,7 +1307,11 @@ def zero(self, shape=None) -> TensorDictBase: if shape is None: shape = torch.Size([]) return TensorDict( - {key: self[key].zero(shape) for key in self.keys()}, + { + key: self[key].zero(shape) + for key in self.keys(True) + if isinstance(key, str) and self[key] is not None + }, shape, device=self.device, ) @@ -1267,6 +1325,61 @@ def __eq__(self, other): def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None: for key, item in dict_or_spec.items(): + if key in self.keys(True) and isinstance(self[key], CompositeSpec): + self[key].update(item) + continue if isinstance(item, TensorSpec) and item.device != self.device: item = deepcopy(item).to(self.device) self[key] = item + + +def _keys_to_empty_composite_spec(keys): + if not len(keys): + return + c = CompositeSpec() + for key in keys: + if isinstance(key, str): + c[key] = None + elif key[0] in c.keys(yield_nesting_keys=True): + if c[key[0]] is None: + # if the value is None we just replace it + c[key[0]] = _keys_to_empty_composite_spec([key[1:]]) + elif isinstance(c[key[0]], CompositeSpec): + # if the value is Composite, we update it + out = _keys_to_empty_composite_spec([key[1:]]) + if out is not None: + c[key[0]].update(out) + else: + raise RuntimeError("Conflicting keys") + else: + c[key[0]] = _keys_to_empty_composite_spec(key[1:]) + return c + + +class _CompositeSpecKeysView: + """Wrapper class that enables richer behaviour of `key in tensordict.keys()`.""" + + def __init__( + self, + composite: CompositeSpec, + nested_keys: bool = True, + _yield_nesting_keys: bool = False, + ): + self.composite = composite + self._yield_nesting_keys = _yield_nesting_keys + self.nested_keys = nested_keys + + def __iter__( + self, + ): + for key, item in self.composite.items(): + if self.nested_keys and isinstance(item, CompositeSpec): + for subkey in item.keys(): + yield (key, *subkey) if isinstance(subkey, tuple) else (key, subkey) + if self._yield_nesting_keys: + yield key + else: + yield key + + def __len__(self): + return len([k for k in self]) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 5e6aaee53bc..58e46579ca9 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -133,7 +133,7 @@ def build_tensordict( Args: next_observation (bool, optional): if False, the observation returned - will be of the current step only (no :obj:`"next_"` key will be present). + will be of the current step only (no :obj:`"next"` nested tensordict will be present). Default is True. log_prob (bool, optional): If True, a log_prob key-value pair will be added to the tensordict. @@ -150,15 +150,11 @@ def build_tensordict( raise RuntimeError("observation_spec is expected to be of Composite type.") else: for (key, item) in self["observation_spec"].items(): - if not key.startswith("next_"): - raise RuntimeError( - f"All observation keys must start with the :obj:`'next_'` prefix. Found {key}" - ) observation_placeholder = torch.zeros(item.shape, dtype=item.dtype) if next_observation: - td.set(key, observation_placeholder) + td.update({"next": {key: observation_placeholder}}) td.set( - key[5:], + key, observation_placeholder.clone(), ) @@ -323,6 +319,10 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict.is_locked = True # make sure _step does not modify the tensordict tensordict_out = self._step(tensordict) tensordict.is_locked = False + obs_keys = set(self.observation_spec.keys()) + tensordict_out_select = tensordict_out.select(*obs_keys) + tensordict_out = tensordict_out.exclude(*obs_keys) + tensordict_out["next"] = tensordict_out_select if tensordict_out is tensordict: raise RuntimeError( @@ -368,7 +368,6 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: def reset( self, tensordict: Optional[TensorDictBase] = None, - execute_step: bool = True, **kwargs, ) -> TensorDictBase: """Resets the environment. @@ -378,8 +377,6 @@ def reset( Args: tensordict (TensorDictBase, optional): tensordict to be used to contain the resulting new observation. In some cases, this input can also be used to pass argument to the reset function. - execute_step (bool, optional): if True, a :obj:`step_mdp` is executed on the output TensorDict, - hereby removing the :obj:`"next_"` prefixes from the keys. kwargs (optional): other arguments to be passed to the native reset function. @@ -409,13 +406,6 @@ def reset( raise RuntimeError( f"Env {self} was done after reset. This is (currently) not allowed." ) - if execute_step: - tensordict_reset = step_mdp( - tensordict_reset, - exclude_done=False, - exclude_reward=False, # some policies may need reward and action at reset time - exclude_action=False, - ) if tensordict is not None: tensordict.update(tensordict_reset) else: @@ -659,13 +649,12 @@ def fake_tensordict(self) -> TensorDictBase: fake_input = input_spec.zero(self.batch_size) observation_spec = self.observation_spec fake_obs = observation_spec.zero(self.batch_size) - fake_obs_step = step_mdp(fake_obs) reward_spec = self.reward_spec fake_reward = reward_spec.zero(self.batch_size) fake_td = TensorDict( { - **fake_obs_step, **fake_obs, + "next": fake_obs.clone(), **fake_input, "reward": fake_reward, "done": fake_reward.to(torch.bool), diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index 768a9f5dd54..8d1a1de5120 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -45,29 +45,29 @@ class EnvCreator: >>> env_creator = EnvCreator(env_fn) >>> >>> def test_env1(env_creator): - >>> env = env_creator() - >>> tensordict = env.reset() - >>> for _ in range(10): - >>> env.rand_step(tensordict) - >>> if env.is_done: - >>> tensordict = env.reset(tensordict) - >>> print("env 1: ", env.transform._td.get("next_observation_count")) + ... env = env_creator() + ... tensordict = env.reset() + ... for _ in range(10): + ... env.rand_step(tensordict) + ... if env.is_done: + ... tensordict = env.reset(tensordict) + ... print("env 1: ", env.transform._td.get(("next", "observation_count"))) >>> >>> def test_env2(env_creator): - >>> env = env_creator() - >>> time.sleep(5) - >>> print("env 2: ", env.transform._td.get("next_observation_count")) + ... env = env_creator() + ... time.sleep(5) + ... print("env 2: ", env.transform._td.get(("next", "observation_count"))) >>> >>> if __name__ == "__main__": - >>> ps = [] - >>> p1 = mp.Process(target=test_env1, args=(env_creator,)) - >>> p1.start() - >>> ps.append(p1) - >>> p2 = mp.Process(target=test_env2, args=(env_creator,)) - >>> p2.start() - >>> ps.append(p1) - >>> for p in ps: - >>> p.join() + ... ps = [] + ... p1 = mp.Process(target=test_env1, args=(env_creator,)) + ... p1.start() + ... ps.append(p1) + ... p2 = mp.Process(target=test_env2, args=(env_creator,)) + ... p2.start() + ... ps.append(p1) + ... for p in ps: + ... p.join() env 1: tensor([11.9934]) env 2: tensor([11.9934]) """ diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 66bb8dd9f84..a0c61b484e1 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -107,9 +107,9 @@ class GymLikeEnv(_EnvWrapper): In this implementation, the info output is discarded (but specific keys can be read by updating info_dict_reader, see :obj:`set_info_dict_reader` class method). - By default, the first output is written at the "next_observation" key-value pair in the output tensordict, unless + By default, the first output is written at the "observation" key-value pair in the output tensordict, unless the first output is a dictionary. In that case, each observation output will be put at the corresponding - "next_observation_{key}" location. + :obj:`f"{key}"` location for each :obj:`f"{key}"` of the dictionary. It is also expected that env.reset() returns an observation similar to the one observed after a step is completed. """ @@ -165,7 +165,7 @@ def read_obs( """ if isinstance(observations, dict): - observations = {"next_" + key: value for key, value in observations.items()} + observations = {key: value for key, value in observations.items()} if not isinstance(observations, (TensorDict, dict)): key = list(self.observation_spec.keys())[0] observations = {key: observations} @@ -219,6 +219,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_out = TensorDict( obs_dict, batch_size=tensordict.batch_size, device=self.device ) + tensordict_out.set("reward", reward) tensordict_out.set("done", done) if self.info_dict_reader is not None and info is not None: diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index 0f5d1864d78..6fab27309ea 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -48,7 +48,7 @@ def _dmcontrol_to_torchrl_spec_transform( ) -> TensorSpec: if isinstance(spec, collections.OrderedDict): spec = { - "next_" + k: _dmcontrol_to_torchrl_spec_transform(item, device=device) + k: _dmcontrol_to_torchrl_spec_transform(item, device=device) for k, item in spec.items() } return CompositeSpec(**spec) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 7a3d95357e5..494de15e92a 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -84,7 +84,7 @@ def _gym_to_torchrl_spec_transform( elif isinstance(spec, (Dict,)): spec_out = {} for k in spec.keys(): - spec_out["next_" + k] = _gym_to_torchrl_spec_transform( + spec_out[k] = _gym_to_torchrl_spec_transform( spec[k], device=device, categorical_action_encoding=categorical_action_encoding, @@ -244,11 +244,9 @@ def _make_specs(self, env: "gym.Env") -> None: ) if not isinstance(self.observation_spec, CompositeSpec): if self.from_pixels: - self.observation_spec = CompositeSpec(next_pixels=self.observation_spec) + self.observation_spec = CompositeSpec(pixels=self.observation_spec) else: - self.observation_spec = CompositeSpec( - next_observation=self.observation_spec - ) + self.observation_spec = CompositeSpec(observation=self.observation_spec) self.reward_spec = UnboundedContinuousTensorSpec( device=self.device, ) diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index a0f4a4fe7b8..189ecb4bdcf 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -33,7 +33,7 @@ class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) ... self.observation_spec = CompositeSpec( - ... next_hidden_observation=NdUnboundedContinuousTensorSpec((4,)) + ... hidden_observation=NdUnboundedContinuousTensorSpec((4,)) ... ) ... self.input_spec = CompositeSpec( ... hidden_observation=NdUnboundedContinuousTensorSpec((4,)), @@ -56,7 +56,7 @@ class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta): ... TensorDictModule( ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0), ... in_keys=["hidden_observation", "action"], - ... out_keys=["next_hidden_observation"], + ... out_keys=["hidden_observation"], ... ), ... TensorDictModule( ... nn.Linear(4, 1), @@ -72,7 +72,12 @@ class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta): action: Tensor(torch.Size([10, 1]), dtype=torch.float32), done: Tensor(torch.Size([10, 1]), dtype=torch.bool), hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32), - next_hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32), + next: LazyStackedTensorDict( + fields={ + hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32)}, + batch_size=torch.Size([10]), + device=cpu, + is_shared=False), reward: Tensor(torch.Size([10, 1]), dtype=torch.float32)}, batch_size=torch.Size([10]), device=cpu, diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index 1388839f603..9e968eee6f8 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -49,8 +49,8 @@ def set_specs_from_env(self, env: EnvBase): # ), # ) self.input_spec = CompositeSpec( - state=self.observation_spec["next_state"], - belief=self.observation_spec["next_belief"], + state=self.observation_spec["state"], + belief=self.observation_spec["belief"], action=self.action_spec.to(self.device), ) @@ -60,7 +60,6 @@ def _reset(self, tensordict=None, **kwargs) -> TensorDict: td = self.input_spec.rand(shape=batch_size).to(device) td["reward"] = self.reward_spec.rand(shape=batch_size).to(device) td.update(self.observation_spec.rand(shape=batch_size).to(device)) - td = self.step(td) return td def decode_obs(self, tensordict: TensorDict, compute_latents=False) -> TensorDict: diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index 33710ec28b0..2d7a7ef8df8 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -90,6 +90,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec keys = [key for key in observation_spec._specs.keys() if key in self.in_keys] device = observation_spec[keys[0]].device + dim = observation_spec[keys[0]].shape[:-3] observation_spec = CompositeSpec(**observation_spec) if self.del_keys: @@ -98,7 +99,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec for out_key in self.out_keys: observation_spec[out_key] = NdUnboundedContinuousTensorSpec( - shape=torch.Size([self.outdim]), device=device + shape=torch.Size([*dim, self.outdim]), device=device ) return observation_spec @@ -153,7 +154,7 @@ class R3MTransform(Compose): can ensure that the following code snippet works as expected: Examples: - >>> transform = R3MTransform("resenet50", in_keys=["next_pixels"]) + >>> transform = R3MTransform("resenet50", in_keys=["pixels"]) >>> env.append_transform(transform) >>> # the forward method will first call _init which will look at env.observation_spec >>> env.reset() @@ -161,9 +162,9 @@ class R3MTransform(Compose): Args: model_name (str): one of resnet50, resnet34 or resnet18 in_keys (list of str, optional): list of input keys. If left empty, the - "next_pixels" key is assumed. + "pixels" key is assumed. out_keys (list of str, optional): list of output keys. If left empty, - "next_r3m_vec" is assumed. + "r3m_vec" is assumed. size (int, optional): Size of the image to feed to resnet. Defaults to 244. download (bool, optional): if True, the weights will be downloaded using @@ -249,9 +250,9 @@ def _init(self): # R3M if out_keys is None: if stack_images: - out_keys = ["next_r3m_vec"] + out_keys = ["r3m_vec"] else: - out_keys = [f"next_r3m_vec_{i}" for i in range(len(in_keys))] + out_keys = [f"r3m_vec_{i}" for i in range(len(in_keys))] elif stack_images and len(out_keys) != 1: raise ValueError( f"out_key must be of length 1 if stack_images is True. Got out_keys={out_keys}" diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b1d0f66cd78..b3c92d15932 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -9,23 +9,11 @@ import multiprocessing as mp from copy import deepcopy, copy from textwrap import indent -from typing import Any, List, Optional, OrderedDict, Sequence, Union, Tuple -from warnings import warn +from typing import Any, List, Optional, OrderedDict, Sequence, Tuple, Union import torch -from torch import nn, Tensor - -try: - from torchvision.transforms.functional import center_crop - from torchvision.transforms.functional_tensor import ( - resize, - ) # as of now resize is imported from torchvision - - _has_tv = True -except ImportError: - _has_tv = False - from tensordict.tensordict import TensorDictBase, TensorDict +from torch import nn, Tensor from torchrl.data.tensor_specs import ( BoundedTensorSpec, @@ -39,11 +27,21 @@ ) from torchrl.envs.common import EnvBase, make_tensordict from torchrl.envs.transforms import functional as F -from torchrl.envs.transforms.utils import FiniteTensor +from torchrl.envs.transforms.utils import check_finite from torchrl.envs.utils import step_mdp -IMAGE_KEYS = ["next_pixels"] +try: + from torchvision.transforms.functional import center_crop + from torchvision.transforms.functional_tensor import ( + resize, + ) # as of now resize is imported from torchvision + + _has_tv = True +except ImportError: + _has_tv = False + +IMAGE_KEYS = ["pixels"] _MAX_NOOPS_TRIALS = 10 @@ -54,7 +52,7 @@ def new_fun(self, observation_spec): for in_key, out_key in zip(self.in_keys, self.out_keys): if in_key in observation_spec.keys(): d[out_key] = function(self, observation_spec[in_key]) - return CompositeSpec(**d) + return CompositeSpec(d) else: return function(self, observation_spec) @@ -138,14 +136,22 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: """Reads the input tensordict, and for the selected keys, applies the transform.""" self._check_inplace() for in_key, out_key in zip(self.in_keys, self.out_keys): - if in_key in tensordict.keys(): + if in_key in tensordict.keys(include_nested=True): observation = self._apply_transform(tensordict.get(in_key)) tensordict.set(out_key, observation, inplace=self.inplace) return tensordict def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - self._call(tensordict) + tensordict = self._call(tensordict) return tensordict + # raise NotImplementedError("""`Transform.forward` is currently not implemented (reserved for usage beyond envs). Use `Transform._step` instead.""") + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + # placeholder when we'll move to tensordict['next'] + # tensordict["next"] = self._call(tensordict.get("next")) + out = self._call(tensordict) + # print(out, tensordict, out is tensordict, (out==tensordict).all()) + return out def _inv_apply_transform(self, obs: torch.Tensor) -> torch.Tensor: if self.invertible: @@ -156,7 +162,7 @@ def _inv_apply_transform(self, obs: torch.Tensor) -> torch.Tensor: def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: self._check_inplace() for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): - if in_key in tensordict.keys(): + if in_key in tensordict.keys(include_nested=True): observation = self._inv_apply_transform(tensordict.get(in_key)) tensordict.set(out_key, observation, inplace=self.inplace) return tensordict @@ -423,11 +429,17 @@ def reward_spec(self) -> TensorSpec: def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # selected_keys = [key for key in tensordict.keys() if "action" in key] # tensordict_in = tensordict.select(*selected_keys).clone() - tensordict_in = self.transform.inv(tensordict.clone(recurse=False)) - tensordict_out = self.base_env.step(tensordict_in) + tensordict = tensordict.clone() + tensordict_in = self.transform.inv(tensordict) + tensordict_out = self.base_env._step(tensordict_in) # tensordict should already have been processed by the transforms # for logging purposes - tensordict_out = self.transform(tensordict_out) + tensordict_out = tensordict_out.update( + tensordict.exclude(*tensordict_out.keys()) + ) + next_tensordict = self.transform._step(tensordict_out) + tensordict_out.update(next_tensordict, inplace=False) + return tensordict_out def set_seed(self, seed: int, static_seed: bool = False) -> int: @@ -437,9 +449,7 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int: def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): if tensordict is not None: tensordict = tensordict.clone(recurse=False) - out_tensordict = self.base_env.reset( - tensordict=tensordict, execute_step=False, **kwargs - ) + out_tensordict = self.base_env.reset(tensordict=tensordict, **kwargs) out_tensordict = self.transform.reset(out_tensordict) out_tensordict = self.transform(out_tensordict) return out_tensordict @@ -470,16 +480,16 @@ def is_closed(self) -> bool: def is_closed(self, value: bool): self.base_env.is_closed = value - def is_done_get_fn(self) -> bool: + @property + def is_done(self) -> bool: if self._is_done is None: return self.base_env.is_done return self._is_done.all() - def is_done_set_fn(self, val: torch.Tensor) -> None: + @is_done.setter + def is_done(self, val: torch.Tensor) -> None: self._is_done = val - is_done = property(is_done_get_fn, is_done_set_fn) - def close(self): self.base_env.close() self.is_closed = True @@ -593,9 +603,9 @@ def __init__( ): if in_keys is None: in_keys = [ - "next_observation", - "next_pixels", - "next_observation_state", + "observation", + "pixels", + "observation_state", ] super(ObservationTransform, self).__init__(in_keys=in_keys, out_keys=out_keys) @@ -619,11 +629,16 @@ def __init__(self, *transforms: Transform): for t in self.transforms: t.set_parent(self) - def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: for t in self.transforms: tensordict = t(tensordict) return tensordict + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + for t in self.transforms: + tensordict = t._step(tensordict) + return tensordict + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: for t in self.transforms[::-1]: tensordict = t.inv(tensordict) @@ -728,13 +743,13 @@ class ToTensorImage(ObservationTransform): observations. Examples: - >>> transform = ToTensorImage(in_keys=["next_pixels"]) + >>> transform = ToTensorImage(in_keys=["pixels"]) >>> ri = torch.randint(0, 255, (1,1,10,11,3), dtype=torch.uint8) >>> td = TensorDict( - ... {"next_pixels": ri}, + ... {"pixels": ri}, ... [1, 1]) >>> _ = transform(td) - >>> obs = td.get("next_pixels") + >>> obs = td.get("pixels") >>> print(obs.shape, obs.dtype) torch.Size([1, 1, 3, 10, 11]) torch.float32 """ @@ -1112,6 +1127,11 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: self._unsqueeze_dim = self._unsqueeze_dim_orig + tensordict.ndimension() return super().forward(tensordict) + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + if self._unsqueeze_dim_orig >= 0: + self._unsqueeze_dim = self._unsqueeze_dim_orig + tensordict.ndimension() + return super()._step(tensordict) + def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: observation = observation.unsqueeze(self.unsqueeze_dim) return observation @@ -1196,6 +1216,11 @@ def squeeze_dim(self): def forward(self, tensordict: TensorDictBase) -> TensorDictBase: return super().inv(tensordict) + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + # placeholder for when we'll move to 'next' indexing for steps + # return super().inv(tensordict["next"]) + return super().inv(tensordict) + def inv(self, tensordict: TensorDictBase) -> TensorDictBase: return super().forward(tensordict) @@ -1249,14 +1274,14 @@ class ObservationNorm(ObservationTransform): Examples: >>> torch.set_default_tensor_type(torch.DoubleTensor) >>> r = torch.randn(100, 3)*torch.randn(3) + torch.randn(3) - >>> td = TensorDict({'next_obs': r}, [100]) + >>> td = TensorDict({'obs': r}, [100]) >>> transform = ObservationNorm( - ... loc = td.get('next_obs').mean(0), - ... scale = td.get('next_obs').std(0), - ... in_keys=["next_obs"], + ... loc = td.get('obs').mean(0), + ... scale = td.get('obs').std(0), + ... in_keys=["obs"], ... standard_normal=True) >>> _ = transform(td) - >>> print(torch.isclose(td.get('next_obs').mean(0), + >>> print(torch.isclose(td.get('obs').mean(0), ... torch.zeros(3)).all()) tensor(True) >>> print(torch.isclose(td.get('next_obs').std(0), @@ -1288,9 +1313,9 @@ def __init__( ): if in_keys is None: in_keys = [ - "next_observation", - "next_pixels", - "next_observation_state", + "observation", + "pixels", + "observation_state", ] super().__init__(in_keys=in_keys) self.standard_normal = standard_normal @@ -1301,7 +1326,7 @@ def __init__( if scale is not None and not isinstance(scale, torch.Tensor): scale = torch.tensor(scale, dtype=torch.float) - scale.clamp_min(self.eps) + scale = scale.clamp_min(self.eps) # self.observation_spec_key = observation_spec_key self.register_buffer("loc", loc) @@ -1527,18 +1552,8 @@ def __init__(self): super().__init__(in_keys=[]) def _call(self, tensordict: TensorDictBase) -> TensorDictBase: - source = {} - for key, item in tensordict.items(): - try: - source[key] = FiniteTensor(item) - except RuntimeError as err: - if str(err).rfind("FiniteTensor encountered") > -1: - raise ValueError(f"Found non-finite elements in {key}") - else: - raise RuntimeError(str(err)) - - finite_tensordict = TensorDict(batch_size=tensordict.batch_size, source=source) - return finite_tensordict + tensordict.apply(check_finite) + return tensordict class DoubleToFloat(Transform): @@ -1546,10 +1561,10 @@ class DoubleToFloat(Transform): Examples: >>> td = TensorDict( - ... {'next_obs': torch.ones(1, dtype=torch.double)}, []) - >>> transform = DoubleToFloat(in_keys=["next_obs"]) + ... {'obs': torch.ones(1, dtype=torch.double)}, []) + >>> transform = DoubleToFloat(in_keys=["obs"]) >>> _ = transform(td) - >>> print(td.get("next_obs").dtype) + >>> print(td.get("obs").dtype) torch.float32 """ @@ -1605,7 +1620,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec def __repr__(self) -> str: s = ( - f"{self.__class__.__name__}(in_keys={self.in_keys}, out_keys={self.out_keys}," + f"{self.__class__.__name__}(in_keys={self.in_keys}, out_keys={self.out_keys}, " f"in_keys_inv={self.in_keys_inv}, out_keys_inv={self.out_keys_inv})" ) return s @@ -1654,7 +1669,7 @@ class CatTensors(Transform): def __init__( self, in_keys: Optional[Sequence[str]] = None, - out_key: str = "next_observation_vector", + out_key: str = "observation_vector", dim: int = -1, del_keys: bool = True, unsqueeze_if_oor: bool = False, @@ -1667,7 +1682,6 @@ def __init__( ) else: in_keys = sorted(list(in_keys)) - self._check_in_keys(in_keys, out_key) if type(out_key) != str: raise Exception("CatTensors requires out_key to be of type string") # super().__init__(in_keys=in_keys) @@ -1676,15 +1690,6 @@ def __init__( self.del_keys = del_keys self.unsqueeze_if_oor = unsqueeze_if_oor - def _check_in_keys(self, in_keys, out_key): - if not out_key.startswith("next_") and all( - key.startswith("next_") for key in in_keys - ): - warn( - f"It seems that 'next_'-like keys are being concatenated to a non 'next_' key {out_key}. This may result in unwanted behaviours, and the 'next_' flag is missing from the output key." - f"Consider renaming the out_key to 'next_{out_key}'" - ) - def _find_in_keys(self): parent = self.parent obs_spec = parent.observation_spec @@ -1692,7 +1697,6 @@ def _find_in_keys(self): for key, value in obs_spec.items(): if len(value.shape) == 1: in_keys.append(key) - self._check_in_keys(in_keys, self.out_keys[0]) return sorted(in_keys) def _call(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -1700,7 +1704,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: self.in_keys = self._find_in_keys() self._initialized = True - if all([key in tensordict.keys() for key in self.in_keys]): + if all([key in tensordict.keys(include_nested=True) for key in self.in_keys]): values = [tensordict.get(key) for key in self.in_keys] if self.unsqueeze_if_oor: pos_idx = self.dim > 0 @@ -1722,7 +1726,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: raise Exception( f"CatTensor failed, as it expected input keys =" f" {sorted(list(self.in_keys))} but got a TensorDict with keys" - f" {sorted(list(tensordict.keys()))}" + f" {sorted(list(tensordict.keys(include_nested=True)))}" ) return tensordict @@ -1873,8 +1877,7 @@ def base_env(self): def reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Do no-op action for a number of steps in [1, noop_max].""" parent = self.parent - keys = tensordict.keys() - keys = [key for key in keys if not key.startswith("next_")] + # keys = tensordict.keys() noops = ( self.noops if not self.random else torch.randint(self.noops, (1,)).item() ) @@ -1883,7 +1886,8 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: while i < noops: i += 1 - tensordict = parent.rand_step(step_mdp(tensordict)) + tensordict = parent.rand_step(tensordict) + tensordict = step_mdp(tensordict) if parent.is_done: parent.reset() i = 0 @@ -1894,19 +1898,19 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: break if parent.is_done: raise RuntimeError("NoopResetEnv concluded with done environment") - td = step_mdp( - tensordict, exclude_done=False, exclude_reward=True, exclude_action=True - ) + # td = step_mdp( + # tensordict, exclude_done=False, exclude_reward=True, exclude_action=True + # ) - for k in keys: - if k not in td.keys(): - td.set(k, tensordict.get(k)) + # for k in keys: + # if k not in td.keys(): + # td.set(k, tensordict.get(k)) - # replace the next_ prefix - for out_key in parent.observation_spec: - td.rename_key(out_key[5:], out_key) + # # replace the next_ prefix + # for out_key in parent.observation_spec: + # td.rename_key(out_key[5:], out_key) - return td + return tensordict def __repr__(self) -> str: random = self.random @@ -1984,13 +1988,14 @@ def transform_observation_spec( f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." ) for key, spec in self.primers.items(): - if key in observation_spec: - raise RuntimeError( - f"The key {key} is already in the observation_spec. This means " - f"that the value reset by TensorDictPrimer will confict with the " - f"value obtained through the call to `env.reset()`. Consider renaming " - f"the {key} key." - ) + # deprecating this with the new "next_" logic where we expect keys to collide + # if key in observation_spec: + # raise RuntimeError( + # f"The key {key} is already in the observation_spec. This means " + # f"that the value reset by TensorDictPrimer will confict with the " + # f"value obtained through the call to `env.reset()`. Consider renaming " + # f"the {key} key." + # ) observation_spec[key] = spec.to(self.device) return observation_spec @@ -2101,7 +2106,7 @@ class VecNorm(Transform): Args: in_keys (iterable of str, optional): keys to be updated. - default: ["next_observation", "reward"] + default: ["observation", "reward"] shared_td (TensorDictBase, optional): A shared tensordict containing the keys of the transform. decay (number, optional): decay rate of the moving average. @@ -2121,9 +2126,9 @@ class VecNorm(Transform): ... _ = env.reset() ... tds += [td] >>> tds = torch.stack(tds, 0) - >>> print((abs(tds.get("next_observation").mean(0))<0.2).all()) + >>> print((abs(tds.get(("next", "observation")).mean(0))<0.2).all()) tensor(True) - >>> print((abs(tds.get("next_observation").std(0)-1)<0.2).all()) + >>> print((abs(tds.get(("next", "observation")).std(0)-1)<0.2).all()) tensor(True) """ @@ -2141,7 +2146,7 @@ def __init__( if lock is None: lock = mp.Lock() if in_keys is None: - in_keys = ["next_observation", "reward"] + in_keys = ["observation", "reward"] super().__init__(in_keys) self._td = shared_td if shared_td is not None and not ( @@ -2171,7 +2176,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: self.lock.acquire() for key in self.in_keys: - if key not in tensordict.keys(): + if key not in tensordict.keys(include_nested=True): continue self._init(tensordict, key) # update and standardize @@ -2258,7 +2263,7 @@ def to_observation_norm(self) -> Union[Compose, ObservationNorm]: @staticmethod def build_td_for_shared_vecnorm( env: EnvBase, - keys_prefix: Optional[Sequence[str]] = None, + keys: Optional[Sequence[str]] = None, memmap: bool = False, ) -> TensorDictBase: """Creates a shared tensordict for normalization across processes. @@ -2266,8 +2271,8 @@ def build_td_for_shared_vecnorm( Args: env (EnvBase): example environment to be used to create the tensordict - keys_prefix (iterable of str, optional): prefix of the keys that - have to be normalized. Default is `["next_", "reward"]` + keys (iterable of str, optional): keys that + have to be normalized. Default is `["next", "reward"]` memmap (bool): if True, the resulting tensordict will be cast into memmory map (using `memmap_()`). Otherwise, the tensordict will be placed in shared memory. @@ -2280,7 +2285,7 @@ def build_td_for_shared_vecnorm( >>> queue = mp.Queue() >>> env = make_env() >>> td_shared = VecNorm.build_td_for_shared_vecnorm(env, - ... ["next_observation", "reward"]) + ... ["next", "reward"]) >>> assert td_shared.is_shared() >>> queue.put(td_shared) >>> # on workers @@ -2288,20 +2293,20 @@ def build_td_for_shared_vecnorm( >>> env = TransformedEnv(make_env(), v) """ - if keys_prefix is None: - keys_prefix = ["next_", "reward"] + raise NotImplementedError("this feature is currently put on hold.") + sep = ".-|-." + if keys is None: + keys = ["next", "reward"] td = make_tensordict(env) - keys = set( - key - for key in td.keys() - if any(key.startswith(_prefix) for _prefix in keys_prefix) - ) + keys = set(key for key in td.keys() if key in keys) td_select = td.select(*keys) + td_select = td_select.flatten_keys(sep) if td.batch_dims: raise RuntimeError( f"VecNorm should be used with non-batched environments. " f"Got batch_size={td.batch_size}" ) + keys = list(td_select.keys()) for key in keys: td_select.set(key + "_ssq", td_select.get(key).clone()) td_select.set( @@ -2314,7 +2319,8 @@ def build_td_for_shared_vecnorm( ), ) td_select.rename_key(key, key + "_sum") - td_select.zero_() + td_select.exclude(*keys).zero_() + td_select = td_select.unflatten_keys(sep) if memmap: return td_select.memmap_() return td_select.share_memory_() diff --git a/torchrl/envs/transforms/utils.py b/torchrl/envs/transforms/utils.py index e0ff27ec9b3..c6d573c244d 100644 --- a/torchrl/envs/transforms/utils.py +++ b/torchrl/envs/transforms/utils.py @@ -3,41 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Optional, Tuple import torch -from torch.utils._pytree import tree_map -class FiniteTensor(torch.Tensor): - """A finite tensor. - - If the data contained in this tensor contain non-finite values (nan or inf) - a :obj:`RuntimeError` will be thrown. - - """ - - @staticmethod - def __new__(cls, elem: torch.Tensor, *args, **kwargs): - if not torch.isfinite(elem).all(): - raise RuntimeError("FiniteTensor encountered a non-finite tensor.") - return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) - - def __repr__(self) -> str: - return f"FiniteTensor({super().__repr__()})" - - @classmethod - def __torch_dispatch__( - cls, - func: Callable, - types, - args: Tuple = (), - kwargs: Optional[dict] = None, - ): - # TODO: also explicitly recheck invariants on inplace/out mutation - if kwargs: - raise Exception("Expected empty kwargs") - rs = func(*args) - return tree_map( - lambda e: FiniteTensor(e) if isinstance(e, torch.Tensor) else e, rs - ) +def check_finite(tensor: torch.Tensor): + """Raise an error if a tensor has non-finite elements.""" + if not tensor.isfinite().all(): + raise ValueError("Encountered a non-finite tensor.") diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index 8212775d102..8167593f6ef 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -82,6 +82,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec keys = [key for key in observation_spec._specs.keys() if key in self.in_keys] device = observation_spec[keys[0]].device + dim = observation_spec[keys[0]].shape[:-3] observation_spec = CompositeSpec(**observation_spec) if self.del_keys: @@ -90,7 +91,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec for out_key in self.out_keys: observation_spec[out_key] = NdUnboundedContinuousTensorSpec( - shape=torch.Size([self.outdim]), device=device + shape=torch.Size([*dim, self.outdim]), device=device ) return observation_spec @@ -136,9 +137,9 @@ class VIPTransform(Compose): Args: model_name (str): one of resnet50 in_keys (list of str, optional): list of input keys. If left empty, the - "next_pixels" key is assumed. + "pixels" key is assumed. out_keys (list of str, optional): list of output keys. If left empty, - "next_vip_vec" is assumed. + "vip_vec" is assumed. size (int, optional): Size of the image to feed to resnet. Defaults to 244. download (bool, optional): if True, the weights will be downloaded using @@ -224,9 +225,9 @@ def _init(self): # VIP if out_keys is None: if stack_images: - out_keys = ["next_vip_vec"] + out_keys = ["vip_vec"] else: - out_keys = [f"next_vip_vec_{i}" for i in range(len(in_keys))] + out_keys = [f"vip_vec_{i}" for i in range(len(in_keys))] elif stack_images and len(out_keys) != 1: raise ValueError( f"out_key must be of length 1 if stack_images is True. Got out_keys={out_keys}" @@ -336,13 +337,13 @@ def _embed_goal(self, tensordict): ) return tensordict - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if "goal_embedding" not in tensordict.keys(): tensordict = self._embed_goal(tensordict) - tensordict = super().forward(tensordict) - cur_embedding = tensordict.get(self.out_keys[0]) - last_embedding_key = self.out_keys[0].split("next_")[1] + last_embedding_key = self.out_keys[0] last_embedding = tensordict.get(last_embedding_key, None) + tensordict = super()._step(tensordict) + cur_embedding = tensordict.get(self.out_keys[0]) if last_embedding is not None: goal_embedding = tensordict["goal_embedding"] reward = -torch.norm(cur_embedding - goal_embedding, dim=-1) - ( @@ -350,3 +351,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) tensordict.set("reward", reward) return tensordict + + def forward(self, tensordict): + tensordict = super().forward(tensordict) + return tensordict diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 225a3e394ab..36a5cd8df36 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -24,12 +24,11 @@ def step_mdp( exclude_reward: bool = True, exclude_done: bool = True, exclude_action: bool = True, + _run_check: bool = True, ) -> TensorDictBase: """Creates a new tensordict that reflects a step in time of the input tensordict. - Given a tensordict retrieved after a step, returns another tensordict with all the :obj:`'next_'` prefixes are removed, - i.e. all the :obj:`'next_some_other_string'` keys will be renamed onto :obj:`'some_other_string'` keys. - + Given a tensordict retrieved after a step, returns the :obj:`"next"` indexed-tensordict. Args: tensordict (TensorDictBase): tensordict with keys to be renamed @@ -47,7 +46,7 @@ def step_mdp( Default is True. Returns: - A new tensordict (or next_tensordict) with the "next_*" keys renamed without the "next_" prefix. + A new tensordict (or next_tensordict) containing the tensors of the t+1 step. Examples: This funtion allows for this kind of loop to be used: @@ -79,19 +78,13 @@ def step_mdp( prohibited.add("action") else: other_keys.append("action") - keys = [key for key in tensordict.keys() if key.startswith("next_")] - if len(keys) == 0: - raise RuntimeError( - "There was no key starting with 'next_' in the provided TensorDict: ", - tensordict, - ) - new_keys = [key[5:] for key in keys] - prohibited = prohibited.union(keys).union(new_keys) + + prohibited.add("next") if keep_other: other_keys = [key for key in tensordict.keys() if key not in prohibited] - select_tensordict = tensordict.select(*other_keys, *keys) - for new_key, key in zip(new_keys, keys): - select_tensordict.rename_key(key, new_key, safe=True) + select_tensordict = tensordict.select(*other_keys) + select_tensordict = select_tensordict.update(tensordict.get("next")) + if next_tensordict is not None: return next_tensordict.update(select_tensordict) else: diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index dfec0d28239..b8cdece106b 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -99,7 +99,7 @@ class _BatchedEnv(EnvBase): selected_keys (list of str, optional): keys that have to be returned by the environment. When creating a batch of environment, it might be the case that only some of the keys are to be returned. For instance, if the environment returns 'next_pixels' and 'next_vector', the user might only - be interested in, say, 'next_vector'. By indicating which keys must be returned in the tensordict, + be interested in, say, 'vector'. By indicating which keys must be returned in the tensordict, one can easily control the amount of data occupied in memory (for instance to limit the memory size of a replay buffer) and/or limit the amount of data passed from one process to the other; excluded_keys (list of str, optional): list of keys to be excluded from the returned tensordicts. @@ -580,7 +580,7 @@ def _step( ) tensordict_out = [] for i in range(self.num_workers): - _tensordict_out = self._envs[i].step(tensordict_in[i]) + _tensordict_out = self._envs[i]._step(tensordict_in[i]) tensordict_out.append(_tensordict_out) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps @@ -611,7 +611,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: for i, _env in enumerate(self._envs): if not reset_workers[i]: continue - _td = _env.reset(execute_step=False, **kwargs) + _td = _env._reset(**kwargs) keys = keys.union(_td.keys()) self.shared_tensordicts[i].update_(_td) @@ -980,7 +980,7 @@ def _run_worker_pipe_shared_mem( if not initialized: raise RuntimeError("call 'init' before resetting") # _td = tensordict.select("observation").to(env.device).clone() - _td = env.reset(execute_step=False, **reset_kwargs) + _td = env._reset(**reset_kwargs) if reset_keys is None: reset_keys = set(_td.keys()) if pin_memory: @@ -1005,9 +1005,11 @@ def _run_worker_pipe_shared_mem( raise RuntimeError( f"calling step when env is done, just reset = {just_reset}" ) - _td = env.step(_td) + _td = env._step(_td) if step_keys is None: - step_keys = set(_td.keys()) - set(env_input_keys) + step_keys = set(env.observation_spec.keys()).union( + {"done", "terminated", "reward"} + ) if pin_memory: _td.pin_memory() tensordict.update_(_td.select(*step_keys, strict=False)) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 18d12d7fa78..9b94ed4912b 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -190,11 +190,11 @@ def forward(self, tensordict): update_values = tensordict.exclude(*self.out_keys) for t in range(time_steps): # samples according to p(s_{t+1} | s_t, a_t, b_t) - # ["state", "belief", "action"] -> ["next_prior_mean", "next_prior_std", "_", "next_belief"] + # ["state", "belief", "action"] -> [("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief")] self.rssm_prior(_tensordict) # samples according to p(s_{t+1} | s_t, a_t, o_{t+1}) = p(s_t | b_t, o_t) - # ["next_belief", "next_encoded_latents"] -> ["next_posterior_mean", "next_posterior_std", "next_state"] + # [("next", "belief"), ("next", "encoded_latents")] -> [("next", "posterior_mean"), ("next", "posterior_std"), ("next", "state")] self.rssm_posterior(_tensordict) tensordict_out.append(_tensordict) diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index 88e19de094c..d11c9ab12fd 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -74,7 +74,7 @@ class CEMPlanner(MPCPlannerBase): ... TensorDictModule( ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0), ... in_keys=["hidden_observation", "action"], - ... out_keys=["next_hidden_observation"], + ... out_keys=["hidden_observation"], ... ), ... TensorDictModule( ... nn.Linear(4, 1), diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 09625886214..81ca1c33f1b 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -55,13 +55,18 @@ ) -def _check_all_str(list_of_str): - if isinstance(list_of_str, str): +def _check_all_str(list_of_str, first_level=True): + if isinstance(list_of_str, str) and first_level: raise RuntimeError( f"Expected a list of strings but got a string: {list_of_str}" ) - if any(not isinstance(key, str) for key in list_of_str): - raise TypeError(f"Expected a list of strings but got: {list_of_str}") + elif not isinstance(list_of_str, str): + try: + return [_check_all_str(item, False) for item in list_of_str] + except Exception as err: + raise TypeError( + f"Expected a list of strings but got: {list_of_str} that raised the following error: {err}." + ) def _forward_hook_safe_action(module, tensordict_in, tensordict_out): @@ -226,7 +231,7 @@ def __init__( if set(spec.keys()) != set(self.out_keys): raise RuntimeError( - f"spec keys and out_keys do not match, got: {spec.keys()} and {self.out_keys} respectively" + f"spec keys and out_keys do not match, got: {set(spec.keys())} and {set(self.out_keys)} respectively" ) self._spec = spec diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index ab7cedd8d4d..a1f3b96f8f2 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -142,7 +142,7 @@ def __init__( if isinstance(module, TensorDictModule) or hasattr(module, "spec"): spec.update(module.spec) else: - spec.update(CompositeSpec(**{key: None for key in module.out_keys})) + spec.update(CompositeSpec({key: None for key in module.out_keys})) super().__init__( spec=spec, module=nn.ModuleList(list(modules)), diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 62c5fff8b22..21a0bdd0620 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -187,10 +187,7 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: tensordict_save = tensordict obs_keys = self.actor_network.in_keys - next_obs_keys = [key for key in tensordict.keys() if key.startswith("next_")] - tensordict = tensordict.select( - "reward", "done", *next_obs_keys, *obs_keys, "action" - ) + tensordict = tensordict.select("reward", "done", "next", *obs_keys, "action") selected_models_idx = torch.randperm(self.num_qvalue_nets)[ : self.sub_sample_len @@ -228,7 +225,7 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: ) state_value = state_value.min(0)[0] - tensordict.set("next_state_value", state_value) + tensordict.set("next.state_value", state_value) target_value = get_next_state_value( tensordict, gamma=self.gamma, diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 109000e266f..863b34a2c56 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -70,14 +70,14 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: tensordict = self.world_model(tensordict) # compute model loss kl_loss = self.kl_loss( - tensordict.get("next_prior_mean"), - tensordict.get("next_prior_std"), - tensordict.get("next_posterior_mean"), - tensordict.get("next_posterior_std"), + tensordict.get(("next", "prior_mean")), + tensordict.get(("next", "prior_std")), + tensordict.get(("next", "posterior_mean")), + tensordict.get(("next", "posterior_std")), ) reco_loss = distance_loss( - tensordict.get("next_pixels"), - tensordict.get("next_reco_pixels"), + tensordict.get(("next", "pixels")), + tensordict.get(("next", "reco_pixels")), self.reco_loss, ) if not self.global_average: @@ -171,6 +171,7 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: tensordict = tensordict.reshape(-1) with hold_out_net(self.model_based_env), set_exploration_mode("random"): + tensordict = self.model_based_env.reset(tensordict.clone(recurse=False)) fake_data = self.model_based_env.rollout( max_steps=self.imagination_horizon, policy=self.actor_model, diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 06dfe923b1f..2f3c010f7ee 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -146,9 +146,8 @@ def alpha(self): def forward(self, tensordict: TensorDictBase) -> TensorDictBase: obs_keys = self.actor_network.in_keys - next_obs_keys = [key for key in tensordict.keys() if key.startswith("next_")] tensordict_select = tensordict.select( - "reward", "done", *next_obs_keys, *obs_keys, "action" + "reward", "done", "next", *obs_keys, "action" ) selected_models_idx = torch.randperm(self.num_qvalue_nets)[ : self.sub_sample_len @@ -297,7 +296,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "entropy": -sample_log_prob.mean().detach(), "state_action_value_actor": state_action_value_actor.mean().detach(), "action_log_prob_actor": action_log_prob_actor.mean().detach(), - "next_state_value": next_state_value.mean().detach(), + "next.state_value": next_state_value.mean().detach(), "target_value": target_value.mean().detach(), }, [], diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 92ba11b96ba..279339d9e6a 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -7,15 +7,6 @@ import torch from tensordict.tensordict import TensorDictBase - -# for value, log_policy, reward, entropy in list(zip(values, log_policies, rewards, entropies))[::-1]: -# gae = gae * opt.gamma * opt.tau -# gae = gae + reward + opt.gamma * next_value.detach() - value.detach() -# next_value = value -# actor_loss = actor_loss + log_policy * gae -# R = R * opt.gamma + reward -# critic_loss = critic_loss + (R - value) ** 2 / 2 -# entropy_loss = entropy_loss + entropy from torch import Tensor, nn from torchrl.envs.utils import step_mdp diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index 8fa60feedc3..6adfd0ed97f 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -30,7 +30,7 @@ class VideoRecorder(ObservationTransform): should be written. tag (str): the video tag in the logger. in_keys (Sequence[str], optional): keys to be read to produce the video. - Default is :obj:`"next_pixels"`. + Default is :obj:`"pixels"`. skip (int): frame interval in the output video. Default is 2. center_crop (int, optional): value of square center crop. @@ -51,7 +51,7 @@ def __init__( **kwargs, ) -> None: if in_keys is None: - in_keys = ["next_pixels"] + in_keys = ["pixels"] super().__init__(in_keys=in_keys) video_kwargs = {"fps": 6} diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index ded4a4d1152..87bd86ec790 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -124,13 +124,13 @@ def make_env_transforms( if cfg.grayscale: env.append_transform(GrayScale()) env.append_transform(FlattenObservation()) - env.append_transform(CatFrames(N=cfg.catframes, in_keys=["next_pixels"])) + env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"])) if stats is None: obs_stats = {"loc": 0.0, "scale": 1.0} else: obs_stats = stats obs_stats["standard_normal"] = True - env.append_transform(ObservationNorm(**obs_stats, in_keys=["next_pixels"])) + env.append_transform(ObservationNorm(**obs_stats, in_keys=["pixels"])) if norm_rewards: reward_scaling = 1.0 reward_loc = 0.0 @@ -154,12 +154,11 @@ def make_env_transforms( selected_keys = [ key for key in env.observation_spec.keys() - if ("pixels" not in key) - and (key.replace("next_", "") not in env.input_spec.keys()) + if ("pixels" not in key) and (key not in env.input_spec.keys()) ] - # even if there is a single tensor, it'll be renamed in "next_observation_vector" - out_key = "next_observation_vector" + # even if there is a single tensor, it'll be renamed in "observation_vector" + out_key = "observation_vector" env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) if not vecnorm: @@ -393,7 +392,7 @@ def get_stats_random_rollout( "thus get_stats_random_rollout cannot infer which to compute the stats of." ) - if key == "next_pixels": + if key == "pixels": m = val_stats.mean() s = val_stats.std() else: diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 962bb0e3047..a574e054bd3 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -169,7 +169,7 @@ def make_dqn_actor( "mlp_kwargs_output": {"num_cells": 512, "layer_class": linear_layer_class}, } # automatically infer in key - in_key = list(env_specs["observation_spec"])[0].split("next_")[-1] + in_key = list(env_specs["observation_spec"])[0] out_features = action_spec.shape[0] actor_class = QValueActor @@ -246,8 +246,8 @@ def make_ddpg_actor( >>> import hydra >>> from hydra.core.config_store import ConfigStore >>> import dataclasses - >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["next_observation"]), - ... CatTensors(["next_observation"], "next_observation_vector"))) + >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["observation"]), + ... CatTensors(["observation"], "observation_vector"))) >>> device = torch.device("cpu") >>> config_fields = [(config_field.name, config_field.type, config_field) for config_cls in ... (DDPGModelConfig, EnvConfig) @@ -736,8 +736,8 @@ def make_ppo_model( >>> import hydra >>> from hydra.core.config_store import ConfigStore >>> import dataclasses - >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["next_observation"]), - ... CatTensors(["next_observation"], "next_observation_vector"))) + >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["observation"]), + ... CatTensors(["observation"], "observation_vector"))) >>> device = torch.device("cpu") >>> config_fields = [(config_field.name, config_field.type, config_field) for config_cls in ... (PPOModelConfig, EnvConfig) @@ -921,7 +921,7 @@ def make_ppo_model( mlp_kwargs={"num_cells": [256, 256], "out_features": 256}, ) in_keys_actor += ["hidden0", "hidden1"] - out_keys += ["hidden0", "hidden1", "next_hidden0", "next_hidden1"] + out_keys += ["hidden0", "hidden1", ("next", "hidden0"), ("next", "hidden1")] else: policy_net = MLP( num_cells=[400, 300], @@ -1029,8 +1029,8 @@ def make_sac_model( >>> import hydra >>> from hydra.core.config_store import ConfigStore >>> import dataclasses - >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["next_observation"]), - ... CatTensors(["next_observation"], "next_observation_vector"))) + >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["observation"]), + ... CatTensors(["observation"], "observation_vector"))) >>> device = torch.device("cpu") >>> config_fields = [(config_field.name, config_field.type, config_field) for config_cls in ... (SACModelConfig, EnvConfig) @@ -1249,8 +1249,8 @@ def make_redq_model( >>> import hydra >>> from hydra.core.config_store import ConfigStore >>> import dataclasses - >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["next_observation"]), - ... CatTensors(["next_observation"], "next_observation_vector"))) + >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["observation"]), + ... CatTensors(["observation"], "observation_vector"))) >>> device = torch.device("cpu") >>> config_fields = [(config_field.name, config_field.type, config_field) for config_cls in ... (RedqModelConfig, EnvConfig) @@ -1562,19 +1562,19 @@ def _dreamer_make_world_model( rssm_prior, in_keys=["state", "belief", "action"], out_keys=[ - "next_prior_mean", - "next_prior_std", + ("next", "prior_mean"), + ("next", "prior_std"), "_", - "next_belief", + ("next", "belief"), ], ), TensorDictModule( rssm_posterior, - in_keys=["next_belief", "next_encoded_latents"], + in_keys=[("next", "belief"), ("next", "encoded_latents")], out_keys=[ - "next_posterior_mean", - "next_posterior_std", - "next_state", + ("next", "posterior_mean"), + ("next", "posterior_std"), + ("next", "state"), ], ), ) @@ -1582,19 +1582,19 @@ def _dreamer_make_world_model( transition_model = TensorDictSequential( TensorDictModule( obs_encoder, - in_keys=["next_pixels"], - out_keys=["next_encoded_latents"], + in_keys=[("next", "pixels")], + out_keys=[("next", "encoded_latents")], ), rssm_rollout, TensorDictModule( obs_decoder, - in_keys=["next_state", "next_belief"], - out_keys=["next_reco_pixels"], + in_keys=[("next", "state"), ("next", "belief")], + out_keys=[("next", "reco_pixels")], ), ) reward_model = TensorDictModule( reward_module, - in_keys=["next_state", "next_belief"], + in_keys=[("next", "state"), ("next", "belief")], out_keys=["reward"], ) world_model = WorldModelWrapper( @@ -1710,7 +1710,7 @@ def _dreamer_make_actor_real( "_", "_", "_", # we don't need the prior state - "next_belief", + ("next", "belief"), ], ), ) @@ -1745,8 +1745,8 @@ def _dreamer_make_mbenv( if use_decoder_in_env: mb_env_obs_decoder = TensorDictModule( obs_decoder, - in_keys=["next_state", "next_belief"], - out_keys=["next_reco_pixels"], + in_keys=[("next", "state"), ("next", "belief")], + out_keys=[("next", "reco_pixels")], ) else: mb_env_obs_decoder = None @@ -1758,14 +1758,14 @@ def _dreamer_make_mbenv( out_keys=[ "_", "_", - "next_state", - "next_belief", + "state", + "belief", ], ), ) reward_model = TensorDictModule( reward_module, - in_keys=["next_state", "next_belief"], + in_keys=["state", "belief"], out_keys=["reward"], ) model_based_env = DreamerEnv( @@ -1781,8 +1781,8 @@ def _dreamer_make_mbenv( model_based_env.set_specs_from_env(proof_environment) model_based_env = TransformedEnv(model_based_env) default_dict = { - "next_state": NdUnboundedContinuousTensorSpec(state_dim), - "next_belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": NdUnboundedContinuousTensorSpec(state_dim), + "belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim), # "action": proof_environment.action_spec, } model_based_env.append_transform( From 1479497f76900dfd7ae6fdb61d0eae72f7f2365e Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Thu, 17 Nov 2022 19:23:00 +0100 Subject: [PATCH 18/33] adapted to nested next td --- examples/a2c/a2c.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/a2c/a2c.py b/examples/a2c/a2c.py index 0a12faa114e..3e50a48bbd8 100644 --- a/examples/a2c/a2c.py +++ b/examples/a2c/a2c.py @@ -103,7 +103,7 @@ def main(cfg: "DictConfig"): # noqa: F821 stats = get_stats_random_rollout( cfg, proof_env, - key="next_pixels" if cfg.from_pixels else "next_observation_vector", + key="pixels" if cfg.from_pixels else "observation_vector", ) # make sure proof_env is closed proof_env.close() From 39515df366f6c0edd22681004987ae736e9c4c43 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 16 Nov 2022 11:20:48 +0000 Subject: [PATCH 19/33] [Refactor] Refactor 'next_' into nested tensordicts (#649) * init * [Feature] Nested composite spec (#654) * [Feature] Move `transform.forward` to `transform.step` (#660) * transform step function * amend * amend * amend * amend * amend * fixing key names * fixing key names * [Refactor] Transform next remove (#661) * Refactor "next_" into ("next", ) (#673) * amend * amend * bugfix * init * strict=False * strict=False * minor * amend * [BugFix] Use GitHub for flake8 pre-commit hook (#679) * amend * [BugFix] Update to strict select (#675) * init * strict=False * amend * amend * [Feature] Auto-compute stats for ObservationNorm (#669) * Add auto-compute stats feature for ObservationNorm * Fix issue in ObservNorm init function * Quick refactor of ObservationNorm init method * Minor refactoring and adding more tests for ObservationNorm * lint * docstring * docstring Co-authored-by: vmoens * amend * amend * lint * bf * bf * amend Co-authored-by: Romain Julien Co-authored-by: Romain Julien --- torchrl/trainers/helpers/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index a574e054bd3..e2fb2074215 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -447,8 +447,8 @@ def make_a2c_model( >>> import hydra >>> from hydra.core.config_store import ConfigStore >>> import dataclasses - >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["next_observation"]), - ... CatTensors(["next_observation"], "next_observation_vector"))) + >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["observation"]), + ... CatTensors(["observation"], "observation_vector"))) >>> device = torch.device("cpu") >>> config_fields = [(config_field.name, config_field.type, config_field) for config_cls in ... (PPOModelConfig, EnvConfig) From 0d28c794627052b40f93bb51b27c8e80fd8129b2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Nov 2022 18:34:05 +0000 Subject: [PATCH 20/33] [Doc] More doc about environments (#683) * amend * amend * amend * amend * amend * amend --- docs/source/reference/envs.rst | 226 ++++++++++++++++++++++++++++----- 1 file changed, 193 insertions(+), 33 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 20704f16bd9..7112d8a59f2 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -3,52 +3,131 @@ torchrl.envs package ==================== +TorchRL offers an API to handle environments of different backends, such as gym, +dm-control, dm-lab, model-based environments as well as custom environments. +The goal is to be able to swap environments in an experiment with little or no effort, +even if these environments are simulated using different libraries. +TorchRL offers some out-of-the-box environment wrappers under :obj:`torchrl.envs.libs`, +which we hope can be easily imitated for other libraries. +The parent class :obj:`EnvBase` is a :obj:`torch.nn.Module` subclass that implements +some typical environment methods using :obj:`TensorDict` as a data organiser. This allows this +class to be generic and to handle an arbitrary number of input and outputs, as well as +nested or batched data structures. + +Each env will have the following attributes: + +- :obj:`env.batch_size`: a :obj:`torch.Size` representing the number of envs batched together. +- :obj:`env.device`: the device where the input and output tensordict are expected to live. + The environment device does not mean that the actual step operations will be computed on device + (this is the responsibility of the backend, with which TorchRL can do little). The device of + an environment just represents the device where the data is to be expected when input to the + environment or retrieved from it. TorchRL takes care of mapping the data to the desired device. + This is especially useful for transforms (see below). For parametric environments (e.g. + model-based environments), the device does represent the hardware that will be used to + compute the operations. +- :obj:`env.observation_spec`: a :obj:`CompositeSpec` object containing all the observation key-spec pairs. +- :obj:`env.input_spec`: a :obj:`CompositeSpec` object containing all the input keys (:obj:`"action"` and others). +- :obj:`env.action_spec`: a :obj:`TensorSpec` object representing the action spec. +- :obj:`env.reward_spec`: a :obj:`TensorSpec` object representing the reward spec. + +Importantly, the environment spec shapes should *not* contain the batch size, e.g. +an environment with :obj:`env.batch_size == torch.Size([4])` should not have +an :obj:`env.action_spec` with shape :obj:`torch.Size([4, action_size])` but simply +:obj:`torch.Size([action_size])`. + +With these, the following methods are implemented: + +- :obj:`env.reset(tensordict)`: a reset method that may (but not necessarily requires to) take + a :obj:`TensorDict` input. It return the first tensordict of a rollout, usually + containing a :obj:`"done"` state and a set of observations. +- :obj:`env.step(tensordict)`: a step method that takes a :obj:`TensorDict` input + containing an input action as well as other inputs (for model-based or stateless + environments, for instance). +- :obj:`env.set_seed(integer)`: a seeding method that will return the next seed + to be used in a multi-env setting. This next seed is deterministically computed + from the preceding one, such that one can seed multiple environments with a different + seed without risking to overlap seeds in consecutive experiments, while still + having reproducible results. +- :obj:`env.rollout(max_steps, policy)`: executes a rollout in the environment for + a maximum number of steps :obj:`max_steps` and using a policy :obj:`policy`. + The policy should be coded using a :obj:`TensorDictModule` (or any other + :obj:`TensorDict`-compatible module). + + .. autosummary:: :toctree: generated/ :template: rl_template.rst EnvBase GymLikeEnv - SerialEnv - ParallelEnv - -Helpers -------- -.. currentmodule:: torchrl.envs.utils - -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - step_mdp - get_available_libraries - set_exploration_mode - exploration_mode - -Domain-specific +Vectorized envs --------------- -.. currentmodule:: torchrl.envs -.. autosummary:: - :toctree: generated/ - :template: rl_template_fun.rst - - ModelBasedEnvBase - model_based.dreamer.DreamerEnv - - -Libraries ---------- -.. currentmodule:: torchrl.envs.libs +Vectorized (or better: parallel) environments is a common feature in Reinforcement Learning +where executing the environment step can be cpu-intensive. +Some libraries such as `gym3 `_ or `EnvPool `_ +offer interfaces to execute batches of environments simultaneously. +While they often offer a very competitive computational advantage, they do not +necessarily scale to the wide variety of environment libraries supported by TorchRL. +Therefore, TorchRL offers its own, generic :obj:`ParallelEnv` class to run multiple +environments in parallel. +As this class inherits from :obj:`EnvBase`, it enjoys the exact same API as other environment. +Of course, a :obj:`ParallelEnv` will have a batch size that corresponds to its environment count: + +.. code-block:: + :caption: Parallel environment + + >>> def make_env(): + ... return GymEnv("Pendulum-v1", from_pixels=True, g=9.81, device="cuda:0") + >>> env = ParallelEnv(4, make_env) + >>> print(env.batch_size) + torch.Size([4]) + +:obj:`ParallelEnv` allows to retrieve the attributes from its contained environments: +one can simply call: + +.. code-block:: + :caption: Parallel environment attributes + + >>> a, b, c, d = env.g # gets the g-force of the various envs, which we set to 9.81 before + >>> print(a) + 9.81 + +It is also possible to reset some but not all of the environments: + +.. code-block:: + :caption: Parallel environment reset + + >>> tensordict = TensorDict({"reset_workers": [True, False, True, True]}, [4]) + >>> env.reset(tensordict) + TensorDict( + fields={ + done: Tensor(torch.Size([4, 1]), dtype=torch.bool), + pixels: Tensor(torch.Size([4, 500, 500, 3]), dtype=torch.uint8), + reset_workers: Tensor(torch.Size([4, 1]), dtype=torch.bool)}, + batch_size=torch.Size([4]), + device=None, + is_shared=True) + + +A note on performance: launching a :obj:`ParallelEnv` can take quite some time +as it requires to launch as many python instances as there are processes. Due to +the time that it takes to run :obj:`import torch` (and other imports), starting the +parallel env can be a bottleneck. This is why, for instance, TorchRL tests are so slow. +Once the environment is launched, a great speedup should be observed. + +We also offer the :obj:`SerialEnv` class that enjoys the exact same API but is executed +serially. This is mostly useful for testing purposes, when one wants to assess the +behaviour of a :obj:`ParallelEnv` without launching the subprocesses. .. autosummary:: :toctree: generated/ - :template: rl_template_fun.rst + :template: rl_template.rst + + SerialEnv + ParallelEnv - gym.GymEnv - gym.GymWrapper - dm_control.DMControlEnv - dm_control.DMControlWrapper Transforms ---------- @@ -58,6 +137,49 @@ In most cases, the raw output of an environment must be treated before being pas policy or a value operator). To do this, TorchRL provides a set of transforms that aim at reproducing the transform logic of `torch.distributions.Transform` and `torchvision.transforms`. +Transformed environments are build using the :doc:`TransformedEnv` primitive. +Composed transforms are built using the :doc:`Compose` class: + +.. code-block:: + :caption: Transformed environment + + >>> base_env = GymEnv("Pendulum-v1", from_pixels=True, device="cuda:0") + >>> transform = Compose(ToTensorImage(in_keys=["pixels"]), Resize(64, 64, in_keys=["pixels"])) + >>> env = TransformedEnv(base_env, transform) + + +By default, the transformed environment will inherit the device of the +:obj:`base_env` that is passed to it. The transforms will then be executed on that device. +It is now apparent that this can bring a significant speedup depending on the kind of +operations that is to be computed. + +A great advantage of environment wrappers is that one can consult the environment up to that wrapper. +The same can be achieved with TorchRL transformed environments: the :doc:`parent` attribute will +return a new :obj:`TransformedEnv` with all the transforms up to the transform of interest. +Re-using the example above: + +.. code-block:: + :caption: Transform parent + + >>> resize_parent = env.transform[-1].parent # returns the same as TransformedEnv(base_env, transform[:-1]) + + +Transformed environment can be used with vectorized environments. +Since each transform uses a :doc:`"in_keys"`/:doc:`"out_keys"` set of keyword argument, it is +also easy to root the transform graph to each component of the observation data (e.g. +pixels or states etc). + +Transforms also have an :doc:`inv` method that is called before +the action is applied in reverse order over the composed transform chain: +this allows to apply transforms to data in the environment before the action is taken +in the environment. The keys to be included in this inverse transform are passed through the +:doc:`"in_keys_inv"` keyword argument: + +.. code-block:: + :caption: Inverse transform + + >>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step + .. autosummary:: :toctree: generated/ @@ -88,3 +210,41 @@ logic of `torch.distributions.Transform` and `torchvision.transforms`. TensorDictPrimer R3MTransform VIPTransform + +Helpers +------- +.. currentmodule:: torchrl.envs.utils + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + step_mdp + get_available_libraries + set_exploration_mode + exploration_mode + +Domain-specific +--------------- +.. currentmodule:: torchrl.envs + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + ModelBasedEnvBase + model_based.dreamer.DreamerEnv + + +Libraries +--------- +.. currentmodule:: torchrl.envs.libs + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + gym.GymEnv + gym.GymWrapper + dm_control.DMControlEnv + dm_control.DMControlWrapper From 7c36de67c940151602b4a012c72d8c81317dbb87 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Nov 2022 18:52:50 +0000 Subject: [PATCH 21/33] [Doc] Fix missing tensordict install for doc (#685) --- setup.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c15266610b5..bba67f6c94f 100644 --- a/setup.py +++ b/setup.py @@ -202,7 +202,13 @@ def _main(argv): "build_ext": BuildExtension.with_options(no_python_abi_suffix=True), "clean": clean, }, - install_requires=[pytorch_package_dep, "numpy", "packaging", "cloudpickle"], + install_requires=[ + pytorch_package_dep, + "numpy", + "packaging", + "cloudpickle", + "tensordict-nightly", + ], extras_require={ "atari": [ "gym<=0.24", From 509276f77732aafb661b0ac5a77dbdad82c47333 Mon Sep 17 00:00:00 2001 From: albert bou Date: Fri, 18 Nov 2022 12:56:12 +0100 Subject: [PATCH 22/33] model config fix --- torchrl/trainers/helpers/models.py | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index e2fb2074215..80d9b33239a 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -1830,26 +1830,7 @@ class PPOModelConfig: @dataclass class A2CModelConfig: - """PPO model config struct.""" - - gSDE: bool = False - # if True, exploration is achieved using the gSDE technique. - tanh_loc: bool = False - # if True, uses a Tanh-Normal transform for the policy location of the form - # upscale * tanh(loc/upscale) (only available with TanhTransform and TruncatedGaussian distributions) - default_policy_scale: float = 1.0 - # Default policy scale parameter - distribution: str = "tanh_normal" - # if True, uses a Tanh-Normal-Tanh distribution for the policy - lstm: bool = False - # if True, uses an LSTM for the policy. - shared_mapping: bool = False - # if True, the first layers of the actor-critic are shared. - - -@dataclass -class PPOModelConfig: - """PPO model config struct.""" + """A2C model config struct.""" gSDE: bool = False # if True, exploration is achieved using the gSDE technique. From 0a0fca8a975cf8515217a0ecead5ac0aa2f02b45 Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Mon, 21 Nov 2022 16:37:03 +0100 Subject: [PATCH 23/33] formatting --- torchrl/objectives/__init__.py | 2 +- torchrl/objectives/a2c.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index d3b99bb23e3..5b502e25dfd 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -3,11 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .a2c import A2CLoss from .common import LossModule from .ddpg import DDPGLoss from .dqn import DQNLoss, DistributionalDQNLoss from .dreamer import DreamerValueLoss, DreamerActorLoss, DreamerModelLoss -from .a2c import A2CLoss from .ppo import PPOLoss, ClipPPOLoss, KLPENPPOLoss from .redq import REDQLoss from .sac import SACLoss diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 58b09efcf65..1ad19fdf4f3 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -8,13 +8,13 @@ from typing import Callable, Optional, Tuple import torch +from tensordict.tensordict import TensorDict, TensorDictBase from torch import distributions as d -from tensordict.tensordict import TensorDictBase, TensorDict from torchrl.modules import TensorDictModule -from torchrl.objectives.utils import distance_loss from torchrl.modules.tensordict_module import ProbabilisticTensorDictModule from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import distance_loss class A2CLoss(LossModule): @@ -135,7 +135,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = tensordict.clone() advantage = tensordict.get(self.advantage_key) log_probs, dist = self._log_probs(tensordict) - loss = - (log_probs * advantage) + loss = -(log_probs * advantage) td_out = TensorDict({"loss_objective": loss.mean()}, []) if self.entropy_bonus: entropy = self.get_entropy_bonus(dist) @@ -145,4 +145,3 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_critic = self.loss_critic(tensordict).mean() td_out.set("loss_critic", loss_critic.mean()) return td_out - From ac5b8570139083d35e4264a58d2ce86f2d75bd92 Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Mon, 21 Nov 2022 16:41:31 +0100 Subject: [PATCH 24/33] formatting --- examples/a2c/a2c.py | 13 ++++--------- torchrl/objectives/a2c.py | 2 -- torchrl/trainers/helpers/losses.py | 8 ++++---- torchrl/trainers/helpers/models.py | 26 +++++++++++--------------- 4 files changed, 19 insertions(+), 30 deletions(-) diff --git a/examples/a2c/a2c.py b/examples/a2c/a2c.py index 3e50a48bbd8..f56e567140f 100644 --- a/examples/a2c/a2c.py +++ b/examples/a2c/a2c.py @@ -12,28 +12,23 @@ import hydra import torch.cuda from hydra.core.config_store import ConfigStore -from torchrl.envs import ParallelEnv, EnvCreator -from torchrl.envs.transforms import RewardScaling, TransformedEnv +from torchrl.envs.transforms import RewardScaling from torchrl.envs.utils import set_exploration_mode from torchrl.objectives.value import TDEstimate -from torchrl.record import VideoRecorder from torchrl.trainers.helpers.collectors import ( make_collector_onpolicy, OnPolicyCollectorConfig, ) from torchrl.trainers.helpers.envs import ( correct_for_frame_skip, + EnvConfig, get_stats_random_rollout, parallel_env_constructor, transformed_env_constructor, - EnvConfig, ) from torchrl.trainers.helpers.logger import LoggerConfig -from torchrl.trainers.helpers.losses import make_a2c_loss, A2CLossConfig -from torchrl.trainers.helpers.models import ( - make_a2c_model, - A2CModelConfig, -) +from torchrl.trainers.helpers.losses import A2CLossConfig, make_a2c_loss +from torchrl.trainers.helpers.models import A2CModelConfig, make_a2c_model from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig config_fields = [ diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 1ad19fdf4f3..7ba7134fde0 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -3,8 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import math -import warnings from typing import Callable, Optional, Tuple import torch diff --git a/torchrl/trainers/helpers/losses.py b/torchrl/trainers/helpers/losses.py index cc63bd69746..92e304a895f 100644 --- a/torchrl/trainers/helpers/losses.py +++ b/torchrl/trainers/helpers/losses.py @@ -4,17 +4,17 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Optional, Tuple, Any +from typing import Any, Optional, Tuple -from torchrl.modules import ActorValueOperator, ActorCriticOperator +from torchrl.modules import ActorCriticOperator, ActorValueOperator from torchrl.objectives import ( + A2CLoss, ClipPPOLoss, DDPGLoss, DistributionalDQNLoss, DQNLoss, HardUpdate, KLPENPPOLoss, - A2CLoss, PPOLoss, SACLoss, SoftUpdate, @@ -193,7 +193,7 @@ def make_a2c_loss(model, cfg) -> A2CLoss: "critic": critic_model, "loss_critic_type": cfg.critic_loss_function, "entropy_coef": cfg.entropy_coef, - "advantage_module": advantage + "advantage_module": advantage, } loss_module = A2CLoss(**kwargs) diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 80d9b33239a..0f8a93deb80 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -7,24 +7,24 @@ from typing import Optional, Sequence import torch -from torch import nn, distributions as d +from torch import distributions as d, nn from torchrl.data import ( CompositeSpec, - NdUnboundedContinuousTensorSpec, DiscreteTensorSpec, + NdUnboundedContinuousTensorSpec, ) from torchrl.data.utils import DEVICE_TYPING -from torchrl.envs import TransformedEnv, TensorDictPrimer +from torchrl.envs import TensorDictPrimer, TransformedEnv from torchrl.envs.common import EnvBase from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.utils import set_exploration_mode from torchrl.modules import ( ActorValueOperator, NoisyLinear, - TensorDictModule, - ProbabilisticTensorDictModule, NormalParamWrapper, + ProbabilisticTensorDictModule, + TensorDictModule, TensorDictSequential, ) from torchrl.modules.distributions import ( @@ -34,16 +34,14 @@ TanhNormal, TruncatedNormal, ) -from torchrl.modules.distributions.continuous import ( - SafeTanhTransform, -) +from torchrl.modules.distributions.continuous import SafeTanhTransform from torchrl.modules.models.exploration import LazygSDEModule from torchrl.modules.models.model_based import ( DreamerActor, - ObsEncoder, ObsDecoder, - RSSMPrior, + ObsEncoder, RSSMPosterior, + RSSMPrior, RSSMRollout, ) from torchrl.modules.models.models import ( @@ -53,9 +51,9 @@ DdpgMlpActor, DdpgMlpQNet, DuelingCnnDQNet, + DuelingMlpDQNet, LSTMNet, MLP, - DuelingMlpDQNet, ) from torchrl.modules.tensordict_module import ( Actor, @@ -64,12 +62,10 @@ ) from torchrl.modules.tensordict_module.actors import ( ActorCriticWrapper, - ValueOperator, ProbabilisticActor, + ValueOperator, ) -from torchrl.modules.tensordict_module.world_models import ( - WorldModelWrapper, -) +from torchrl.modules.tensordict_module.world_models import WorldModelWrapper from torchrl.trainers.helpers import transformed_env_constructor DISTRIBUTIONS = { From 34b411f3650ce44dcacfdc9ca38b320739342a14 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 21 Nov 2022 16:52:50 +0100 Subject: [PATCH 25/33] a2c runtime error comment change --- torchrl/objectives/a2c.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 7ba7134fde0..5f1ffe01618 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -89,7 +89,7 @@ def _log_probs( # current log_prob of actions action = tensordict.get("action") if action.requires_grad: - raise RuntimeError("tensordict stored action requires grad.") + raise RuntimeError("tensordict stored action require grad.") tensordict_clone = tensordict.select(*self.actor.in_keys).clone() dist, *_ = self.actor.get_dist( @@ -104,7 +104,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: advantage_diff = tensordict.get(self.advantage_diff_key) if not advantage_diff.requires_grad: raise RuntimeError( - "value_target retrieved from tensordict does not requires grad." + "value_target retrieved from tensordict does not require grad." ) loss_value = distance_loss( advantage_diff, From 5068059ad3f1a5cdb535bf14be22b025c1038e4c Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 21 Nov 2022 17:19:39 +0100 Subject: [PATCH 26/33] a2c test --- test/test_cost.py | 241 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 237 insertions(+), 4 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index af513a2d55e..5a771511a89 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -65,6 +65,7 @@ ValueOperator, ) from torchrl.objectives import ( + A2CLoss, ClipPPOLoss, DDPGLoss, DistributionalDQNLoss, @@ -1617,13 +1618,12 @@ def _create_seq_mock_data_ppo( ) return td - @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda")) @pytest.mark.parametrize("device", get_available_devices()) - def test_ppo(self, loss_class, device, gradient_mode, advantage): + def test_ppo(self, device, gradient_mode, advantage): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) + td = self._create_seq_mock_data_a2c(device=device) actor = self._create_mock_actor(device=device) value = self._create_mock_value(device=device) @@ -1642,7 +1642,7 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage): else: raise NotImplementedError - loss_fn = loss_class( + loss_fn = A2CLoss( actor, value, advantage_module=advantage, gamma=0.9, loss_critic_type="l2" ) @@ -1756,6 +1756,239 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): param.grad = None +class TestA2C: + seed = 0 + + def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # Actor + action_spec = NdBoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + module = TensorDictModule( + net, in_keys=["observation"], out_keys=["loc", "scale"] + ) + actor = ProbabilisticActor( + module=module, + distribution_class=TanhNormal, + dist_in_keys=["loc", "scale"], + spec=CompositeSpec(action=action_spec, loc=None, scale=None), + ) + return actor.to(device) + + def _create_mock_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + module = nn.Linear(obs_dim, 1) + value = ValueOperator( + module=module, + in_keys=["observation"], + ) + return value.to(device) + + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=0, vmin=1, vmax=5 + ): + raise NotImplementedError + + def _create_mock_data_a2c( + self, batch=2, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + next_obs = torch.randn(batch, obs_dim, device=device) + if atoms: + raise NotImplementedError + else: + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, 1, device=device) + done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + "observation": obs, + "next": {"observation": next_obs}, + "done": done, + "reward": reward, + "action": action, + "sample_log_prob": torch.randn_like(action[..., :1]) / 10, + }, + device=device, + ) + return td + + def _create_seq_mock_data_a2c( + self, batch=2, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + if atoms: + action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( + -1, 1 + ) + else: + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + params_mean = torch.randn_like(action) / 10 + params_scale = torch.rand_like(action) / 10 + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs * mask.to(obs.dtype), + "next": {"observation": next_obs * mask.to(obs.dtype)}, + "done": done, + "mask": mask, + "reward": reward * mask.to(obs.dtype), + "action": action * mask.to(obs.dtype), + "sample_log_prob": torch.randn_like(action[..., :1]) + / 10 + * mask.to(obs.dtype), + "loc": params_mean * mask.to(obs.dtype), + "scale": params_scale * mask.to(obs.dtype), + }, + device=device, + ) + return td + + @pytest.mark.parametrize("gradient_mode", (True, False)) + @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda")) + @pytest.mark.parametrize("device", get_available_devices()) + def test_a2c(self, loss_class, device, gradient_mode, advantage): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_ac2(device=device) + + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + if advantage == "gae": + advantage = GAE( + gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode + ) + elif advantage == "td": + advantage = TDEstimate( + gamma=0.9, value_network=value, gradient_mode=gradient_mode + ) + elif advantage == "td_lambda": + advantage = TDLambdaEstimate( + gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode + ) + else: + raise NotImplementedError + + loss_fn = A2CLoss( + actor, value, advantage_module=advantage, gamma=0.9, loss_critic_type="l2" + ) + + loss = loss_fn(td) + loss_critic = loss["loss_critic"] + loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) + loss_critic.backward(retain_graph=True) + # check that grads are independent and non null + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" not in name + assert "critic" in name + if p.grad is None: + assert "actor" in name + assert "critic" not in name + + value.zero_grad() + loss_objective.backward() + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "critic" not in name + if p.grad is None: + assert "actor" not in name + assert "critic" in name + actor.zero_grad() + + @pytest.mark.parametrize("gradient_mode", (True, False)) + @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda")) + @pytest.mark.parametrize("device", get_available_devices()) + def test_a2c_diff(self, device, gradient_mode, advantage): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_a2c(device=device) + + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + if advantage == "gae": + advantage = GAE( + gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode + ) + elif advantage == "td": + advantage = TDEstimate( + gamma=0.9, value_network=value, gradient_mode=gradient_mode + ) + elif advantage == "td_lambda": + advantage = TDLambdaEstimate( + gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode + ) + else: + raise NotImplementedError + + loss_fn = A2CLoss( + actor, value, advantage_module=advantage, gamma=0.9, loss_critic_type="l2" + ) + + floss_fn, params, buffers = make_functional_with_buffers(loss_fn) + + loss = floss_fn(params, buffers, td) + loss_critic = loss["loss_critic"] + loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) + loss_critic.backward(retain_graph=True) + # check that grads are independent and non null + named_parameters = loss_fn.named_parameters() + if _has_functorch: + for (name, _), p in zip(named_parameters, params): + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" not in name + assert "critic" in name + if p.grad is None: + assert "actor" in name + assert "critic" not in name + else: + for key, p in params.flatten_keys(".").items(): + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" not in key + assert "value" in key or "critic" in key + if p.grad is None: + assert "actor" in key + assert "value" not in key and "critic" not in key + + if _has_functorch: + for param in params: + param.grad = None + else: + for param in params.flatten_keys(".").values(): + param.grad = None + loss_objective.backward() + named_parameters = loss_fn.named_parameters() + if _has_functorch: + for (name, _), p in zip(named_parameters, params): + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "critic" not in name + if p.grad is None: + assert "actor" not in name + assert "critic" in name + for param in params: + param.grad = None + else: + for key, p in params.flatten_keys(".").items(): + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in key + assert "value" not in key and "critic" not in key + if p.grad is None: + assert "actor" not in key + assert "value" in key or "critic" in key + for param in params.flatten_keys(".").values(): + param.grad = None + + class TestReinforce: @pytest.mark.parametrize("delay_value", [True, False]) @pytest.mark.parametrize("gradient_mode", [True, False]) From f04b0f9af578fdd7daab32fbd9f5742c47490bc7 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 21 Nov 2022 17:54:33 +0100 Subject: [PATCH 27/33] a2c test --- test/test_cost.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 5a771511a89..3296f3ce099 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -1618,12 +1618,13 @@ def _create_seq_mock_data_ppo( ) return td + @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda")) @pytest.mark.parametrize("device", get_available_devices()) - def test_ppo(self, device, gradient_mode, advantage): + def test_ppo(self, loss_class, device, gradient_mode, advantage): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_a2c(device=device) + td = self._create_seq_mock_data_ppo(device=device) actor = self._create_mock_actor(device=device) value = self._create_mock_value(device=device) @@ -1642,7 +1643,7 @@ def test_ppo(self, device, gradient_mode, advantage): else: raise NotImplementedError - loss_fn = A2CLoss( + loss_fn = loss_class( actor, value, advantage_module=advantage, gamma=0.9, loss_critic_type="l2" ) From f4b2289ea6e36d3bfdc3e2b454e205a8e4d4431b Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 21 Nov 2022 18:32:42 +0100 Subject: [PATCH 28/33] a2c test --- test/test_cost.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 3296f3ce099..2e45332f2b9 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -1785,11 +1785,6 @@ def _create_mock_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): ) return value.to(device) - def _create_mock_distributional_actor( - self, batch=2, obs_dim=3, action_dim=4, atoms=0, vmin=1, vmax=5 - ): - raise NotImplementedError - def _create_mock_data_a2c( self, batch=2, obs_dim=3, action_dim=4, atoms=None, device="cpu" ): @@ -1856,9 +1851,9 @@ def _create_seq_mock_data_a2c( @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda")) @pytest.mark.parametrize("device", get_available_devices()) - def test_a2c(self, loss_class, device, gradient_mode, advantage): + def test_a2c(self, device, gradient_mode, advantage): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ac2(device=device) + td = self._create_seq_mock_data_a2c(device=device) actor = self._create_mock_actor(device=device) value = self._create_mock_value(device=device) From e39d7ffd6c83b8ccc4aae588afce093d2ffd3ca5 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 22 Nov 2022 09:59:57 +0100 Subject: [PATCH 29/33] make a2c model test --- test/test_helpers.py | 127 +++++++++++++++++++++++++++++ torchrl/trainers/helpers/models.py | 4 +- 2 files changed, 129 insertions(+), 2 deletions(-) diff --git a/test/test_helpers.py b/test/test_helpers.py index bec3757fce0..53895ba7a80 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -37,9 +37,11 @@ DiscreteModelConfig, make_ddpg_actor, make_dqn_actor, + make_a2c_model, make_ppo_model, make_redq_model, make_sac_model, + A2CModelConfig, PPOModelConfig, REDQModelConfig, SACModelConfig, @@ -355,6 +357,131 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration): del proof_environment +@pytest.mark.skipif(not _has_hydra, reason="No hydra library found") +@pytest.mark.skipif(not _has_gym, reason="No gym library found") +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("from_pixels", [tuple(), ("from_pixels=True", "catframes=4")]) +@pytest.mark.parametrize("gsde", [tuple(), ("gSDE=True",)]) +@pytest.mark.parametrize("shared_mapping", [tuple(), ("shared_mapping=True",)]) +@pytest.mark.parametrize("exploration", ["random", "mode"]) +def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration): + if not gsde and exploration != "random": + pytest.skip("no need to test this setting") + flags = list(from_pixels + shared_mapping + gsde) + config_fields = [ + (config_field.name, config_field.type, config_field) + for config_cls in ( + EnvConfig, + A2CModelConfig, + ) + for config_field in dataclasses.fields(config_cls) + ] + + Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) + cs = ConfigStore.instance() + cs.store(name="config", node=Config) + + with initialize(version_base=None, config_path=None): + cfg = compose(config_name="config", overrides=flags) + # if gsde and from_pixels: + # pytest.skip("gsde and from_pixels are incompatible") + + env_maker = ( + ContinuousActionConvMockEnvNumpy + if from_pixels + else ContinuousActionVecMockEnv + ) + env_maker = transformed_env_constructor( + cfg, use_env_creator=False, custom_env_maker=env_maker + ) + proof_environment = env_maker() + + if cfg.from_pixels and not cfg.shared_mapping: + with pytest.raises( + RuntimeError, + match="A2C learnt from pixels require the shared_mapping to be set to True", + ): + actor_value = make_a2c_model( + proof_environment, + device=device, + cfg=cfg, + ) + return + + actor_value = make_a2c_model( + proof_environment, + device=device, + cfg=cfg, + ) + actor = actor_value.get_policy_operator() + expected_keys = [ + "done", + "pixels" if len(from_pixels) else "observation_vector", + "pixels_orig" if len(from_pixels) else "observation_orig", + "action", + "sample_log_prob", + "loc", + "scale", + ] + if shared_mapping: + expected_keys += ["hidden"] + if len(gsde): + expected_keys += ["_eps_gSDE"] + + td = proof_environment.reset().to(device) + td_clone = td.clone() + with set_exploration_mode(exploration): + if UNSQUEEZE_SINGLETON and not td_clone.ndimension(): + # Linear and conv used to break for non-batched data + actor(td_clone.unsqueeze(0)) + else: + actor(td_clone) + + try: + _assert_keys_match(td_clone, expected_keys) + except AssertionError: + proof_environment.close() + raise + + if cfg.gSDE: + if cfg.shared_mapping: + tsf_loc = actor[-1].module[-1].module.transform(td_clone.get("loc")) + else: + tsf_loc = actor.module[-1].module.transform(td_clone.get("loc")) + + if exploration == "random": + with pytest.raises(AssertionError): + torch.testing.assert_close(td_clone.get("action"), tsf_loc) + else: + torch.testing.assert_close(td_clone.get("action"), tsf_loc) + + value = actor_value.get_value_operator() + expected_keys = [ + "done", + "pixels" if len(from_pixels) else "observation_vector", + "pixels_orig" if len(from_pixels) else "observation_orig", + "state_value", + ] + if shared_mapping: + expected_keys += ["hidden"] + if len(gsde): + expected_keys += ["_eps_gSDE"] + + td_clone = td.clone() + if UNSQUEEZE_SINGLETON and not td_clone.ndimension(): + # Linear and conv used to break for non-batched data + value(td_clone.unsqueeze(0)) + else: + value(td_clone) + try: + _assert_keys_match(td_clone, expected_keys) + except AssertionError: + proof_environment.close() + raise + proof_environment.close() + del proof_environment + + @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") @pytest.mark.skipif(not _has_gym, reason="No gym library found") @pytest.mark.parametrize("device", get_available_devices()) diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 0f8a93deb80..a3bb04af184 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -539,7 +539,7 @@ def make_a2c_model( else: if cfg.lstm: raise NotImplementedError( - "lstm not yet compatible with shared mapping for PPO" + "lstm not yet compatible with shared mapping for A2C" ) common_module = MLP( num_cells=[ @@ -619,7 +619,7 @@ def make_a2c_model( else: if cfg.from_pixels: raise RuntimeError( - "PPO learnt from pixels require the shared_mapping to be set to True." + "A2C learnt from pixels require the shared_mapping to be set to True." ) if cfg.lstm: policy_net = LSTMNet( From 859ccd9231d7a56cd9c4466299ff8f46ae877a22 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 22 Nov 2022 12:24:14 +0100 Subject: [PATCH 30/33] increase a2c tests coverage --- test/test_cost.py | 48 ++++++++++++++++++++------------------------ test/test_helpers.py | 12 +++++++++++ 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 2e45332f2b9..3d4da8d10ce 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -1785,32 +1785,6 @@ def _create_mock_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): ) return value.to(device) - def _create_mock_data_a2c( - self, batch=2, obs_dim=3, action_dim=4, atoms=None, device="cpu" - ): - # create a tensordict - obs = torch.randn(batch, obs_dim, device=device) - next_obs = torch.randn(batch, obs_dim, device=device) - if atoms: - raise NotImplementedError - else: - action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) - reward = torch.randn(batch, 1, device=device) - done = torch.zeros(batch, 1, dtype=torch.bool, device=device) - td = TensorDict( - batch_size=(batch,), - source={ - "observation": obs, - "next": {"observation": next_obs}, - "done": done, - "reward": reward, - "action": action, - "sample_log_prob": torch.randn_like(action[..., :1]) / 10, - }, - device=device, - ) - return td - def _create_seq_mock_data_a2c( self, batch=2, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" ): @@ -1876,6 +1850,25 @@ def test_a2c(self, device, gradient_mode, advantage): actor, value, advantage_module=advantage, gamma=0.9, loss_critic_type="l2" ) + # Check error is raised when actions require grads + td["action"].requires_grad = True + with pytest.raises( + RuntimeError, + match="tensordict stored action require grad.", + ): + loss = loss_fn._log_probs(td) + td["action"].requires_grad = False + + # Check error is raised when advantage_diff_key present and does not required grad + td[loss_fn.advantage_diff_key] = torch.randn_like(td["reward"]) + with pytest.raises( + RuntimeError, + match="value_target retrieved from tensordict does not require grad.", + ): + loss = loss_fn.loss_critic(td) + td = td.exclude(loss_fn.advantage_diff_key) + assert loss_fn.advantage_diff_key not in td.keys() + loss = loss_fn(td) loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) @@ -1902,6 +1895,9 @@ def test_a2c(self, device, gradient_mode, advantage): assert "critic" in name actor.zero_grad() + # test reset + loss_fn.reset() + @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda")) @pytest.mark.parametrize("device", get_available_devices()) diff --git a/test/test_helpers.py b/test/test_helpers.py index 53895ba7a80..ed4dd139a07 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -6,6 +6,7 @@ import argparse import dataclasses from time import sleep +from omegaconf import open_dict import pytest import torch @@ -48,6 +49,10 @@ DreamerConfig, make_dreamer, ) +from torchrl.trainers.helpers.losses import ( + make_a2c_loss, + A2CLossConfig, +) TORCH_VERSION = version.parse(torch.__version__) if TORCH_VERSION < version.parse("1.12.0"): @@ -372,6 +377,7 @@ def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration): (config_field.name, config_field.type, config_field) for config_cls in ( EnvConfig, + A2CLossConfig, A2CModelConfig, ) for config_field in dataclasses.fields(config_cls) @@ -481,6 +487,12 @@ def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration): proof_environment.close() del proof_environment + with open_dict(cfg): + cfg.advantage_in_loss = True + loss_fn = make_a2c_loss(actor_value, cfg) + cfg.advantage_in_loss = False + loss_fn = make_a2c_loss(actor_value, cfg) + @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") @pytest.mark.skipif(not _has_gym, reason="No gym library found") From 1eaa4c54d1b33c54188efa29c07ff0f3772a5b8b Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 22 Nov 2022 12:30:37 +0100 Subject: [PATCH 31/33] formatting --- examples/a2c/config.yaml | 58 ++++++++++++++++++++-------------------- test/test_cost.py | 23 +++++++--------- test/test_helpers.py | 15 +++++------ 3 files changed, 45 insertions(+), 51 deletions(-) diff --git a/examples/a2c/config.yaml b/examples/a2c/config.yaml index b647dea1632..d3c4711d063 100644 --- a/examples/a2c/config.yaml +++ b/examples/a2c/config.yaml @@ -1,41 +1,41 @@ # Environment -env_library: gym # env_library used for the simulated environment. -env_name: HalfCheetah-v4 # name of the environment to be created. Default=Humanoid-v2 -frame_skip: 2 # frame_skip for the environment. +env_library: gym # env_library used for the simulated environment. +env_name: HalfCheetah - v4 # name of the environment to be created. Default=Humanoid-v2 +frame_skip: 2 # frame_skip for the environment. # Logger -logger: wandb # recorder type to be used. One of 'tensorboard', 'wandb' or 'csv' -record_video: False # whether a video of the task should be rendered during logging. -exp_name: A2C # experiment name. Used for logging directory. -record_interval: 100 # number of batch collections in between two collections of validation rollouts. Default=1000. +logger: wandb # recorder type to be used. One of 'tensorboard', 'wandb' or 'csv' +record_video: False # whether a video of the task should be rendered during logging. +exp_name: A2C # experiment name. Used for logging directory. +record_interval: 100 # number of batch collections in between two collections of validation rollouts. Default=1000. # Collector -frames_per_batch: 64 # Number of steps executed in the environment per collection. -total_frames: 2_000_000 # total number of frames collected for training. Does account for frame_skip. -num_workers: 2 # Number of workers used for data collection. -env_per_collector: 2 # Number of environments per collector. If the env_per_collector is in the range: +frames_per_batch: 64 # Number of steps executed in the environment per collection. +total_frames: 2_000_000 # total number of frames collected for training. Does account for frame_skip. +num_workers: 2 # Number of workers used for data collection. +env_per_collector: 2 # Number of environments per collector. If the env_per_collector is in the range: # Model -default_policy_scale: 1.0 # Default policy scale parameter -distribution: tanh_normal # if True, uses a Tanh-Normal-Tanh distribution for the policy -lstm: False # if True, uses an LSTM for the policy. -shared_mapping: False # if True, the first layers of the actor-critic are shared. +default_policy_scale: 1.0 # Default policy scale parameter +distribution: tanh_normal # if True, uses a Tanh-Normal-Tanh distribution for the policy +lstm: False # if True, uses an LSTM for the policy. +shared_mapping: False # if True, the first layers of the actor-critic are shared. # Objective gamma: 0.99 -entropy_coef: 0.01 # Entropy factor for the A2C loss -critic_coef: 0.25 # Critic factor for the A2C loss -critic_loss_function: l2 # loss function for the value network. Either one of l1, l2 or smooth_l1 (default). -advantage_in_loss: False # if True, the advantage is computed on the sub-batch +entropy_coef: 0.01 # Entropy factor for the A2C loss +critic_coef: 0.25 # Critic factor for the A2C loss +critic_loss_function: l2 # loss function for the value network. Either one of l1, l2 or smooth_l1 (default). +advantage_in_loss: False # if True, the advantage is computed on the sub-batch # Trainer -optim_steps_per_batch: 1 # Number of optimization steps in between two collection of data. -optimizer: adam # Optimizer to be used. -lr_scheduler: "" # LR scheduler. -batch_size: 64 # batch size of the TensorDict retrieved from the replay buffer. Default=256. -log_interval: 1 # logging interval, in terms of optimization steps. Default=10000. -lr: 0.0007 # Learning rate used for the optimizer. Default=3e-4. -normalize_rewards_online: True # Computes the running statistics of the rewards and normalizes them before they are passed to the loss module. -normalize_rewards_online_scale: 1.0 # Final scale of the normalized rewards. -normalize_rewards_online_decay: 0.0 # Decay of the reward moving averaging -sub_traj_len: 64 # length of the trajectories that sub-samples must have in online settings. +optim_steps_per_batch: 1 # Number of optimization steps in between two collection of data. +optimizer: adam # Optimizer to be used. +lr_scheduler: "" # LR scheduler. +batch_size: 64 # batch size of the TensorDict retrieved from the replay buffer. Default=256. +log_interval: 1 # logging interval, in terms of optimization steps. Default=10000. +lr: 0.0007 # Learning rate used for the optimizer. Default=3e-4. +normalize_rewards_online: True # Computes the running statistics of the rewards and normalizes them before they are passed to the loss module. +normalize_rewards_online_scale: 1.0 # Final scale of the normalized rewards. +normalize_rewards_online_decay: 0.0 # Decay of the reward moving averaging +sub_traj_len: 64 # length of the trajectories that sub-samples must have in online settings. diff --git a/test/test_cost.py b/test/test_cost.py index 3d4da8d10ce..0a5b1785128 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -21,40 +21,40 @@ import numpy as np import pytest import torch -from _utils_internal import get_available_devices, dtype_fixture # noqa +from _utils_internal import dtype_fixture, get_available_devices # noqa from mocking_classes import ContinuousActionConvMockEnv # from torchrl.data.postprocs.utils import expand_as_right -from tensordict.tensordict import assert_allclose_td, TensorDictBase, TensorDict +from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase from tensordict.utils import expand_as_right from torch import autograd, nn from torchrl.data import ( CompositeSpec, + DiscreteTensorSpec, MultOneHotDiscreteTensorSpec, NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, - DiscreteTensorSpec, ) from torchrl.data.postprocs.postprocs import MultiStep from torchrl.envs.model_based.dreamer import DreamerEnv -from torchrl.envs.transforms import TransformedEnv, TensorDictPrimer +from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv from torchrl.modules import ( DistributionalQValueActor, + ProbabilisticTensorDictModule, QValueActor, TensorDictModule, TensorDictSequential, - ProbabilisticTensorDictModule, WorldModelWrapper, ) from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal from torchrl.modules.models.model_based import ( + DreamerActor, ObsDecoder, ObsEncoder, RSSMPosterior, RSSMPrior, RSSMRollout, - DreamerActor, ) from torchrl.modules.models.models import MLP from torchrl.modules.tensordict_module.actors import ( @@ -70,18 +70,15 @@ DDPGLoss, DistributionalDQNLoss, DQNLoss, + DreamerActorLoss, + DreamerModelLoss, + DreamerValueLoss, KLPENPPOLoss, PPOLoss, SACLoss, - DreamerModelLoss, - DreamerActorLoss, - DreamerValueLoss, ) from torchrl.objectives.common import LossModule -from torchrl.objectives.deprecated import ( - DoubleREDQLoss_deprecated, - REDQLoss_deprecated, -) +from torchrl.objectives.deprecated import DoubleREDQLoss_deprecated, REDQLoss_deprecated from torchrl.objectives.redq import REDQLoss from torchrl.objectives.reinforce import ReinforceLoss from torchrl.objectives.utils import HardUpdate, hold_out_net, SoftUpdate diff --git a/test/test_helpers.py b/test/test_helpers.py index ed4dd139a07..73c7e5ee6a7 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -6,11 +6,11 @@ import argparse import dataclasses from time import sleep -from omegaconf import open_dict import pytest import torch from _utils_internal import generate_seeds, get_available_devices +from omegaconf import open_dict from torchrl._utils import timeit try: @@ -33,25 +33,22 @@ from torchrl.modules.tensordict_module.common import _has_functorch from torchrl.trainers.helpers import transformed_env_constructor from torchrl.trainers.helpers.envs import EnvConfig +from torchrl.trainers.helpers.losses import A2CLossConfig, make_a2c_loss from torchrl.trainers.helpers.models import ( + A2CModelConfig, DDPGModelConfig, DiscreteModelConfig, + DreamerConfig, + make_a2c_model, make_ddpg_actor, make_dqn_actor, - make_a2c_model, + make_dreamer, make_ppo_model, make_redq_model, make_sac_model, - A2CModelConfig, PPOModelConfig, REDQModelConfig, SACModelConfig, - DreamerConfig, - make_dreamer, -) -from torchrl.trainers.helpers.losses import ( - make_a2c_loss, - A2CLossConfig, ) TORCH_VERSION = version.parse(torch.__version__) From 7f78bf6249989025a331b74244a17dafc6957301 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 22 Nov 2022 13:05:50 +0100 Subject: [PATCH 32/33] fix bug a2c testing --- test/test_helpers.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/test_helpers.py b/test/test_helpers.py index 73c7e5ee6a7..fd8b05982fd 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -10,7 +10,6 @@ import pytest import torch from _utils_internal import generate_seeds, get_available_devices -from omegaconf import open_dict from torchrl._utils import timeit try: @@ -367,6 +366,7 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration): @pytest.mark.parametrize("shared_mapping", [tuple(), ("shared_mapping=True",)]) @pytest.mark.parametrize("exploration", ["random", "mode"]) def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration): + A2CModelConfig.advantage_in_loss = False if not gsde and exploration != "random": pytest.skip("no need to test this setting") flags = list(from_pixels + shared_mapping + gsde) @@ -484,11 +484,10 @@ def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration): proof_environment.close() del proof_environment - with open_dict(cfg): - cfg.advantage_in_loss = True - loss_fn = make_a2c_loss(actor_value, cfg) - cfg.advantage_in_loss = False - loss_fn = make_a2c_loss(actor_value, cfg) + cfg.advantage_in_loss = False + loss_fn = make_a2c_loss(actor_value, cfg) + cfg.advantage_in_loss = True + loss_fn = make_a2c_loss(actor_value, cfg) @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") From 60c273039fdbee390408f0a426ebf92df68b5e02 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 23 Nov 2022 10:08:12 +0100 Subject: [PATCH 33/33] minor fixes --- examples/a2c/config.yaml | 2 +- torchrl/trainers/helpers/models.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/a2c/config.yaml b/examples/a2c/config.yaml index d3c4711d063..f3c05c95c60 100644 --- a/examples/a2c/config.yaml +++ b/examples/a2c/config.yaml @@ -1,6 +1,6 @@ # Environment env_library: gym # env_library used for the simulated environment. -env_name: HalfCheetah - v4 # name of the environment to be created. Default=Humanoid-v2 +env_name: HalfCheetah-v4 # name of the environment to be created. Default=Humanoid-v2 frame_skip: 2 # frame_skip for the environment. # Logger diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index a3bb04af184..f0d21e13997 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -628,7 +628,7 @@ def make_a2c_model( mlp_kwargs={"num_cells": [64, 64], "out_features": 64}, ) in_keys_actor += ["hidden0", "hidden1"] - out_keys += ["hidden0", "hidden1", "next_hidden0", "next_hidden1"] + out_keys += ["hidden0", "hidden1", ("next", "hidden0"), ("next", "hidden1")] else: policy_net = MLP( num_cells=[64, 64],