Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
df26500
Refactor stats computation in example scripts
romainjln Nov 24, 2022
15b9bbb
Merge branch 'main' into init-stats-models and resolve conflict
romainjln Nov 25, 2022
a8fcee9
lint
romainjln Nov 25, 2022
215e57b
Change logic for computing stats using ObservationNorm in examples
romainjln Nov 28, 2022
6ba6ea7
Lint and fix import
romainjln Nov 28, 2022
48ea215
Merge branch 'main' into init-stats-models
romainjln Nov 28, 2022
09d093b
Removing redundant lines
romainjln Nov 28, 2022
7351654
Fix typo
romainjln Nov 28, 2022
88780fb
Adding test and fixing code
romainjln Nov 28, 2022
93e611a
lint
romainjln Nov 29, 2022
e7be8f8
Refactor generate_stats_from_observation_norms to handle many Observa…
romainjln Nov 30, 2022
d702bfd
Merge branch 'main' into init-stats-models
romainjln Nov 30, 2022
85ace78
Fixing tests
romainjln Nov 30, 2022
fc0c7e5
Add decorator to test. Fix dreamer example
romainjln Dec 1, 2022
106f31c
Adding missing parameter in dreamer script
romainjln Dec 1, 2022
5b697e0
lint
romainjln Dec 1, 2022
160abda
Modify dreamer_utils to match new behavior of make_env_transforms
romainjln Dec 1, 2022
7ff3007
Fixing stats issue in dreamer
romainjln Dec 1, 2022
46935ea
Refactoring example logic to use state_dict of ObservationNorm transform
romainjln Dec 6, 2022
6f28b66
Merge branch 'main' into init-stats-models
romainjln Dec 6, 2022
2a5dbf0
Fix new logic in dreamer_utils
romainjln Dec 6, 2022
17b1b15
Fix dreamer helper function following previous refactoring
romainjln Dec 6, 2022
4908082
Adding tests. Modify docstrings based on feedback
romainjln Dec 7, 2022
12a878a
More modifications of docstrings
romainjln Dec 7, 2022
15273e5
Add more tests for helpers functions
romainjln Dec 7, 2022
90e5374
Merge branch 'pytorch:main' into init-stats-models
romainjln Dec 7, 2022
0b2b915
Adding test for transformed_env_constructor
romainjln Dec 7, 2022
6a247bd
Merge branch 'pytorch:main' into init-stats-models
romainjln Dec 8, 2022
e83922a
Adding more test for init_stats. Remove redundant line of code
romainjln Dec 8, 2022
60ad3c4
Merge branch 'init-stats-models' of github.com:romainjln/rl into init…
romainjln Dec 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions examples/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
EnvConfig,
get_stats_random_rollout,
initialize_observation_norm_transforms,
parallel_env_constructor,
retrieve_observation_norms_state_dict,
transformed_env_constructor,
)
from torchrl.trainers.helpers.logger import LoggerConfig
Expand Down Expand Up @@ -92,21 +93,24 @@ def main(cfg: "DictConfig"): # noqa: F821
)
video_tag = exp_name if cfg.record_video else ""

stats = None
key, init_env_steps, stats = None, None, None
if not cfg.vecnorm and cfg.norm_stats:
proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)()
stats = get_stats_random_rollout(
cfg,
proof_env,
key="pixels" if cfg.from_pixels else "observation_vector",
)
# make sure proof_env is closed
proof_env.close()
if not hasattr(cfg, "init_env_steps"):
raise AttributeError("init_env_steps missing from arguments.")
key = "pixels" if cfg.from_pixels else "observation_vector"
init_env_steps = cfg.init_env_steps
stats = {"loc": None, "scale": None}
elif cfg.from_pixels:
stats = {"loc": 0.5, "scale": 0.5}
proof_env = transformed_env_constructor(
cfg=cfg, use_env_creator=False, stats=stats
cfg=cfg,
use_env_creator=False,
stats=stats,
)()
initialize_observation_norm_transforms(
proof_environment=proof_env, num_iter=init_env_steps, key=key
)
_, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0]

