Skip to content

Commit

Permalink
fix(pytorch): correct step-based learning rate decay (#405)
Browse files Browse the repository at this point in the history
This commit addresses an issue with the step-based learning rate decay
mechanism when `lr_decay_ref` is set to 'step'. Previously, the learning
rate was decaying too rapidly due to a bug in the decay logic. This fix
ensures that the learning rate decays at the correct pace as per the
step-based decay configuration.
  • Loading branch information
rickstaa committed Feb 11, 2024
1 parent 99126f7 commit 7d7ac76
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 80 deletions.
68 changes: 59 additions & 9 deletions stable_learning_control/algos/pytorch/common/get_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,24 @@ def get_exponential_decay_rate(lr_start, lr_final, steps):
return gamma


def calc_linear_decay_rate(lr_init, lr_final, steps):
r"""Returns the linear decay factor (G) needed to achieve a given final learning
rate at a certain step. This decay factor can for example be used with a
:class:`torch.optim.lr_scheduler.LambdaLR` scheduler. Keep in mind that this
function assumes the following formula for the learning rate decay.
def get_linear_decay_rate(lr_init, lr_final, steps):
r"""Returns a linear decay factor (G) that enables a learning rate to transition
from an initial value (`lr_init`) at step 0 to a final value (`lr_final`) at a
specified step (N). This decay factor is compatible with the
:class:`torch.optim.lr_scheduler.LambdaLR` scheduler. The decay factor is calculated
using the following formula:
.. math::
lr_{terminal} = lr_{init} * (1.0 - G \cdot step)
Args:
lr_init (float): The initial learning rate.
lr_final (float): The final learning rate you want to achieve.
steps (int): The step/epoch at which you want to achieve this learning rate.
steps (int): The number of steps/epochs over which the learning rate should
decay. This is equal to epochs - 1.
Returns:
decimal.Decimal: Linear learning rate decay factor (G)
decimal.Decimal: Linear learning rate decay factor (G).
""" # noqa: W605
return -(
((Decimal(lr_final) / Decimal(lr_init)) - Decimal(1.0)) / Decimal(max(steps, 1))
Expand All @@ -53,7 +55,7 @@ def get_lr_scheduler(optimizer, decaying_lr_type, lr_start, lr_final, steps):
(options are: ``linear`` and ``exponential`` and ``constant``).
lr_start (float): Initial learning rate.
lr_final (float): Final learning rate.
steps (int, optional): Number of steps/epochs used in the training. This
steps (int, optional): Number of steps/epochs used in the training. This
includes the starting step.
Returns:
Expand Down Expand Up @@ -83,7 +85,7 @@ def lr_multiplier_function(step):
return np.longdouble(
Decimal(1.0)
- (
calc_linear_decay_rate(lr_start, lr_final, (steps - 1.0))
get_linear_decay_rate(lr_start, lr_final, (steps - 1.0))
* Decimal(step)
)
)
Expand All @@ -96,3 +98,51 @@ def lr_multiplier_function(step):
return torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda step: np.longdouble(1.0)
) # Return a constant function.


def estimate_step_learning_rate(
lr_scheduler, lr_start, lr_final, update_after, total_steps, step
):
"""Estimates the learning rate at a given step.
This function estimates the learning rate for a specific training step. It differs
from the `get_last_lr` method of the learning rate scheduler, which returns the
learning rate at the last scheduler step, not necessarily the current training step.
Args:
lr_scheduler (torch.optim.lr_scheduler): The learning rate scheduler.
lr_start (float): The initial learning rate.
update_after (int): The step number after which the learning rate should start
decreasing.
lr_final (float): The final learning rate.
total_steps (int): The total number of steps/epochs in the training process.

Check warning on line 118 in stable_learning_control/algos/pytorch/common/get_lr_scheduler.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] stable_learning_control/algos/pytorch/common/get_lr_scheduler.py#L118 <291>

