Skip to content

Commit

Permalink
Adds custom gym environment support
Browse files Browse the repository at this point in the history
This commit contains a solution to the problems explained in openai#142. It
adds two ways for people to use a custom environment with the Spinningup
package:

 1. People can add the code that initializes the custom gym environment
    in the 'env_config.py' file.
 2. People can use the 'env_pkg' cmd line argument to specify which
    the package should be imported for the custom environment to work.
  • Loading branch information
rickstaa committed Feb 23, 2021
1 parent 038665d commit 1c90283
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 35 deletions.
13 changes: 13 additions & 0 deletions spinup/env_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Module used for adding custom environments.
Example:
This module allows you to add your own custom gym environments to the
Spinningup package. These environments should inherit from the :class:`gym.env`
class. See
`this issue on the openai github <https://github.com/openai/gym/blob/master/docs/creating-environments.md>`_
for more information on how to create custom environments.
.. code-block:: python
import custom environment
""" # noqa: E501
24 changes: 21 additions & 3 deletions spinup/run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import spinup
from spinup.user_config import DEFAULT_BACKEND
from spinup.utils.gym_utils import import_gym_env_pkg
from spinup.utils.run_utils import ExperimentGrid
from spinup.utils.serialization_utils import convert_json
import argparse
Expand Down Expand Up @@ -151,6 +152,23 @@ def process(arg):
assert cmd in add_with_backends(MPI_COMPATIBLE_ALGOS), \
friendly_err("This algorithm can't be run with num_cpu > 1.")

# Try to import custom environments
try:
import spinup.env_config # noqa: F401 - Import custom environments
except Exception as e:
raise Exception(
"Something went wrong when trying to import the 'env_config' file."
) from e
env_pkg_err_msg = ""
if "env_pkg" in arg_dict.keys():
try:
import_gym_env_pkg(arg_dict["env_pkg"], frail=False)
except ImportError:
env_pkg_err_msg = (
"\n\t\t* Make sure the package you supplied in the 'env_pkg' contains a "
"a valid gym environment.\n"
)

# Special handling for environment: make sure that env_name is a real,
# registered gym environment.
valid_envs = [e.id for e in list(gym.envs.registry.all())]
Expand All @@ -168,11 +186,11 @@ def process(arg):
* View the complete list of valid Gym environments at
https://gym.openai.com/envs/
"""%env_name)
%s
""" % (env_name, env_pkg_err_msg)
)
assert env_name in valid_envs, err_msg


# Construct and execute the experiment grid.
eg = ExperimentGrid(name=exp_name)
for k,v in arg_dict.items():
Expand Down
46 changes: 46 additions & 0 deletions spinup/utils/gym_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Module that contains utilities that can be used with the
`openai gym package <https://github.com/openai/gym>`_.
"""

import importlib
import sys


