Skip to content

Commit

Permalink
fix(tf2): correct off-by-one error in learning rate decay calculation (
Browse files Browse the repository at this point in the history
…#415)

This commit resolves an issue that led to incorrect learning rate decay.
The root cause was an off-by-one error in the step count, which skewed
the decay calculation. With this fix, the learning rate now decays
accurately according to the specified schedule.
  • Loading branch information
rickstaa committed Feb 19, 2024
1 parent 27964fe commit 6ab5001
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 29 deletions.
22 changes: 7 additions & 15 deletions stable_learning_control/algos/tf2/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,9 @@ def lac(
# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
# NOTE: Decay applied at policy update to improve performance.
lr_decay_steps = (total_steps - update_after) / update_every
lr_decay_steps = (
total_steps - update_after
) / update_every + 1 # NOTE: +1 since we start at the initial learning rate.
else:
lr_decay_steps = epochs

Expand All @@ -1110,16 +1112,6 @@ def lac(
lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps)
lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps)

# Create step based learning rate schedulers.
# NOTE: Used to estimate the learning rate at each step.
if lr_decay_ref == "step":
lr_a_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1
)
lr_c_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1
)

# Restore policy if supplied.
if start_policy is not None:
logger.log(f"Restoring model from '{start_policy}'.", type="info")
Expand Down Expand Up @@ -1303,10 +1295,10 @@ def lac(
# Retrieve current learning rates.
if lr_decay_ref == "step":
progress = max((t + 1) - update_after, 0) / update_every
lr_actor = lr_a_step_scheduler(progress)
lr_critic = lr_c_step_scheduler(progress)
lr_alpha = lr_a_step_scheduler(progress)
lr_labda = lr_a_step_scheduler(progress)
lr_actor = lr_a_scheduler(progress)
lr_critic = lr_c_scheduler(progress)
lr_alpha = lr_a_scheduler(progress)
lr_labda = lr_a_scheduler(progress)
else:
lr_actor = policy._pi_optimizer.lr.numpy()
lr_critic = policy._c_optimizer.lr.numpy()
Expand Down
20 changes: 6 additions & 14 deletions stable_learning_control/algos/tf2/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,9 @@ def sac(
# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
# NOTE: Decay applied at policy update to improve performance.
lr_decay_steps = (total_steps - update_after) / update_every
lr_decay_steps = (
total_steps - update_after
) / update_every + 1 # NOTE: +1 since we start at the initial learning rate.
else:
lr_decay_steps = epochs

Expand All @@ -977,16 +979,6 @@ def sac(
lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps)
lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps)

# Create step based learning rate schedulers.
# NOTE: Used to estimate the learning rate at each step.
if lr_decay_ref == "step":
lr_a_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1
)
lr_c_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1
)

# Restore policy if supplied.
if start_policy is not None:
logger.log(f"Restoring model from '{start_policy}'.", type="info")
Expand Down Expand Up @@ -1138,9 +1130,9 @@ def sac(
# Retrieve current learning rates.
if lr_decay_ref == "step":
progress = max((t + 1) - update_after, 0) / update_every
lr_actor = lr_a_step_scheduler(progress)
lr_critic = lr_c_step_scheduler(progress)
lr_alpha = lr_a_step_scheduler(progress)
lr_actor = lr_a_scheduler(progress)
lr_critic = lr_c_scheduler(progress)
lr_alpha = lr_a_scheduler(progress)
else:
lr_actor = policy._pi_optimizer.lr.numpy()
lr_critic = policy._c_optimizer.lr.numpy()
Expand Down

0 comments on commit 6ab5001

Please sign in to comment.