diff --git a/torchrl/trainers/helpers/losses.py b/torchrl/trainers/helpers/losses.py index 92e304a895f..f4ae7a9a842 100644 --- a/torchrl/trainers/helpers/losses.py +++ b/torchrl/trainers/helpers/losses.py @@ -222,13 +222,33 @@ def make_ppo_loss(model, cfg) -> PPOLoss: ) else: advantage = None - loss_module = loss_dict[cfg.loss]( - actor=actor_model, - critic=critic_model, - advantage_module=advantage, - loss_critic_type=cfg.loss_function, - entropy_coef=cfg.entropy_coef, - ) + + kwargs = { + "actor": actor_model, + "critic": critic_model, + "advantage_module": advantage, + "loss_critic_type": cfg.loss_function, + "entropy_coef": cfg.entropy_coef, + } + + if cfg.loss == "clip": + kwargs.update( + { + "clip_epsilon": cfg.clip_epsilon, + } + ) + elif cfg.loss == "kl": + kwargs.update( + { + "dtarg": cfg.dtarg, + "beta": cfg.beta, + "increment": cfg.increment, + "decrement": cfg.decrement, + "samples_mc_kl": cfg.samples_mc_kl, + } + ) + + loss_module = loss_dict[cfg.loss](**kwargs) return loss_module @@ -279,13 +299,37 @@ class PPOLossConfig: loss: str = "clip" # PPO loss class, either clip or kl or base/. Default=clip + + # PPOLoss base parameters: gamma: float = 0.99 # Decay factor for return computation. Default=0.99. lmbda: float = 0.95 # lambda factor in GAE (using 'lambda' as attribute is prohibited in python, hence the misspelling) + entropy_bonus: bool = True + # Whether or not to add an entropy term to the PPO loss. entropy_coef: float = 1e-3 # Entropy factor for the PPO loss + samples_mc_entropy: int = 1 + # Number of samples to use for a Monte-Carlo estimate if the policy distribution has not closed formula. loss_function: str = "smooth_l1" # loss function for the value network. Either one of l1, l2 or smooth_l1 (default). advantage_in_loss: bool = False - # if True, the advantage is computed on the sub-batch. + # if True, the advantage is computed on the sub-batch., + critic_coef: float = 1.0 + # Critic loss multiplier when computing the total loss. + + # ClipPPOLoss parameters: + clip_epsilon: float = 0.2 + # weight clipping threshold in the clipped PPO loss equation. + + # KLPENPPOLoss parameters: + dtarg: float = 0.01 + # target KL divergence. + beta: float = 1.0 + # initial KL divergence multiplier. + increment: float = 2 + # how much beta should be incremented if KL > dtarg. Valid range: increment >= 1.0 + decrement: float = 0.5 + # how much beta should be decremented if KL < dtarg. Valid range: decrement <= 1.0 + samples_mc_kl: int = 1 + # Number of samples to use for a Monte-Carlo estimate of KL if necessary