Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions examples/ppo/config.yaml
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
env_name: HalfCheetah-v4
env_task: ""
env_library: gym
env_name: cheetah
env_task: run
env_library: dm_control
async_collection: 0
record_video: 1
normalize_rewards_online: 1
normalize_rewards_online_scale: 5
frame_skip: 1
normalize_rewards_online_scale: 1
normalize_rewards_online_decay: 0.0 # normalize rewards with latest collected batch
frame_skip: 2
frames_per_batch: 1000
optim_steps_per_batch: 10
batch_size: 256
optim_steps_per_batch: 20
batch_size: 256 # each training batch will have 256 frames
sub_traj_len: 64 # we want 4 trajectories of length 64 = 256 frames
total_frames: 1000000
lr: 3e-4
lr: 2e-4
from_pixels: 0
#collector_devices: [cuda:1]
collector_devices: [cpu]
env_per_collector: 4
num_workers: 4
lr_scheduler: ""
record_interval: 100
max_frames_per_traj: -1
env_per_collector: 2
num_workers: 2
lr_scheduler: cosine
record_interval: 20
max_frames_per_traj: 1000
weight_decay: 0.0
init_env_steps: 10000
record_frames: 50000
init_env_steps: 1000
record_frames: 1000
loss_function: smooth_l1
batch_transform: 1
entropy_coef: 0.1
default_policy_scale: 1.0
advantage_in_loss: 1
19 changes: 19 additions & 0 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.objectives import GAE
from torchrl.record import VideoRecorder
from torchrl.trainers.helpers.collectors import (
make_collector_onpolicy,
Expand Down Expand Up @@ -173,6 +174,24 @@ def main(cfg: "DictConfig"):
if cfg.loss == "kl":
trainer.register_op("pre_optim_steps", loss_module.reset)

if not cfg.advantage_in_loss:
critic_model = model.get_value_operator()
advantage = GAE(
cfg.gamma,
cfg.lmbda,
value_network=critic_model,
average_rewards=True,
gradient_mode=False,
)
trainer.register_op(
"process_optim_batch",
advantage,
)
trainer._process_optim_batch_ops = [
trainer._process_optim_batch_ops[-1],
*trainer._process_optim_batch_ops[:-1],
]

final_seed = collector.set_seed(cfg.seed)
print(f"init seed: {cfg.seed}, final seed: {final_seed}")

Expand Down
12 changes: 3 additions & 9 deletions torchrl/objectives/costs/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch import distributions as d

from torchrl.data.tensordict.tensordict import TensorDictBase, TensorDict
from torchrl.envs.utils import step_tensordict
from torchrl.modules import TensorDictModule
from ...modules.tensordict_module import ProbabilisticTensorDictModule

Expand Down Expand Up @@ -129,19 +128,14 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
loss_function=self.loss_critic_type,
)
else:
with torch.no_grad():
reward = tensordict.get("reward")
next_td = step_tensordict(tensordict)
next_value = self.critic(
next_td, params=self.critic_params, buffers=self.critic_buffers
).get("state_value")
value_target = reward + next_value * self.gamma
tensordict_select = tensordict.select(*self.critic.in_keys).clone()
advantage = tensordict.get(self.advantage_key)
tensordict_select = tensordict.select(*self.critic.in_keys)
value = self.critic(
tensordict_select,
params=self.critic_params,
buffers=self.critic_buffers,
).get("state_value")
value_target = advantage + value.detach()
loss_value = distance_loss(
value, value_target, loss_function=self.loss_critic_type
)
Expand Down
19 changes: 12 additions & 7 deletions torchrl/trainers/helpers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,16 @@ def make_ppo_loss(model, cfg) -> PPOLoss:
actor_model = model.get_policy_operator()
critic_model = model.get_value_operator()

advantage = GAE(
cfg.gamma,
cfg.lmbda,
value_network=critic_model,
average_rewards=True,
gradient_mode=False,
)
if cfg.advantage_in_loss:
advantage = GAE(
cfg.gamma,
cfg.lmbda,
value_network=critic_model,
average_rewards=True,
gradient_mode=False,
)
else:
advantage = None
loss_module = loss_dict[cfg.loss](
actor=actor_model,
critic=critic_model,
Expand Down Expand Up @@ -245,3 +248,5 @@ class PPOLossConfig:
# Entropy factor for the PPO loss
loss_function: str = "smooth_l1"
# loss function for the value network. Either one of l1, l2 or smooth_l1 (default).
advantage_in_loss: bool = False
# if True, the advantage is computed on the sub-batch.
12 changes: 8 additions & 4 deletions torchrl/trainers/helpers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
ReplayBufferTrainer,
LogReward,
RewardNormalizer,
mask_batch,
BatchSubSampler,
UpdateWeights,
Recorder,
Expand Down Expand Up @@ -70,7 +69,9 @@ class TrainerConfig:
normalize_rewards_online: bool = False
# Computes the running statistics of the rewards and normalizes them before they are passed to the loss module.
normalize_rewards_online_scale: float = 1.0
# Final value of the normalized rewards.
# Final scale of the normalized rewards.
normalize_rewards_online_decay: float = 0.9999
# Decay of the reward moving averaging
sub_traj_len: int = -1
# length of the trajectories that sub-samples must have in online settings.

Expand Down Expand Up @@ -218,7 +219,7 @@ def make_trainer(
trainer.register_op("process_optim_batch", rb_trainer.sample)
trainer.register_op("post_loss", rb_trainer.update_priority)
else:
trainer.register_op("batch_process", mask_batch)
# trainer.register_op("batch_process", mask_batch)
trainer.register_op(
"process_optim_batch",
BatchSubSampler(batch_size=cfg.batch_size, sub_traj_len=cfg.sub_traj_len),
Expand All @@ -233,7 +234,10 @@ def make_trainer(
if cfg.normalize_rewards_online:
# if used the running statistics of the rewards are computed and the
# rewards used for training will be normalized based on these.
reward_normalizer = RewardNormalizer(scale=cfg.normalize_rewards_online_scale)
reward_normalizer = RewardNormalizer(
scale=cfg.normalize_rewards_online_scale,
decay=cfg.normalize_rewards_online_decay,
)
trainer.register_op("batch_process", reward_normalizer.update_reward_stats)
trainer.register_op("process_optim_batch", reward_normalizer.normalize_reward)

Expand Down
2 changes: 1 addition & 1 deletion torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ def __init__(
frame_skip: int,
policy_exploration: TensorDictModule,
recorder: EnvBase,
exploration_mode: str = "mode",
exploration_mode: str = "mean",
log_keys: Optional[List[str]] = None,
out_keys: Optional[Dict[str, str]] = None,
suffix: Optional[str] = None,
Expand Down