Skip to content

Commit

Permalink
Update baselines: CMA-ES
Browse files Browse the repository at this point in the history
  • Loading branch information
zuoxingdong committed Aug 8, 2019
1 parent eb41d18 commit bd8aff8
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 75 deletions.
31 changes: 12 additions & 19 deletions baselines/cmaes/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from gym.spaces import Discrete
from gym.spaces import Box
from gym.spaces import flatdim

import gym.spaces as spaces
from lagom import BaseAgent
from lagom.utils import pickle_dump
from lagom.utils import tensorify
from lagom.utils import numpify
from lagom.envs.wrappers import get_wrapper
from lagom.networks import Module
from lagom.networks import make_fc
from lagom.networks import CategoricalHead
Expand All @@ -26,7 +21,7 @@ def __init__(self, config, env, device, **kwargs):
self.env = env
self.device = device

self.feature_layers = make_fc(flatdim(env.observation_space), config['nn.sizes'])
self.feature_layers = make_fc(spaces.flatdim(env.observation_space), config['nn.sizes'])
self.layer_norms = nn.ModuleList([nn.LayerNorm(hidden_size) for hidden_size in config['nn.sizes']])
self.to(self.device)

Expand All @@ -42,28 +37,26 @@ def __init__(self, config, env, device, **kwargs):

self.feature_network = MLP(config, env, device, **kwargs)
feature_dim = config['nn.sizes'][-1]
if isinstance(env.action_space, Discrete):
if isinstance(env.action_space, spaces.Discrete):
self.action_head = CategoricalHead(feature_dim, env.action_space.n, device, **kwargs)
elif isinstance(env.action_space, Box):
self.action_head = DiagGaussianHead(feature_dim, flatdim(env.action_space), device, config['agent.std0'], **kwargs)
elif isinstance(env.action_space, spaces.Box):
self.action_head = DiagGaussianHead(feature_dim, spaces.flatdim(env.action_space), device, config['agent.std0'], **kwargs)
self.total_timestep = 0

def choose_action(self, obs, **kwargs):
obs = tensorify(obs, self.device)
out = {}
def choose_action(self, x, **kwargs):
obs = tensorify(x.observation, self.device).unsqueeze(0)
features = self.feature_network(obs)

action_dist = self.action_head(features)
out['entropy'] = action_dist.entropy()
action = action_dist.sample()
out['raw_action'] = numpify(action, self.env.action_space.dtype)
out = {}
out['raw_action'] = numpify(action, self.env.action_space.dtype).squeeze(0)
return out

def learn(self, D, **kwargs):
pass

def checkpoint(self, logdir, num_iter):
self.save(logdir/f'agent_{num_iter}.pth')
obs_env = get_wrapper(self.env, 'VecStandardizeObservation')
if obs_env is not None:
pickle_dump(obj=(obs_env.mean, obs_env.var), f=logdir/f'obs_moments_{num_iter}', ext='.pth')
if 'env.normalize_obs' in self.config and self.config['env.normalize_obs']:
moments = (self.env.obs_moments.mean, self.env.obs_moments.var)
pickle_dump(obj=moments, f=logdir/f'obs_moments_{num_iter}', ext='.pth')
90 changes: 34 additions & 56 deletions baselines/cmaes/experiment.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,20 @@
import os
from pathlib import Path
from itertools import count
from functools import partial
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import Pool
import time

import gym
from gym.spaces import Box
from gym.wrappers import ClipAction

import numpy as np
import torch
import gym
from lagom import Logger
from lagom import EpisodeRunner
from lagom.transform import describe
from lagom.utils import CloudpickleWrapper # VERY IMPORTANT
from lagom.utils import CloudpickleWrapper
from lagom.utils import pickle_dump
from lagom.utils import tensorify
from lagom.utils import set_global_seeds
from lagom.experiment import Config
from lagom.experiment import Grid
from lagom.experiment import Sample
from lagom.experiment import run_experiment
from lagom.envs import make_vec_env
from lagom.envs.wrappers import VecMonitor
from lagom.envs.wrappers import VecStandardizeObservation
from lagom.envs import TimeStepEnv

