Skip to content

Commit

Permalink
fix(tf2): fix critical tf2 gradient update bug (#322)
Browse files Browse the repository at this point in the history
This commit fixes a critical bug in the tf2 algorithms, which prevented the `log_std` and `mu` layers
from being updated. This was because the `trainable_variables` attribute was called before the
`SquasedGaussianActor` was built. It also improves the target network
update and network summarize code.
  • Loading branch information
rickstaa committed Aug 11, 2023
1 parent 89ef5a2 commit dfc239b
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 99 deletions.
72 changes: 28 additions & 44 deletions stable_learning_control/algos/tf2/lac/lac.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Lyapunov Actor-Critic algorithm
"""Lyapunov (soft) Actor-Critic (LAC) algorithm.
This module contains a TensorFlow 2.x implementation of the LAC algorithm of
`Han et al. 2020 <https://arxiv.org/abs/2004.14288>`_.
Expand All @@ -25,7 +25,6 @@
from tensorflow.keras.optimizers import Adam

from stable_learning_control.algos.common.buffers import ReplayBuffer
from stable_learning_control.common.helpers import get_env_id
from stable_learning_control.algos.common.helpers import heuristic_target_entropy
from stable_learning_control.algos.tf2.common.get_lr_scheduler import get_lr_scheduler
from stable_learning_control.algos.tf2.common.helpers import (
Expand All @@ -36,7 +35,7 @@
from stable_learning_control.algos.tf2.policies.lyapunov_actor_critic import (
LyapunovActorCritic,
)
from stable_learning_control.common.helpers import combine_shapes
from stable_learning_control.common.helpers import combine_shapes, get_env_id
from stable_learning_control.utils.eval_utils import test_agent
from stable_learning_control.utils.gym_utils import is_discrete_space, is_gym_env
from stable_learning_control.utils.log_utils.helpers import (
Expand All @@ -47,13 +46,12 @@
from stable_learning_control.utils.safer_eval_util import safer_eval
from stable_learning_control.utils.serialization_utils import save_to_json


# Script settings.
SCALE_LAMBDA_MIN_MAX = (
0.0,
1.0,
) # Range of lambda lagrance multiplier.
SCALE_ALPHA_MIN_MAX = (0.0, np.inf) # Range of alpha lagrance multiplier.
) # Range of lambda Lagrance multiplier.
SCALE_ALPHA_MIN_MAX = (0.0, np.inf) # Range of alpha Lagrance multiplier.
STD_OUT_LOG_VARS_DEFAULT = [
"Epoch",
"TotalEnvInteracts",
Expand All @@ -78,8 +76,8 @@ class LAC(tf.keras.Model):
Attributes:
ac (tf.Module): The (lyapunov) actor critic module.
ac_ (tf.Module): The (lyapunov) target actor critic module.
log_alpha (tf.Variable): The temperature lagrance multiplier.
log_labda (tf.Variable): The Lyapunov lagrance multiplier.
log_alpha (tf.Variable): The temperature Lagrance multiplier.
log_labda (tf.Variable): The Lyapunov Lagrance multiplier.
target_entropy (int): The target entropy.
device (str): The device the networks are placed on (``cpu`` or ``gpu``).
Defaults to ``cpu``.
Expand Down Expand Up @@ -165,7 +163,7 @@ def __init__(
``0.99``.
alpha3 (float, optional): The Lyapunov constraint error boundary. Defaults
to ``0.2``.
labda (float, optional): The Lyapunov lagrance multiplier. Defaults to
labda (float, optional): The Lyapunov Lagrance multiplier. Defaults to
``0.99``.
gamma (float, optional): Discount factor. (Always between 0 and 1.).
Defaults to ``0.99``.
Expand Down Expand Up @@ -202,7 +200,6 @@ def __init__(
self._setup_kwargs = {
k: v for k, v in locals().items() if k not in ["self", "__class__", "env"]
}
self._was_build = False

# Validate gymnasium env.
# NOTE: The current implementation only works with continuous spaces.
Expand Down Expand Up @@ -367,7 +364,7 @@ def update(self, data):
# Get current Lyapunov value.
l1 = self.ac.L([o, a])

# Calculate Lyapunov *CRITIC* error
# Calculate L-critic MSE loss against Bellman backup.
# NOTE: The 0.5 multiplication factor was added to make the derivation
# cleaner and can be safely removed without influencing the
# minimization. We kept it here for consistency.
Expand Down Expand Up @@ -496,7 +493,7 @@ def restore(self, path, restore_lagrance_multipliers=False):
path (str): The path where the model :attr:`state_dict` of the policy is
found.
restore_lagrance_multipliers (bool, optional): Whether you want to restore
the lagrance multipliers. By fault ``False``.
the Lagrance multipliers. By fault ``False``.
Raises:
Exception: Raises an exception if something goes wrong during loading.
Expand Down Expand Up @@ -526,18 +523,18 @@ def restore(self, path, restore_lagrance_multipliers=False):
f"Something went wrong when trying to load model '{latest}'."
) from e

# Make sure learning rates (and lagrance multipliers) are not restored
# Make sure learning rates (and Lagrance multipliers) are not restored.
self._lr_a.assign(lr_a)
self._lr_alpha.assign(lr_alpha)
self._lr_lag.assign(lr_lag)
self._lr_c.assign(lr_c)
if not restore_lagrance_multipliers:
self.log_alpha.assign(log_alpha_init)
self.log_labda.assign(log_labda_init)
log_to_std_out("Restoring lagrance multipliers.", type="info")
log_to_std_out("Restoring Lagrance multipliers.", type="info")
else:
log_to_std_out(
"Keeping lagrance multipliers at their initial value.", type="info"
"Keeping Lagrance multipliers at their initial value.", type="info"
)

def export(self, path):
Expand All @@ -563,34 +560,25 @@ def export(self, path):

def build(self):
"""Function that can be used to build the full model structure such that it can
be visualized using the `tf.keras.Model.summary()`.
be visualized using the `tf.keras.Model.summary()`. This is done by calling the
build method of the parent class with the correct input shape.
.. note::
This is done by calling the build methods of the submodules.
"""
obs_dummy = tf.random.uniform(
combine_shapes(1, self._obs_dim), dtype=tf.float32
)
act_dummy = tf.random.uniform(
combine_shapes(1, self._act_dim), dtype=tf.float32
)
self.ac([obs_dummy, act_dummy])
self.ac_targ([obs_dummy, act_dummy])
super().build(input_shape=combine_shapes(1, self._obs_dim))
self(obs_dummy)
self._was_build = True
super().build(combine_shapes(None, self._obs_dim))

def summary(self):
"""Small wrapper around the :meth:`tf.keras.Model.summary()` method used to
apply a custom format to the model summary.
"""
if not self._was_build: # Ensure the model is build.
if not self.built: # Ensure the model is built.
self.build()
super().summary()

def full_summary(self):
"""Prints a full summary of all the layers of the TensorFlow model"""
if not self._was_build: # Ensure the model is build.
if not self.built: # Ensure the model is built.
self.build()
full_model_summary(self)

Expand All @@ -599,13 +587,13 @@ def set_learning_rates(self, lr_a=None, lr_c=None, lr_alpha=None, lr_labda=None)
Args:
lr_a (float, optional): The learning rate of the actor optimizer. Defaults
to None.
to ``None``.
lr_c (float, optional): The learning rate of the (Lyapunov) Critic. Defaults
to None.
to ``None``.
lr_alpha (float, optional): The learning rate of the temperature optimizer.
Defaults to None.
Defaults to ``None``.
lr_labda (float, optional): The learning rate of the Lyapunov Lagrance
multiplier optimizer. Defaults to None.
multiplier optimizer. Defaults to ``None``.
"""
if lr_a:
self._pi_optimizer.lr.assign(lr_a)
Expand All @@ -620,20 +608,16 @@ def set_learning_rates(self, lr_a=None, lr_c=None, lr_alpha=None, lr_labda=None)
@tf.function
def _init_targets(self):
"""Updates the target network weights to the main network weights."""
for pi_main, pi_targ in zip(self.ac.pi.variables, self.ac_targ.pi.variables):
pi_targ.assign(pi_main)
for c_main, c_targ in zip(self.ac.L.variables, self.ac_targ.L.variables):
c_targ.assign(c_main)
for ac_main, ac_targ in zip(self.ac.variables, self.ac_targ.variables):
ac_targ.assign(ac_main)

@tf.function
def _update_targets(self):
"""Updates the target networks based on a Exponential moving average
(Polyak averaging).
"""
for pi_main, pi_targ in zip(self.ac.pi.variables, self.ac_targ.pi.variables):
pi_targ.assign(self._polyak * pi_targ + (1 - self._polyak) * pi_main)
for c_main, c_targ in zip(self.ac.L.variables, self.ac_targ.L.variables):
c_targ.assign(self._polyak * c_targ + (1 - self._polyak) * c_main)
for ac_main, ac_targ in zip(self.ac.variables, self.ac_targ.variables):
ac_targ.assign(self._polyak * ac_targ + (1 - self._polyak) * ac_main)

@property
def alpha(self):
Expand Down Expand Up @@ -756,7 +740,7 @@ def lac(
start_policy=None,
export=False,
):
"""Trains the lac algorithm in a given environment.
"""Trains the LAC algorithm in a given environment.
Args:
env_fn: A function which creates a copy of the environment.
Expand Down Expand Up @@ -834,7 +818,7 @@ def lac(
``0.99``.
alpha3 (float, optional): The Lyapunov constraint error boundary. Defaults
to ``0.2``.
labda (float, optional): The Lyapunov lagrance multiplier. Defaults to
labda (float, optional): The Lyapunov Lagrance multiplier. Defaults to
``0.99``.
gamma (float, optional): Discount factor. (Always between 0 and 1.).
Defaults to ``0.99``.
Expand Down Expand Up @@ -1416,7 +1400,7 @@ def lac(
"--labda",
type=float,
default=0.99,
help="the Lyapunov lagrance multiplier (default: 0.99)",
help="the Lyapunov Lagrance multiplier (default: 0.99)",
)
parser.add_argument(
"--gamma", type=float, default=0.99, help="discount factor (default: 0.99)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def __init__(
name=name + "/log_std",
)

# Build the model to initialise the (trainable) variables.
self.build((None, obs_dim))

@tf.function
def call(self, obs, deterministic=False, with_logprob=True):
"""Perform forward pass through the network.
Expand Down Expand Up @@ -124,7 +127,7 @@ def call(self, obs, deterministic=False, with_logprob=True):

# Pre-squash distribution and sample
if deterministic:
pi_action = mu # determinestic action used at test time.
pi_action = mu # deterministic action used at test time.
else:
# Sample from the normal distribution and calculate the action.
batch_size = tf.shape(input=obs)[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(
[obs_dim + act_dim] + list(hidden_sizes), activation, activation, name=name
)

# Build the model to initialise the (trainable) variables.
self.build((None, obs_dim + act_dim))

@tf.function
def call(self, inputs):
"""Perform forward pass through the network.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def __init__(
name=name,
)

# Build the model to initialise the (trainable) variables.
self.build((None, obs_dim + act_dim))

@tf.function
def call(self, inputs):
"""Perform forward pass through the network.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
Lyapunov actor critic policy
============================
"""Lyapunov (soft) actor critic policy.
This module contains a TensorFlow 2.x implementation of the Lyapunov Actor Critic policy
of `Han et al. 2020 <https://arxiv.org/abs/2004.14288>`_.
Expand Down Expand Up @@ -99,6 +97,13 @@ def __init__(
activation=activation["critic"],
)

# Perform one forward pass to initialise the networks.
# NOTE: Done because TF doesn't support multiple positional arguments when using
# the tf.function decorator, and autograph doesn't support list unpacking.
obs_dummy = tf.random.uniform((1, obs_dim), dtype=tf.float32)
act_dummy = tf.random.uniform((1, act_dim), dtype=tf.float32)
self([obs_dummy, act_dummy])

@tf.function
def call(self, inputs, deterministic=False, with_logprob=True):
"""Performs a forward pass through all the networks (Actor and L critic).
Expand Down Expand Up @@ -146,8 +151,7 @@ def act(self, obs, deterministic=False):
stochastic policy. Defaults to ``False``.
Returns:
numpy.ndarray: The action from the current state given the current
policy.
numpy.ndarray: The action from the current state given the current policy.
"""
# Make sure the batch dimension is present (Required by tf.keras.layers.Dense)
if obs.shape.ndims == 1:
Expand Down
11 changes: 8 additions & 3 deletions stable_learning_control/algos/tf2/policies/soft_actor_critic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
Soft actor critic policy
========================
"""Soft actor critic policy.
This module contains a TensorFlow 2.x implementation of the Soft Actor Critic policy of
`Haarnoja et al. 2019 <https://arxiv.org/abs/1812.05905>`_.
Expand Down Expand Up @@ -92,6 +90,13 @@ def __init__(
name="q_critic_2",
)

# Perform one forward pass to initialise the networks.
# NOTE: Done because TF doesn't support multiple positional arguments when using
# the tf.function decorator, and autograph doesn't support list unpacking.
obs_dummy = tf.random.uniform((1, obs_dim), dtype=tf.float32)
act_dummy = tf.random.uniform((1, act_dim), dtype=tf.float32)
self([obs_dummy, act_dummy])

@tf.function
def call(self, inputs, deterministic=False, with_logprob=True):
"""Performs a forward pass through all the networks (Actor, Q critic 1 and Q
Expand Down

0 comments on commit dfc239b

Please sign in to comment.