Skip to content

Commit

Permalink
fix: fix several env/policy load bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
rickstaa committed Feb 17, 2022
1 parent 5ab830a commit dddd4d8
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 69 deletions.
1 change: 0 additions & 1 deletion TODOs.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
## Main

* \[ ] Fix environment test problems.
* \[ ] Fix save\_info.json not found in load\_pytorch\_policy when in parent folder.

## Docs

Expand Down
49 changes: 49 additions & 0 deletions bayesian_learning_control/control/common/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Module containing several control related exceptions."""


class EnvLoadError(Exception):
"""Custom exception that is raised when the saved environment could not be loaded.
Attributes:
log_message (str): The full log message.
details (dict): Dictionary containing extra Exception information.
"""

def __init__(self, message="", log_message="", **details):
"""Initializes the EePoseLookupError exception object.
Args:
message (str, optional): Exception message specifying whether the exception
occurred. Defaults to ``""``.
log_message (str, optional): Full log message. Defaults to ``""``.
details (dict): Additional dictionary that can be used to supply the user
with more details about why the exception occurred.
"""
super().__init__(message)

self.log_message = log_message
self.details = details


class PolicyLoadError(Exception):
"""Custom exception that is raised when the saved policy could not be loaded.
Attributes:
log_message (str): The full log message.
details (dict): Dictionary containing extra Exception information.
"""

def __init__(self, message="", log_message="", **details):
"""Initializes the EePoseLookupError exception object.
Args:
message (str, optional): Exception message specifying whether the exception
occurred. Defaults to ``""``.
log_message (str, optional): Full log message. Defaults to ``""``.
details (dict): Additional dictionary that can be used to supply the user
with more details about why the exception occurred.
"""
super().__init__(message)

self.log_message = log_message
self.details = details
88 changes: 59 additions & 29 deletions bayesian_learning_control/control/utils/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
import os.path as osp
import time
from pathlib import Path
import re

import joblib
import torch
from bayesian_learning_control.control.common.exceptions import (
EnvLoadError,
PolicyLoadError,
)
from bayesian_learning_control.utils.import_utils import import_tf
from bayesian_learning_control.utils.log_utils import EpochLogger, log_to_std_out
from bayesian_learning_control.utils.log_utils.helpers import friendly_err
Expand Down Expand Up @@ -60,38 +65,50 @@ def _retrieve_model_folder(fpath):
- backend (:obj:`str`): The inferred backend. Options are ``tf`` and
``torch``.
"""
data_folders = glob.glob(fpath + r"/*_save")
if ("tf2_save" in fpath and "torch_save" in fpath) or len(data_folders) > 1:
data_folders = (
glob.glob(fpath + r"/*_save")
if not bool(re.search(r"/*_save", fpath))
else glob.glob(fpath)
)
if any(["tf2_save" in item for item in data_folders]) and any(
["torch_save" in item for item in data_folders]
):
raise IOError(
friendly_err(
"Policy could not be loaded as the model as the specified model folder "
f"Policy could not be loaded since the specified model folder "
f"'{fpath}' seems to be corrupted. It contains both a 'torch_save' and "
"'tf2_save' folder. Please check your model path (fpath) and try again."
)
)
elif "tf2_save" in fpath:
model_path = os.sep.join(
fpath.split(os.sep)[: fpath.split(os.sep).index("tf2_save") + 1]
elif (
len([item for item in data_folders if "tf2_save" in item]) > 1
or len([item for item in data_folders if "torch_save" in item]) > 1
):
raise IOError(
friendly_err(
"Policy could not be loaded since the specified model folder '{}' "
"seems to be corrupted. It contains multiple '{}' folders. Please "
"check your model path (fpath) and try again.".format(
fpath,
"tf2_save"
if any(["tf2_save" in item for item in data_folders])
else "torch_save",
)
)
)
elif any(["tf2_save" in item for item in data_folders]):
model_path = [item for item in data_folders if "tf2_save" in item][0]
return model_path, "tf"
elif "torch_save" in fpath:
model_path = os.sep.join(
fpath.split(os.sep)[: fpath.split(os.sep).index("torch_save") + 1]
)
elif any(["torch_save" in item for item in data_folders]):
model_path = [item for item in data_folders if "torch_save" in item][0]
return model_path, "torch"
else: # Check
if len(data_folders) == 0:
raise FileNotFoundError(
friendly_err(
f"No model was found inside the supplied model path '{fpath}'. "
"Please check your model path (fpath) and try again."
)
)
else:
return (
data_folders[0],
("torch" if "torch_save" in data_folders[0] else "tf"),
else:
raise FileNotFoundError(
friendly_err(
f"No model was found inside the supplied model path '{fpath}'. "
"Please check your model path (fpath) and try again."
)
)


def load_policy_and_env(fpath, itr="last"):
Expand All @@ -106,8 +123,10 @@ def load_policy_and_env(fpath, itr="last"):
Raises:
FileNotFoundError: Thrown when the fpath does not exist.
Exception: Thrown when something else goes wrong while loading the policy or
EnvLoadError: Thrown when something went wrong trying to load the saved
environment.
PolicyLoadError: Thrown when something went wrong trying to load the saved
policy.
Returns:
(tuple): tuple containing:
Expand Down Expand Up @@ -139,7 +158,7 @@ def load_policy_and_env(fpath, itr="last"):
state = joblib.load(Path(fpath).parent.joinpath("vars.pkl"))
env = state["env"]
except Exception as e:
raise Exception(
raise EnvLoadError(
friendly_err(
(
"Environment not found!\n\n It looks like the environment wasn't "
Expand All @@ -151,11 +170,21 @@ def load_policy_and_env(fpath, itr="last"):
) from e

# load the get_action function
if backend == "tf":
policy = load_tf_policy(fpath, itr, env)
else:
policy = load_pytorch_policy(fpath, itr, env)

try:
if backend == "tf":
policy = load_tf_policy(fpath, itr, env)
else:
policy = load_pytorch_policy(fpath, itr, env)
except Exception as e:
raise PolicyLoadError(
friendly_err(
(
"Policy not found!\n\n It looks like the policy wasn't "
"successfully saved. :( \n\n Check out the documentation page on "
"the Test Policy utility for how to handle this situation."
)
)
) from e
return env, policy


Expand Down Expand Up @@ -206,6 +235,7 @@ def load_pytorch_policy(fpath, itr="last", env=None):
torch.nn.Module: The policy.
"""

fpath, _ = _retrieve_model_folder(fpath)
if itr != "last":
fpath = _retrieve_iter_folder(fpath, itr)
model_file = Path(fpath).joinpath(
Expand Down
47 changes: 47 additions & 0 deletions examples/manual_env_policy_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""A small script which shows how to manually load a saved environment and policy when
the CLI fails.
"""

import gym
import ros_gazebo_gym # noqa: F401
from bayesian_learning_control.control.utils.test_policy import (
load_policy_and_env,
load_pytorch_policy,
load_tf_policy,
run_policy,
)

AGENT_TYPE = "torch" # The type of agent that was trained. Options: 'tf2' and 'torch'.
AGENT_FOLDER = "/home/ricks/Development/work/bayesian-learning-control/data/2022-02-17_staa_lac_panda_reach/2022-02-17_09-35-31-staa_lac_panda_reach_s25" # noqa: E501

if __name__ == "__main__":
# NOTE: STEP 1a: Try to load the policy and environment
try:
env, policy = load_policy_and_env(AGENT_FOLDER)
except Exception:
# NOTE: STEP: 1b: If step 1 fails recreate the environment and load the
# Pytorch/TF2 agent separately.

# Create the environment
# NOTE: Here the 'FlattenObservation' wrapper is used to make sure the alg works
# with dictionary based observation spaces.
env = gym.make("PandaReach-v1")
env = gym.wrappers.FlattenObservation(env)

# Load the policy
if AGENT_TYPE.lower() == "tf2":
policy = load_tf_policy(AGENT_FOLDER, itr="last", env=env) # Load TF2 agent
else:
policy = load_pytorch_policy(
AGENT_FOLDER, itr="last", env=env
) # Load Pytorch agent

# Step 2: Try to run the policy on the environment
try:
run_policy(env, policy)
except Exception:
raise Exception(
"Something went wrong while trying to run the inference. Please check the "
"'AGENT_FOLDER' and try again. If the problem persists please open a issue "
"on https://github.com/rickstaa/bayesian-learning-control/issues."
)
10 changes: 5 additions & 5 deletions experiments/staa_et_al_2022/lac_panda_reach_debug.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ gamma: 0.995
alpha3: 0.2
alpha: 0.99
labda: 0.99
epochs: 300
max_ep_len: 100
steps_per_epoch: 200000
epochs: 2
max_ep_len: 250
steps_per_epoch: 1000
update_every: 1000
update_after: 200000
update_after: 1000
steps_per_update: 1000
num_test_episodes: 2
batch_size: 256
replay_size: 1e6
save_freq: 100
save_freq: 1
save_checkpoints: True

# Env parameters
Expand Down
34 changes: 0 additions & 34 deletions sandbox/test_saved_env.py

This file was deleted.

0 comments on commit dddd4d8

Please sign in to comment.