def import_gym_env_pkg(module_name, frail=True, dry_run=False):
"""Tries to import the custom gym environment package.
Args:
module_name (str): The python module you want to import.
frail (bool, optional): Throw ImportError when tensorflow can not be imported.
Defaults to ``true``.
dry_run (bool, optional): Do not actually import tensorflow if available.
Defaults to ``False``.
Raises:
ImportError: A import error if the package could not be imported.
Returns:
union[tf, bool]: Custom env package if ``dry_run`` is set to ``False``.
Returns a success bool if ``dry_run`` is set to ``True``.
"""
module_name = module_name[0] if isinstance(module_name, list) else module_name
try:
if module_name in sys.modules:
if not dry_run:
return sys.modules[module_name]
else:
return True
elif importlib.util.find_spec(module_name) is not None:
if not dry_run:
return importlib.import_module(module_name)
else:
return True
else:
if frail:
raise ImportError("No module named '{}'.".format(module_name))
return False
except (ImportError, KeyError, AttributeError) as e:
if ImportError:
if not frail:
return False
raise e
86 changes: 54 additions & 32 deletions spinup/utils/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from spinup.utils.logx import colorize
from spinup.utils.mpi_tools import mpi_fork, msg
from spinup.utils.serialization_utils import convert_json
from spinup.utils.gym_utils import import_gym_env_pkg
import base64
from copy import deepcopy
import cloudpickle
Expand All @@ -26,7 +27,7 @@ def setup_logger_kwargs(exp_name, seed=None, data_dir=None, datestamp=False):
"""
Sets up the output_dir for a logger and returns a dict for logger kwargs.
If no seed is given and datestamp is false,
If no seed is given and datestamp is false,
::
Expand All @@ -44,8 +45,8 @@ def setup_logger_kwargs(exp_name, seed=None, data_dir=None, datestamp=False):
output_dir = data_dir/YY-MM-DD_exp_name/YY-MM-DD_HH-MM-SS_exp_name_s[seed]
You can force datestamp=True by setting ``FORCE_DATESTAMP=True`` in
``spinup/user_config.py``.
You can force datestamp=True by setting ``FORCE_DATESTAMP=True`` in
``spinup/user_config.py``.
Args:
Expand All @@ -70,7 +71,7 @@ def setup_logger_kwargs(exp_name, seed=None, data_dir=None, datestamp=False):
# Make base path
ymd_time = time.strftime("%Y-%m-%d_") if datestamp else ''
relpath = ''.join([ymd_time, exp_name])

if seed is not None:
# Make a seed-specific subfolder in the experiment directory.
if datestamp:
Expand All @@ -81,30 +82,30 @@ def setup_logger_kwargs(exp_name, seed=None, data_dir=None, datestamp=False):
relpath = osp.join(relpath, subfolder)

data_dir = data_dir or DEFAULT_DATA_DIR
logger_kwargs = dict(output_dir=osp.join(data_dir, relpath),
logger_kwargs = dict(output_dir=osp.join(data_dir, relpath),
exp_name=exp_name)
return logger_kwargs


def call_experiment(exp_name, thunk, seed=0, num_cpu=1, data_dir=None,
def call_experiment(exp_name, thunk, seed=0, num_cpu=1, data_dir=None,
datestamp=False, **kwargs):
"""
Run a function (thunk) with hyperparameters (kwargs), plus configuration.
This wraps a few pieces of functionality which are useful when you want
to run many experiments in sequence, including logger configuration and
splitting into multiple processes for MPI.
splitting into multiple processes for MPI.
There's also a SpinningUp-specific convenience added into executing the
thunk: if ``env_name`` is one of the kwargs passed to call_experiment, it's
assumed that the thunk accepts an argument called ``env_fn``, and that
the ``env_fn`` should make a gym environment with the given ``env_name``.
the ``env_fn`` should make a gym environment with the given ``env_name``.
The way the experiment is actually executed is slightly complicated: the
function is serialized to a string, and then ``run_entrypoint.py`` is
executed in a subprocess call with the serialized string as an argument.
``run_entrypoint.py`` unserializes the function call and executes it.
We choose to do it this way---instead of just calling the function
We choose to do it this way---instead of just calling the function
directly here---to avoid leaking state between successive experiments.
Args:
Expand All @@ -121,7 +122,7 @@ def call_experiment(exp_name, thunk, seed=0, num_cpu=1, data_dir=None,
data_dir (string): Used in configuring the logger, to decide where
to store experiment results. Note: if left as None, data_dir will
default to ``DEFAULT_DATA_DIR`` from ``spinup/user_config.py``.
default to ``DEFAULT_DATA_DIR`` from ``spinup/user_config.py``.
**kwargs: All kwargs to pass to thunk.
Expand Down Expand Up @@ -150,7 +151,28 @@ def call_experiment(exp_name, thunk, seed=0, num_cpu=1, data_dir=None,
def thunk_plus():
# Make 'env_fn' from 'env_name'
if 'env_name' in kwargs:

# Import main gym environments
import gym

# Import custom gym environments
try:
import spinup.env_config # noqa: F401 - Import custom environments
except Exception as e:
raise Exception(
"Something went wrong when trying to import the 'env_config' file."
) from e
if "env_pkg" in kwargs.keys():
env_pkg = kwargs.pop("env_pkg")
try:
import_gym_env_pkg(env_pkg)
except ImportError as e:
import_error_msg = (
"{} Please make sure you supplied a valid package ".format(e)
+ "in the 'env_pkg' input argument."
)
raise ImportError(import_error_msg) from e

env_name = kwargs['env_name']
kwargs['env_fn'] = lambda : gym.make(env_name)
del kwargs['env_name']
Expand All @@ -174,9 +196,9 @@ def thunk_plus():
There appears to have been an error in your experiment.
Check the traceback above to see what actually went wrong. The
Check the traceback above to see what actually went wrong. The
traceback below, included for completeness (but probably not useful
for diagnosing the error), shows the stack leading up to the
for diagnosing the error), shows the stack leading up to the
experiment launch.
""") + '='*DIV_LINE_WIDTH + '\n'*3
Expand Down Expand Up @@ -215,7 +237,7 @@ def all_bools(vals):
return all([isinstance(v,bool) for v in vals])

def valid_str(v):
"""
"""
Convert a value or values to a string which could go in a filepath.
Partly based on `this gist`_.
Expand All @@ -230,7 +252,7 @@ def valid_str(v):
return '-'.join([valid_str(x) for x in v])

# Valid characters are '-', '_', and alphanumeric. Replace invalid chars
# with '-'.
# with '-'.
str_v = str(v).lower()
valid_chars = "-_%s%s" % (string.ascii_letters, string.digits)
str_v = ''.join(c if c in valid_chars else '-' for c in str_v)
Expand Down Expand Up @@ -293,7 +315,7 @@ def print(self):


def _default_shorthand(self, key):
# Create a default shorthand for the key, built from the first
# Create a default shorthand for the key, built from the first
# three letters of each colon-separated part.
# But if the first three letters contains something which isn't
# alphanumeric, shear that off.
Expand All @@ -310,16 +332,16 @@ def add(self, key, vals, shorthand=None, in_name=False):
By default, if a shorthand isn't given, one is automatically generated
from the key using the first three letters of each colon-separated
term. To disable this behavior, change ``DEFAULT_SHORTHAND`` in the
``spinup/user_config.py`` file to ``False``.
``spinup/user_config.py`` file to ``False``.
Args:
key (string): Name of parameter.
vals (value or list of values): Allowed values of parameter.
shorthand (string): Optional, shortened name of parameter. For
shorthand (string): Optional, shortened name of parameter. For
example, maybe the parameter ``steps_per_epoch`` is shortened
to ``steps``.
to ``steps``.
in_name (bool): When constructing variant names, force the
inclusion of this parameter into the name.
Expand All @@ -340,16 +362,16 @@ def variant_name(self, variant):
"""
Given a variant (dict of valid param/value pairs), make an exp_name.
A variant's name is constructed as the grid name (if you've given it
one), plus param names (or shorthands if available) and values
A variant's name is constructed as the grid name (if you've given it
one), plus param names (or shorthands if available) and values
separated by underscores.
Note: if ``seed`` is a parameter, it is not included in the name.
"""

def get_val(v, k):
# Utility method for getting the correct value out of a variant
# given as a nested dict. Assumes that a parameter name, k,
# given as a nested dict. Assumes that a parameter name, k,
# describes a path into the nested dict, such that k='a:b:c'
# corresponds to value=variant['a']['b']['c']. Uses recursion
# to get this.
Expand All @@ -370,7 +392,7 @@ def get_val(v, k):
# Include a parameter in a name if either 1) it can take multiple
# values, or 2) the user specified that it must appear in the name.
# Except, however, when the parameter is 'seed'. Seed is handled
# differently so that runs of the same experiment, with different
# differently so that runs of the same experiment, with different
# seeds, will be grouped by experiment name.
if (len(v)>1 or inn) and not(k=='seed'):

Expand All @@ -382,7 +404,7 @@ def get_val(v, k):
variant_val = get_val(variant, k)

# Append to name
if all_bools(v):
if all_bools(v):
# If this is a param which only takes boolean values,
# only include in the name if it's True for this variant.
var_name += ('_' + param_name) if variant_val else ''
Expand Down Expand Up @@ -438,13 +460,13 @@ def variants(self):
a : 1,
b : 2
}
}
}
}
"""
flat_variants = self._variants(self.keys, self.vals)