from lagom import CMAES
from baselines.cmaes.agent import Agent
Expand All @@ -33,65 +24,51 @@
{'log.freq': 10,
'checkpoint.num': 3,

'env.id': Grid(['HalfCheetah-v3', 'Hopper-v3', 'Walker2d-v3', 'Swimmer-v3']),
'env.standardize_obs': False,
'env.id': Grid(['Acrobot-v1', 'BipedalWalker-v2', 'Pendulum-v0', 'LunarLanderContinuous-v2']),

'nn.sizes': [64, 64],

# only for continuous control
'env.clip_action': True, # clip action within valid bound before step()
'agent.std0': 0.6, # initial std

'train.generations': int(1e3), # total number of ES generations
'train.popsize': 64,
'train.generations': 500, # total number of ES generations
'train.popsize': 32,
'train.worker_chunksize': 4, # must be divisible by popsize
'train.mu0': 0.0,
'train.std0': 1.0,

})


def make_env(config, seed, mode):
assert mode in ['train', 'eval']
def _make_env():
env = gym.make(config['env.id'])
if config['env.clip_action'] and isinstance(env.action_space, Box):
env = ClipAction(env)
return env
env = make_vec_env(_make_env, 1, seed) # single environment
if mode == 'train':
env = VecMonitor(env)
if config['env.standardize_obs']:
env = VecStandardizeObservation(env, clip=5.)
env = gym.make(config['env.id'])
env.seed(seed)
env.observation_space.seed(seed)
env.action_space.seed(seed)
if config['env.clip_action'] and isinstance(env.action_space, gym.spaces.Box):
env = gym.wrappers.ClipAction(env) # TODO: use tanh to squash policy output when RescaleAction wrapper merged in gym
env = TimeStepEnv(env)
return env
def initializer(config, seed, device):


def fitness(data):
torch.set_num_threads(1) # VERY IMPORTANT TO AVOID GETTING STUCK
global env
config, seed, device, param = data
env = make_env(config, seed, 'train')
global agent
agent = Agent(config, env, device)


def fitness(param):
agent.from_vec(tensorify(param, 'cpu'))
R = []
H = []
runner = EpisodeRunner()
with torch.no_grad():
for i in range(10):
observation = env.reset()
for t in range(env.spec.max_episode_steps):
action = agent.choose_action(observation)['raw_action']
observation, reward, done, info = env.step(action)
if done[0]:
R.append(info[0]['episode']['return'])
H.append(info[0]['episode']['horizon'])
break
return np.mean(R), np.mean(H)

D = runner(agent, env, 10)
R = np.mean([sum(traj.rewards) for traj in D])
H = np.mean([traj.T for traj in D])
return R, H


def run(config, seed, device, logdir):
set_global_seeds(seed)
torch.set_num_threads(1) # VERY IMPORTANT TO AVOID GETTING STUCK

print('Initializing...')
agent = Agent(config, make_env(config, seed, 'eval'), device)
Expand All @@ -100,17 +77,18 @@ def run(config, seed, device, logdir):
'seed': seed})
train_logs = []
checkpoint_count = 0
with ProcessPoolExecutor(max_workers=config['train.popsize'], initializer=initializer, initargs=(config, seed, device)) as executor:
with Pool(processes=config['train.popsize']//config['train.worker_chunksize']) as pool:
print('Finish initialization. Training starts...')
for generation in range(config['train.generations']):
start_time = time.perf_counter()
t0 = time.perf_counter()
solutions = es.ask()
out = list(executor.map(fitness, solutions, chunksize=2))
data = [(config, seed, device, solution) for solution in solutions]
out = pool.map(CloudpickleWrapper(fitness), data, chunksize=config['train.worker_chunksize'])
Rs, Hs = zip(*out)
es.tell(solutions, [-R for R in Rs])
logger = Logger()
logger('generation', generation+1)
logger('num_seconds', round(time.perf_counter() - start_time, 1))
logger('num_seconds', round(time.perf_counter() - t0, 1))
logger('Returns', describe(Rs, axis=-1, repr_indent=1, repr_prefix='\n'))
logger('Horizons', describe(Hs, axis=-1, repr_indent=1, repr_prefix='\n'))
logger('fbest', es.result.fbest)
Expand All @@ -123,14 +101,14 @@ def run(config, seed, device, logdir):
checkpoint_count += 1
pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
return None


if __name__ == '__main__':
run_experiment(run=run,
config=config,
seeds=[1770966829, 1500925526, 2054191100],
log_dir='logs/default',
max_workers=None, # no parallelization
max_workers=12, # tune to fulfill computation power
chunksize=1,
use_gpu=False,
gpu_ids=None)

0 comments on commit bd8aff8

Please sign in to comment.