Skip to content

Commit

Permalink
feat(tf2): add alpha/lambda learning rate customization (#416)
Browse files Browse the repository at this point in the history
This commit enhances user control over the training process by allowing
direct customization of the alpha/lambda learning rates and their decay
rates. Users can now fine-tune these parameters to better suit their
specific training requirements.
  • Loading branch information
rickstaa committed Feb 20, 2024
1 parent 6ab5001 commit 712e94b
Show file tree
Hide file tree
Showing 3 changed files with 356 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_linear_decay_rate(lr_init, lr_final, steps):
lr_init (float): The initial learning rate.
lr_final (float): The final learning rate you want to achieve.
steps (int): The number of steps/epochs over which the learning rate should
decay. This is equal to epochs - 1.
decay. This is equal to epochs -1.
Returns:
decimal.Decimal: Linear learning rate decay factor (G).
Expand Down
239 changes: 206 additions & 33 deletions stable_learning_control/algos/tf2/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@
"AverageLossPi",
"AverageEntropy",
]
VALID_DECAY_TYPES = ["linear", "exponential", "constant"]
VALID_DECAY_REFERENCES = ["step", "epoch"]
DEFAULT_DECAY_TYPE = "linear"
DEFAULT_DECAY_REFERENCE = "epoch"

# tf.config.run_functions_eagerly(True) # NOTE: Uncomment for debugging.

