Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions examples/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,16 @@ def main(cfg: "DictConfig"): # noqa: F821
cfg=cfg,
)

if not cfg.advantage_in_loss:
critic_model = model.get_value_operator()
advantage = TDEstimate(
cfg.gamma,
value_network=critic_model,
average_rewards=True,
gradient_mode=False,
)
advantage = advantage.to(device)
trainer.register_op(
"process_optim_batch",
advantage,
)
critic_model = model.get_value_operator()
advantage = TDEstimate(
cfg.gamma,
value_network=critic_model,
average_rewards=True,
)
trainer.register_op(
"process_optim_batch",
advantage,
)

final_seed = collector.set_seed(cfg.seed)
print(f"init seed: {cfg.seed}, final seed: {final_seed}")
Expand Down
1 change: 0 additions & 1 deletion examples/a2c/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ gamma: 0.99
entropy_coef: 0.01 # Entropy factor for the A2C loss
critic_coef: 0.25 # Critic factor for the A2C loss
critic_loss_function: l2 # loss function for the value network. Either one of l1, l2 or smooth_l1 (default).
advantage_in_loss: False # if True, the advantage is computed on the sub-batch

# Trainer
optim_steps_per_batch: 1 # Number of optimization steps in between two collection of data.
Expand Down
1 change: 0 additions & 1 deletion examples/ppo/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,3 @@ loss_function: smooth_l1
batch_transform: 1
entropy_coef: 0.1
default_policy_scale: 1.0
advantage_in_loss: 1
32 changes: 15 additions & 17 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,23 +169,21 @@ def main(cfg: "DictConfig"): # noqa: F821
if cfg.loss == "kl":
trainer.register_op("pre_optim_steps", loss_module.reset)

if not cfg.advantage_in_loss:
critic_model = model.get_value_operator()
advantage = GAE(
cfg.gamma,
cfg.lmbda,
value_network=critic_model,
average_rewards=True,
gradient_mode=False,
)
trainer.register_op(
"process_optim_batch",
advantage,
)
trainer._process_optim_batch_ops = [
trainer._process_optim_batch_ops[-1],
*trainer._process_optim_batch_ops[:-1],
]
critic_model = model.get_value_operator()
advantage = GAE(
cfg.gamma,
cfg.lmbda,
value_network=critic_model,
average_gae=True,
)
trainer.register_op(
"process_optim_batch",
lambda tensordict: advantage(tensordict.to(device)),
)
trainer._process_optim_batch_ops = [
trainer._process_optim_batch_ops[-1],
*trainer._process_optim_batch_ops[:-1],
]

final_seed = collector.set_seed(cfg.seed)
print(f"init seed: {cfg.seed}, final seed: {final_seed}")
Expand Down
Loading