Skip to content

Commit

Permalink
Update baselines: SAC
Browse files Browse the repository at this point in the history
  • Loading branch information
zuoxingdong committed Aug 7, 2019
1 parent 23482e6 commit 1203bec
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 143 deletions.
4 changes: 1 addition & 3 deletions baselines/sac/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,4 @@ python experiment.py
One could modify [experiment.py](./experiment.py) to quickly set up different configurations.

# Results

## MLP policy
<img src='logs/default/result.png' width='100%'>
<img src='https://i.imgur.com/NsWSs4E.png' width='100%'>
79 changes: 35 additions & 44 deletions baselines/sac/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,23 @@
import torch.optim as optim
from torch.distributions import Independent
from torch.distributions import Normal
from torch.distributions import Transform
from torch.distributions import TransformedDistribution
from torch.distributions import Transform
from torch.distributions import constraints

from gym.spaces import flatdim
from lagom import BaseAgent
from lagom.transform import describe
from lagom.utils import pickle_dump
from lagom.utils import tensorify
from lagom.utils import numpify
from lagom.networks import Module
from lagom.networks import make_fc
from lagom.networks import ortho_init
from lagom.transform import describe


# TODO: import from PyTorch when PR merged: https://github.com/pytorch/pytorch/pull/19785
class TanhTransform(Transform):
r"""
Transform via the mapping :math:`y = \tanh(x)`.
"""
r"""Transform via the mapping :math:`y = \tanh(x)`."""
domain = constraints.real
codomain = constraints.interval(-1.0, 1.0)
bijective = True
Expand Down Expand Up @@ -52,6 +49,7 @@ def log_abs_det_jacobian(self, x, y):
class Actor(Module):
LOGSTD_MAX = 2
LOGSTD_MIN = -20

def __init__(self, config, env, device, **kwargs):
super().__init__(**kwargs)
self.config = config
Expand Down Expand Up @@ -127,7 +125,6 @@ def __init__(self, config, env, device, **kwargs):
self.critic = Critic(config, env, device, **kwargs)
self.critic_target = Critic(config, env, device, **kwargs)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_target.eval()
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=config['agent.critic.lr'])

self.target_entropy = -float(flatdim(env.action_space))
Expand All @@ -137,6 +134,7 @@ def __init__(self, config, env, device, **kwargs):
self.optimizer_zero_grad = lambda: [opt.zero_grad() for opt in [self.actor_optimizer,
self.critic_optimizer,
self.log_alpha_optimizer]]
self.total_timestep = 0

@property
def alpha(self):
Expand All @@ -147,54 +145,46 @@ def polyak_update_target(self):
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(p*target_param.data + (1 - p)*param.data)

def choose_action(self, obs, **kwargs):
obs = tensorify(obs, self.device)
def choose_action(self, x, **kwargs):
obs = tensorify(x.observation, self.device).unsqueeze(0)
with torch.no_grad():
if kwargs['mode'] == 'train':
action = numpify(self.actor(obs).sample(), 'float')
elif kwargs['mode'] == 'eval':
action = numpify(torch.tanh(self.actor.mean_forward(obs)), 'float')
out = {}
if kwargs['mode'] == 'train':
dist = self.actor(obs)
action = dist.rsample()
out['action'] = action
out['action_logprob'] = dist.log_prob(action)
elif kwargs['mode'] == 'stochastic':
with torch.no_grad():
out['action'] = numpify(self.actor(obs).sample(), 'float')
elif kwargs['mode'] == 'eval':
with torch.no_grad():
out['action'] = numpify(torch.tanh(self.actor.mean_forward(obs)), 'float')
else:
raise NotImplementedError
out['raw_action'] = action.squeeze(0)
return out

def learn(self, D, **kwargs):
replay = kwargs['replay']
episode_length = kwargs['episode_length']
out = {}
out['actor_loss'] = []
out['critic_loss'] = []
out['alpha_loss'] = []
T = kwargs['T']
list_actor_loss = []
list_critic_loss = []
list_alpha_loss = []
Q1_vals = []
Q2_vals = []
logprob_vals = []
for i in range(episode_length):
for i in range(T):
observations, actions, rewards, next_observations, masks = replay.sample(self.config['replay.batch_size'])

