Skip to content

Commit

Permalink
fix(torch): handle 'update_after' set to zero (#408)
Browse files Browse the repository at this point in the history
This commit addresses a bug in the learning rate decay logic when
'update_after' is set to zero. Previously, the algorithm would
malfunction under these conditions. With this fix, the algorithm can now
correctly handle 'update_after' being set to zero, ensuring proper
learning rate decay.
  • Loading branch information
rickstaa committed Feb 12, 2024
1 parent 642a193 commit 7999590
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 0 deletions.
1 change: 1 addition & 0 deletions stable_learning_control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,7 @@ def lac(
- replay_buffer (union[:class:`~stable_learning_control.algos.pytorch.common.buffers.ReplayBuffer`, :class:`~stable_learning_control.algos.pytorch.common.buffers.FiniteHorizonReplayBuffer`]):
The replay buffer used during training.
""" # noqa: E501, D301
update_after = max(1, update_after) # You can not update before the first step.
validate_args(**locals())

# Retrieve hyperparameters while filtering out the logger_kwargs.
Expand Down
1 change: 1 addition & 0 deletions stable_learning_control/algos/pytorch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,7 @@ def sac(
- replay_buffer (union[:class:`~stable_learning_control.algos.common.buffers.ReplayBuffer`, :class:`~stable_learning_control.algos.common.buffers.FiniteHorizonReplayBuffer`]):
The replay buffer used during training.
""" # noqa: E501, D301
update_after = max(1, update_after) # You can not update before the first step.
validate_args(**locals())

# Retrieve hyperparameters while filtering out the logger_kwargs.
Expand Down

0 comments on commit 7999590

Please sign in to comment.