Skip to content

Commit

Permalink
fix: fix several policy loading problems
Browse files Browse the repository at this point in the history
This commit fixes several bugs that prevented singleton gym environments
to be loaded. It further updates the documentation regarding the manual
loading of environments.
  • Loading branch information
rickstaa committed Feb 17, 2022
1 parent 8c93b61 commit 51a664e
Show file tree
Hide file tree
Showing 14 changed files with 115 additions and 58 deletions.
11 changes: 0 additions & 11 deletions TODOs.md

This file was deleted.

16 changes: 12 additions & 4 deletions bayesian_learning_control/control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,18 @@ def __init__( # noqa: C901
"need this."
)

log_to_std_out(
"You are using the {} environment.".format(env.unwrapped.spec.id),
type="info",
)
if hasattr(env.unwrapped.spec, "id"):
log_to_std_out(
"You are using the '{}' environment.".format(env.unwrapped.spec.id),
type="info",
)
else:
log_to_std_out(
"You are using the '{}' environment.".format(
type(env.unwrapped).__name__
),
type="info",
)
log_to_std_out("You are using the LAC algorithm.", type="info")
log_to_std_out(
"This agent is {}.".format(
Expand Down
16 changes: 12 additions & 4 deletions bayesian_learning_control/control/algos/pytorch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,18 @@ def __init__( # noqa: C901
"need this."
)

log_to_std_out(
"You are using the {} environment.".format(env.unwrapped.spec.id),
type="info",
)
if hasattr(env.unwrapped.spec, "id"):
log_to_std_out(
"You are using the '{}' environment.".format(env.unwrapped.spec.id),
type="info",
)
else:
log_to_std_out(
"You are using the '{}' environment.".format(
type(env.unwrapped).__name__
),
type="info",
)
log_to_std_out("You are using the SAC algorithm.", type="info")
log_to_std_out(
"This agent is {}.".format(
Expand Down
16 changes: 12 additions & 4 deletions bayesian_learning_control/control/algos/tf2/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,18 @@ def __init__( # noqa: C901
"need this."
)

log_to_std_out(
"You are using the {} environment.".format(env.unwrapped.spec.id),
type="info",
)
if hasattr(env.unwrapped.spec, "id"):
log_to_std_out(
"You are using the '{}' environment.".format(env.unwrapped.spec.id),
type="info",
)
else:
log_to_std_out(
"You are using the '{}' environment.".format(
type(env.unwrapped).__name__
),
type="info",
)
log_to_std_out("You are using the LAC algorithm.", type="info")
log_to_std_out(
"This agent is {}.".format(
Expand Down
16 changes: 12 additions & 4 deletions bayesian_learning_control/control/algos/tf2/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,18 @@ def __init__( # noqa: C901
"need this."
)

log_to_std_out(
"You are using the {} environment.".format(env.unwrapped.spec.id),
type="info",
)
if hasattr(env.unwrapped.spec, "id"):
log_to_std_out(
"You are using the '{}' environment.".format(env.unwrapped.spec.id),
type="info",
)
else:
log_to_std_out(
"You are using the '{}' environment.".format(
type(env.unwrapped).__name__
),
type="info",
)
log_to_std_out("You are using the SAC algorithm.", type="info")
log_to_std_out(
"This agent is {}.".format(
Expand Down
12 changes: 9 additions & 3 deletions bayesian_learning_control/control/utils/eval_robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,16 @@ def run_disturbed_policy( # noqa: C901
""" # noqa: E501
# Validate environment, environment disturber
assert env is not None, friendly_err(
"Environment not found!\n\n It looks like the environment wasn't saved, "
+ "and we can't run the agent in it. :( \n\n Check out the documentation "
+ "page on the robustness evaluation utility for how to handle this situation."
"Environment not found!\n\n It looks like the environment wasn't saved, and we "
"can't run the agent in it. :( \n\n Check out the documentation page on the "
"page on the robustness evaluation utility for how to handle this situation."
)
assert env is not None, friendly_err(
"Policy not found!\n\n It looks like the policy could not be loaded. :( \n\n "
"Check out the documentation page on the robustness evaluation utility for how "
"to handle this situation."
)

disturber_implemented, missing_objects = _disturber_implemented(env)
if not disturber_implemented:
missing_keys = [key for key, item in missing_objects.items() if len(item) >= 1]
Expand Down
29 changes: 17 additions & 12 deletions bayesian_learning_control/control/utils/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def load_policy_and_env(fpath, itr="last"):
- env (:obj:`gym.env`): The gym environment.
- get_action (:obj:`func`): The policy get_action function.
"""
if not os.path.isdir(fpath):
raise FileNotFoundError(
Expand Down Expand Up @@ -172,29 +171,30 @@ def load_policy_and_env(fpath, itr="last"):
# load the get_action function
try:
if backend == "tf":
policy = load_tf_policy(fpath, itr, env)
policy = load_tf_policy(fpath, env=env, itr=itr)
else:
policy = load_pytorch_policy(fpath, itr, env)
policy = load_pytorch_policy(fpath, env=env, itr=itr)
except Exception as e:
raise PolicyLoadError(
friendly_err(
(
"Policy not found!\n\n It looks like the policy wasn't "
"Policy could not be loaded!\n\n It looks like the policy wasn't "
"successfully saved. :( \n\n Check out the documentation page on "
"the Test Policy utility for how to handle this situation."
)
)
) from e

return env, policy


def load_tf_policy(fpath, itr="last", env=None):
def load_tf_policy(fpath, env, itr="last"):
"""Load a tensorflow policy saved with Bayesian Learning Control Logger.
Args:
fpath (str): The path where the model is found.
itr (str, optional): The current policy iteration. Defaults to "last".
env (:obj:`gym.env`): The gym environment in which you want to test the policy.
itr (str, optional): The current policy iteration. Defaults to "last".
Returns:
tf.keras.Model: The policy.
Expand All @@ -219,17 +219,16 @@ def load_tf_policy(fpath, itr="last", env=None):
latest = tf.train.latest_checkpoint(model_path) # Restore latest checkpoint
model.load_weights(latest)

# return model.get_action
return model


def load_pytorch_policy(fpath, itr="last", env=None):
def load_pytorch_policy(fpath, env, itr="last"):
"""Load a pytorch policy saved with Bayesian Learning Control Logger.
Args:
fpath (str): The path where the model is found.
itr (str, optional): The current policy iteration. Defaults to "last".
env (:obj:`gym.env`): The gym environment in which you want to test the policy.
itr (str, optional): The current policy iteration. Defaults to "last".
Returns:
torch.nn.Module: The policy.
Expand All @@ -255,6 +254,7 @@ def load_pytorch_policy(fpath, itr="last", env=None):
ac_kwargs = {}
model = getattr(torch_algos, save_info["alg_name"])(env=env, **ac_kwargs)
model.load_state_dict(model_data) # Retore model parameters

return model


Expand All @@ -275,9 +275,14 @@ def run_policy(
Defaults to ``True``.
"""
assert env is not None, friendly_err(
"Environment not found!\n\n It looks like the environment wasn't saved, "
+ "and we can't run the agent in it. :( \n\n Check out the documentation "
+ "page on the Test Policy utility for how to handle this situation."
"Environment not found!\n\n It looks like the environment wasn't saved, and we "
"can't run the agent in it. :( \n\n Check out the documentation page on the "
"Test Policy utility for how to handle this situation."
)
assert env is not None, friendly_err(
"Policy not found!\n\n It looks like the policy could not be loaded. :( \n\n "
"Check out the documentation page on the Test Policy utility for how to "
"handle this situation."
)

logger = EpochLogger(verbose_fmt="table")
Expand Down
15 changes: 11 additions & 4 deletions docs/source/control/control_utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ There are a few flags for options:
by SAC, but you should set the deterministic flag to watch the deterministic mean policy (the correct evaluation policy for SAC). This flag is not used
for any other algorithms.

.. _test_policy_env_not_found:

Environment Not Found Error
---------------------------
Expand All @@ -154,16 +155,19 @@ In this case, watching your agent perform is slightly more of a pain but not imp
.. code-block::
>>> import gym
>>> from bayesian_learning_control.control.utils.test_policy import load_policy_and_env, run_policy
>>> from bayesian_learning_control.control.utils.test_policy import load_pytorch_policy, run_policy
>>> import your_env
>>> _, policy = load_policy_and_env('/path/to/output_directory')
>>> env = gym.make('<YOUR_ENV_NAME>')
>>> policy = load_pytorch_policy("/path/to/output_directory", env=env)
>>> run_policy(env, policy)
Logging data to /tmp/experiments/1536150702/progress.txt
Episode 0 EpRet -163.830 EpLen 93
Episode 1 EpRet -346.164 EpLen 99
...
If you want to load a tensorflow agent please replace the :meth:`~bayesian_learning_control.control.utils.test_policy.load_pytorch_policy` with
:meth:`~bayesian_learning_control.control.utils.test_policy.load_tf_policy`. A example script for manually loading policies can be found in the
``examples`` folder (i.e. :blc:`manual_env_policy_inference.py <blob/main/examples/manual_env_policy_inference.py>`).

Using Trained Value Functions
-----------------------------
Expand Down Expand Up @@ -293,11 +297,11 @@ In this case, evaluating the robustness is slightly more of a pain but not impos
.. code-block::
>>> import gym
>>> from bayesian_learning_control.control.utils.test_policy import load_policy_and_env
>>> from bayesian_learning_control.control.utils.test_policy import load_pytorch_policy, run_policy
>>> from bayesian_learning_control.control.utils.eval_robustness import run_disturbed_policy, plot_robustness_results
>>> import your_env
>>> _, policy = load_policy_and_env('/path/to/output_directory')
>>> env = gym.make('<YOUR_ENV_NAME>')
>>> policy = load_pytorch_policy("/path/to/output_directory", env=env)
>>> run_results_df = run_disturbed_policy(env, policy, disturbance_type="<TYPE_YOU_WANT_TO_USE>")
>>> plot_robustness_results(run_results_df)
INFO: Logging data to /tmp/experiments/1616515040/eval_statistics.csv
Expand All @@ -310,6 +314,9 @@ In this case, evaluating the robustness is slightly more of a pain but not impos
Episode 2 EpRet 330.313 EpLen 800 Died False
...
If you want to load a tensorflow agent please replace the :meth:`~bayesian_learning_control.control.utils.test_policy.load_pytorch_policy` with
:meth:`~bayesian_learning_control.control.utils.test_policy.load_tf_policy`.

ExperimentGrid utility
======================

Expand Down
2 changes: 1 addition & 1 deletion docs/source/control/eval_robustness.rst
Original file line number Diff line number Diff line change
Expand Up @@ -371,5 +371,5 @@ modified disturber. You can then choose this new disturbance using the ``-d_type
Manual robustness evaluation
============================

A script version of the eval robustness tool can be found in the ``examples`` folder (i.e. ``eval_robustness.py``). This script can be used when you want to perform some quick tests without implementing a disturber
A script version of the eval robustness tool can be found in the ``examples`` folder (i.e. :blc:`eval_robustness.py <blob/main/examples/eval_robustness.py>`). This script can be used when you want to perform some quick tests without implementing a disturber
class for your given environment.
8 changes: 6 additions & 2 deletions docs/source/control/saving_and_loading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,11 @@ is successfully saved alongside the agent, it's a cinch to watch the trained age
.. seealso::

For more information on how to use this utility see the :ref:`test_policy <test_policy>` documentation or the code :ref:`api`.
For more information on how to use this utility see the :ref:`test_policy` documentation or the code :ref:`api`.

.. warning::
It could be that automatic loading of the policy/environment is not possible, and you receive a ``environment could not be loaded`` error.
Please see the :ref:`test_policy_env_not_found` on how to fix these errors.

.. _manual_policy_testing:

Expand Down Expand Up @@ -216,7 +220,7 @@ In this example, observe that
Additionally, each algorithm also contains a :obj:`~bayesian_learning_control.control.algos.pytorch.lac.LAC.restore` method which serves as a
wrapper around the :obj:`torch.load` and :obj:`torch.nn.Module.load_state_dict` methods.

Load Tensorflow Policy
Load Tensorflow Policy
~~~~~~~~~~~~~~~~~~~~~~

.. code-block:: python
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dev/doc_dev.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,5 @@ push the documentation to the :blc:`main <tree/main>` branch and run the

.. warning::

Please make sure you are on the `main`_ branch while building the documentation. Otherwise,
Please make sure you are on the :blc:`main <tree/main>` branch while building the documentation. Otherwise,
errors will greet you.
15 changes: 14 additions & 1 deletion examples/eval_robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import time
from pathlib import Path
import sys

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -118,7 +119,19 @@ def noise_disturbance(mean, std):
args = parser.parse_args()

# Load policy and environment
env, policy = load_policy_and_env(args.fpath, args.itr if args.itr >= 0 else "last")
try:
env, policy = load_policy_and_env(
args.fpath, args.itr if args.itr >= 0 else "last"
)
except Exception:
log_to_std_out(
(
"Environment and policy could not be loaded. Please check the 'fpath' "
"and try again."
),
type="error",
)
sys.exit(0)

# Remove action clipping if present
if hasattr(env.unwrapped, "_clipped_action"):
Expand Down
12 changes: 6 additions & 6 deletions examples/manual_env_policy_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import gym
import ros_gazebo_gym # noqa: F401
import ros_gazebo_gym # Imports the in this example used environment # noqa: F401
from bayesian_learning_control.control.utils.test_policy import (
load_policy_and_env,
load_pytorch_policy,
Expand All @@ -12,15 +12,15 @@
)

AGENT_TYPE = "torch" # The type of agent that was trained. Options: 'tf2' and 'torch'.
AGENT_FOLDER = "/home/ricks/Development/work/bayesian-learning-control/data/2022-02-17_staa_lac_panda_reach/2022-02-17_09-35-31-staa_lac_panda_reach_s25" # noqa: E501
AGENT_FOLDER = "/home/ricks/Development/work/bayesian-learning-control/data/2022-02-17_staa_lac_panda_reach/2022-02-17_13-21-04-staa_lac_panda_reach_s25" # noqa: E501

if __name__ == "__main__":
# NOTE: STEP 1a: Try to load the policy and environment
try:
env, policy = load_policy_and_env(AGENT_FOLDER)
except Exception:
# NOTE: STEP: 1b: If step 1 fails recreate the environment and load the
# Pytorch/TF2 agent separately.
# NOTE: STEP: 1b: If step 1 fails recreate the environment and load the Pytorch/
# TF2 agent separately.

# Create the environment
# NOTE: Here the 'FlattenObservation' wrapper is used to make sure the alg works
Expand All @@ -33,10 +33,10 @@
policy = load_tf_policy(AGENT_FOLDER, itr="last", env=env) # Load TF2 agent
else:
policy = load_pytorch_policy(
AGENT_FOLDER, itr="last", env=env
AGENT_FOLDER, env=env, itr="last"
) # Load Pytorch agent

# Step 2: Try to run the policy on the environment
# NOTE: Step 2: Try to run the policy on the environment.
try:
run_policy(env, policy)
except Exception:
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ keywords = rl, openai gym
license = Rick Staa copyright
license_file = LICENSE
classifiers =
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Natural Language :: English",
"Topic :: Scientific/Engineering"

Expand Down Expand Up @@ -70,6 +70,7 @@ docs =
sphinx-autobuild
sphinx-rtd-theme
myst_parser
pygments >=2.11.2
dev =
%(tf)s
%(tuning)s
Expand Down

0 comments on commit 51a664e

Please sign in to comment.