Skip to content

Commit

Permalink
fix(pytorch): correct epoch-based learning rate decay behavior (#410)
Browse files Browse the repository at this point in the history
This commit addresses a bug in the epoch-based learning rate decay
mechanism. Previously, the decay process did not correctly reach the
specified final learning rate. This fix ensures that the learning rate
accurately decays to the intended final value throughout the epochs.
  • Loading branch information
rickstaa committed Feb 19, 2024
1 parent 4b1154c commit a8df90f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 19 deletions.
17 changes: 9 additions & 8 deletions stable_learning_control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,8 +724,8 @@ def bound_lr(
if self._c_optimizer.param_groups[0]["lr"] < lr_c_final:
self._c_optimizer.param_groups[0]["lr"] = lr_c_final
if lr_alpha_final is not None:
if self._log_alpha_optimizer.param_groups[0]["lr"] < lr_a_final:
self._log_alpha_optimizer.param_groups[0]["lr"] = lr_a_final
if self._log_alpha_optimizer.param_groups[0]["lr"] < lr_alpha_final:
self._log_alpha_optimizer.param_groups[0]["lr"] = lr_alpha_final
if lr_labda_final is not None:
if self._log_labda_optimizer.param_groups[0]["lr"] < lr_labda_final:
self._log_labda_optimizer.param_groups[0]["lr"] = lr_labda_final
Expand Down Expand Up @@ -1216,32 +1216,33 @@ def lac(
# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
# NOTE: Decay applied at policy update to improve performance.
lr_decay_steps = (total_steps - update_after) / update_every
lr_decay_steps = (
total_steps - update_after
) / update_every + 1 # NOTE: +1 since we start at the initial learning rate.
else:
lr_decay_steps = epochs

# Setup learning rate schedulers.
# NOTE: +1 since we start at the initial learning rate.
opt_schedulers = {
"pi": get_lr_scheduler(
policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1
policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_steps
),
"c": get_lr_scheduler(
policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1
policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_steps
),
"alpha": get_lr_scheduler(
policy._log_alpha_optimizer,
lr_decay_type,
lr_a,
lr_a_final,
lr_decay_steps + 1,
lr_decay_steps,
),
"lambda": get_lr_scheduler(
policy._log_labda_optimizer,
lr_decay_type,
lr_a,
lr_a_final,
lr_decay_steps + 1,
lr_decay_steps,
),
}

Expand Down
19 changes: 8 additions & 11 deletions stable_learning_control/algos/pytorch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,8 +626,8 @@ def bound_lr(self, lr_a_final=None, lr_c_final=None, lr_alpha_final=None):
if self._c_optimizer.param_groups[0]["lr"] < lr_c_final:
self._c_optimizer.param_groups[0]["lr"] = lr_c_final
if lr_alpha_final is not None:
if self._log_alpha_optimizer.param_groups[0]["lr"] < lr_a_final:
self._log_alpha_optimizer.param_groups[0]["lr"] = lr_a_final
if self._log_alpha_optimizer.param_groups[0]["lr"] < lr_alpha_final:
self._log_alpha_optimizer.param_groups[0]["lr"] = lr_alpha_final

def _update_targets(self):
"""Updates the target networks based on a Exponential moving average
Expand Down Expand Up @@ -1067,25 +1067,22 @@ def sac(
# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
# NOTE: Decay applied at policy update to improve performance.
lr_decay_steps = (total_steps - update_after) / update_every
lr_decay_steps = (
total_steps - update_after
) / update_every + 1 # NOTE: +1 since we start at the initial learning rate.
else:
lr_decay_steps = epochs

# Setup learning rate schedulers.
# NOTE: +1 since we start at the initial learning rate.
opt_schedulers = {
"pi": get_lr_scheduler(
policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1
policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_steps
),
"c": get_lr_scheduler(
policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1
policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_steps
),
"alpha": get_lr_scheduler(
policy._log_alpha_optimizer,
lr_decay_type,
lr_a,
lr_a_final,
lr_decay_steps + 1,
policy._log_alpha_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_steps
),
}

Expand Down

0 comments on commit a8df90f

Please sign in to comment.