Skip to content

Commit

Permalink
Add V-trace metric computation
Browse files Browse the repository at this point in the history
  • Loading branch information
zuoxingdong committed Jul 4, 2019
1 parent a912e2f commit d10e8ce
Show file tree
Hide file tree
Showing 17 changed files with 916 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/source/metric.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ lagom.metric: Metrics
.. autofunction:: td0_error

.. autofunction:: gae

.. autofunction:: vtrace
4 changes: 3 additions & 1 deletion lagom/envs/make_vec_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial # argument-free functions

from lagom.utils import Seeder
from lagom.utils import CloudpickleWrapper

from .vec_env import VecEnv

Expand Down Expand Up @@ -37,5 +38,6 @@ def f(seed):
return env

# Use partial to generate a list of argument-free make_env, each with different seed
list_make_env = [partial(f, seed=seed) for seed in seeds]
# partial object is not picklable, so wrap it with magical CloudpickleWrapper
list_make_env = [CloudpickleWrapper(partial(f, seed=seed)) for seed in seeds]
return VecEnv(list_make_env)
2 changes: 2 additions & 0 deletions lagom/metric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@
from .td import td0_error

from .gae import gae

from .vtrace import vtrace
33 changes: 33 additions & 0 deletions lagom/metric/vtrace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np

from lagom.utils import numpify

from .td import td0_error


def vtrace(behavior_logprobs, target_logprobs, gamma, Rs, Vs, last_V, reach_terminal, clip_rho=1.0, clip_pg_rho=1.0):
behavior_logprobs = numpify(behavior_logprobs, np.float32)
target_logprobs = numpify(target_logprobs, np.float32)
Rs = numpify(Rs, np.float32)
Vs = numpify(Vs, np.float32)
last_V = numpify(last_V, np.float32)
assert all([item.ndim == 1 for item in [behavior_logprobs, target_logprobs, Rs, Vs]])
assert np.isscalar(gamma)

rhos = np.exp(target_logprobs - behavior_logprobs)
clipped_rhos = np.minimum(clip_rho, rhos)
cs = np.minimum(1.0, rhos)
deltas = clipped_rhos*td0_error(gamma, Rs, Vs, last_V, reach_terminal)

vs_minus_V = []
total = 0.0
for delta_t, c_t in zip(deltas[::-1], cs[::-1]):
total = delta_t + gamma*c_t*total
vs_minus_V.append(total)
vs_minus_V = np.asarray(vs_minus_V)[::-1]

vs = vs_minus_V + Vs
vs_next = np.append(vs[1:], (1. - reach_terminal)*last_V)
clipped_pg_rhos = np.minimum(clip_pg_rho, rhos)
As = clipped_pg_rhos*(Rs + gamma*vs_next - Vs)
return vs, As
18 changes: 18 additions & 0 deletions legacy/impala/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Importance Weighted Actor-Learner Architectures (IMPALA)

