diff --git a/examples/a2c/a2c.py b/examples/a2c/a2c.py index f56e567140f..af06c44c016 100644 --- a/examples/a2c/a2c.py +++ b/examples/a2c/a2c.py @@ -22,8 +22,9 @@ from torchrl.trainers.helpers.envs import ( correct_for_frame_skip, EnvConfig, - get_stats_random_rollout, + initialize_observation_norm_transforms, parallel_env_constructor, + retrieve_observation_norms_state_dict, transformed_env_constructor, ) from torchrl.trainers.helpers.logger import LoggerConfig @@ -92,21 +93,24 @@ def main(cfg: "DictConfig"): # noqa: F821 ) video_tag = exp_name if cfg.record_video else "" - stats = None + key, init_env_steps, stats = None, None, None if not cfg.vecnorm and cfg.norm_stats: - proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)() - stats = get_stats_random_rollout( - cfg, - proof_env, - key="pixels" if cfg.from_pixels else "observation_vector", - ) - # make sure proof_env is closed - proof_env.close() + if not hasattr(cfg, "init_env_steps"): + raise AttributeError("init_env_steps missing from arguments.") + key = "pixels" if cfg.from_pixels else "observation_vector" + init_env_steps = cfg.init_env_steps + stats = {"loc": None, "scale": None} elif cfg.from_pixels: stats = {"loc": 0.5, "scale": 0.5} proof_env = transformed_env_constructor( - cfg=cfg, use_env_creator=False, stats=stats + cfg=cfg, + use_env_creator=False, + stats=stats, )() + initialize_observation_norm_transforms( + proof_environment=proof_env, num_iter=init_env_steps, key=key + ) + _, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0] model = make_a2c_model( proof_env, @@ -128,7 +132,7 @@ def main(cfg: "DictConfig"): # noqa: F821 proof_env.close() create_env_fn = parallel_env_constructor( cfg=cfg, - stats=stats, + obs_norm_state_dict=obs_norm_state_dict, action_dim_gsde=action_dim_gsde, state_dim_gsde=state_dim_gsde, ) @@ -143,7 +147,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg, video_tag=video_tag, norm_obs_only=True, - stats=stats, + obs_norm_state_dict=obs_norm_state_dict, logger=logger, use_env_creator=False, )() diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index 69bb54fbbd6..b7c2d920574 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -24,8 +24,9 @@ from torchrl.trainers.helpers.envs import ( correct_for_frame_skip, EnvConfig, - get_stats_random_rollout, + initialize_observation_norm_transforms, parallel_env_constructor, + retrieve_observation_norms_state_dict, transformed_env_constructor, ) from torchrl.trainers.helpers.logger import LoggerConfig @@ -105,23 +106,25 @@ def main(cfg: "DictConfig"): # noqa: F821 ) video_tag = exp_name if cfg.record_video else "" - stats = None + key, init_env_steps, stats = None, None, None if not cfg.vecnorm and cfg.norm_stats: - proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)() - stats = get_stats_random_rollout( - cfg, - proof_env, - key=("next", "pixels") - if cfg.from_pixels - else ("next", "observation_vector"), - ) - # make sure proof_env is closed - proof_env.close() + if not hasattr(cfg, "init_env_steps"): + raise AttributeError("init_env_steps missing from arguments.") + key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector") + init_env_steps = cfg.init_env_steps + stats = {"loc": None, "scale": None} elif cfg.from_pixels: stats = {"loc": 0.5, "scale": 0.5} + proof_env = transformed_env_constructor( - cfg=cfg, use_env_creator=False, stats=stats + cfg=cfg, + stats=stats, + use_env_creator=False, )() + initialize_observation_norm_transforms( + proof_environment=proof_env, num_iter=init_env_steps, key=key + ) + _, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0] model = make_ddpg_actor( proof_env, @@ -154,9 +157,10 @@ def main(cfg: "DictConfig"): # noqa: F821 action_dim_gsde, state_dim_gsde = None, None proof_env.close() + create_env_fn = parallel_env_constructor( cfg=cfg, - stats=stats, + obs_norm_state_dict=obs_norm_state_dict, action_dim_gsde=action_dim_gsde, state_dim_gsde=state_dim_gsde, ) @@ -177,7 +181,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg, video_tag=video_tag, norm_obs_only=True, - stats=stats, + obs_norm_state_dict=obs_norm_state_dict, logger=logger, use_env_creator=False, )() diff --git a/examples/dqn/dqn.py b/examples/dqn/dqn.py index 96a81e533a4..329c032c733 100644 --- a/examples/dqn/dqn.py +++ b/examples/dqn/dqn.py @@ -23,8 +23,9 @@ from torchrl.trainers.helpers.envs import ( correct_for_frame_skip, EnvConfig, - get_stats_random_rollout, + initialize_observation_norm_transforms, parallel_env_constructor, + retrieve_observation_norms_state_dict, transformed_env_constructor, ) from torchrl.trainers.helpers.logger import LoggerConfig @@ -95,23 +96,25 @@ def main(cfg: "DictConfig"): # noqa: F821 ) video_tag = exp_name if cfg.record_video else "" - stats = None + key, init_env_steps, stats = None, None, None if not cfg.vecnorm and cfg.norm_stats: - proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)() - stats = get_stats_random_rollout( - cfg, - proof_env, - key=("next", "pixels") - if cfg.from_pixels - else ("next", "observation_vector"), - ) - # make sure proof_env is closed - proof_env.close() + if not hasattr(cfg, "init_env_steps"): + raise AttributeError("init_env_steps missing from arguments.") + key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector") + init_env_steps = cfg.init_env_steps + stats = {"loc": None, "scale": None} elif cfg.from_pixels: stats = {"loc": 0.5, "scale": 0.5} proof_env = transformed_env_constructor( - cfg=cfg, use_env_creator=False, stats=stats + cfg=cfg, + use_env_creator=False, + stats=stats, )() + initialize_observation_norm_transforms( + proof_environment=proof_env, num_iter=init_env_steps, key=key + ) + _, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0] + model = make_dqn_actor( proof_environment=proof_env, cfg=cfg, @@ -127,7 +130,7 @@ def main(cfg: "DictConfig"): # noqa: F821 proof_env.close() create_env_fn = parallel_env_constructor( cfg=cfg, - stats=stats, + obs_norm_state_dict=obs_norm_state_dict, action_dim_gsde=action_dim_gsde, state_dim_gsde=state_dim_gsde, ) @@ -148,7 +151,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg, video_tag=video_tag, norm_obs_only=True, - stats=stats, + obs_norm_state_dict=obs_norm_state_dict, logger=logger, )() diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index 140d7c83287..fe155173350 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -35,7 +35,8 @@ ) from torchrl.trainers.helpers.envs import ( correct_for_frame_skip, - get_stats_random_rollout, + initialize_observation_norm_transforms, + retrieve_observation_norms_state_dict, ) from torchrl.trainers.helpers.logger import LoggerConfig from torchrl.trainers.helpers.models import DreamerConfig, make_dreamer @@ -43,7 +44,6 @@ from torchrl.trainers.helpers.trainers import TrainerConfig from torchrl.trainers.trainers import Recorder, RewardNormalizer - config_fields = [ (config_field.name, config_field.type, config_field) for config_cls in ( @@ -61,6 +61,13 @@ cs.store(name="config", node=Config) +def retrieve_stats_from_state_dict(obs_norm_state_dict): + return { + "loc": obs_norm_state_dict["loc"], + "scale": obs_norm_state_dict["scale"], + } + + @hydra.main(version_base=None, config_path=".", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 @@ -113,30 +120,35 @@ def main(cfg: "DictConfig"): # noqa: F821 video_tag = f"Dreamer_{cfg.env_name}_policy_test" if cfg.record_video else "" - stats = None - - # Compute the stats of the observations + key, init_env_steps, stats = None, None, None if not cfg.vecnorm and cfg.norm_stats: - stats = get_stats_random_rollout( - cfg, - proof_environment=transformed_env_constructor(cfg)(), - key=("next", "pixels") - if cfg.from_pixels - else ("next", "observation_vector"), - ) - stats = {k: v.clone() for k, v in stats.items()} + if not hasattr(cfg, "init_env_steps"): + raise AttributeError("init_env_steps missing from arguments.") + key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector") + init_env_steps = cfg.init_env_steps + stats = {"loc": None, "scale": None} elif cfg.from_pixels: stats = {"loc": 0.5, "scale": 0.5} + proof_env = transformed_env_constructor( + cfg=cfg, use_env_creator=False, stats=stats + )() + initialize_observation_norm_transforms( + proof_environment=proof_env, num_iter=init_env_steps, key=key + ) + _, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0] + proof_env.close() # Create the different components of dreamer world_model, model_based_env, actor_model, value_model, policy = make_dreamer( - stats=stats, + obs_norm_state_dict=obs_norm_state_dict, cfg=cfg, device=device, use_decoder_in_env=True, action_key="action", value_key="state_value", - proof_environment=transformed_env_constructor(cfg)(), + proof_environment=transformed_env_constructor( + cfg, stats={"loc": 0.0, "scale": 1.0} + )(), ) # reward normalization @@ -178,7 +190,7 @@ def main(cfg: "DictConfig"): # noqa: F821 action_dim_gsde, state_dim_gsde = None, None create_env_fn = parallel_env_constructor( cfg=cfg, - stats=stats, + obs_norm_state_dict=obs_norm_state_dict, action_dim_gsde=action_dim_gsde, state_dim_gsde=state_dim_gsde, ) @@ -203,11 +215,11 @@ def main(cfg: "DictConfig"): # noqa: F821 frame_skip=cfg.frame_skip, policy_exploration=policy, recorder=make_recorder_env( - cfg, - video_tag, - stats, - logger, - create_env_fn, + cfg=cfg, + video_tag=video_tag, + obs_norm_state_dict=obs_norm_state_dict, + logger=logger, + create_env_fn=create_env_fn, ), record_interval=cfg.record_interval, log_keys=cfg.recorder_log_keys, @@ -371,6 +383,7 @@ def main(cfg: "DictConfig"): # noqa: F821 if j == cfg.optim_steps_per_batch - 1: do_log = False + stats = retrieve_stats_from_state_dict(obs_norm_state_dict) call_record( logger, record, diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index 12b2e23842d..dfcd444262d 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -52,6 +52,7 @@ def make_env_transforms( action_dim_gsde, state_dim_gsde, batch_dims=0, + obs_norm_state_dict=None, ): env = TransformedEnv(env) @@ -91,11 +92,17 @@ def make_env_transforms( env.append_transform(FlattenObservation(0)) env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"])) if stats is None: - obs_stats = {"loc": 0.0, "scale": 1.0} + obs_stats = { + "loc": torch.zeros(env.observation_spec["pixels"].shape), + "scale": torch.ones(env.observation_spec["pixels"].shape), + } else: obs_stats = stats obs_stats["standard_normal"] = True - env.append_transform(ObservationNorm(**obs_stats, in_keys=["pixels"])) + obs_norm = ObservationNorm(**obs_stats, in_keys=["pixels"]) + if obs_norm_state_dict: + obs_norm.load_state_dict(obs_norm_state_dict) + env.append_transform(obs_norm) if norm_rewards: reward_scaling = 1.0 reward_loc = 0.0 @@ -141,6 +148,7 @@ def transformed_env_constructor( action_dim_gsde: Optional[int] = None, state_dim_gsde: Optional[int] = None, batch_dims: Optional[int] = 0, + obs_norm_state_dict: Optional[dict] = None, ) -> Union[Callable, EnvCreator]: """ Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. @@ -171,6 +179,8 @@ def transformed_env_constructor( batch_dims (int, optional): number of dimensions of a batch of data. If a single env is used, it should be 0 (default). If multiple envs are being transformed in parallel, it should be set to 1 (or the number of dims of the batch). + obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded + into the environment """ def make_transformed_env(**kwargs) -> TransformedEnv: @@ -226,6 +236,7 @@ def make_transformed_env(**kwargs) -> TransformedEnv: action_dim_gsde, state_dim_gsde, batch_dims=batch_dims, + obs_norm_state_dict=obs_norm_state_dict, ) if use_env_creator: @@ -335,12 +346,12 @@ def grad_norm(optimizer: torch.optim.Optimizer): return sum_of_sq.sqrt().detach().item() -def make_recorder_env(cfg, video_tag, stats, logger, create_env_fn): +def make_recorder_env(cfg, video_tag, obs_norm_state_dict, logger, create_env_fn): recorder = transformed_env_constructor( cfg, video_tag=video_tag, norm_obs_only=True, - stats=stats, + obs_norm_state_dict=obs_norm_state_dict, logger=logger, use_env_creator=False, )() diff --git a/examples/ppo/ppo.py b/examples/ppo/ppo.py index c2e63abafe4..145a0dbeb53 100644 --- a/examples/ppo/ppo.py +++ b/examples/ppo/ppo.py @@ -24,8 +24,9 @@ from torchrl.trainers.helpers.envs import ( correct_for_frame_skip, EnvConfig, - get_stats_random_rollout, + initialize_observation_norm_transforms, parallel_env_constructor, + retrieve_observation_norms_state_dict, transformed_env_constructor, ) from torchrl.trainers.helpers.logger import LoggerConfig @@ -94,23 +95,25 @@ def main(cfg: "DictConfig"): # noqa: F821 ) video_tag = exp_name if cfg.record_video else "" - stats = None + key, init_env_steps, stats = None, None, None if not cfg.vecnorm and cfg.norm_stats: - proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)() - stats = get_stats_random_rollout( - cfg, - proof_env, - key=("next", "pixels") - if cfg.from_pixels - else ("next", "observation_vector"), - ) - # make sure proof_env is closed - proof_env.close() + if not hasattr(cfg, "init_env_steps"): + raise AttributeError("init_env_steps missing from arguments.") + key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector") + init_env_steps = cfg.init_env_steps + stats = {"loc": None, "scale": None} elif cfg.from_pixels: stats = {"loc": 0.5, "scale": 0.5} + proof_env = transformed_env_constructor( - cfg=cfg, use_env_creator=False, stats=stats + cfg=cfg, + use_env_creator=False, + stats=stats, )() + initialize_observation_norm_transforms( + proof_environment=proof_env, num_iter=init_env_steps, key=key + ) + _, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0] model = make_ppo_model( proof_env, @@ -132,7 +135,7 @@ def main(cfg: "DictConfig"): # noqa: F821 proof_env.close() create_env_fn = parallel_env_constructor( cfg=cfg, - stats=stats, + obs_norm_state_dict=obs_norm_state_dict, action_dim_gsde=action_dim_gsde, state_dim_gsde=state_dim_gsde, ) @@ -151,7 +154,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg, video_tag=video_tag, norm_obs_only=True, - stats=stats, + obs_norm_state_dict=obs_norm_state_dict, logger=logger, use_env_creator=False, )() diff --git a/examples/redq/redq.py b/examples/redq/redq.py index a470d41e8fe..c7715fa92de 100644 --- a/examples/redq/redq.py +++ b/examples/redq/redq.py @@ -24,8 +24,9 @@ from torchrl.trainers.helpers.envs import ( correct_for_frame_skip, EnvConfig, - get_stats_random_rollout, + initialize_observation_norm_transforms, parallel_env_constructor, + retrieve_observation_norms_state_dict, transformed_env_constructor, ) from torchrl.trainers.helpers.logger import LoggerConfig @@ -106,23 +107,26 @@ def main(cfg: "DictConfig"): # noqa: F821 ) video_tag = exp_name if cfg.record_video else "" - stats = None + key, init_env_steps, stats = None, None, None if not cfg.vecnorm and cfg.norm_stats: - proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)() - stats = get_stats_random_rollout( - cfg, - proof_env, - key=("next", "pixels") - if cfg.from_pixels - else ("next", "observation_vector"), - ) - # make sure proof_env is closed - proof_env.close() + if not hasattr(cfg, "init_env_steps"): + raise AttributeError("init_env_steps missing from arguments.") + key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector") + init_env_steps = cfg.init_env_steps + stats = {"loc": None, "scale": None} elif cfg.from_pixels: stats = {"loc": 0.5, "scale": 0.5} + proof_env = transformed_env_constructor( - cfg=cfg, use_env_creator=False, stats=stats + cfg=cfg, + use_env_creator=False, + stats=stats, )() + initialize_observation_norm_transforms( + proof_environment=proof_env, num_iter=init_env_steps, key=key + ) + _, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0] + model = make_redq_model( proof_env, cfg=cfg, @@ -156,7 +160,7 @@ def main(cfg: "DictConfig"): # noqa: F821 proof_env.close() create_env_fn = parallel_env_constructor( cfg=cfg, - stats=stats, + obs_norm_state_dict=obs_norm_state_dict, action_dim_gsde=action_dim_gsde, state_dim_gsde=state_dim_gsde, ) @@ -177,7 +181,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg, video_tag=video_tag, norm_obs_only=True, - stats=stats, + obs_norm_state_dict=obs_norm_state_dict, logger=logger, use_env_creator=False, )() diff --git a/examples/sac/sac.py b/examples/sac/sac.py index 9f477748293..ec81f4b4f02 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -24,8 +24,9 @@ from torchrl.trainers.helpers.envs import ( correct_for_frame_skip, EnvConfig, - get_stats_random_rollout, + initialize_observation_norm_transforms, parallel_env_constructor, + retrieve_observation_norms_state_dict, transformed_env_constructor, ) from torchrl.trainers.helpers.logger import LoggerConfig @@ -106,23 +107,26 @@ def main(cfg: "DictConfig"): # noqa: F821 ) video_tag = exp_name if cfg.record_video else "" - stats = None + key, init_env_steps, stats = None, None, None if not cfg.vecnorm and cfg.norm_stats: - proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)() - stats = get_stats_random_rollout( - cfg, - proof_env, - key=("next", "pixels") - if cfg.from_pixels - else ("next", "observation_vector"), - ) - # make sure proof_env is closed - proof_env.close() + if not hasattr(cfg, "init_env_steps"): + raise AttributeError("init_env_steps missing from arguments.") + key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector") + init_env_steps = cfg.init_env_steps + stats = {"loc": None, "scale": None} elif cfg.from_pixels: stats = {"loc": 0.5, "scale": 0.5} + proof_env = transformed_env_constructor( - cfg=cfg, use_env_creator=False, stats=stats + cfg=cfg, + use_env_creator=False, + stats=stats, )() + initialize_observation_norm_transforms( + proof_environment=proof_env, num_iter=init_env_steps, key=key + ) + _, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0] + model = make_sac_model( proof_env, cfg=cfg, @@ -153,7 +157,7 @@ def main(cfg: "DictConfig"): # noqa: F821 proof_env.close() create_env_fn = parallel_env_constructor( cfg=cfg, - stats=stats, + obs_norm_state_dict=obs_norm_state_dict, action_dim_gsde=action_dim_gsde, state_dim_gsde=state_dim_gsde, ) @@ -174,7 +178,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg, video_tag=video_tag, norm_obs_only=True, - stats=stats, + obs_norm_state_dict=obs_norm_state_dict, logger=logger, )() diff --git a/test/test_helpers.py b/test/test_helpers.py index 35d84c53e07..ae8f2312b56 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -24,14 +24,25 @@ ContinuousActionVecMockEnv, DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, + MockSerialEnv, ) from packaging import version +from torchrl.data import CompositeSpec, NdBoundedTensorSpec from torchrl.envs.libs.gym import _has_gym -from torchrl.envs.transforms.transforms import _has_tv +from torchrl.envs.transforms import ObservationNorm +from torchrl.envs.transforms.transforms import ( + _has_tv, + FlattenObservation, + TransformedEnv, +) from torchrl.envs.utils import set_exploration_mode from torchrl.modules.tensordict_module.common import _has_functorch from torchrl.trainers.helpers import transformed_env_constructor -from torchrl.trainers.helpers.envs import EnvConfig +from torchrl.trainers.helpers.envs import ( + EnvConfig, + initialize_observation_norm_transforms, + retrieve_observation_norms_state_dict, +) from torchrl.trainers.helpers.losses import A2CLossConfig, make_a2c_loss from torchrl.trainers.helpers.models import ( A2CModelConfig, @@ -56,6 +67,7 @@ else: UNSQUEEZE_SINGLETON = False + ## these tests aren't truly unitary but setting up a fake env for the # purpose of building a model with args is a lot of unstable scaffoldings # with unclear benefits @@ -121,10 +133,13 @@ def test_dqn_maker( DiscreteActionConvMockEnvNumpy if from_pixels else DiscreteActionVecMockEnv ) env_maker = transformed_env_constructor( - cfg, use_env_creator=False, custom_env_maker=env_maker + cfg, + use_env_creator=False, + custom_env_maker=env_maker, + stats={"loc": 0.0, "scale": 1.0}, ) proof_environment = env_maker( - categorical_action_encoding=cfg.categorical_action_encoding + categorical_action_encoding=cfg.categorical_action_encoding, ) actor = make_dqn_actor(proof_environment, cfg, device) @@ -184,7 +199,10 @@ def test_ddpg_maker(device, from_pixels, gsde, exploration): else ContinuousActionVecMockEnv ) env_maker = transformed_env_constructor( - cfg, use_env_creator=False, custom_env_maker=env_maker + cfg, + use_env_creator=False, + custom_env_maker=env_maker, + stats={"loc": 0.0, "scale": 1.0}, ) proof_environment = env_maker() actor, value = make_ddpg_actor(proof_environment, device=device, cfg=cfg) @@ -277,7 +295,10 @@ def test_ppo_maker( env_maker = DiscreteActionVecMockEnv env_maker = transformed_env_constructor( - cfg, use_env_creator=False, custom_env_maker=env_maker + cfg, + use_env_creator=False, + custom_env_maker=env_maker, + stats={"loc": 0.0, "scale": 1.0}, ) proof_environment = env_maker() @@ -427,7 +448,10 @@ def test_a2c_maker( env_maker = DiscreteActionVecMockEnv env_maker = transformed_env_constructor( - cfg, use_env_creator=False, custom_env_maker=env_maker + cfg, + use_env_creator=False, + custom_env_maker=env_maker, + stats={"loc": 0.0, "scale": 1.0}, ) proof_environment = env_maker() @@ -574,7 +598,10 @@ def test_sac_make(device, gsde, tanh_loc, from_pixels, exploration): else ContinuousActionVecMockEnv ) env_maker = transformed_env_constructor( - cfg, use_env_creator=False, custom_env_maker=env_maker + cfg, + use_env_creator=False, + custom_env_maker=env_maker, + stats={"loc": 0.0, "scale": 1.0}, ) proof_environment = env_maker() @@ -701,7 +728,10 @@ def test_redq_make(device, from_pixels, gsde, exploration): else ContinuousActionVecMockEnv ) env_maker = transformed_env_constructor( - cfg, use_env_creator=False, custom_env_maker=env_maker + cfg, + use_env_creator=False, + custom_env_maker=env_maker, + stats={"loc": 0.0, "scale": 1.0}, ) proof_environment = env_maker() @@ -778,7 +808,6 @@ def test_redq_make(device, from_pixels, gsde, exploration): @pytest.mark.parametrize("tanh_loc", [(), ("tanh_loc=True",)]) @pytest.mark.parametrize("exploration", ["random", "mode"]) def test_dreamer_make(device, tanh_loc, exploration, dreamer_constructor_fixture): - transformed_env_constructor = dreamer_constructor_fixture flags = ["from_pixels=True", "catframes=1"] @@ -798,7 +827,10 @@ def test_dreamer_make(device, tanh_loc, exploration, dreamer_constructor_fixture cfg = compose(config_name="config", overrides=flags) env_maker = ContinuousActionConvMockEnvNumpy env_maker = transformed_env_constructor( - cfg, use_env_creator=False, custom_env_maker=env_maker + cfg, + use_env_creator=False, + custom_env_maker=env_maker, + stats={"loc": 0.0, "scale": 1.0}, ) proof_environment = env_maker().to(device) model = make_dreamer( @@ -908,6 +940,133 @@ def test_timeit(): assert val2[2] == n2 +@pytest.mark.skipif(not _has_hydra, reason="No hydra library found") +@pytest.mark.parametrize("from_pixels", [(), ("from_pixels=True", "catframes=4")]) +def test_transformed_env_constructor_with_state_dict(from_pixels): + config_fields = [ + (config_field.name, config_field.type, config_field) + for config_cls in ( + EnvConfig, + DreamerConfig, + ) + for config_field in dataclasses.fields(config_cls) + ] + flags = list(from_pixels) + + Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) + cs = ConfigStore.instance() + cs.store(name="config", node=Config) + with initialize(version_base=None, config_path=None): + cfg = compose(config_name="config", overrides=flags) + env_maker = ( + ContinuousActionConvMockEnvNumpy + if from_pixels + else ContinuousActionVecMockEnv + ) + t_env = transformed_env_constructor( + cfg, + use_env_creator=False, + custom_env_maker=env_maker, + )() + idx, state_dict = retrieve_observation_norms_state_dict(t_env)[0] + + obs_transform = transformed_env_constructor( + cfg, + use_env_creator=False, + custom_env_maker=env_maker, + obs_norm_state_dict=state_dict, + )().transform[idx] + torch.testing.assert_close(obs_transform.state_dict(), state_dict) + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("keys", [None, ["observation", "observation_orig"]]) +@pytest.mark.parametrize("composed", [True, False]) +@pytest.mark.parametrize("initialized", [True, False]) +def test_initialize_stats_from_observation_norms(device, keys, composed, initialized): + obs_spec, stat_key = None, None + if keys: + obs_spec = CompositeSpec( + **{ + key: NdBoundedTensorSpec(maximum=1, minimum=1, shape=torch.Size([1])) + for key in keys + } + ) + stat_key = keys[0] + env = ContinuousActionVecMockEnv( + device=device, + observation_spec=obs_spec, + action_spec=NdBoundedTensorSpec( + minimum=1, maximum=2, shape=torch.Size((1,)) + ), + ) + env.out_key = "observation" + else: + env = MockSerialEnv(device=device) + env.set_seed(1) + + t_env = TransformedEnv(env) + stats = {"loc": None, "scale": None} + if initialized: + stats = {"loc": 0.0, "scale": 1.0} + t_env.transform = ObservationNorm(standard_normal=True, **stats) + if composed: + t_env.append_transform(ObservationNorm(standard_normal=True, **stats)) + pre_init_state_dict = t_env.transform.state_dict() + initialize_observation_norm_transforms( + proof_environment=t_env, num_iter=100, key=stat_key + ) + post_init_state_dict = t_env.transform.state_dict() + expected_dict_size = 4 if composed else 2 + expected_dict_size = expected_dict_size if not initialized else 0 + + assert len(post_init_state_dict) == len(pre_init_state_dict) + expected_dict_size + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_initialize_stats_from_non_obs_transform(device): + env = MockSerialEnv(device=device) + env.set_seed(1) + + t_env = TransformedEnv(env) + t_env.transform = FlattenObservation(first_dim=0) + pre_init_state_dict = t_env.transform.state_dict() + initialize_observation_norm_transforms(proof_environment=t_env, num_iter=100) + post_init_state_dict = t_env.transform.state_dict() + assert len(post_init_state_dict) == len(pre_init_state_dict) + + +def test_initialize_obs_transform_stats_raise_exception(): + env = ContinuousActionVecMockEnv() + t_env = TransformedEnv(env) + t_env.transform = ObservationNorm() + with pytest.raises( + RuntimeError, match="More than one key exists in the observation_specs" + ): + initialize_observation_norm_transforms(proof_environment=t_env, num_iter=100) + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("composed", [True, False]) +def test_retrieve_observation_norms_state_dict(device, composed): + env = MockSerialEnv(device=device) + env.set_seed(1) + + t_env = TransformedEnv(env) + t_env.transform = ObservationNorm(standard_normal=True) + if composed: + t_env.append_transform(ObservationNorm(standard_normal=True)) + initialize_observation_norm_transforms(proof_environment=t_env, num_iter=100) + state_dicts = retrieve_observation_norms_state_dict(t_env) + expected_state_count = 2 if composed else 1 + expected_idx = [0, 1] if composed else [0] + + assert len(state_dicts) == expected_state_count + for idx, state_dict in enumerate(state_dicts): + assert len(state_dict[1]) == 2 + assert state_dict[0] == expected_idx[idx] + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_transforms.py b/test/test_transforms.py index 9442c09bfc8..f76684b8cf1 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -952,6 +952,22 @@ def test_observationnorm_init_stats_multiple_keys_error(self): with pytest.raises(RuntimeError, match=err_msg): transform.init_stats(num_iter=11) + def test_observationnorm_initialization_order_error(self): + base_env = ContinuousActionVecMockEnv() + t_env = TransformedEnv(base_env) + + transform1 = ObservationNorm(in_keys=["next_observation"]) + transform2 = ObservationNorm(in_keys=["next_observation"]) + t_env.append_transform(transform1) + t_env.append_transform(transform2) + + err_msg = ( + "ObservationNorms need to be initialized in the right order." + "Trying to initialize an ObservationNorm while a parent ObservationNorm transform is still uninitialized" + ) + with pytest.raises(RuntimeError, match=err_msg): + transform2.init_stats(num_iter=10, key="observation") + def test_observationnorm_uninitialized_stats_error(self): transform = ObservationNorm(in_keys=["next_observation", "next_pixels"]) @@ -962,6 +978,33 @@ def test_observationnorm_uninitialized_stats_error(self): with pytest.raises(RuntimeError, match=err_msg): transform._apply_transform(torch.Tensor([1])) + @pytest.mark.parametrize("device", get_available_devices()) + def test_observationnorm_infinite_stats_error(self, device): + base_env = ContinuousActionVecMockEnv( + observation_spec=CompositeSpec( + observation=NdBoundedTensorSpec( + minimum=1, maximum=1, shape=torch.Size([1]) + ), + observation_orig=NdBoundedTensorSpec( + minimum=1, maximum=1, shape=torch.Size([1]) + ), + ), + action_spec=NdBoundedTensorSpec( + minimum=1, maximum=1, shape=torch.Size((1,)) + ), + seed=0, + ) + base_env.out_key = "observation" + t_env = TransformedEnv( + base_env, + transform=ObservationNorm(in_keys="observation"), + ) + t_env.append_transform(ObservationNorm(in_keys="observation")) + err_msg = "Non-finite values found in" + with pytest.raises(RuntimeError, match=err_msg): + for transform in t_env.transform: + transform.init_stats(num_iter=100) + def test_catframes_transform_observation_spec(self): N = 4 key1 = "first key" diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ae1e6a8c0fb..4eb5f0f7146 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1391,7 +1391,21 @@ def init_stats( ) key = self.in_keys[0] if key is None else key + def raise_initialization_exception(module): + if ( + isinstance(module, ObservationNorm) + and module.scale is None + and module.loc is None + ): + raise RuntimeError( + "ObservationNorms need to be initialized in the right order." + "Trying to initialize an ObservationNorm " + "while a parent ObservationNorm transform is still uninitialized" + ) + parent = self.parent + parent.apply(raise_initialization_exception) + collected_frames = 0 data = [] while collected_frames < num_iter: @@ -1407,6 +1421,11 @@ def init_stats( loc = loc / scale scale = 1 / scale + if not torch.isfinite(loc).all(): + raise RuntimeError("Non-finite values found in loc") + if not torch.isfinite(scale).all(): + raise RuntimeError("Non-finite values found in scale") + self.register_buffer("loc", loc) self.register_buffer("scale", scale.clamp_min(self.eps)) diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index 09311e35bd8..4ba2a3213af 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field as dataclass_field -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Tuple, Union import torch @@ -17,6 +17,7 @@ CatFrames, CatTensors, CenterCrop, + Compose, DoubleToFloat, FiniteTensorDictCheck, GrayScale, @@ -32,7 +33,6 @@ from torchrl.record.recorder import VideoRecorder from torchrl.trainers.loggers import Logger - LIBS = { "gym": GymEnv, "dm_control": DMControlEnv, @@ -84,6 +84,7 @@ def make_env_transforms( action_dim_gsde, state_dim_gsde, batch_dims=0, + obs_norm_state_dict=None, ): """Creates the typical transforms for and env.""" env = TransformedEnv(env) @@ -125,11 +126,17 @@ def make_env_transforms( env.append_transform(FlattenObservation(0)) env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"])) if stats is None: - obs_stats = {"loc": 0.0, "scale": 1.0} + obs_stats = { + "loc": torch.zeros(env.observation_spec["pixels"].shape), + "scale": torch.ones(env.observation_spec["pixels"].shape), + } else: obs_stats = stats obs_stats["standard_normal"] = True - env.append_transform(ObservationNorm(**obs_stats, in_keys=["pixels"])) + obs_norm = ObservationNorm(**obs_stats, in_keys=["pixels"]) + if obs_norm_state_dict: + obs_norm.load_state_dict(obs_norm_state_dict) + env.append_transform(obs_norm) if norm_rewards: reward_scaling = 1.0 reward_loc = 0.0 @@ -162,12 +169,18 @@ def make_env_transforms( if not vecnorm: if stats is None: - _stats = {"loc": 0.0, "scale": 1.0} + _stats = { + "loc": torch.zeros(env.observation_spec[out_key].shape), + "scale": torch.ones(env.observation_spec[out_key].shape), + } else: _stats = stats - env.append_transform( - ObservationNorm(**_stats, in_keys=[out_key], standard_normal=True) + obs_norm = ObservationNorm( + **_stats, in_keys=[out_key], standard_normal=True ) + if obs_norm_state_dict: + obs_norm.load_state_dict(obs_norm_state_dict) + env.append_transform(obs_norm) else: env.append_transform( VecNorm( @@ -201,6 +214,7 @@ def make_env_transforms( ) env.append_transform(FiniteTensorDictCheck()) + return env @@ -217,6 +231,7 @@ def transformed_env_constructor( action_dim_gsde: Optional[int] = None, state_dim_gsde: Optional[int] = None, batch_dims: Optional[int] = 0, + obs_norm_state_dict: Optional[dict] = None, ) -> Union[Callable, EnvCreator]: """Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. @@ -246,6 +261,8 @@ def transformed_env_constructor( batch_dims (int, optional): number of dimensions of a batch of data. If a single env is used, it should be 0 (default). If multiple envs are being transformed in parallel, it should be set to 1 (or the number of dims of the batch). + obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded into the + environment """ def make_transformed_env(**kwargs) -> TransformedEnv: @@ -306,6 +323,7 @@ def make_transformed_env(**kwargs) -> TransformedEnv: action_dim_gsde, state_dim_gsde, batch_dims=batch_dims, + obs_norm_state_dict=obs_norm_state_dict, ) if use_env_creator: @@ -372,7 +390,7 @@ def get_stats_random_rollout( proof_env_is_none = proof_environment is None if proof_env_is_none: proof_environment = transformed_env_constructor( - cfg=cfg, use_env_creator=False + cfg=cfg, use_env_creator=False, stats={"loc": 0.0, "scale": 1.0} )() print("computing state stats") @@ -428,6 +446,79 @@ def get_stats_random_rollout( return stats +def initialize_observation_norm_transforms( + proof_environment: EnvBase, + num_iter: int = 1000, + key: Union[str, Tuple[str, ...]] = None, +): + """Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`. + + If an :obj:`ObservationNorm` already has non-null :obj:`loc` or :obj:`scale`, a call to :obj:`initialize_observation_norm_transforms` will be a no-op. + Similarly, if the transformed environment does not contain any :obj:`ObservationNorm`, a call to this function will have no effect. + If no key is provided but the observations of the :obj:`EnvBase` contains more than one key, an exception will + be raised. + + Args: + proof_environment (EnvBase instance, optional): if provided, this env will + be used ot execute the rollouts. If not, it will be created using + the cfg object. + num_iter (int): Number of iterations used for initializing the :obj:`ObservationNorms` + key (str, optional): if provided, the stats of this key will be gathered. + If not, it is expected that only one key exists in `env.observation_spec`. + + """ + if not isinstance(proof_environment.transform, Compose) and not isinstance( + proof_environment.transform, ObservationNorm + ): + return + + if key is None: + keys = list(proof_environment.base_env.observation_spec.keys()) + key = keys.pop() + if len(keys): + raise RuntimeError( + f"More than one key exists in the observation_specs: {[key] + keys} were found, " + "thus initialize_observation_norm_transforms cannot infer which to compute the stats of." + ) + + if isinstance(proof_environment.transform, Compose): + for transform in proof_environment.transform: + if ( + isinstance(transform, ObservationNorm) + and transform.loc is None + and transform.scale is None + ): + transform.init_stats(num_iter=num_iter, key=key) + elif ( + proof_environment.transform.loc is None + and proof_environment.transform.scale is None + ): + proof_environment.transform.init_stats(num_iter=num_iter, key=key) + + +def retrieve_observation_norms_state_dict(proof_environment: TransformedEnv): + """Traverses the transforms of the environment and retrieves the :obj:`ObservationNorm` state dicts. + + Returns a list of tuple (idx, state_dict) for each :obj:`ObservationNorm` transform in proof_environment + If the environment transforms do not contain any :obj:`ObservationNorm`, returns an empty list + + Args: + proof_environment (EnvBase instance, optional): the :obj:``TransformedEnv` to retrieve the :obj:`ObservationNorm` + state dict from + """ + obs_norm_state_dicts = [] + + if isinstance(proof_environment.transform, Compose): + for idx, transform in enumerate(proof_environment.transform): + if isinstance(transform, ObservationNorm): + obs_norm_state_dicts.append((idx, transform.state_dict())) + + if isinstance(proof_environment.transform, ObservationNorm): + obs_norm_state_dicts.append((0, proof_environment.transform.state_dict())) + + return obs_norm_state_dicts + + @dataclass class EnvConfig: """Environment config struct.""" diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index e9e57eadb45..5c5b3b5499e 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -1465,7 +1465,7 @@ def make_dreamer( action_key: str = "action", value_key: str = "state_value", use_decoder_in_env: bool = False, - stats: Optional[dict] = None, + obs_norm_state_dict=None, ) -> nn.ModuleList: """Create Dreamer components. @@ -1480,8 +1480,8 @@ def make_dreamer( Defaults to "state_value". use_decoder_in_env (bool, optional): Whether to use the decoder in the model based dreamer env. Defaults to False. - stats (Optional[dict], optional): Stats to use for normalization. - Defaults to None. + obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform used + when proof_environment is missing. Defaults to None. Returns: nn.TensorDictModel: Dreamer World model. @@ -1494,7 +1494,7 @@ def make_dreamer( proof_env_is_none = proof_environment is None if proof_env_is_none: proof_environment = transformed_env_constructor( - cfg=cfg, use_env_creator=False, stats=stats + cfg=cfg, use_env_creator=False, obs_norm_state_dict=obs_norm_state_dict )() # Modules