model = make_a2c_model(
proof_env,
Expand All @@ -128,7 +132,7 @@ def main(cfg: "DictConfig"): # noqa: F821
proof_env.close()
create_env_fn = parallel_env_constructor(
cfg=cfg,
stats=stats,
obs_norm_state_dict=obs_norm_state_dict,
action_dim_gsde=action_dim_gsde,
state_dim_gsde=state_dim_gsde,
)
Expand All @@ -143,7 +147,7 @@ def main(cfg: "DictConfig"): # noqa: F821
cfg,
video_tag=video_tag,
norm_obs_only=True,
stats=stats,
obs_norm_state_dict=obs_norm_state_dict,
logger=logger,
use_env_creator=False,
)()
Expand Down
34 changes: 19 additions & 15 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
EnvConfig,
get_stats_random_rollout,
initialize_observation_norm_transforms,
parallel_env_constructor,
retrieve_observation_norms_state_dict,
transformed_env_constructor,
)
from torchrl.trainers.helpers.logger import LoggerConfig
Expand Down Expand Up @@ -105,23 +106,25 @@ def main(cfg: "DictConfig"): # noqa: F821
)
video_tag = exp_name if cfg.record_video else ""

stats = None
key, init_env_steps, stats = None, None, None
if not cfg.vecnorm and cfg.norm_stats:
proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)()
stats = get_stats_random_rollout(
cfg,
proof_env,
key=("next", "pixels")
if cfg.from_pixels
else ("next", "observation_vector"),
)
# make sure proof_env is closed
proof_env.close()
if not hasattr(cfg, "init_env_steps"):
raise AttributeError("init_env_steps missing from arguments.")
key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector")
init_env_steps = cfg.init_env_steps
stats = {"loc": None, "scale": None}
elif cfg.from_pixels:
stats = {"loc": 0.5, "scale": 0.5}

proof_env = transformed_env_constructor(
cfg=cfg, use_env_creator=False, stats=stats
cfg=cfg,
stats=stats,
use_env_creator=False,
)()
initialize_observation_norm_transforms(
proof_environment=proof_env, num_iter=init_env_steps, key=key
)
_, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0]

model = make_ddpg_actor(
proof_env,
Expand Down Expand Up @@ -154,9 +157,10 @@ def main(cfg: "DictConfig"): # noqa: F821
action_dim_gsde, state_dim_gsde = None, None

proof_env.close()

create_env_fn = parallel_env_constructor(
cfg=cfg,
stats=stats,
obs_norm_state_dict=obs_norm_state_dict,
action_dim_gsde=action_dim_gsde,
state_dim_gsde=state_dim_gsde,
)
Expand All @@ -177,7 +181,7 @@ def main(cfg: "DictConfig"): # noqa: F821
cfg,
video_tag=video_tag,
norm_obs_only=True,
stats=stats,
obs_norm_state_dict=obs_norm_state_dict,
logger=logger,
use_env_creator=False,
)()
Expand Down
33 changes: 18 additions & 15 deletions examples/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
EnvConfig,
get_stats_random_rollout,
initialize_observation_norm_transforms,
parallel_env_constructor,
retrieve_observation_norms_state_dict,
transformed_env_constructor,
)
from torchrl.trainers.helpers.logger import LoggerConfig
Expand Down Expand Up @@ -95,23 +96,25 @@ def main(cfg: "DictConfig"): # noqa: F821
)
video_tag = exp_name if cfg.record_video else ""