This is an implementation of [IMPALA](https://arxiv.org/abs/1802.01561) algorithm.

# Usage

Run the following command to start parallelized training:

```bash
python experiment.py
```

One could modify [experiment.py](./experiment.py) to quickly set up different configurations.

# Results

## MLP policy
<img src='logs/default/result.png' width='100%'>
Empty file added legacy/impala/__init__.py
Empty file.
149 changes: 149 additions & 0 deletions legacy/impala/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from gym.spaces import Discrete
from gym.spaces import Box
from gym.spaces import flatdim

from lagom import BaseAgent
from lagom.utils import pickle_dump
from lagom.utils import tensorify
from lagom.utils import numpify
from lagom.envs.wrappers import get_wrapper
from lagom.networks import Module
from lagom.networks import make_fc
from lagom.networks import ortho_init
from lagom.networks import CategoricalHead
from lagom.networks import DiagGaussianHead
from lagom.networks import linear_lr_scheduler
from lagom.metric import vtrace
from lagom.transform import explained_variance as ev
from lagom.transform import describe


class MLP(Module):
def __init__(self, config, env, device, **kwargs):
super().__init__(**kwargs)
self.config = config
self.env = env
self.device = device

self.feature_layers = make_fc(flatdim(env.observation_space), config['nn.sizes'])
for layer in self.feature_layers:
ortho_init(layer, nonlinearity='relu', constant_bias=0.0)
self.layer_norms = nn.ModuleList([nn.LayerNorm(hidden_size) for hidden_size in config['nn.sizes']])

self.to(self.device)

def forward(self, x):
for layer, layer_norm in zip(self.feature_layers, self.layer_norms):
x = layer_norm(F.relu(layer(x)))
return x


class Agent(BaseAgent):
def __init__(self, config, env, device, **kwargs):
super().__init__(config, env, device, **kwargs)

feature_dim = config['nn.sizes'][-1]
self.feature_network = MLP(config, env, device, **kwargs)
if isinstance(env.action_space, Discrete):
self.action_head = CategoricalHead(feature_dim, env.action_space.n, device, **kwargs)
elif isinstance(env.action_space, Box):
self.action_head = DiagGaussianHead(feature_dim, flatdim(env.action_space), device, config['agent.std0'], **kwargs)
self.V_head = nn.Linear(feature_dim, 1)
ortho_init(self.V_head, weight_scale=1.0, constant_bias=0.0)
self.V_head = self.V_head.to(device) # reproducible between CPU/GPU, ortho_init behaves differently

self.register_buffer('total_timestep', torch.tensor(0))
#self.total_timestep = 0

self.optimizer = optim.Adam(self.parameters(), lr=config['agent.lr'])
if config['agent.use_lr_scheduler']:
self.lr_scheduler = linear_lr_scheduler(self.optimizer, config['train.timestep'], min_lr=1e-8)
self.gamma = config['agent.gamma']
self.clip_rho = config['agent.clip_rho']
self.clip_pg_rho = config['agent.clip_pg_rho']

def choose_action(self, obs, **kwargs):
obs = tensorify(obs, self.device)
out = {}
features = self.feature_network(obs)

action_dist = self.action_head(features)
out['action_dist'] = action_dist
out['entropy'] = action_dist.entropy()

action = action_dist.sample()
out['action'] = action
out['raw_action'] = numpify(action, self.env.action_space.dtype)
out['action_logprob'] = action_dist.log_prob(action.detach())

V = self.V_head(features)
out['V'] = V
return out

def learn(self, D, **kwargs):
# Compute all metrics, D: list of Trajectory
Ts = [len(traj) for traj in D]
behavior_logprobs = [torch.cat(traj.get_all_info('action_logprob')) for traj in D]
out_agent = self.choose_action(np.concatenate([traj.numpy_observations[:-1] for traj in D], 0))
logprobs = out_agent['action_logprob'].squeeze()
entropies = out_agent['entropy'].squeeze()
Vs = out_agent['V'].squeeze()
with torch.no_grad():
last_observations = tensorify(np.concatenate([traj.last_observation for traj in D], 0), self.device)
last_Vs = self.V_head(self.feature_network(last_observations)).squeeze(-1)

vs, As = [], []
for traj, behavior_logprob, logprob, V, last_V in zip(D, behavior_logprobs, logprobs.detach().cpu().split(Ts),
Vs.detach().cpu().split(Ts), last_Vs):
v, A = vtrace(behavior_logprob, logprob, self.gamma, traj.rewards, V, last_V,
traj.reach_terminal, self.clip_rho, self.clip_pg_rho)
vs.append(v)
As.append(A)

# Metrics -> Tensor, device
vs, As = map(lambda x: tensorify(np.concatenate(x).copy(), self.device), [vs, As])
if self.config['agent.standardize_adv']:
As = (As - As.mean())/(As.std() + 1e-8)

assert all([x.ndimension() == 1 for x in [logprobs, entropies, Vs, vs, As]])

# Loss
policy_loss = -logprobs*As
entropy_loss = -entropies
value_loss = F.mse_loss(Vs, vs, reduction='none')

loss = policy_loss + self.config['agent.value_coef']*value_loss + self.config['agent.entropy_coef']*entropy_loss
loss = loss.mean()

self.optimizer.zero_grad()
loss.backward()
grad_norm = nn.utils.clip_grad_norm_(self.parameters(), self.config['agent.max_grad_norm'])
self.optimizer.step()
if self.config['agent.use_lr_scheduler']:
self.lr_scheduler.step(self.total_timestep)

self.total_timestep += sum([len(traj) for traj in D])
out = {}
if self.config['agent.use_lr_scheduler']:
out['current_lr'] = self.lr_scheduler.get_lr()
out['loss'] = loss.item()
out['grad_norm'] = grad_norm
out['policy_loss'] = policy_loss.mean().item()
out['entropy_loss'] = entropy_loss.mean().item()
out['policy_entropy'] = -entropy_loss.mean().item()
out['value_loss'] = value_loss.mean().item()
out['V'] = describe(numpify(Vs, 'float').squeeze(), axis=-1, repr_indent=1, repr_prefix='\n')
out['explained_variance'] = ev(y_true=numpify(vs, 'float'), y_pred=numpify(Vs, 'float'))
return out

def checkpoint(self, logdir, num_iter):
self.save(logdir/f'agent_{num_iter}.pth')
obs_env = get_wrapper(self.env, 'VecStandardizeObservation')
if obs_env is not None:
pickle_dump(obj=(obs_env.mean, obs_env.var), f=logdir/f'obs_moments_{num_iter}', ext='.pth')
44 changes: 44 additions & 0 deletions legacy/impala/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from time import perf_counter
from itertools import chain

import numpy as np
import torch

from lagom import Logger
from lagom import BaseEngine
from lagom.transform import describe
from lagom.utils import color_str
from lagom.envs.wrappers import get_wrapper


class Engine(BaseEngine):
def train(self, n=None, **kwargs):
self.agent.train()
start_time = perf_counter()

D = kwargs['D']
out_agent = self.agent.learn(D)

logger = Logger()
logger('train_iteration', n+1)
logger('num_seconds', round(perf_counter() - start_time, 1))
[logger(key, value) for key, value in out_agent.items()]
logger('num_trajectories', len(D))
logger('num_timesteps', sum([len(traj) for traj in D]))
logger('accumulated_trained_timesteps', self.agent.total_timestep)
G = [traj.numpy_rewards.sum() for traj in D]
logger('return', describe(G, axis=-1, repr_indent=1, repr_prefix='\n'))

infos = [info for info in chain.from_iterable([traj.infos for traj in D]) if 'episode' in info]
online_returns = [info['episode']['return'] for info in infos]
online_horizons = [info['episode']['horizon'] for info in infos]
logger('online_return', describe(online_returns, axis=-1, repr_indent=1, repr_prefix='\n'))
logger('online_horizon', describe(online_horizons, axis=-1, repr_indent=1, repr_prefix='\n'))

monitor_env = get_wrapper(self.env, 'VecMonitor')
logger('running_return', describe(monitor_env.return_queue, axis=-1, repr_indent=1, repr_prefix='\n'))
logger('running_horizon', describe(monitor_env.horizon_queue, axis=-1, repr_indent=1, repr_prefix='\n'))
return logger

def eval(self, n=None, **kwargs):
pass

0 comments on commit d10e8ce

Please sign in to comment.