Skip to content

Commit

Permalink
fix(pytorch): ensure correct application of constant learning rate (#411
Browse files Browse the repository at this point in the history
)

This commit rectifies an issue with the constant learning rate
scheduler. Previously, the constant learning rate was not being applied
as expected when selected by the user. This fix ensures that the
scheduler correctly maintains a constant learning rate throughout the
training process when it's chosen.
  • Loading branch information
rickstaa committed Feb 19, 2024
1 parent a8df90f commit 2b3693e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 27 deletions.
65 changes: 39 additions & 26 deletions stable_learning_control/algos/pytorch/common/get_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@
import numpy as np
import torch

import torch.optim


class ConstantLRScheduler(torch.optim.lr_scheduler.LambdaLR):
"""A learning rate scheduler that keeps the learning rate constant."""

def __init__(self, optimizer):
"""Initialize the constant learning rate scheduler.
Args:
optimizer (:class:`torch.optim.Optimizer`): The wrapped optimizer.
"""
super().__init__(optimizer, lr_lambda=lambda step: np.longdouble(1.0))


def get_exponential_decay_rate(lr_start, lr_final, steps):
"""Calculates the exponential decay rate needed to go from a initial learning rate
Expand Down Expand Up @@ -56,7 +70,7 @@ def get_lr_scheduler(optimizer, decaying_lr_type, lr_start, lr_final, steps):
lr_start (float): Initial learning rate.
lr_final (float): Final learning rate.
steps (int, optional): Number of steps/epochs used in the training. This
includes the starting step.
includes the starting step/epoch.
Returns:
:obj:`torch.optim.lr_scheduler`: A learning rate scheduler object.
Expand Down Expand Up @@ -95,9 +109,7 @@ def lr_multiplier_function(step):
)
return lr_scheduler
else:
return torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda step: np.longdouble(1.0)
) # Return a constant function.
return ConstantLRScheduler(optimizer)


def estimate_step_learning_rate(
Expand All @@ -122,27 +134,28 @@ def estimate_step_learning_rate(
Returns:
float: The learning rate at the given step.
"""
if step < update_after:
if step < update_after or isinstance(lr_scheduler, ConstantLRScheduler):
return lr_start

# Estimate the learning rate at a given step for the lt_scheduler type.
adjusted_step = step - update_after
adjusted_total_steps = total_steps - update_after
if isinstance(lr_scheduler, torch.optim.lr_scheduler.LambdaLR):
decay_rate = get_linear_decay_rate(lr_start, lr_final, adjusted_total_steps)
lr = float(
Decimal(lr_start) * (Decimal(1.0) - decay_rate * Decimal(adjusted_step))
)
elif isinstance(lr_scheduler, torch.optim.lr_scheduler.ExponentialLR):
decay_rate = get_exponential_decay_rate(
lr_start, lr_final, adjusted_total_steps
)
lr = float(
Decimal(lr_start) * (Decimal(decay_rate) ** Decimal(adjusted_step))
)
else:
adjusted_step = step - update_after
adjusted_total_steps = total_steps - update_after
if isinstance(lr_scheduler, torch.optim.lr_scheduler.LambdaLR):
decay_rate = get_linear_decay_rate(lr_start, lr_final, adjusted_total_steps)
lr = float(
Decimal(lr_start) * (Decimal(1.0) - decay_rate * Decimal(adjusted_step))
)
elif isinstance(lr_scheduler, torch.optim.lr_scheduler.ExponentialLR):
decay_rate = get_exponential_decay_rate(
lr_start, lr_final, adjusted_total_steps
)
lr = float(
Decimal(lr_start) * (Decimal(decay_rate) ** Decimal(adjusted_step))
)
else:
supported_schedulers = ["LambdaLR", "ExponentialLR"]
raise ValueError(
f"The learning rate scheduler is not supported for this function. "
f"Supported schedulers are: {', '.join(supported_schedulers)}"
)
return max(lr, lr_final)
supported_schedulers = ["LambdaLR", "ExponentialLR"]
raise ValueError(
f"The learning rate scheduler is not supported for this function. "
f"Supported schedulers are: {', '.join(supported_schedulers)}"
)
return max(lr, lr_final)
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_lr_scheduler(decaying_lr_type, lr_start, lr_final, steps):
lr_start (float): Initial learning rate.
lr_final (float): Final learning rate.
steps (int, optional): Number of steps/epochs used in the training. This
includes the starting step.
includes the starting step/epoch.
Returns:
tensorflow.keras.optimizers.schedules.LearningRateSchedule: A learning rate
Expand Down

0 comments on commit 2b3693e

Please sign in to comment.