stats = None
key, init_env_steps, stats = None, None, None
if not cfg.vecnorm and cfg.norm_stats:
proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)()
stats = get_stats_random_rollout(
cfg,
proof_env,
key=("next", "pixels")
if cfg.from_pixels
else ("next", "observation_vector"),
)
# make sure proof_env is closed
proof_env.close()
if not hasattr(cfg, "init_env_steps"):
raise AttributeError("init_env_steps missing from arguments.")
key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector")
init_env_steps = cfg.init_env_steps
stats = {"loc": None, "scale": None}
elif cfg.from_pixels:
stats = {"loc": 0.5, "scale": 0.5}
proof_env = transformed_env_constructor(
cfg=cfg, use_env_creator=False, stats=stats
cfg=cfg,
use_env_creator=False,
stats=stats,
)()
initialize_observation_norm_transforms(
proof_environment=proof_env, num_iter=init_env_steps, key=key
)
_, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0]

model = make_dqn_actor(
proof_environment=proof_env,
cfg=cfg,
Expand All @@ -127,7 +130,7 @@ def main(cfg: "DictConfig"): # noqa: F821
proof_env.close()
create_env_fn = parallel_env_constructor(
cfg=cfg,
stats=stats,
obs_norm_state_dict=obs_norm_state_dict,
action_dim_gsde=action_dim_gsde,
state_dim_gsde=state_dim_gsde,
)
Expand All @@ -148,7 +151,7 @@ def main(cfg: "DictConfig"): # noqa: F821
cfg,
video_tag=video_tag,
norm_obs_only=True,
stats=stats,
obs_norm_state_dict=obs_norm_state_dict,
logger=logger,
)()

Expand Down
55 changes: 34 additions & 21 deletions examples/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@
)
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
get_stats_random_rollout,
initialize_observation_norm_transforms,
retrieve_observation_norms_state_dict,
)
from torchrl.trainers.helpers.logger import LoggerConfig
from torchrl.trainers.helpers.models import DreamerConfig, make_dreamer
from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig
from torchrl.trainers.helpers.trainers import TrainerConfig
from torchrl.trainers.trainers import Recorder, RewardNormalizer