def unflatten_var(var):
"""
"""
Build the full nested dict version of var, based on key names.
"""
new_var = dict()
Expand Down Expand Up @@ -482,11 +504,11 @@ def run(self, thunk, num_cpu=1, data_dir=None, datestamp=False):
Run each variant in the grid with function 'thunk'.
Note: 'thunk' must be either a callable function, or a string. If it is
a string, it must be the name of a parameter whose values are all
a string, it must be the name of a parameter whose values are all
callable functions.
Uses ``call_experiment`` to actually launch each experiment, and gives
each variant a name using ``self.variant_name()``.
each variant a name using ``self.variant_name()``.
Maintenance note: the args for ExperimentGrid.run should track closely
to the args for call_experiment. However, ``seed`` is omitted because
Expand All @@ -503,7 +525,7 @@ def run(self, thunk, num_cpu=1, data_dir=None, datestamp=False):
var_names = set([self.variant_name(var) for var in variants])
var_names = sorted(list(var_names))
line = '='*DIV_LINE_WIDTH
preparing = colorize('Preparing to run the following experiments...',
preparing = colorize('Preparing to run the following experiments...',
color='green', bold=True)
joined_var_names = '\n'.join(var_names)
announcement = f"\n{preparing}\n\n{joined_var_names}\n\n{line}"
Expand All @@ -520,8 +542,8 @@ def run(self, thunk, num_cpu=1, data_dir=None, datestamp=False):
"""), color='cyan', bold=True)+line
print(delay_msg)
wait, steps = WAIT_BEFORE_LAUNCH, 100
prog_bar = trange(steps, desc='Launching in...',
leave=False, ncols=DIV_LINE_WIDTH,
prog_bar = trange(steps, desc='Launching in...',
leave=False, ncols=DIV_LINE_WIDTH,
mininterval=0.25,
bar_format='{desc}: {bar}| {remaining} {elapsed}')
for _ in prog_bar:
Expand All @@ -534,15 +556,15 @@ def run(self, thunk, num_cpu=1, data_dir=None, datestamp=False):
# Figure out what the thunk is.
if isinstance(thunk, str):
# Assume one of the variant parameters has the same
# name as the string you passed for thunk, and that
# name as the string you passed for thunk, and that
# variant[thunk] is a valid callable function.
thunk_ = var[thunk]
del var[thunk]
else:
# Assume thunk is given as a function.
thunk_ = thunk

call_experiment(exp_name, thunk_, num_cpu=num_cpu,
call_experiment(exp_name, thunk_, num_cpu=num_cpu,
data_dir=data_dir, datestamp=datestamp, **var)


Expand Down

0 comments on commit 1c90283

Please sign in to comment.