Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions torchrl/trainers/helpers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,33 @@ def make_ppo_loss(model, cfg) -> PPOLoss:
)
else:
advantage = None
loss_module = loss_dict[cfg.loss](
actor=actor_model,
critic=critic_model,
advantage_module=advantage,
loss_critic_type=cfg.loss_function,
entropy_coef=cfg.entropy_coef,
)

kwargs = {
"actor": actor_model,
"critic": critic_model,
"advantage_module": advantage,
"loss_critic_type": cfg.loss_function,
"entropy_coef": cfg.entropy_coef,
}

if cfg.loss == "clip":
kwargs.update(
{
"clip_epsilon": cfg.clip_epsilon,
}
)
elif cfg.loss == "kl":
kwargs.update(
{
"dtarg": cfg.dtarg,
"beta": cfg.beta,
"increment": cfg.increment,
"decrement": cfg.decrement,
"samples_mc_kl": cfg.samples_mc_kl,
}
)

loss_module = loss_dict[cfg.loss](**kwargs)
return loss_module


Expand Down Expand Up @@ -279,13 +299,37 @@ class PPOLossConfig:

loss: str = "clip"
# PPO loss class, either clip or kl or base/<empty>. Default=clip

# PPOLoss base parameters:
gamma: float = 0.99
# Decay factor for return computation. Default=0.99.
lmbda: float = 0.95
# lambda factor in GAE (using 'lambda' as attribute is prohibited in python, hence the misspelling)
entropy_bonus: bool = True
# Whether or not to add an entropy term to the PPO loss.
entropy_coef: float = 1e-3
# Entropy factor for the PPO loss
samples_mc_entropy: int = 1
# Number of samples to use for a Monte-Carlo estimate if the policy distribution has not closed formula.
loss_function: str = "smooth_l1"
# loss function for the value network. Either one of l1, l2 or smooth_l1 (default).
advantage_in_loss: bool = False
# if True, the advantage is computed on the sub-batch.
# if True, the advantage is computed on the sub-batch.,
critic_coef: float = 1.0
# Critic loss multiplier when computing the total loss.

# ClipPPOLoss parameters:
clip_epsilon: float = 0.2
# weight clipping threshold in the clipped PPO loss equation.

# KLPENPPOLoss parameters:
dtarg: float = 0.01
# target KL divergence.
beta: float = 1.0
# initial KL divergence multiplier.
increment: float = 2
# how much beta should be incremented if KL > dtarg. Valid range: increment >= 1.0
decrement: float = 0.5
# how much beta should be decremented if KL < dtarg. Valid range: decrement <= 1.0
samples_mc_kl: int = 1
# Number of samples to use for a Monte-Carlo estimate of KL if necessary