config_fields = [
(config_field.name, config_field.type, config_field)
for config_cls in (
Expand All @@ -61,6 +61,13 @@
cs.store(name="config", node=Config)


def retrieve_stats_from_state_dict(obs_norm_state_dict):
return {
"loc": obs_norm_state_dict["loc"],
"scale": obs_norm_state_dict["scale"],
}


@hydra.main(version_base=None, config_path=".", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821

Expand Down Expand Up @@ -113,30 +120,35 @@ def main(cfg: "DictConfig"): # noqa: F821

video_tag = f"Dreamer_{cfg.env_name}_policy_test" if cfg.record_video else ""

stats = None

# Compute the stats of the observations
key, init_env_steps, stats = None, None, None
if not cfg.vecnorm and cfg.norm_stats:
stats = get_stats_random_rollout(
cfg,
proof_environment=transformed_env_constructor(cfg)(),
key=("next", "pixels")
if cfg.from_pixels
else ("next", "observation_vector"),
)
stats = {k: v.clone() for k, v in stats.items()}
if not hasattr(cfg, "init_env_steps"):
raise AttributeError("init_env_steps missing from arguments.")
key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector")
init_env_steps = cfg.init_env_steps
stats = {"loc": None, "scale": None}
elif cfg.from_pixels:
stats = {"loc": 0.5, "scale": 0.5}
proof_env = transformed_env_constructor(
cfg=cfg, use_env_creator=False, stats=stats
)()
initialize_observation_norm_transforms(
proof_environment=proof_env, num_iter=init_env_steps, key=key
)
_, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0]
proof_env.close()

# Create the different components of dreamer
world_model, model_based_env, actor_model, value_model, policy = make_dreamer(
stats=stats,
obs_norm_state_dict=obs_norm_state_dict,
cfg=cfg,
device=device,
use_decoder_in_env=True,
action_key="action",
value_key="state_value",
proof_environment=transformed_env_constructor(cfg)(),
proof_environment=transformed_env_constructor(
cfg, stats={"loc": 0.0, "scale": 1.0}
)(),
)

# reward normalization
Expand Down Expand Up @@ -178,7 +190,7 @@ def main(cfg: "DictConfig"): # noqa: F821
action_dim_gsde, state_dim_gsde = None, None
create_env_fn = parallel_env_constructor(
cfg=cfg,
stats=stats,
obs_norm_state_dict=obs_norm_state_dict,
action_dim_gsde=action_dim_gsde,
state_dim_gsde=state_dim_gsde,
)
Expand All @@ -203,11 +215,11 @@ def main(cfg: "DictConfig"): # noqa: F821
frame_skip=cfg.frame_skip,
policy_exploration=policy,
recorder=make_recorder_env(
cfg,
video_tag,
stats,
logger,
create_env_fn,
cfg=cfg,
video_tag=video_tag,
obs_norm_state_dict=obs_norm_state_dict,
logger=logger,
create_env_fn=create_env_fn,
),
record_interval=cfg.record_interval,
log_keys=cfg.recorder_log_keys,
Expand Down Expand Up @@ -371,6 +383,7 @@ def main(cfg: "DictConfig"): # noqa: F821
if j == cfg.optim_steps_per_batch - 1:
do_log = False

stats = retrieve_stats_from_state_dict(obs_norm_state_dict)
call_record(
logger,
record,
Expand Down
19 changes: 15 additions & 4 deletions examples/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def make_env_transforms(
action_dim_gsde,
state_dim_gsde,
batch_dims=0,
obs_norm_state_dict=None,
):
env = TransformedEnv(env)

Expand Down Expand Up @@ -91,11 +92,17 @@ def make_env_transforms(
env.append_transform(FlattenObservation(0))
env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"]))
if stats is None:
obs_stats = {"loc": 0.0, "scale": 1.0}
obs_stats = {
"loc": torch.zeros(env.observation_spec["pixels"].shape),
"scale": torch.ones(env.observation_spec["pixels"].shape),
}
else:
obs_stats = stats
obs_stats["standard_normal"] = True
env.append_transform(ObservationNorm(**obs_stats, in_keys=["pixels"]))
obs_norm = ObservationNorm(**obs_stats, in_keys=["pixels"])
if obs_norm_state_dict:
obs_norm.load_state_dict(obs_norm_state_dict)
env.append_transform(obs_norm)
if norm_rewards:
reward_scaling = 1.0
reward_loc = 0.0
Expand Down Expand Up @@ -141,6 +148,7 @@ def transformed_env_constructor(
action_dim_gsde: Optional[int] = None,
state_dim_gsde: Optional[int] = None,
batch_dims: Optional[int] = 0,
obs_norm_state_dict: Optional[dict] = None,
) -> Union[Callable, EnvCreator]:
"""
Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.
Expand Down Expand Up @@ -171,6 +179,8 @@ def transformed_env_constructor(
batch_dims (int, optional): number of dimensions of a batch of data. If a single env is
used, it should be 0 (default). If multiple envs are being transformed in parallel,
it should be set to 1 (or the number of dims of the batch).
obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded
into the environment
"""

def make_transformed_env(**kwargs) -> TransformedEnv:
Expand Down Expand Up @@ -226,6 +236,7 @@ def make_transformed_env(**kwargs) -> TransformedEnv:
action_dim_gsde,
state_dim_gsde,
batch_dims=batch_dims,
obs_norm_state_dict=obs_norm_state_dict,
)

if use_env_creator:
Expand Down Expand Up @@ -335,12 +346,12 @@ def grad_norm(optimizer: torch.optim.Optimizer):
return sum_of_sq.sqrt().detach().item()


def make_recorder_env(cfg, video_tag, stats, logger, create_env_fn):
def make_recorder_env(cfg, video_tag, obs_norm_state_dict, logger, create_env_fn):
recorder = transformed_env_constructor(
cfg,
video_tag=video_tag,
norm_obs_only=True,
stats=stats,
obs_norm_state_dict=obs_norm_state_dict,
logger=logger,
use_env_creator=False,
)()
Expand Down
Loading