Expand Down Expand Up @@ -106,6 +110,8 @@ def __init__(
adaptive_temperature=True,
lr_a=1e-4,
lr_c=3e-4,
lr_alpha=1e-4,
lr_labda=3e-4,
device="cpu",
name="LAC",
):
Expand Down Expand Up @@ -194,6 +200,10 @@ def __init__(
``1e-4``.
lr_c (float, optional): Learning rate used for the (lyapunov) critic.
Defaults to ``1e-4``.
lr_alpha (float, optional): Learning rate used for the entropy temperature.
Defaults to ``1e-4``.
lr_labda (float, optional): Learning rate used for the Lyapunov Lagrance
multiplier. Defaults to ``3e-4``.
device (str, optional): The device the networks are placed on (options:
``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``.
Expand Down Expand Up @@ -258,8 +268,8 @@ def __init__(
self._alpha3 = alpha3
self._lr_a = tf.Variable(lr_a, name="Lr_a")
if self._adaptive_temperature:
self._lr_alpha = tf.Variable(lr_a, name="Lr_alpha")
self._lr_lag = tf.Variable(lr_a, name="Lr_lag")
self._lr_alpha = tf.Variable(lr_alpha, name="Lr_alpha")
self._lr_lag = tf.Variable(lr_labda, name="Lr_lag")
self._lr_c = tf.Variable(lr_c, name="Lr_c")
if not isinstance(target_entropy, (float, int)):
self._target_entropy = heuristic_target_entropy(env.action_space)
Expand Down Expand Up @@ -801,10 +811,18 @@ def lac(
adaptive_temperature=True,
lr_a=1e-4,
lr_c=3e-4,
lr_alpha=1e-4,
lr_labda=3e-4,
lr_a_final=1e-10,
lr_c_final=1e-10,
lr_decay_type="linear",
lr_decay_ref="epoch",
lr_alpha_final=1e-10,
lr_labda_final=1e-10,
lr_decay_type=DEFAULT_DECAY_TYPE,
lr_a_decay_type=None,
lr_c_decay_type=None,
lr_alpha_decay_type=None,
lr_labda_decay_type=None,
lr_decay_ref=DEFAULT_DECAY_REFERENCE,
batch_size=256,
replay_size=int(1e6),
seed=None,
Expand Down Expand Up @@ -919,10 +937,33 @@ def lac(
``1e-4``.
lr_c (float, optional): Learning rate used for the (lyapunov) critic. Defaults
to ``1e-4``.
lr_alpha (float, optional): Learning rate used for the entropy temperature.
Defaults to ``1e-4``.
lr_labda (float, optional): Learning rate used for the Lyapunov Lagrance
multiplier. Defaults to ``3e-4``.
lr_a_final(float, optional): The final actor learning rate that is achieved
at the end of the training. Defaults to ``1e-10``.
lr_c_final(float, optional): The final critic learning rate that is achieved
at the end of the training. Defaults to ``1e-10``.
lr_alpha_final(float, optional): The final alpha learning rate that is
achieved at the end of the training. Defaults to ``1e-10``.
lr_labda_final(float, optional): The final labda learning rate that is
achieved at the end of the training. Defaults to ``1e-10``.
lr_decay_type (str, optional): The learning rate decay type that is used (options
are: ``linear`` and ``exponential`` and ``constant``). Defaults to
``linear``.Can be overridden by the specific learning rate decay types.
lr_a_decay_type (str, optional): The learning rate decay type that is used for
the actor learning rate (options are: ``linear`` and ``exponential`` and
``constant``). If not specified, the general learning rate decay type is used.
lr_c_decay_type (str, optional): The learning rate decay type that is used for
the critic learning rate (options are: ``linear`` and ``exponential`` and
``constant``). If not specified, the general learning rate decay type is used.
lr_alpha_decay_type (str, optional): The learning rate decay type that is used
for the alpha learning rate (options are: ``linear`` and ``exponential``
and ``constant``). If not specified, the general learning rate decay type is used.
lr_labda_decay_type (str, optional): The learning rate decay type that is used
for the labda learning rate (options are: ``linear`` and ``exponential``
and ``constant``). If not specified, the general learning rate decay type is used.
lr_decay_type (str, optional): The learning rate decay type that is used (
options are: ``linear`` and ``exponential`` and ``constant``). Defaults to
``linear``.
Expand Down Expand Up @@ -1068,35 +1109,69 @@ def lac(
# os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" # Disable for reproducibility.

policy = LAC(
env,
actor_critic,
ac_kwargs,
opt_type,
alpha,
alpha3,
labda,
gamma,
polyak,
target_entropy,
adaptive_temperature,
lr_a,
lr_c,
device,
)

# Parse learning rate decay type.
valid_lr_decay_options = ["step", "epoch"]
env=env,
actor_critic=actor_critic,
ac_kwargs=ac_kwargs,
opt_type=opt_type,
alpha=alpha,
alpha3=alpha3,
labda=labda,
gamma=gamma,
polyak=polyak,
target_entropy=target_entropy,
adaptive_temperature=adaptive_temperature,
lr_a=lr_a,
lr_c=lr_c,
lr_alpha=lr_alpha,
lr_labda=lr_labda,
device=device,
)

# Parse learning rate decay reference.
lr_decay_ref = lr_decay_ref.lower()
if lr_decay_ref not in valid_lr_decay_options:
options = [f"'{option}'" for option in valid_lr_decay_options]
if lr_decay_ref not in VALID_DECAY_REFERENCES:
options = [f"'{option}'" for option in VALID_DECAY_REFERENCES]
logger.log(
f"The learning rate decay reference variable was set to '{lr_decay_ref}', "
"which is not a valid option. Valid options are "
f"{', '.join(options)}. The learning rate decay reference "
"variable has been set to 'epoch'.",
f"variable has been set to '{DEFAULT_DECAY_REFERENCE}'.",
type="warning",
)
lr_decay_ref = "epoch"
lr_decay_ref = DEFAULT_DECAY_REFERENCE

# Parse learning rate decay types.
lr_decay_type = lr_decay_type.lower()
if lr_decay_type not in VALID_DECAY_TYPES:
options = [f"'{option}'" for option in VALID_DECAY_TYPES]
logger.log(
f"The learning rate decay type was set to '{lr_decay_type}', which is not "
"a valid option. Valid options are "
f"{', '.join(options)}. The learning rate decay type has been set to "
f"'{DEFAULT_DECAY_TYPE}'.",
type="warning",
)
lr_decay_type = DEFAULT_DECAY_TYPE
decay_types = {
"actor": lr_a_decay_type.lower() if lr_a_decay_type else None,
"critic": lr_c_decay_type.lower() if lr_c_decay_type else None,
"alpha": lr_alpha_decay_type.lower() if lr_alpha_decay_type else None,
"labda": lr_labda_decay_type.lower() if lr_labda_decay_type else None,
}
for name, decay_type in decay_types.items():
if decay_type is None:
decay_types[name] = lr_decay_type
else:
if decay_type not in VALID_DECAY_TYPES:
logger.log(
f"Invalid {name} learning rate decay type: '{decay_type}'. Using "
f"global learning rate decay type: '{lr_decay_type}' instead.",
type="warning",
)
decay_types[name] = lr_decay_type
lr_a_decay_type, lr_c_decay_type, lr_alpha_decay_type, lr_labda_decay_type = (
decay_types.values()
)

# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
Expand All @@ -1108,9 +1183,19 @@ def lac(
lr_decay_steps = epochs

# Create learning rate schedulers.
# NOTE: Alpha and labda currently use the same scheduler as the actor.
lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps)
lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps)
lr_a_init, lr_c_init, lr_alpha_init, lr_labda_init = lr_a, lr_c, lr_alpha, lr_labda
lr_a_scheduler = get_lr_scheduler(
lr_a_decay_type, lr_a_init, lr_a_final, lr_decay_steps
)
lr_c_scheduler = get_lr_scheduler(
lr_c_decay_type, lr_c_init, lr_c_final, lr_decay_steps
)
lr_alpha_scheduler = get_lr_scheduler(
lr_alpha_decay_type, lr_alpha_init, lr_alpha_final, lr_decay_steps
)
lr_labda_scheduler = get_lr_scheduler(
lr_labda_decay_type, lr_labda_init, lr_labda_final, lr_decay_steps
)

# Restore policy if supplied.
if start_policy is not None:
Expand Down Expand Up @@ -1265,8 +1350,17 @@ def lac(
lr_c_now = max(
lr_c_scheduler(n_update + 1), lr_c_final
) # Make sure lr is bounded above final lr.
lr_alpha_now = max(
lr_alpha_scheduler(n_update + 1), lr_alpha_final
) # Make sure lr is bounded above final lr.
lr_labda_now = max(
lr_labda_scheduler(n_update + 1), lr_labda_final
) # Make sure lr is bounded above final lr.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now, lr_labda=lr_a_now
lr_a=lr_a_now,
lr_c=lr_c_now,
lr_alpha=lr_alpha_now,
lr_labda=lr_labda_now,
)

# SGD batch tb logging.
Expand Down Expand Up @@ -1297,8 +1391,8 @@ def lac(
progress = max((t + 1) - update_after, 0) / update_every
lr_actor = lr_a_scheduler(progress)
lr_critic = lr_c_scheduler(progress)
lr_alpha = lr_a_scheduler(progress)
lr_labda = lr_a_scheduler(progress)
lr_alpha = lr_alpha_scheduler(progress)
lr_labda = lr_labda_scheduler(progress)
else:
lr_actor = policy._pi_optimizer.lr.numpy()
lr_critic = policy._c_optimizer.lr.numpy()
Expand Down Expand Up @@ -1398,8 +1492,17 @@ def lac(
lr_c_now = max(
lr_c_scheduler(epoch), lr_c_final
) # Make sure lr is bounded above final.
lr_alpha_now = max(
lr_alpha_scheduler(epoch), lr_alpha_final
) # Make sure lr is bounded above final.
lr_labda_now = max(
lr_labda_scheduler(epoch), lr_labda_final
) # Make sure lr is bounded above final.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now, lr_labda=lr_a_now
lr_a=lr_a_now,
lr_c=lr_c_now,
lr_alpha=lr_alpha_now,
lr_labda=lr_labda_now,
)

# Export model to 'SavedModel'
Expand Down Expand Up @@ -1574,6 +1677,18 @@ def lac(
parser.add_argument(
"--lr_c", type=float, default=3e-4, help="critic learning rate (default: 1e-4)"
)
parser.add_argument(
"--lr_alpha",
type=float,
default=1e-4,
help="entropy temperature learning rate (default: 1e-4)",
)
parser.add_argument(
"--lr_labda",
type=float,
default=3e-4,
help="lyapunov Lagrance multiplier learning rate (default: 3e-4)",
)
parser.add_argument(
"--lr_a_final",
type=float,
Expand All @@ -1586,12 +1701,62 @@ def lac(
default=1e-10,
help="the finalcritic learning rate (default: 1e-10)",
)
parser.add_argument(
"--lr_alpha_final",
type=float,
default=1e-10,
help="the final entropy temperature learning rate (default: 1e-10)",
)
parser.add_argument(
"--lr_labda_final",
type=float,
default=1e-10,
help="the final lyapunov Lagrance multiplier learning rate (default: 1e-10)",
)
parser.add_argument(
"--lr_decay_type",
type=str,
default="linear",
help="the learning rate decay type (default: linear)",
)
parser.add_argument(
"--lr_a_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the actor learning rate. "
"If not specified, the general learning rate decay type is used."
),
)
parser.add_argument(
"--lr_c_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the critic learning rate. "
"If not specified, the general learning rate decay type is used."
),
)
parser.add_argument(
"--lr_alpha_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the entropy temperature "
"learning rate. If not specified, the general learning rate decay type is "
"used."
),
)
parser.add_argument(
"--lr_labda_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the lyapunov Lagrance "
"multiplier learning rate. If not specified, the general learning rate "
"decay type is used."
),
)
parser.add_argument(
"--lr_decay_ref",
type=str,
Expand Down Expand Up @@ -1803,10 +1968,18 @@ def lac(
adaptive_temperature=args.adaptive_temperature,
lr_a=args.lr_a,
lr_c=args.lr_c,
lr_alpha=args.lr_alpha,
lr_labda=args.lr_labda,
lr_a_final=args.lr_a_final,
lr_c_final=args.lr_c_final,
lr_alpha_final=args.lr_alpha_final,
lr_labda_final=args.lr_labda_final,
lr_decay_type=args.lr_decay_type,
lr_decay_ref=args.lr_decay_ref,
lr_a_decay_type=args.lr_a_decay_type,
lr_c_decay_type=args.lr_c_decay_type,
lr_alpha_decay_type=args.lr_alpha_decay_type,
lr_labda_decay_type=args.lr_labda_decay_type,
batch_size=args.batch_size,
replay_size=args.replay_size,
horizon_length=args.horizon_length,
Expand Down

0 comments on commit 712e94b

Please sign in to comment.