Qs1, Qs2 = self.critic(observations, actions)
with torch.no_grad():
out_actor = self.choose_action(next_observations, mode='train')
next_actions = out_actor['action']
next_actions_logprob = out_actor['action_logprob'].unsqueeze(-1)
action_dist = self.actor(next_observations)
next_actions = action_dist.rsample()
next_actions_logprob = action_dist.log_prob(next_actions).unsqueeze(-1)
next_Qs1, next_Qs2 = self.critic_target(next_observations, next_actions)
next_Qs = torch.min(next_Qs1, next_Qs2) - self.alpha.detach()*next_actions_logprob
Q_targets = rewards + self.config['agent.gamma']*masks*next_Qs
critic_loss = F.mse_loss(Qs1, Q_targets.detach()) + F.mse_loss(Qs2, Q_targets.detach())
targets = rewards + self.config['agent.gamma']*masks*next_Qs
critic_loss = F.mse_loss(Qs1, targets.detach()) + F.mse_loss(Qs2, targets.detach())
self.optimizer_zero_grad()
critic_loss.backward()
critic_grad_norm = nn.utils.clip_grad_norm_(self.critic.parameters(), self.config['agent.max_grad_norm'])
self.critic_optimizer.step()

out_actor = self.choose_action(observations, mode='train')
policy_actions = out_actor['action']
policy_actions_logprob = out_actor['action_logprob'].unsqueeze(-1)
action_dist = self.actor(observations)
policy_actions = action_dist.rsample()
policy_actions_logprob = action_dist.log_prob(policy_actions).unsqueeze(-1)
actor_Qs1, actor_Qs2 = self.critic(observations, policy_actions)
actor_Qs = torch.min(actor_Qs1, actor_Qs2)
actor_loss = torch.mean(self.alpha.detach()*policy_actions_logprob - actor_Qs)
Expand All @@ -209,25 +199,26 @@ def learn(self, D, **kwargs):
self.log_alpha_optimizer.step()

self.polyak_update_target()

out['critic_loss'].append(critic_loss)
out['actor_loss'].append(actor_loss)
out['alpha_loss'].append(alpha_loss)
list_actor_loss.append(actor_loss)
list_critic_loss.append(critic_loss)
list_alpha_loss.append(alpha_loss)
Q1_vals.append(Qs1)
Q2_vals.append(Qs2)
logprob_vals.append(policy_actions_logprob)
out['actor_loss'] = torch.tensor(out['actor_loss']).mean().item()
self.total_timestep += T

out = {}
out['actor_loss'] = torch.tensor(list_actor_loss).mean(0).item()
out['actor_grad_norm'] = actor_grad_norm
out['critic_loss'] = torch.tensor(out['critic_loss']).mean().item()
out['critic_loss'] = torch.tensor(list_critic_loss).mean(0).item()
out['critic_grad_norm'] = critic_grad_norm
describe_it = lambda x: describe(numpify(torch.cat(x), 'float').squeeze(), axis=-1, repr_indent=1, repr_prefix='\n')
out['Q1'] = describe_it(Q1_vals)
out['Q2'] = describe_it(Q2_vals)
out['logprob'] = describe_it(logprob_vals)
out['alpha_loss'] = torch.tensor(out['alpha_loss']).mean().item()
out['alpha_loss'] = torch.tensor(list_alpha_loss).mean(0).item()
out['alpha'] = self.alpha.item()
return out

def checkpoint(self, logdir, num_iter):
self.save(logdir/f'agent_{num_iter}.pth')
# TODO: save normalization moments
110 changes: 39 additions & 71 deletions baselines/sac/engine.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,60 @@
from time import perf_counter
import time
from itertools import count

import numpy as np
import torch

from lagom import Logger
from lagom import BaseEngine
from lagom.transform import describe
from lagom.utils import color_str
from lagom.envs.wrappers import get_wrapper


class Engine(BaseEngine):
def train(self, n=None, **kwargs):
train_logs = []
eval_logs = []
eval_togo = 0
dump_togo = 0
num_episode = 0
train_logs, eval_logs = [], []
checkpoint_count = 0
observation, _ = self.env.reset()
for i in count():
if i >= self.config['train.timestep']:
for iteration in count():
if self.agent.total_timestep >= self.config['train.timestep']:
break
if i < self.config['replay.init_size']:
action = [self.env.action_space.sample()]
t0 = time.perf_counter()

