Skip to content

Commit

Permalink
Now can read from a module or parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
sharif1093 committed Mar 6, 2019
1 parent 9404102 commit 31ad076
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 11 deletions.
22 changes: 19 additions & 3 deletions digideep/environment/make_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import gym
# Even though we don't need dm_control to be loaded here, it helps in initializing glfw.
import digideep.environment.dmc2gym

from gym.envs.registration import registry
################################################


Expand Down Expand Up @@ -76,20 +78,34 @@ class MakeEnvironment:
Except :class:`~digideep.environment.common.monitor.Monitor` environment, no environment will be applied on the environment
unless explicitly specified.
"""

registered = False

def __init__(self, session, mode, seed, **params):
self.mode = mode # train/test/eval
self.seed = seed
self.session = session
self.params = params

# Load user-defined modules in which the environment will be registered:
if params["module"]:
# Won't we have several environment registrations by this?
if params["from_module"]:
try:
get_module(params["module"])
get_module(params["from_module"])
except Exception as ex:
logger.fatal("While importing user module:", ex)
exit()
elif (params["from_params"]) and (not MakeEnvironment.registered):
try:
registry.register(**params["register_args"])
MakeEnvironment.registered = True
except Exception as ex:
logger.fatal("While registering from parameters:", ex)
exit()

# After all of these, check if environment is registered in the gym or not.
if not params["name"] in registry.env_specs:
logger.fatal("Environment '" + params["name"] + "' is not registered in the gym registry.")
exit()

def make_env(self, rank, force_no_monitor=False):
import sys # For debugging
Expand Down
5 changes: 3 additions & 2 deletions digideep/params/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

# Asteroids-v4 | AirRaid-v4 | Alien-v4 | VideoPinball-v4
cpanel["model_name"] = 'PongNoFrameskip-v4' # Atari Env
cpanel["env_module"] = ''

# General Parameters
# num_frames = 10e6 # Number of frames to train
Expand Down Expand Up @@ -89,7 +88,9 @@ def gen_params(cpanel):
# Environment
params["env"] = {}
params["env"]["name"] = cpanel["model_name"]
params["env"]["module"] = cpanel["env_module"]

params["env"]["from_module"] = cpanel.get("from_module", '')
params["env"]["from_params"] = cpanel.get("from_params", False)

params["env"]["wrappers"] = {"add_monitor": cpanel["add_monitor"],
"add_time_step": cpanel["add_time_step"],
Expand Down
6 changes: 3 additions & 3 deletions digideep/params/classic_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@

# Acrobot-v1 | CartPole-v1 | MountainCarContinuous-v0
cpanel["model_name"] = 'Pendulum-v0' # Classic Control Env
cpanel["env_module"] = ''
# Other possible modules: roboschool | pybullet_envs

# General Parameters
# num_frames = 10e6 # Number of frames to train
Expand Down Expand Up @@ -91,7 +89,9 @@ def gen_params(cpanel):
# Environment
params["env"] = {}
params["env"]["name"] = cpanel["model_name"]
params["env"]["module"] = cpanel["env_module"]

params["env"]["from_module"] = cpanel.get("from_module", '')
params["env"]["from_params"] = cpanel.get("from_params", False)

params["env"]["wrappers"] = {"add_monitor": cpanel["add_monitor"],
"add_time_step": cpanel["add_time_step"],
Expand Down
7 changes: 4 additions & 3 deletions digideep/params/mujoco_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
# 'HalfCheetah-v2'
# 'DMBenchHumanoidStand-v0' | 'DMBenchCheetahRun-v0'
cpanel["model_name"] = 'Ant-v2' # MuJoCo Env
cpanel["env_module"] = ''
# Other possible modules: roboschool | pybullet_envs

# General Parameters
# num_frames = 10e6 # Number of frames to train
Expand Down Expand Up @@ -92,7 +90,10 @@ def gen_params(cpanel):
# Environment
params["env"] = {}
params["env"]["name"] = cpanel["model_name"]
params["env"]["module"] = cpanel["env_module"]

# Other possible modules: roboschool | pybullet_envs
params["env"]["from_module"] = cpanel.get("from_module", '')
params["env"]["from_params"] = cpanel.get("from_params", False)

params["env"]["wrappers"] = {"add_monitor": cpanel["add_monitor"],
"add_time_step": cpanel["add_time_step"],
Expand Down

0 comments on commit 31ad076

Please sign in to comment.