Skip to content

Commit

Permalink
fix: improve tensorflow lazy import (#272)
Browse files Browse the repository at this point in the history
* fix: improve tensorflow lazy import

* refactor: fix docstrings and import order
  • Loading branch information
rickstaa committed Jul 1, 2023
1 parent e100445 commit 75192a4
Show file tree
Hide file tree
Showing 21 changed files with 221 additions and 304 deletions.
32 changes: 16 additions & 16 deletions docs/source/usage/running.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ to see a readout of the docstring.

.. parsed-literal::
python -m stable_learning_control.run SAC --env Walker2d-v2 --exp_name walker --act torch.nn.ELU
python -m stable_learning_control.run SAC --env Walker2d-v2 --exp_name walker --act torch.nn.ReLU
sets ``torch.nn.ELU`` as the activation function. (TensorFlow equivalent: run ``sac_tf`` with ``--act tf.nn.relu``.)
sets ``torch.nn.ReLU`` as the activation function. (TensorFlow equivalent: run ``sac_tf`` with ``--act tf.nn.relu``.)

.. admonition:: You Should Know

Expand Down Expand Up @@ -218,39 +218,39 @@ Some algorithm arguments are relatively long, and we enabled shortcuts for them:

.. option:: --hid, --ac_kwargs:hidden_sizes

:obj:`:obj:`list of ints``. Sets the sizes of the hidden layers in the neural networks of both the actor and critic.
:obj:`list of ints`. Sets the sizes of the hidden layers in the neural networks of both the actor and critic.

.. option:: --hid_a, --ac_kwargs:hidden_sizes:actor

:obj:`:obj:`list of ints``. Sets the sizes of the hidden layers in the neural networks of the actor.
:obj:`list of ints`. Sets the sizes of the hidden layers in the neural networks of the actor.

.. option:: --hid_c, --ac_kwargs:hidden_sizes:critic

:obj:`:obj:`list of ints``. Sets the sizes of the hidden layers in the neural networks of the critic.
:obj:`list of ints`. Sets the sizes of the hidden layers in the neural networks of the critic.

.. option:: --act, --ac_kwargs:activation

:obj:`tf op`. The activation function for the neural networks in the actor and critic.
:mod:`torch.nn` or :mod:`tf.nn`. The activation function for the neural networks in the actor and critic.

.. option:: --act_out, --ac_kwargs:output_activation

:obj:`tf op`. The activation function for the neural networks in the actor and critic.
:mod:`torch.nn` or :mod:`tf.nn`. The activation function for the neural networks in the actor and critic.

.. option:: --act_a, --ac_kwargs:activation:actor

:obj:`tf op`. The activation function for the neural networks in the actor.
:mod:`torch.nn` or :mod:`tf.nn`. The activation function for the neural networks in the actor.

.. option:: --act_c, --ac_kwargs:activation:critic

:obj:`tf op`. The activation function for the neural networks in the critic.
:mod:`torch.nn` or :mod:`tf.nn`. The activation function for the neural networks in the critic.

.. option:: --act_out_a, --ac_kwargs:output_activation:actor

:obj:`tf op`. The activation function for the output activation function of the actor.
:mod:`torch.nn` or :mod:`tf.nn`. The activation function for the output activation function of the actor.

.. option:: --act_out_c, --ac_kwargs:output_activation:critic

:obj:`tf op`. The activation function for the output activation function of the critic.
:mod:`torch.nn` or :mod:`tf.nn`. The activation function for the output activation function of the critic.

These flags are valid for all current SLC algorithms.

Expand Down Expand Up @@ -539,14 +539,14 @@ Consider the example in ``stable_learning_control/examples/pytorch/sac_ray_hyper
:language: python
:linenos:
:lines: 16-
:emphasize-lines: 19-34, 47, 54-66, 72-87
:emphasize-lines: 17-32, 45, 52-64, 70-85

(An equivalent TensorFlow example is available in ``stable_learning_control/examples/tf2/sac_ray_hyper_parameter_tuning.py``.)

In this example, on lines ``19-34`` we first create a small wrapper function that ensures that the Ray Tuner serves the
hyperparameters in the SLC algorithm's format. Following in line ``47``, we set the starting point for several
hyperparameters used in the hyperparameter search. Next, on lines ``54-66``, we define the hyperparameter search space.
Lastly, we start the hyperparameter search using the :meth:`tune.run` method online ``72-87``.
In this example, on lines ``17-32`` we first create a small wrapper function that ensures that the Ray Tuner serves the
hyperparameters in the SLC algorithm's format. Following in line ``45``, we set the starting point for several
hyperparameters used in the hyperparameter search. Next, on lines ``52-64``, we define the hyperparameter search space.
Lastly, we start the hyperparameter search using the :meth:`tune.run` method online ``70-85``.

The Ray tuner will search for the best hyperparameter combination when running the script. While doing so, it will print
the results both to the ``std_out`` and a Tensorboard file. You can check these Tensorboard logs using the
Expand Down
12 changes: 5 additions & 7 deletions examples/pytorch/lac_ray_hyper_parameter_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@

import gymnasium as gym
import numpy as np
import ray
import stable_gym # Imports the in this example used environment # noqa: F401
from hyperopt import hp
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.suggest.hyperopt import HyperOptSearch

# Import the algorithm we want to tune.
from stable_learning_control.algos.pytorch.lac import lac
from stable_learning_control.utils.import_utils import lazy_importer

ray = lazy_importer(module_name="ray", frail=True)
from hyperopt import hp # noqa: E402
from ray import tune # noqa: E402
from ray.tune.schedulers import ASHAScheduler # noqa: E402
from ray.tune.suggest.hyperopt import HyperOptSearch # noqa: E402


def train_lac(config):
Expand Down
12 changes: 5 additions & 7 deletions examples/pytorch/sac_ray_hyper_parameter_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@

import gymnasium as gym
import numpy as np
import ray
import stable_gym # Imports the in this example used environment # noqa: F401
from hyperopt import hp
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.suggest.hyperopt import HyperOptSearch

# Import the algorithm we want to tune.
from stable_learning_control.algos.pytorch.sac import sac
from stable_learning_control.utils.import_utils import lazy_importer

ray = lazy_importer(module_name="ray", frail=True)
from hyperopt import hp # noqa: E402
from ray import tune # noqa: E402
from ray.tune.schedulers import ASHAScheduler # noqa: E402
from ray.tune.suggest.hyperopt import HyperOptSearch # noqa: E402


def train_sac(config):
Expand Down
12 changes: 5 additions & 7 deletions examples/tf2/lac_ray_hyper_parameter_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@

import gymnasium as gym
import numpy as np
import ray
import stable_gym # Imports the in this example used environment # noqa: F401
from hyperopt import hp
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.suggest.hyperopt import HyperOptSearch

# Import the algorithm we want to tune.
from stable_learning_control.algos.tf2.lac import lac
from stable_learning_control.utils.import_utils import lazy_importer

ray = lazy_importer(module_name="ray", frail=True)
from hyperopt import hp # noqa: E402
from ray import tune # noqa: E402
from ray.tune.schedulers import ASHAScheduler # noqa: E402
from ray.tune.suggest.hyperopt import HyperOptSearch # noqa: E402


def train_lac(config):
Expand Down
12 changes: 5 additions & 7 deletions examples/tf2/sac_ray_hyper_parameter_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@

import gymnasium as gym
import numpy as np
import ray
import stable_gym # Imports the in this example used environment # noqa: F401
from hyperopt import hp
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.suggest.hyperopt import HyperOptSearch

# Import the algorithm we want to tune.
from stable_learning_control.algos.tf2.sac import sac
from stable_learning_control.utils.import_utils import lazy_importer

ray = lazy_importer(module_name="ray", frail=True)
from hyperopt import hp # noqa: E402
from ray import tune # noqa: E402
from ray.tune.schedulers import ASHAScheduler # noqa: E402
from ray.tune.suggest.hyperopt import HyperOptSearch # noqa: E402


def train_sac(config):
Expand Down
114 changes: 0 additions & 114 deletions sandbox/test_algorithm_seeding.py

This file was deleted.

3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ max-line-length = 89
extend-ignore = E203
exclude =
docs/source/conf.py,
sandbox
sandbox,
build
per-file-ignores =
__init__.py: F401, E501
12 changes: 6 additions & 6 deletions stable_learning_control/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Module that initializes the stable_learning_control package."""
# Make module version available.
from .version import __version__ # noqa: F401
from .version import __version_tuple__ # noqa: F401

# Put algorithms in main namespace.
from stable_learning_control.algos.pytorch.lac.lac import lac as lac_pytorch
from stable_learning_control.algos.pytorch.sac.sac import sac as sac_pytorch
from stable_learning_control.utils.import_utils import import_tf
from stable_learning_control.utils.import_utils import tf_installed

# Make module version available.
from .version import __version__ # noqa: F401
from .version import __version_tuple__ # noqa: F401

if import_tf(dry_run=True, frail=False):
if tf_installed():
from stable_learning_control.algos.tf2.lac.lac import lac as lac_tf2
from stable_learning_control.algos.tf2.sac.sac import sac as sac_tf2
4 changes: 2 additions & 2 deletions stable_learning_control/algos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
"""
from stable_learning_control.algos.pytorch.lac.lac import LAC as LAC_pytorch
from stable_learning_control.algos.pytorch.sac.sac import SAC as SAC_pytorch
from stable_learning_control.utils.import_utils import import_tf
from stable_learning_control.utils.import_utils import tf_installed

if import_tf(dry_run=True, frail=False):
if tf_installed():
from stable_learning_control.algos.tf2.lac.lac import LAC as LAC_tf
from stable_learning_control.algos.tf2.sac.sac import SAC as SAC_tf
11 changes: 10 additions & 1 deletion stable_learning_control/algos/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from stable_learning_control.utils.import_utils import import_tf

tf = import_tf(frail=False)
tf = import_tf(frail=False) # Suppress import warning.
tensorflow = tf


Expand Down Expand Up @@ -77,6 +77,15 @@ def get_activation_function(activation_fn_name, backend="torch"):
elif len(activation_fn_name.split(".")) == 2:
if activation_fn_name.split(".")[0] == "nn":
activation_fn_name = backend_prefix[0] + "." + activation_fn_name
else:
if activation_fn_name.split(".")[0] not in backend_prefix:
raise ValueError(
"'{}' is not a valid '{}' activation function.".format(
activation_fn_name, backend_prefix[0]
)
)

# Import activation function.
try:
return getattr(
importlib.import_module(".".join(activation_fn_name.split(".")[:-1])),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from stable_learning_control.utils.import_utils import import_tf

tf = import_tf()
tf = import_tf() # Throw custom warning if tf is not installed.


def get_lr_scheduler(decaying_lr_type, lr_start, lr_final, steps):
Expand Down
13 changes: 13 additions & 0 deletions stable_learning_control/algos/tf2/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,16 @@ def clamp(data, min_bound, max_bound):
boundaries.
"""
return (data + 1.0) * (max_bound - min_bound) / 2 + min_bound


def full_model_summary(model):
"""Prints a full summary of all the layers of a TensorFlow model.
Args:
layer (:tf:`keras.layers`): The model to print the full summary of.
"""
if hasattr(model, "layers"):
model.summary()
print("\n\n")
for layer in model.layers:
full_model_summary(layer)

0 comments on commit 75192a4

Please sign in to comment.