if iteration < self.config['replay.init_trial']:
[traj] = self.runner(self.random_agent, self.env, 1)
else:
action = self.agent.choose_action(observation, mode='stochastic')['action']
next_observation, reward, step_info = self.env.step(action)
eval_togo += 1
dump_togo += 1
if step_info[0].last: # [0] due to single environment
start_time = perf_counter()
self.replay.add(observation[0], action[0], reward[0], step_info[0]['last_observation'], step_info[0].terminal)
[traj] = self.runner(self.agent, self.env, 1, mode='train')
self.replay.add(traj)
# Number of gradient updates = collected episode length
out_agent = self.agent.learn(D=None, replay=self.replay, T=traj.T)

logger = Logger()
logger('train_iteration', iteration+1)
logger('num_seconds', round(time.perf_counter() - t0, 1))
[logger(key, value) for key, value in out_agent.items()]
logger('episode_return', sum(traj.rewards))
logger('episode_horizon', traj.T)
logger('accumulated_trained_timesteps', self.agent.total_timestep)
train_logs.append(logger.logs)
if iteration == 0 or (iteration+1) % self.config['log.freq'] == 0:
logger.dump(keys=None, index=0, indent=0, border='-'*50)
if self.agent.total_timestep >= int(self.config['train.timestep']*(checkpoint_count/(self.config['checkpoint.num'] - 1))):
self.agent.checkpoint(self.logdir, iteration + 1)
checkpoint_count += 1

# updates in the end of episode, for each time step
out_agent = self.agent.learn(D=None, replay=self.replay, episode_length=step_info[0]['episode']['horizon'])
num_episode += 1
if (i+1) >= int(self.config['train.timestep']*(checkpoint_count/(self.config['checkpoint.num'] - 1))):
self.agent.checkpoint(self.logdir, num_episode)
checkpoint_count += 1
logger = Logger()
logger('num_seconds', round(perf_counter() - start_time, 1))
logger('accumulated_trained_timesteps', i + 1)
logger('accumulated_trained_episodes', num_episode)
[logger(key, value) for key, value in out_agent.items()]
logger('episode_return', step_info[0]['episode']['return'])
logger('episode_horizon', step_info[0]['episode']['horizon'])
train_logs.append(logger.logs)
if dump_togo >= self.config['log.freq']:
dump_togo %= self.config['log.freq']
logger.dump(keys=None, index=0, indent=0, border='-'*50)
if eval_togo >= self.config['eval.freq']:
eval_togo %= self.config['eval.freq']
eval_logs.append(self.eval(accumulated_trained_timesteps=(i+1),
accumulated_trained_episodes=num_episode))
else:
self.replay.add(observation[0], action[0], reward[0], next_observation[0], step_info[0].terminal)
observation = next_observation
if checkpoint_count < self.config['checkpoint.num']:
self.agent.checkpoint(self.logdir, num_episode)
checkpoint_count += 1
if self.agent.total_timestep >= int(self.config['train.timestep']*(len(eval_logs)/(self.config['eval.num'] - 1))):
eval_logs.append(self.eval(n=len(eval_logs)))
return train_logs, eval_logs

def eval(self, n=None, **kwargs):
start_time = perf_counter()
returns = []
horizons = []
for _ in range(self.config['eval.num_episode']):
observation = self.eval_env.reset()
for _ in range(self.eval_env.spec.max_episode_steps):
with torch.no_grad():
action = self.agent.choose_action(observation, mode='eval')['action']
next_observation, reward, done, info = self.eval_env.step(action)
if done[0]: # [0] single environment
returns.append(info[0]['episode']['return'])
horizons.append(info[0]['episode']['horizon'])
break
observation = next_observation
logger = Logger()
logger('num_seconds', round(perf_counter() - start_time, 1))
logger('accumulated_trained_timesteps', kwargs['accumulated_trained_timesteps'])
logger('accumulated_trained_episodes', kwargs['accumulated_trained_episodes'])
logger('online_return', describe(returns, axis=-1, repr_indent=1, repr_prefix='\n'))
logger('online_horizon', describe(horizons, axis=-1, repr_indent=1, repr_prefix='\n'))
t0 = time.perf_counter()
with torch.no_grad():
D = self.runner(self.agent, self.eval_env, 10, mode='eval')

