Skip to content

Commit

Permalink
Load environment from external modules
Browse files Browse the repository at this point in the history
  • Loading branch information
sharif1093 committed Mar 5, 2019
1 parent 04db980 commit 9404102
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 24 deletions.
26 changes: 11 additions & 15 deletions digideep/environment/make_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,17 @@
from digideep.environment.wrappers import VecNormalize
from digideep.environment.wrappers import VecSaveState

from digideep.utility.toolbox import get_module

from digideep.utility.logging import logger
from gym import spaces

################################################
### Importing Environment Packages ###
################################################
import gym

# By importing "dmc2gym" the __init__ file of the package
# will be called and will register all the DM environments.
# Even though we don't need dm_control to be loaded here, it helps in initializing glfw.
import digideep.environment.dmc2gym

try:
import roboschool
except ImportError:
# logger.warn('roboschool is missing.')
pass

try:
import pybullet_envs
except ImportError:
# logger.warn('pybullet_envs is missing.')
pass
################################################


Expand Down Expand Up @@ -95,6 +83,14 @@ def __init__(self, session, mode, seed, **params):
self.session = session
self.params = params

# Load user-defined modules in which the environment will be registered:
if params["module"]:
try:
get_module(params["module"])
except Exception as ex:
logger.fatal("While importing user module:", ex)
exit()

def make_env(self, rank, force_no_monitor=False):
import sys # For debugging
def _f():
Expand Down
4 changes: 3 additions & 1 deletion digideep/params/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

# Asteroids-v4 | AirRaid-v4 | Alien-v4 | VideoPinball-v4
cpanel["model_name"] = 'PongNoFrameskip-v4' # Atari Env
cpanel["env_module"] = ''

# General Parameters
# num_frames = 10e6 # Number of frames to train
Expand Down Expand Up @@ -87,7 +88,8 @@ def gen_params(cpanel):
params = {}
# Environment
params["env"] = {}
params["env"]["name"] = cpanel["model_name"]
params["env"]["name"] = cpanel["model_name"]
params["env"]["module"] = cpanel["env_module"]

params["env"]["wrappers"] = {"add_monitor": cpanel["add_monitor"],
"add_time_step": cpanel["add_time_step"],
Expand Down
10 changes: 5 additions & 5 deletions digideep/params/classic_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@

cpanel = OrderedDict()

cpanel["model_name"] = 'Pendulum-v0' # Classic Control Env
# Acrobot-v1 | CartPole-v1 | MountainCarContinuous-v0

# cpanel["model_name"] = 'DMBenchHumanoidStand-v0' # dm_control wrapped environment
# cpanel["model_name"] = 'DMBenchCheetahRun-v0' # dm_control wrapped environment
cpanel["model_name"] = 'Pendulum-v0' # Classic Control Env
cpanel["env_module"] = ''
# Other possible modules: roboschool | pybullet_envs

# General Parameters
# num_frames = 10e6 # Number of frames to train
Expand Down Expand Up @@ -91,7 +90,8 @@ def gen_params(cpanel):
params = {}
# Environment
params["env"] = {}
params["env"]["name"] = cpanel["model_name"]
params["env"]["name"] = cpanel["model_name"]
params["env"]["module"] = cpanel["env_module"]

params["env"]["wrappers"] = {"add_monitor": cpanel["add_monitor"],
"add_time_step": cpanel["add_time_step"],
Expand Down
5 changes: 4 additions & 1 deletion digideep/params/mujoco_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
# 'HalfCheetah-v2'
# 'DMBenchHumanoidStand-v0' | 'DMBenchCheetahRun-v0'
cpanel["model_name"] = 'Ant-v2' # MuJoCo Env
cpanel["env_module"] = ''
# Other possible modules: roboschool | pybullet_envs

# General Parameters
# num_frames = 10e6 # Number of frames to train
Expand Down Expand Up @@ -89,7 +91,8 @@ def gen_params(cpanel):
params = {}
# Environment
params["env"] = {}
params["env"]["name"] = cpanel["model_name"]
params["env"]["name"] = cpanel["model_name"]
params["env"]["module"] = cpanel["env_module"]

params["env"]["wrappers"] = {"add_monitor": cpanel["add_monitor"],
"add_time_step": cpanel["add_time_step"],
Expand Down
4 changes: 2 additions & 2 deletions digideep/pipeline/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def __init__(self, root_path):
if not self.is_loading:
try:
get_module(self.args["params"])
except:
logger.fatal("Neither a checkpoint nor a valid params file are specified!")
except Exception as ex:
logger.fatal("While importing user-specified params:", ex)
exit()

print(':: The session will be stored in ' + self.state['path_session'])
Expand Down

0 comments on commit 9404102

Please sign in to comment.