Skip to content

Commit

Permalink
Merge pull request #172 from zuoxingdong/step_info_trajectory
Browse files Browse the repository at this point in the history
update CEM
  • Loading branch information
zuoxingdong committed May 8, 2019
2 parents ba9f9c3 + 8cb8689 commit 17b375f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
7 changes: 4 additions & 3 deletions baselines/cem/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from lagom import BaseAgent
from lagom.utils import pickle_dump
from lagom.utils import tensorify
from lagom.utils import numpify
from lagom.envs import flatdim
from lagom.envs.wrappers import get_wrapper
from lagom.networks import Module
Expand Down Expand Up @@ -47,15 +49,14 @@ def __init__(self, config, env, device, **kwargs):
self.total_timestep = 0

def choose_action(self, obs, **kwargs):
if not torch.is_tensor(obs):
obs = torch.from_numpy(np.asarray(obs)).float().to(self.device)
obs = tensorify(obs, self.device)
out = {}
features = self.feature_network(obs)

action_dist = self.action_head(features)
out['entropy'] = action_dist.entropy()
action = action_dist.sample()
out['raw_action'] = action.detach().cpu().numpy()
out['raw_action'] = numpify(action, 'float')
return out

def learn(self, D, **kwargs):
Expand Down
20 changes: 11 additions & 9 deletions baselines/cem/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from lagom.transform import describe
from lagom.utils import CloudpickleWrapper # VERY IMPORTANT
from lagom.utils import pickle_dump
from lagom.utils import tensorify
from lagom.utils import set_global_seeds
from lagom.experiment import Config
from lagom.experiment import Grid
Expand All @@ -30,9 +31,7 @@


config = Config(
{'cuda': False,
'log.dir': 'logs/default',
'log.freq': 10,
{'log.freq': 10,
'checkpoint.num': 3,

'env.id': Grid(['HalfCheetah-v3', 'Hopper-v3', 'Walker2d-v3', 'Swimmer-v3']),
Expand All @@ -42,7 +41,7 @@

# only for continuous control
'env.clip_action': True, # clip action within valid bound before step()
'agent.std0': 0.5, # initial std
'agent.std0': 0.6, # initial std

'train.generations': int(1e3), # total number of ES generations
'train.popsize': 64,
Expand Down Expand Up @@ -78,7 +77,7 @@ def initializer(config, seed, device):


def fitness(param):
agent.from_vec(torch.from_numpy(param).float())
agent.from_vec(tensorify(param, 'cpu'))
R = []
H = []
with torch.no_grad():
Expand All @@ -94,9 +93,8 @@ def fitness(param):
return np.mean(R), np.mean(H)


def run(config, seed, device):
def run(config, seed, device, logdir):
set_global_seeds(seed)
logdir = Path(config['log.dir']) / str(config['ID']) / str(seed)

print('Initializing...')
agent = Agent(config, make_env(config, seed), device)
Expand Down Expand Up @@ -125,7 +123,7 @@ def run(config, seed, device):
if generation == 0 or (generation+1)%config['log.freq'] == 0:
logger.dump(keys=None, index=0, indent=0, border='-'*50)
if (generation+1) >= int(config['train.generations']*(checkpoint_count/(config['checkpoint.num'] - 1))):
agent.from_vec(torch.from_numpy(es.result.xbest).float())
agent.from_vec(tensorify(es.result.xbest, 'cpu'))
agent.checkpoint(logdir, generation+1)
checkpoint_count += 1
pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
Expand All @@ -136,4 +134,8 @@ def run(config, seed, device):
run_experiment(run=run,
config=config,
seeds=[1770966829, 1500925526, 2054191100],
num_worker=5)
log_dir='logs/default',
max_workers=5,
chunksize=1,
use_gpu=False,
gpu_ids=None)

0 comments on commit 17b375f

Please sign in to comment.