monitor_env = get_wrapper(self.eval_env, 'VecMonitor')
logger('running_return', describe(monitor_env.return_queue, axis=-1, repr_indent=1, repr_prefix='\n'))
logger('running_horizon', describe(monitor_env.horizon_queue, axis=-1, repr_indent=1, repr_prefix='\n'))
logger = Logger()
logger('eval_iteration', n+1)
logger('num_seconds', round(time.perf_counter() - t0, 1))
logger('accumulated_trained_timesteps', self.agent.total_timestep)
logger('online_return', describe([sum(traj.rewards) for traj in D], axis=-1, repr_indent=1, repr_prefix='\n'))
logger('online_horizon', describe([traj.T for traj in D], axis=-1, repr_indent=1, repr_prefix='\n'))
logger('running_return', describe(self.eval_env.return_queue, axis=-1, repr_indent=1, repr_prefix='\n'))
logger('running_horizon', describe(self.eval_env.horizon_queue, axis=-1, repr_indent=1, repr_prefix='\n'))
logger.dump(keys=None, index=0, indent=0, border=color_str('+'*50, color='green'))
return logger.logs
42 changes: 19 additions & 23 deletions baselines/sac/experiment.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
import os
from pathlib import Path

import gym

from lagom import EpisodeRunner
from lagom import RandomAgent
from lagom.utils import pickle_dump
from lagom.utils import set_global_seeds
from lagom.experiment import Config
from lagom.experiment import Grid
from lagom.experiment import Sample
from lagom.experiment import Condition
from lagom.experiment import run_experiment
from lagom.envs import make_vec_env
from lagom.envs import RecordEpisodeStatistics
from lagom.envs import TimeStepEnv
from lagom.envs.wrappers import NormalizeAction
from lagom.envs.wrappers import VecMonitor
from lagom.envs.wrappers import VecStepInfo

from baselines.sac.agent import Agent
from baselines.sac.engine import Engine
from baselines.sac.replay_buffer import ReplayBuffer


config = Config(
{'log.freq': 1000, # every n timesteps
{'log.freq': 10,
'checkpoint.num': 3,

'env.id': Grid(['HalfCheetah-v3', 'Hopper-v3', 'Walker2d-v3', 'Swimmer-v3']),
Expand All @@ -36,27 +33,24 @@
'agent.max_grad_norm': 999999, # grad clipping by norm

'replay.capacity': 1000000,
# number of time steps to take uniform actions initially
'replay.init_size': Condition(lambda x: 1000 if x['env.id'] in ['Hopper-v3', 'Walker2d-v3'] else 10000),
'replay.init_trial': 10, # number of random rollouts initially
'replay.batch_size': 256,

'train.timestep': int(1e6), # total number of training (environmental) timesteps
'eval.freq': 5000,
'eval.num_episode': 10

'eval.num': 200
})


def make_env(config, seed, mode):
assert mode in ['train', 'eval']
def _make_env():
env = gym.make(config['env.id'])
env = NormalizeAction(env)
return env
env = make_vec_env(_make_env, 1, seed) # single environment
env = VecMonitor(env)
if mode == 'train':
env = VecStepInfo(env)
env = gym.make(config['env.id'])
env.seed(seed)
env.observation_space.seed(seed)
env.action_space.seed(seed)
env = NormalizeAction(env) # TODO: use gym new wrapper RescaleAction when it's merged
if mode == 'eval':
env = RecordEpisodeStatistics(env, deque_size=100)
env = TimeStepEnv(env)
return env


Expand All @@ -65,9 +59,11 @@ def run(config, seed, device, logdir):

env = make_env(config, seed, 'train')
eval_env = make_env(config, seed, 'eval')
random_agent = RandomAgent(config, env, device)
agent = Agent(config, env, device)
runner = EpisodeRunner()
replay = ReplayBuffer(env, config['replay.capacity'], device)
engine = Engine(config, agent=agent, env=env, eval_env=eval_env, replay=replay, logdir=logdir)
engine = Engine(config, agent=agent, random_agent=random_agent, env=env, eval_env=eval_env, runner=runner, replay=replay, logdir=logdir)

train_logs, eval_logs = engine.train()
pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
Expand All @@ -78,7 +74,7 @@ def run(config, seed, device, logdir):
if __name__ == '__main__':
run_experiment(run=run,
config=config,
seeds=[4153361530, 3503522377, 2876994566, 172236777, 3949341511, 849059707],
seeds=[4153361530, 3503522377, 2876994566, 172236777, 3949341511],
log_dir='logs/default',
max_workers=os.cpu_count(),
chunksize=1,
Expand Down

0 comments on commit 1203bec

Please sign in to comment.