-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a912e2f
commit d10e8ce
Showing
17 changed files
with
916 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,3 +16,5 @@ lagom.metric: Metrics | |
.. autofunction:: td0_error | ||
|
||
.. autofunction:: gae | ||
|
||
.. autofunction:: vtrace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,5 @@ | |
from .td import td0_error | ||
|
||
from .gae import gae | ||
|
||
from .vtrace import vtrace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import numpy as np | ||
|
||
from lagom.utils import numpify | ||
|
||
from .td import td0_error | ||
|
||
|
||
def vtrace(behavior_logprobs, target_logprobs, gamma, Rs, Vs, last_V, reach_terminal, clip_rho=1.0, clip_pg_rho=1.0): | ||
behavior_logprobs = numpify(behavior_logprobs, np.float32) | ||
target_logprobs = numpify(target_logprobs, np.float32) | ||
Rs = numpify(Rs, np.float32) | ||
Vs = numpify(Vs, np.float32) | ||
last_V = numpify(last_V, np.float32) | ||
assert all([item.ndim == 1 for item in [behavior_logprobs, target_logprobs, Rs, Vs]]) | ||
assert np.isscalar(gamma) | ||
|
||
rhos = np.exp(target_logprobs - behavior_logprobs) | ||
clipped_rhos = np.minimum(clip_rho, rhos) | ||
cs = np.minimum(1.0, rhos) | ||
deltas = clipped_rhos*td0_error(gamma, Rs, Vs, last_V, reach_terminal) | ||
|
||
vs_minus_V = [] | ||
total = 0.0 | ||
for delta_t, c_t in zip(deltas[::-1], cs[::-1]): | ||
total = delta_t + gamma*c_t*total | ||
vs_minus_V.append(total) | ||
vs_minus_V = np.asarray(vs_minus_V)[::-1] | ||
|
||
vs = vs_minus_V + Vs | ||
vs_next = np.append(vs[1:], (1. - reach_terminal)*last_V) | ||
clipped_pg_rhos = np.minimum(clip_pg_rho, rhos) | ||
As = clipped_pg_rhos*(Rs + gamma*vs_next - Vs) | ||
return vs, As |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Importance Weighted Actor-Learner Architectures (IMPALA) | ||
|
||
This is an implementation of [IMPALA](https://arxiv.org/abs/1802.01561) algorithm. | ||
|
||
# Usage | ||
|
||
Run the following command to start parallelized training: | ||
|
||
```bash | ||
python experiment.py | ||
``` | ||
|
||
One could modify [experiment.py](./experiment.py) to quickly set up different configurations. | ||
|
||
# Results | ||
|
||
## MLP policy | ||
<img src='logs/default/result.png' width='100%'> |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
|
||
from gym.spaces import Discrete | ||
from gym.spaces import Box | ||
from gym.spaces import flatdim | ||
|
||
from lagom import BaseAgent | ||
from lagom.utils import pickle_dump | ||
from lagom.utils import tensorify | ||
from lagom.utils import numpify | ||
from lagom.envs.wrappers import get_wrapper | ||
from lagom.networks import Module | ||
from lagom.networks import make_fc | ||
from lagom.networks import ortho_init | ||
from lagom.networks import CategoricalHead | ||
from lagom.networks import DiagGaussianHead | ||
from lagom.networks import linear_lr_scheduler | ||
from lagom.metric import vtrace | ||
from lagom.transform import explained_variance as ev | ||
from lagom.transform import describe | ||
|
||
|
||
class MLP(Module): | ||
def __init__(self, config, env, device, **kwargs): | ||
super().__init__(**kwargs) | ||
self.config = config | ||
self.env = env | ||
self.device = device | ||
|
||
self.feature_layers = make_fc(flatdim(env.observation_space), config['nn.sizes']) | ||
for layer in self.feature_layers: | ||
ortho_init(layer, nonlinearity='relu', constant_bias=0.0) | ||
self.layer_norms = nn.ModuleList([nn.LayerNorm(hidden_size) for hidden_size in config['nn.sizes']]) | ||
|
||
self.to(self.device) | ||
|
||
def forward(self, x): | ||
for layer, layer_norm in zip(self.feature_layers, self.layer_norms): | ||
x = layer_norm(F.relu(layer(x))) | ||
return x | ||
|
||
|
||
class Agent(BaseAgent): | ||
def __init__(self, config, env, device, **kwargs): | ||
super().__init__(config, env, device, **kwargs) | ||
|
||
feature_dim = config['nn.sizes'][-1] | ||
self.feature_network = MLP(config, env, device, **kwargs) | ||
if isinstance(env.action_space, Discrete): | ||
self.action_head = CategoricalHead(feature_dim, env.action_space.n, device, **kwargs) | ||
elif isinstance(env.action_space, Box): | ||
self.action_head = DiagGaussianHead(feature_dim, flatdim(env.action_space), device, config['agent.std0'], **kwargs) | ||
self.V_head = nn.Linear(feature_dim, 1) | ||
ortho_init(self.V_head, weight_scale=1.0, constant_bias=0.0) | ||
self.V_head = self.V_head.to(device) # reproducible between CPU/GPU, ortho_init behaves differently | ||
|
||
self.register_buffer('total_timestep', torch.tensor(0)) | ||
#self.total_timestep = 0 | ||
|
||
self.optimizer = optim.Adam(self.parameters(), lr=config['agent.lr']) | ||
if config['agent.use_lr_scheduler']: | ||
self.lr_scheduler = linear_lr_scheduler(self.optimizer, config['train.timestep'], min_lr=1e-8) | ||
self.gamma = config['agent.gamma'] | ||
self.clip_rho = config['agent.clip_rho'] | ||
self.clip_pg_rho = config['agent.clip_pg_rho'] | ||
|
||
def choose_action(self, obs, **kwargs): | ||
obs = tensorify(obs, self.device) | ||
out = {} | ||
features = self.feature_network(obs) | ||
|
||
action_dist = self.action_head(features) | ||
out['action_dist'] = action_dist | ||
out['entropy'] = action_dist.entropy() | ||
|
||
action = action_dist.sample() | ||
out['action'] = action | ||
out['raw_action'] = numpify(action, self.env.action_space.dtype) | ||
out['action_logprob'] = action_dist.log_prob(action.detach()) | ||
|
||
V = self.V_head(features) | ||
out['V'] = V | ||
return out | ||
|
||
def learn(self, D, **kwargs): | ||
# Compute all metrics, D: list of Trajectory | ||
Ts = [len(traj) for traj in D] | ||
behavior_logprobs = [torch.cat(traj.get_all_info('action_logprob')) for traj in D] | ||
out_agent = self.choose_action(np.concatenate([traj.numpy_observations[:-1] for traj in D], 0)) | ||
logprobs = out_agent['action_logprob'].squeeze() | ||
entropies = out_agent['entropy'].squeeze() | ||
Vs = out_agent['V'].squeeze() | ||
with torch.no_grad(): | ||
last_observations = tensorify(np.concatenate([traj.last_observation for traj in D], 0), self.device) | ||
last_Vs = self.V_head(self.feature_network(last_observations)).squeeze(-1) | ||
|
||
vs, As = [], [] | ||
for traj, behavior_logprob, logprob, V, last_V in zip(D, behavior_logprobs, logprobs.detach().cpu().split(Ts), | ||
Vs.detach().cpu().split(Ts), last_Vs): | ||
v, A = vtrace(behavior_logprob, logprob, self.gamma, traj.rewards, V, last_V, | ||
traj.reach_terminal, self.clip_rho, self.clip_pg_rho) | ||
vs.append(v) | ||
As.append(A) | ||
|
||
# Metrics -> Tensor, device | ||
vs, As = map(lambda x: tensorify(np.concatenate(x).copy(), self.device), [vs, As]) | ||
if self.config['agent.standardize_adv']: | ||
As = (As - As.mean())/(As.std() + 1e-8) | ||
|
||
assert all([x.ndimension() == 1 for x in [logprobs, entropies, Vs, vs, As]]) | ||
|
||
# Loss | ||
policy_loss = -logprobs*As | ||
entropy_loss = -entropies | ||
value_loss = F.mse_loss(Vs, vs, reduction='none') | ||
|
||
loss = policy_loss + self.config['agent.value_coef']*value_loss + self.config['agent.entropy_coef']*entropy_loss | ||
loss = loss.mean() | ||
|
||
self.optimizer.zero_grad() | ||
loss.backward() | ||
grad_norm = nn.utils.clip_grad_norm_(self.parameters(), self.config['agent.max_grad_norm']) | ||
self.optimizer.step() | ||
if self.config['agent.use_lr_scheduler']: | ||
self.lr_scheduler.step(self.total_timestep) | ||
|
||
self.total_timestep += sum([len(traj) for traj in D]) | ||
out = {} | ||
if self.config['agent.use_lr_scheduler']: | ||
out['current_lr'] = self.lr_scheduler.get_lr() | ||
out['loss'] = loss.item() | ||
out['grad_norm'] = grad_norm | ||
out['policy_loss'] = policy_loss.mean().item() | ||
out['entropy_loss'] = entropy_loss.mean().item() | ||
out['policy_entropy'] = -entropy_loss.mean().item() | ||
out['value_loss'] = value_loss.mean().item() | ||
out['V'] = describe(numpify(Vs, 'float').squeeze(), axis=-1, repr_indent=1, repr_prefix='\n') | ||
out['explained_variance'] = ev(y_true=numpify(vs, 'float'), y_pred=numpify(Vs, 'float')) | ||
return out | ||
|
||
def checkpoint(self, logdir, num_iter): | ||
self.save(logdir/f'agent_{num_iter}.pth') | ||
obs_env = get_wrapper(self.env, 'VecStandardizeObservation') | ||
if obs_env is not None: | ||
pickle_dump(obj=(obs_env.mean, obs_env.var), f=logdir/f'obs_moments_{num_iter}', ext='.pth') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from time import perf_counter | ||
from itertools import chain | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from lagom import Logger | ||
from lagom import BaseEngine | ||
from lagom.transform import describe | ||
from lagom.utils import color_str | ||
from lagom.envs.wrappers import get_wrapper | ||
|
||
|
||
class Engine(BaseEngine): | ||
def train(self, n=None, **kwargs): | ||
self.agent.train() | ||
start_time = perf_counter() | ||
|
||
D = kwargs['D'] | ||
out_agent = self.agent.learn(D) | ||
|
||
logger = Logger() | ||
logger('train_iteration', n+1) | ||
logger('num_seconds', round(perf_counter() - start_time, 1)) | ||
[logger(key, value) for key, value in out_agent.items()] | ||
logger('num_trajectories', len(D)) | ||
logger('num_timesteps', sum([len(traj) for traj in D])) | ||
logger('accumulated_trained_timesteps', self.agent.total_timestep) | ||
G = [traj.numpy_rewards.sum() for traj in D] | ||
logger('return', describe(G, axis=-1, repr_indent=1, repr_prefix='\n')) | ||
|
||
infos = [info for info in chain.from_iterable([traj.infos for traj in D]) if 'episode' in info] | ||
online_returns = [info['episode']['return'] for info in infos] | ||
online_horizons = [info['episode']['horizon'] for info in infos] | ||
logger('online_return', describe(online_returns, axis=-1, repr_indent=1, repr_prefix='\n')) | ||
logger('online_horizon', describe(online_horizons, axis=-1, repr_indent=1, repr_prefix='\n')) | ||
|
||
monitor_env = get_wrapper(self.env, 'VecMonitor') | ||
logger('running_return', describe(monitor_env.return_queue, axis=-1, repr_indent=1, repr_prefix='\n')) | ||
logger('running_horizon', describe(monitor_env.horizon_queue, axis=-1, repr_indent=1, repr_prefix='\n')) | ||
return logger | ||
|
||
def eval(self, n=None, **kwargs): | ||
pass |
Oops, something went wrong.