Skip to content

Commit

Permalink
Update baselines: recover plot.ipynb / more generic make_env
Browse files Browse the repository at this point in the history
  • Loading branch information
zuoxingdong committed Jul 3, 2019
1 parent 1b8748f commit a912e2f
Show file tree
Hide file tree
Showing 9 changed files with 4,004 additions and 60 deletions.
14 changes: 8 additions & 6 deletions baselines/cem/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,25 @@
})


def make_env(config, seed):
def make_env(config, seed, mode):
assert mode in ['train', 'eval']
def _make_env():
env = gym.make(config['env.id'])
if config['env.clip_action'] and isinstance(env.action_space, Box):
env = ClipAction(env)
return env
env = make_vec_env(_make_env, 1, seed) # single environment
if mode == 'train':
env = VecMonitor(env)
if config['env.standardize_obs']:
env = VecStandardizeObservation(env, clip=5.)
return env


def initializer(config, seed, device):
torch.set_num_threads(1) # VERY IMPORTANT TO AVOID GETTING STUCK
global env
env = make_env(config, seed)
env = VecMonitor(env)
if config['env.standardize_obs']:
env = VecStandardizeObservation(env, clip=5.)
env = make_env(config, seed, 'train')
global agent
agent = Agent(config, env, device)

Expand All @@ -94,7 +96,7 @@ def run(config, seed, device, logdir):
set_global_seeds(seed)

print('Initializing...')
agent = Agent(config, make_env(config, seed), device)
agent = Agent(config, make_env(config, seed, 'eval'), device)
es = CEM([config['train.mu0']]*agent.num_params, config['train.std0'],
{'popsize': config['train.popsize'],
'seed': seed,
Expand Down
14 changes: 8 additions & 6 deletions baselines/cmaes/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,25 @@
})


def make_env(config, seed):
def make_env(config, seed, mode):
assert mode in ['train', 'eval']
def _make_env():
env = gym.make(config['env.id'])
if config['env.clip_action'] and isinstance(env.action_space, Box):
env = ClipAction(env)
return env
env = make_vec_env(_make_env, 1, seed) # single environment
if mode == 'train':
env = VecMonitor(env)
if config['env.standardize_obs']:
env = VecStandardizeObservation(env, clip=5.)
return env


def initializer(config, seed, device):
torch.set_num_threads(1) # VERY IMPORTANT TO AVOID GETTING STUCK
global env
env = make_env(config, seed)
env = VecMonitor(env)
if config['env.standardize_obs']:
env = VecStandardizeObservation(env, clip=5.)
env = make_env(config, seed, 'train')
global agent
agent = Agent(config, env, device)

Expand All @@ -92,7 +94,7 @@ def run(config, seed, device, logdir):
set_global_seeds(seed)

print('Initializing...')
agent = Agent(config, make_env(config, seed), device)
agent = Agent(config, make_env(config, seed, 'eval'), device)
es = CMAES([config['train.mu0']]*agent.num_params, config['train.std0'],
{'popsize': config['train.popsize'],
'seed': seed})
Expand Down
15 changes: 7 additions & 8 deletions baselines/ddpg/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,24 @@
})


def make_env(config, seed):
def make_env(config, seed, mode):
assert mode in ['train', 'eval']
def _make_env():
env = gym.make(config['env.id'])
env = ClipAction(env)
return env
env = make_vec_env(_make_env, 1, seed) # single environment
env = VecMonitor(env)
if mode == 'train':
env = VecStepInfo(env)
return env


def run(config, seed, device, logdir):
set_global_seeds(seed)

env = make_env(config, seed)
env = VecMonitor(env)
env = VecStepInfo(env)

eval_env = make_env(config, seed)
eval_env = VecMonitor(eval_env)

env = make_env(config, seed, 'train')
eval_env = make_env(config, seed, 'eval')
agent = Agent(config, env, device)
replay = ReplayBuffer(env, config['replay.capacity'], device)
engine = Engine(config, agent=agent, env=env, eval_env=eval_env, replay=replay, logdir=logdir)
Expand Down
14 changes: 8 additions & 6 deletions baselines/openaies/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,25 @@
})


def make_env(config, seed):
def make_env(config, seed, mode):
assert mode in ['train', 'eval']
def _make_env():
env = gym.make(config['env.id'])
if config['env.clip_action'] and isinstance(env.action_space, Box):
env = ClipAction(env)
return env
env = make_vec_env(_make_env, 1, seed) # single environment
if mode == 'train':
env = VecMonitor(env)
if config['env.standardize_obs']:
env = VecStandardizeObservation(env, clip=5.)
return env


def initializer(config, seed, device):
torch.set_num_threads(1) # VERY IMPORTANT TO AVOID GETTING STUCK
global env
env = make_env(config, seed)
env = VecMonitor(env)
if config['env.standardize_obs']:
env = VecStandardizeObservation(env, clip=5.)
env = make_env(config, seed, 'train')
global agent
agent = Agent(config, env, device)

Expand All @@ -98,7 +100,7 @@ def run(config, seed, device, logdir):
set_global_seeds(seed)

print('Initializing...')
agent = Agent(config, make_env(config, seed), device)
agent = Agent(config, make_env(config, seed, 'eval'), device)
es = OpenAIES([config['train.mu0']]*agent.num_params, config['train.std0'],
{'popsize': config['train.popsize'],
'seed': seed,
Expand Down

0 comments on commit a912e2f

Please sign in to comment.