Skip to content

Commit

Permalink
Accepting extra parameters for environments
Browse files Browse the repository at this point in the history
  • Loading branch information
sharif1093 committed May 29, 2019
1 parent 10f77f5 commit ae8062a
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 11 deletions.
16 changes: 12 additions & 4 deletions digideep/environment/dmc2gym/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@ def __init__(self, task, task_kwargs=None, environment_kwargs=None, visualize_re
self.visualize_reward = visualize_reward
EzPickle.__init__(self, task, task_kwargs=task_kwargs, environment_kwargs=environment_kwargs, visualize_reward=visualize_reward)

def __call__(self):
def __call__(self, **extra_env_kwargs):
"""
Returns:
:obj:`dm_control.rl.control.Environment`: The ``dm_control`` environment.
"""
task_kwargs_subs = self.task_kwargs or {}
if self.environment_kwargs is not None:

if extra_env_kwargs:
task_kwargs_subs.update(extra_env_kwargs)

if (self.environment_kwargs is not None) or not (self.environment_kwargs == {}):
task_kwargs_subs = task_kwargs_subs.copy()
task_kwargs_subs['environment_kwargs'] = self.environment_kwargs
dmcenv = self.task(**task_kwargs_subs)
Expand All @@ -55,19 +59,23 @@ class EnvCreatorSuite(EzPickle):
environment_kwargs (dict): The keywords that will pass to the environment maker function.
visualize_reward (bool): Whether to visualize rewards in the viewer or not.
"""
def __init__(self, domain_name, task_name, task_kwargs=None, environment_kwargs=None, visualize_reward=False):
def __init__(self, domain_name, task_name, task_kwargs={}, environment_kwargs=None, visualize_reward=False):
# def __init__(self, domain_name, task_name, task_kwargs={}, environment_kwargs={}, visualize_reward=False):
self.domain_name = domain_name
self.task_name = task_name
self.task_kwargs = task_kwargs
self.environment_kwargs = environment_kwargs
self.visualize_reward = visualize_reward
EzPickle.__init__(self, domain_name, task_name, task_kwargs=task_kwargs, environment_kwargs=environment_kwargs, visualize_reward=visualize_reward)

def __call__(self):
def __call__(self, **extra_env_kwargs):
"""
Returns:
:obj:`dm_control.rl.control.Environment`: The ``dm_control`` environment.
"""
if extra_env_kwargs:
self.task_kwargs.update(extra_env_kwargs)

dmcenv = suite.load(domain_name=self.domain_name,
task_name=self.task_name,
task_kwargs=self.task_kwargs,
Expand Down
4 changes: 2 additions & 2 deletions digideep/environment/dmc2gym/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class DmControlWrapper(Env, EzPickle):
A callable object can delay the creation of the environment until the time we need it.
flat_observation (bool): Whether to flatten the observation dict or not.
"""
def __init__(self, dmcenv_creator, flat_observation=False, observation_key="agent"):
self.dmcenv = dmcenv_creator()
def __init__(self, dmcenv_creator, flat_observation=False, observation_key="agent", **extra_env_kwargs):
self.dmcenv = dmcenv_creator(**extra_env_kwargs)
self._flat_observation = flat_observation
self._observation_key = observation_key
# NOTE: We do not use the following to flatten observation to have more control over flattening and extracting "info".
Expand Down
3 changes: 2 additions & 1 deletion digideep/environment/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def __init__(self, session, agents=None, **params):
self.params = params

# Create models
extra_env_kwargs = self.params.get("extra_env_kwargs", {})
menv = MakeEnvironment(session, mode=self.params["mode"], seed=self.params["seed"], **self.params["env"])
self.envs = menv.create_envs(num_workers=self.params["num_workers"])
self.envs = menv.create_envs(num_workers=self.params["num_workers"], extra_env_kwargs=extra_env_kwargs)

self.state = {}
self.state["steps"] = 0
Expand Down
9 changes: 5 additions & 4 deletions digideep/environment/make_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@ def __init__(self, session, mode, seed, **params):
logger.fatal("Environment '" + params["name"] + "' is not registered in the gym registry.")
exit()

def make_env(self, rank, force_no_monitor=False):
def make_env(self, rank, force_no_monitor=False, extra_env_kwargs={}):
import sys # For debugging
def _f():
env = gym.make(self.params["name"])
# The header of gym.make(.): `def make(id, **kwargs)`
env = gym.make(self.params["name"], **extra_env_kwargs)
env.seed(self.seed + rank)

## Atari environment wrappers
Expand Down Expand Up @@ -159,8 +160,8 @@ def _f():
return env
return _f

def create_envs(self, num_workers=1, force_no_monitor=False):
envs = [self.make_env(rank=idx, force_no_monitor=force_no_monitor) for idx in range(num_workers)]
def create_envs(self, num_workers=1, force_no_monitor=False, extra_env_kwargs={}):
envs = [self.make_env(rank=idx, force_no_monitor=force_no_monitor, extra_env_kwargs=extra_env_kwargs) for idx in range(num_workers)]

## NOTE: We do not use DummyVecEnvs when num_workers==1 to avoid running glfw.init() on the Main process.
if self.mode == "eval":
Expand Down
3 changes: 3 additions & 0 deletions digideep/params/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def gen_params(cpanel):
params["explorer"]["train"]["render"] = False
params["explorer"]["train"]["render_delay"] = 0
params["explorer"]["train"]["seed"] = cpanel["seed"] # + 3500
params["explorer"]["train"]["extra_env_kwargs"] = {}

params["explorer"]["test"] = {}
params["explorer"]["test"]["mode"] = "test"
Expand All @@ -273,6 +274,7 @@ def gen_params(cpanel):
params["explorer"]["test"]["render"] = False
params["explorer"]["test"]["render_delay"] = 0
params["explorer"]["test"]["seed"] = cpanel["seed"] + 100 # We want to make the seed of test environments different from training.
params["explorer"]["test"]["extra_env_kwargs"] = {}

params["explorer"]["eval"] = {}
params["explorer"]["eval"]["mode"] = "eval"
Expand All @@ -285,6 +287,7 @@ def gen_params(cpanel):
params["explorer"]["eval"]["render"] = True
params["explorer"]["eval"]["render_delay"] = 0
params["explorer"]["eval"]["seed"] = cpanel["seed"] + 101 # We want to make the seed of eval environment different from test/train.
params["explorer"]["eval"]["extra_env_kwargs"] = {}
##############################################

return params
Expand Down
3 changes: 3 additions & 0 deletions digideep/params/classic_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def gen_params(cpanel):
params["explorer"]["train"]["render"] = False
params["explorer"]["train"]["render_delay"] = 0
params["explorer"]["train"]["seed"] = cpanel["seed"] # + 3500
params["explorer"]["train"]["extra_env_kwargs"] = {}

params["explorer"]["test"] = {}
params["explorer"]["test"]["mode"] = "test"
Expand All @@ -269,6 +270,7 @@ def gen_params(cpanel):
params["explorer"]["test"]["render"] = False
params["explorer"]["test"]["render_delay"] = 0
params["explorer"]["test"]["seed"] = cpanel["seed"] + 100 # We want to make the seed of test environments different from training.
params["explorer"]["test"]["extra_env_kwargs"] = {}

params["explorer"]["eval"] = {}
params["explorer"]["eval"]["mode"] = "eval"
Expand All @@ -281,6 +283,7 @@ def gen_params(cpanel):
params["explorer"]["eval"]["render"] = True
params["explorer"]["eval"]["render_delay"] = 0
params["explorer"]["eval"]["seed"] = cpanel["seed"] + 101 # We want to make the seed of eval environment different from test/train.
params["explorer"]["eval"]["extra_env_kwargs"] = {}
##############################################

return params
Expand Down
3 changes: 3 additions & 0 deletions digideep/params/mujoco_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def gen_params(cpanel):
params["explorer"]["train"]["render"] = False
params["explorer"]["train"]["render_delay"] = 0
params["explorer"]["train"]["seed"] = cpanel["seed"] # + 3500
params["explorer"]["train"]["extra_env_kwargs"] = {}

params["explorer"]["test"] = {}
params["explorer"]["test"]["mode"] = "test"
Expand All @@ -261,6 +262,7 @@ def gen_params(cpanel):
params["explorer"]["test"]["render"] = False
params["explorer"]["test"]["render_delay"] = 0
params["explorer"]["test"]["seed"] = cpanel["seed"] + 100 # We want to make the seed of test environments different from training.
params["explorer"]["test"]["extra_env_kwargs"] = {}

params["explorer"]["eval"] = {}
params["explorer"]["eval"]["mode"] = "eval"
Expand All @@ -273,6 +275,7 @@ def gen_params(cpanel):
params["explorer"]["eval"]["render"] = True
params["explorer"]["eval"]["render_delay"] = 0
params["explorer"]["eval"]["seed"] = cpanel["seed"] + 101 # We want to make the seed of eval environment different from test/train.
params["explorer"]["eval"]["extra_env_kwargs"] = {}
##############################################

return params
Expand Down
3 changes: 3 additions & 0 deletions digideep/params/sac_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def gen_params(cpanel):
params["explorer"]["train"]["render"] = False
params["explorer"]["train"]["render_delay"] = 0
params["explorer"]["train"]["seed"] = cpanel["seed"] # + 3500
params["explorer"]["train"]["extra_env_kwargs"] = {}

params["explorer"]["test"] = {}
params["explorer"]["test"]["mode"] = "test"
Expand All @@ -280,6 +281,7 @@ def gen_params(cpanel):
params["explorer"]["test"]["render"] = False
params["explorer"]["test"]["render_delay"] = 0
params["explorer"]["test"]["seed"] = cpanel["seed"] + 100 # We want to make the seed of test environments different from training.
params["explorer"]["test"]["extra_env_kwargs"] = {}

params["explorer"]["eval"] = {}
params["explorer"]["eval"]["mode"] = "eval"
Expand All @@ -292,6 +294,7 @@ def gen_params(cpanel):
params["explorer"]["eval"]["render"] = True
params["explorer"]["eval"]["render_delay"] = 0
params["explorer"]["eval"]["seed"] = cpanel["seed"] + 101 # We want to make the seed of eval environment different from test/train.
params["explorer"]["eval"]["extra_env_kwargs"] = {}
##############################################

return params
Expand Down

0 comments on commit ae8062a

Please sign in to comment.