trailing whitespace
Raw output
./stable_learning_control/algos/pytorch/common/get_lr_scheduler.py:118:85: W291 trailing whitespace
Excludes the initial step.
step (int): The current step number. Excludes the initial step.
Returns:
float: The learning rate at the given step.
"""
if step < update_after:
return lr_start
else:
adjusted_step = step - update_after
adjusted_total_steps = total_steps - update_after
if isinstance(lr_scheduler, torch.optim.lr_scheduler.LambdaLR):
decay_rate = get_linear_decay_rate(lr_start, lr_final, adjusted_total_steps)
lr = float(
Decimal(lr_start) * (Decimal(1.0) - decay_rate * Decimal(adjusted_step))
)
elif isinstance(lr_scheduler, torch.optim.lr_scheduler.ExponentialLR):
decay_rate = get_exponential_decay_rate(
lr_start, lr_final, adjusted_total_steps
)
lr = float(
Decimal(lr_start) * (Decimal(decay_rate) ** Decimal(adjusted_step))
)
else:
supported_schedulers = ["LambdaLR", "ExponentialLR"]
raise ValueError(
f"The learning rate scheduler is not supported for this function. "
f"Supported schedulers are: {', '.join(supported_schedulers)}"
)
return max(lr, lr_final)
147 changes: 107 additions & 40 deletions stable_learning_control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from stable_learning_control.algos.pytorch.common.get_lr_scheduler import (
get_lr_scheduler,
estimate_step_learning_rate,
)
from stable_learning_control.algos.pytorch.common.helpers import (
count_vars,
Expand Down Expand Up @@ -1111,7 +1112,7 @@ def lac(
actor_critic = LyapunovActorCritic if actor_critic is None else actor_critic

# Ensure the environment is correctly seeded.
# NOTE: Done here since we donote:n't want to seed on every env.reset() call.
# NOTE: Done here since we don't want to seed on every env.reset() call.
if seed is not None:
env.np_random, _ = seeding.np_random(seed)
env.action_space.seed(seed)
Expand Down Expand Up @@ -1197,29 +1198,51 @@ def lac(
logger.log("Network structure:\n", type="info")
logger.log(policy.ac, end="\n\n")

# Create learning rate schedulers.
opt_schedulers = []
lr_decay_ref_var = total_steps if lr_decay_ref.lower() == "steps" else epochs
pi_opt_scheduler = get_lr_scheduler(
policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_ref_var
)
opt_schedulers.append(pi_opt_scheduler)
alpha_opt_scheduler = get_lr_scheduler(
policy._log_alpha_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_ref_var
)
opt_schedulers.append(alpha_opt_scheduler)
c_opt_scheduler = get_lr_scheduler(
policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_ref_var
)
opt_schedulers.append(c_opt_scheduler)
labda_opt_scheduler = get_lr_scheduler(
policy._log_labda_optimizer,
lr_decay_type,
lr_a,
lr_a_final,
lr_decay_ref_var,
)
opt_schedulers.append(labda_opt_scheduler)
# Parse learning rate decay type.
valid_lr_decay_options = ["step", "epoch"]
lr_decay_ref = lr_decay_ref.lower()
if lr_decay_ref not in valid_lr_decay_options:
options = [f"'{option}'" for option in valid_lr_decay_options]
logger.log(
f"The learning rate decay reference variable was set to '{lr_decay_ref}', "
"which is not a valid option. Valid options are "
f"{', '.join(options)}. The learning rate decay reference "
"variable has been set to 'epoch'.",
type="warning",
)
lr_decay_ref = "epoch"

# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
# NOTE: Decay applied at policy update to improve performance.
lr_decay_steps = (total_steps - update_after) / update_every
else:
lr_decay_steps = epochs

# Setup learning rate schedulers.
# NOTE: +1 since we start at the initial learning rate.
opt_schedulers = {
"pi": get_lr_scheduler(
policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1
),
"c": get_lr_scheduler(
policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1
),
"alpha": get_lr_scheduler(
policy._log_alpha_optimizer,
lr_decay_type,
lr_a,
lr_a_final,
lr_decay_steps + 1,
),
"lambda": get_lr_scheduler(
policy._log_labda_optimizer,
lr_decay_type,
lr_a,
lr_a_final,
lr_decay_steps + 1,
),
}

logger.setup_pytorch_saver(policy)

Expand Down Expand Up @@ -1253,6 +1276,7 @@ def lac(
"Entropy",
]
if use_tensorboard:
# NOTE: TensorBoard counts from 0.
logger.log_to_tb(
"Lr_a",
policy._pi_optimizer.param_groups[0]["lr"],
Expand Down Expand Up @@ -1321,16 +1345,18 @@ def lac(
logger.store(**update_diagnostics) # Log diagnostics.

# Step based learning rate decay.
if lr_decay_ref.lower() == "step":
for scheduler in opt_schedulers:
if lr_decay_ref == "step":
for scheduler in opt_schedulers.values():
scheduler.step()
policy.bound_lr(
lr_a_final, lr_c_final, lr_a_final, lr_a_final
) # Make sure lr is bounded above the final lr.

# SGD batch tb logging.
if use_tensorboard and not tb_low_log_freq:
logger.log_to_tb(keys=diag_tb_log_list, global_step=t)
logger.log_to_tb(
keys=diag_tb_log_list, global_step=t
) # NOTE: TensorBoard counts from 0.

# End of epoch handling (Save model, test performance and log data)
if (t + 1) % steps_per_epoch == 0:
Expand All @@ -1349,17 +1375,50 @@ def lac(
extend=True,
)

# Epoch based learning rate decay.
if lr_decay_ref.lower() != "step":
for scheduler in opt_schedulers:
scheduler.step()
policy.bound_lr(
lr_a_final, lr_c_final, lr_a_final, lr_a_final
) # Make sure lr is bounded above the final lr.
# Retrieve current learning rates.
if lr_decay_ref == "step":
# NOTE: Estimate since 'step' decay is applied at policy update.
lr_actor = estimate_step_learning_rate(
opt_schedulers["pi"],
lr_a,
lr_a_final,
update_after,
total_steps,
t + 1,
)
lr_critic = estimate_step_learning_rate(
opt_schedulers["c"],
lr_c,
lr_c_final,
update_after,
total_steps,
t + 1,
)
lr_alpha = estimate_step_learning_rate(
opt_schedulers["alpha"],
lr_a,
lr_a_final,
update_after,
total_steps,
t + 1,
)
lr_labda = estimate_step_learning_rate(
opt_schedulers["lambda"],
lr_a,
lr_a_final,
update_after,
total_steps,
t + 1,
)
else:
lr_actor = policy._pi_optimizer.param_groups[0]["lr"]
lr_critic = policy._c_optimizer.param_groups[0]["lr"]
lr_alpha = policy._log_alpha_optimizer.param_groups[0]["lr"]
lr_labda = policy._log_labda_optimizer.param_groups[0]["lr"]

# Log info about epoch.
logger.log_tabular("Epoch", epoch)
logger.log_tabular("TotalEnvInteracts", t)
logger.log_tabular("TotalEnvInteracts", t + 1)
logger.log_tabular(
"EpRet",
with_min_and_max=True,
Expand All @@ -1379,25 +1438,25 @@ def lac(
)
logger.log_tabular(
"Lr_a",
policy._pi_optimizer.param_groups[0]["lr"],
lr_actor,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_c",
policy._c_optimizer.param_groups[0]["lr"],
lr_critic,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_alpha",
policy._log_alpha_optimizer.param_groups[0]["lr"],
lr_alpha,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_labda",
policy._log_labda_optimizer.param_groups[0]["lr"],
lr_labda,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
Expand Down Expand Up @@ -1440,7 +1499,15 @@ def lac(
tb_write=(use_tensorboard and tb_low_log_freq),
)
logger.log_tabular("Time", time.time() - start_time)
logger.dump_tabular(global_step=t)
logger.dump_tabular(global_step=t) # NOTE: TensorBoard counts from 0.

# Epoch based learning rate decay.
if lr_decay_ref != "step":
for scheduler in opt_schedulers.values():
scheduler.step()
policy.bound_lr(
lr_a_final, lr_c_final, lr_a_final, lr_a_final
) # Make sure lr is bounded above the final lr.

# Export model to 'TorchScript'
if export:
Expand Down

0 comments on commit 7d7ac76

Please sign in to comment.