Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7cc7c61
a2c
albertbou92 Nov 14, 2022
5eac813
a2c
albertbou92 Nov 14, 2022
880fed3
a2c config
albertbou92 Nov 15, 2022
4c2436c
a2c config
albertbou92 Nov 15, 2022
2b77432
fix imports
albertbou92 Nov 15, 2022
ea4c3a2
latest
albertbou92 Nov 15, 2022
5f4d290
simplified config
albertbou92 Nov 15, 2022
e08dd5d
simplified config
albertbou92 Nov 15, 2022
82db357
Update config.yaml
albertbou92 Nov 15, 2022
afca8f4
[BugFix] Use GitHub for flake8 pre-commit hook (#679)
vmoens Nov 15, 2022
ea83339
[BugFix] Update to strict select (#675)
vmoens Nov 15, 2022
0bc21da
[Feature] Auto-compute stats for ObservationNorm (#669)
romainjln Nov 15, 2022
8765ac9
[Doc] _make_collector helper function (#678)
albertbou92 Nov 15, 2022
14fbac9
[Doc] BatchSubSampler class docstrings example (#677)
albertbou92 Nov 15, 2022
bcdb0bc
[BugFix] PPO objective crashes if advantage_module is None (#676)
albertbou92 Nov 15, 2022
fbb0e9f
Minor: lint
vmoens Nov 15, 2022
26dcbcf
[Refactor] Refactor 'next_' into nested tensordicts (#649)
vmoens Nov 16, 2022
1479497
adapted to nested next td
albertbou92 Nov 17, 2022
39515df
[Refactor] Refactor 'next_' into nested tensordicts (#649)
vmoens Nov 16, 2022
0d28c79
[Doc] More doc about environments (#683)
vmoens Nov 17, 2022
7c36de6
[Doc] Fix missing tensordict install for doc (#685)
vmoens Nov 17, 2022
1d8dc7b
Merge branch 'main' into a2c
albertbou92 Nov 18, 2022
509276f
model config fix
albertbou92 Nov 18, 2022
0a0fca8
formatting
albertbou92 Nov 21, 2022
ac5b857
formatting
albertbou92 Nov 21, 2022
34b411f
a2c runtime error comment change
albertbou92 Nov 21, 2022
5068059
a2c test
albertbou92 Nov 21, 2022
f04b0f9
a2c test
albertbou92 Nov 21, 2022
f4b2289
a2c test
albertbou92 Nov 21, 2022
e39d7ff
make a2c model test
albertbou92 Nov 22, 2022
859ccd9
increase a2c tests coverage
albertbou92 Nov 22, 2022
1eaa4c5
formatting
albertbou92 Nov 22, 2022
7f78bf6
fix bug a2c testing
albertbou92 Nov 22, 2022
60c2730
minor fixes
albertbou92 Nov 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions examples/a2c/a2c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import dataclasses
import os
import pathlib
import uuid
from datetime import datetime

import hydra
import torch.cuda
from hydra.core.config_store import ConfigStore
from torchrl.envs.transforms import RewardScaling
from torchrl.envs.utils import set_exploration_mode
from torchrl.objectives.value import TDEstimate
from torchrl.trainers.helpers.collectors import (
make_collector_onpolicy,
OnPolicyCollectorConfig,
)
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
EnvConfig,
get_stats_random_rollout,
parallel_env_constructor,
transformed_env_constructor,
)
from torchrl.trainers.helpers.logger import LoggerConfig
from torchrl.trainers.helpers.losses import A2CLossConfig, make_a2c_loss
from torchrl.trainers.helpers.models import A2CModelConfig, make_a2c_model
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig

config_fields = [
(config_field.name, config_field.type, config_field)
for config_cls in (
TrainerConfig,
OnPolicyCollectorConfig,
EnvConfig,
A2CLossConfig,
A2CModelConfig,
LoggerConfig,
)
for config_field in dataclasses.fields(config_cls)
]

Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields)
cs = ConfigStore.instance()
cs.store(name="config", node=Config)


@hydra.main(version_base=None, config_path="", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821

cfg = correct_for_frame_skip(cfg)

if not isinstance(cfg.reward_scaling, float):
cfg.reward_scaling = 1.0

device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda:0")
)

exp_name = "_".join(
[
"A2C",
cfg.exp_name,
str(uuid.uuid4())[:8],
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
]
)
if cfg.logger == "tensorboard":
from torchrl.trainers.loggers.tensorboard import TensorboardLogger

logger = TensorboardLogger(log_dir="a2c_logging", exp_name=exp_name)
elif cfg.logger == "csv":
from torchrl.trainers.loggers.csv import CSVLogger

logger = CSVLogger(log_dir="a2c_logging", exp_name=exp_name)
elif cfg.logger == "wandb":
from torchrl.trainers.loggers.wandb import WandbLogger

logger = WandbLogger(log_dir="a2c_logging", exp_name=exp_name)
elif cfg.logger == "mlflow":
from torchrl.trainers.loggers.mlflow import MLFlowLogger

logger = MLFlowLogger(
tracking_uri=pathlib.Path(os.path.abspath("a2c_logging")).as_uri(),
exp_name=exp_name,
)
video_tag = exp_name if cfg.record_video else ""

stats = None
if not cfg.vecnorm and cfg.norm_stats:
proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)()
stats = get_stats_random_rollout(
cfg,
proof_env,
key="pixels" if cfg.from_pixels else "observation_vector",
)
# make sure proof_env is closed
proof_env.close()
elif cfg.from_pixels:
stats = {"loc": 0.5, "scale": 0.5}
proof_env = transformed_env_constructor(
cfg=cfg, use_env_creator=False, stats=stats
)()

model = make_a2c_model(
proof_env,
cfg=cfg,
device=device,
)
actor_model = model.get_policy_operator()

loss_module = make_a2c_loss(model, cfg)
if cfg.gSDE:
with torch.no_grad(), set_exploration_mode("random"):
# get dimensions to build the parallel env
proof_td = model(proof_env.reset().to(device))
action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:]
del proof_td
else:
action_dim_gsde, state_dim_gsde = None, None

proof_env.close()
create_env_fn = parallel_env_constructor(
cfg=cfg,
stats=stats,
action_dim_gsde=action_dim_gsde,
state_dim_gsde=state_dim_gsde,
)

collector = make_collector_onpolicy(
make_env=create_env_fn,
actor_model_explore=actor_model,
cfg=cfg,
)

recorder = transformed_env_constructor(
cfg,
video_tag=video_tag,
norm_obs_only=True,
stats=stats,
logger=logger,
use_env_creator=False,
)()

# reset reward scaling
for t in recorder.transform:
if isinstance(t, RewardScaling):
t.scale.fill_(1.0)
t.loc.fill_(0.0)

trainer = make_trainer(
collector=collector,
loss_module=loss_module,
recorder=recorder,
target_net_updater=None,
policy_exploration=actor_model,
replay_buffer=None,
logger=logger,
cfg=cfg,
)

if not cfg.advantage_in_loss:
critic_model = model.get_value_operator()
advantage = TDEstimate(
cfg.gamma,
value_network=critic_model,
average_rewards=True,
gradient_mode=False,
)
advantage = advantage.to(device)
trainer.register_op(
"process_optim_batch",
advantage,
)

final_seed = collector.set_seed(cfg.seed)
print(f"init seed: {cfg.seed}, final seed: {final_seed}")

trainer.train()
return (logger.log_dir, trainer._log_dict)


if __name__ == "__main__":
main()
41 changes: 41 additions & 0 deletions examples/a2c/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Environment
env_library: gym # env_library used for the simulated environment.
env_name: HalfCheetah-v4 # name of the environment to be created. Default=Humanoid-v2
frame_skip: 2 # frame_skip for the environment.

# Logger
logger: wandb # recorder type to be used. One of 'tensorboard', 'wandb' or 'csv'
record_video: False # whether a video of the task should be rendered during logging.
exp_name: A2C # experiment name. Used for logging directory.
record_interval: 100 # number of batch collections in between two collections of validation rollouts. Default=1000.

# Collector
frames_per_batch: 64 # Number of steps executed in the environment per collection.
total_frames: 2_000_000 # total number of frames collected for training. Does account for frame_skip.
num_workers: 2 # Number of workers used for data collection.
env_per_collector: 2 # Number of environments per collector. If the env_per_collector is in the range:

# Model
default_policy_scale: 1.0 # Default policy scale parameter
distribution: tanh_normal # if True, uses a Tanh-Normal-Tanh distribution for the policy
lstm: False # if True, uses an LSTM for the policy.
shared_mapping: False # if True, the first layers of the actor-critic are shared.

# Objective
gamma: 0.99
entropy_coef: 0.01 # Entropy factor for the A2C loss
critic_coef: 0.25 # Critic factor for the A2C loss
critic_loss_function: l2 # loss function for the value network. Either one of l1, l2 or smooth_l1 (default).
advantage_in_loss: False # if True, the advantage is computed on the sub-batch

# Trainer
optim_steps_per_batch: 1 # Number of optimization steps in between two collection of data.
optimizer: adam # Optimizer to be used.
lr_scheduler: "" # LR scheduler.
batch_size: 64 # batch size of the TensorDict retrieved from the replay buffer. Default=256.
log_interval: 1 # logging interval, in terms of optimization steps. Default=10000.
lr: 0.0007 # Learning rate used for the optimizer. Default=3e-4.
normalize_rewards_online: True # Computes the running statistics of the rewards and normalizes them before they are passed to the loss module.
normalize_rewards_online_scale: 1.0 # Final scale of the normalized rewards.
normalize_rewards_online_decay: 0.0 # Decay of the reward moving averaging
sub_traj_len: 64 # length of the trajectories that sub-samples must have in online settings.
Loading