# Actor Critic Methods

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/riccardoberta/machine-learning/blob/master/06-deep-reinforcement-learning/05-actor-critic-methods.ipynb)

We explore a combined class of methods that learn both policies and value functions. These methods are referred to as **actor-critic** because the policy, which selects actions, can be seen as an actor, and the value function, which evaluates policies, can be seen as a critic. Actor-critic methods often perform better than value-based or policy-based methods alone on many of the deep reinforcement learning benchmarks. 

1. [Asynchronous Advantage Actor-Critic (A3C)](#AsynchronousAdvantage-Actor-Critic-(A3C))
2. [Generalized advantage estimation (GAE)](#Generalized-Advantage-Estimation-(GAE))
3. [](#)
4. [](#)

## Asynchronous Advantage Actor-Critic (A3C)

One of the main sources of variance in DRL algorithms is how correlated and non-stationary
online samples are. In value-based methods, we use a replay buffer to uniformly sample
mini-batches. Unfortunately,this is limited to off-policy methods, because on-policy agents cannot reuse data generated by previous policies. In other words, every optimization step requires a fresh batch of on-policy experiences. Instead of using a replay buffer, what we can do in on-policy methods (likes the policy-based methids) is have **multiple workers** generating experience in parallel and asynchronously updating the policy and value function. Having multiple workers generating experience on multiple instances of the environment in parallel decorrelates the data used for training and reduces the variance of the algorithm.

<img src="./images/multiple-workers.png" width="700">

The **asynchronous advantage actor-critic (A3C)** uses concurrent actors to generate a broad set of experience samples in parallel.  Moreover, it also uses **n-step returns with bootstrapping** to learn the policy and value function.

We can write it in Python:

In [2]:
import numpy as np
import random 
import time
from itertools import count
import torch.multiprocessing as mp
#import multiprocess as mp
    
class A3C():
    def __init__(self, policy_model_fn, policy_model_max_grad_norm, policy_optimizer_fn, policy_optimizer_lr,
                 value_model_fn, value_model_max_grad_norm, value_optimizer_fn, value_optimizer_lr, 
                 entropy_loss_weight, max_n_steps, n_workers):
                
        self.policy_model_fn = policy_model_fn
        self.policy_model_max_grad_norm = policy_model_max_grad_norm
        self.policy_optimizer_fn = policy_optimizer_fn
        self.policy_optimizer_lr = policy_optimizer_lr
        
        self.value_model_fn = value_model_fn
        self.value_model_max_grad_norm = value_model_max_grad_norm
        self.value_optimizer_fn = value_optimizer_fn
        self.value_optimizer_lr = value_optimizer_lr
        
        self.entropy_loss_weight = entropy_loss_weight
        self.max_n_steps = max_n_steps
        self.n_workers = n_workers

    # this is the work function each worker loops around in 
    # the rank parameter is used as an ID for workers
    def work(self, rank):
        last_debug_time = float('-inf')
        self.stats['n_active_workers'].add_(1)
        
        # create a unique seed per worker: we want diverse experiences
        local_seed = self.seed + rank
        
        # we create a uniquely seeded environment for each worker
        env = self.make_env_fn(seed=local_seed)
        torch.manual_seed(local_seed)
        np.random.seed(local_seed)
        random.seed(local_seed)

        # create a local policy model: initialize its weights with the weights
        # of a shared policy network. This network allow us to synchronize the agents 
        # periodically.
        nS, nA = env.observation_space.shape[0], env.action_space.n
        local_policy_model = self.policy_model_fn(nS, nA)
        local_policy_model.load_state_dict(self.shared_policy_model.state_dict())
        
        # do the same thing with the value model
        local_value_model = self.value_model_fn(nS)
        local_value_model.load_state_dict(self.shared_value_model.state_dict())

        # start the training loop, until the worker is signaled to get out of it
        global_episode_idx = self.stats['episode'].add_(1).item() - 1
        while not self.get_out_signal: 
            episode_start = time.time()
            
            # reset the environment, and set the is_terminal flag to false
            state, is_terminal = env.reset(), False
            
            # use n-step returns for training the policy and value functions
            n_steps_start, total_episode_rewards = 0, 0
            total_episode_steps, total_episode_exploration = 0, 0
            logpas, entropies, rewards, values = [], [], [], []

            # the episode loop 
            for step in count(start=1):
                
                # collect a step of experience
                state, reward, is_terminal, is_truncated, is_exploratory = self.interaction_step(
                    state, env, local_policy_model, local_value_model, 
                    logpas, entropies, rewards, values)

                total_episode_steps += 1
                total_episode_rewards += reward
                total_episode_exploration += int(is_exploratory)
                
                # collect n-steps maximum. If we hit a terminal state, we stop there
                if is_terminal or step - n_steps_start == self.max_n_steps:

                    # check if the time wrapper was triggered or this is a true terminal state
                    is_failure = is_terminal and not is_truncated
                    
                    # if it’s a failure, then the value of the next state is 0; otherwise, we bootstrap
                    next_value = 0 if is_failure else local_value_model(state).detach().item()
                    
                    # Look! we are sneaky here: appending the next_value to the rewards, the optimization 
                    # code from VPG remains largely the same
                    rewards.append(next_value)
                    
                    # optimize the model
                    self.optimize_model(logpas, entropies, rewards, values, 
                                        local_policy_model, local_value_model)
                    
                    logpas, entropies, rewards, values = [], [], [], []
                    n_steps_start = step

                if is_terminal:
                    break

            # save global stats
            episode_elapsed = time.time() - episode_start
            evaluation_score, _ = self.evaluate(local_policy_model, env)
            
            self.stats['episode_elapsed'][global_episode_idx].add_(episode_elapsed)
            self.stats['episode_timestep'][global_episode_idx].add_(total_episode_steps)
            self.stats['episode_reward'][global_episode_idx].add_(total_episode_rewards)
            self.stats['episode_exploration'][global_episode_idx].add_(total_episode_exploration/total_episode_steps)
            self.stats['evaluation_scores'][global_episode_idx].add_(evaluation_score)

            mean_100_train_reward = self.stats['episode_reward'][:global_episode_idx+1][-100:].mean().item()
            std_100_train_reward = self.stats['episode_reward'][:global_episode_idx+1][-100:].std().item()
            mean_100_eval_reward = self.stats['evaluation_scores'][:global_episode_idx+1][-100:].mean().item()
            std_100_eval_reward = self.stats['evaluation_scores'][:global_episode_idx+1][-100:].std().item()
            
            global_n_steps = self.stats['episode_timestep'][:global_episode_idx+1].sum().item()
            global_training_elapsed = self.stats['episode_elapsed'][:global_episode_idx+1].sum().item()
            
            total_elapsed = time.time() - self.training_start
            
            self.stats['result'][global_episode_idx][0].add_(global_n_steps)
            self.stats['result'][global_episode_idx][1].add_(mean_100_train_reward)
            self.stats['result'][global_episode_idx][2].add_(mean_100_eval_reward)
            self.stats['result'][global_episode_idx][3].add_(global_training_elapsed)
            self.stats['result'][global_episode_idx][4].add_(total_elapsed)

            debug_message = 'episode {:04}, steps {:06}, '
            debug_message += 'avg score {:05.1f}\u00B1{:05.1f}, '
            debug_message = debug_message.format(global_episode_idx, global_n_steps, mean_100_train_reward, std_100_train_reward)
            print(debug_message, end='\r', flush=True)
                        
            if rank == 0:
                print(debug_message, end='\r', flush=True)
                if time.time() - last_debug_time >= 60:
                    print(debug_message, flush=True)
                    last_debug_time = time.time()

            with self.get_out_lock:
                potential_next_global_episode_idx = self.stats['episode'].item()
                self.reached_goal_mean_reward.add_(mean_100_eval_reward >= self.goal_mean_100_reward)
                self.reached_max_minutes.add_(time.time() - self.training_start >= self.max_minutes * 60)
                self.reached_max_episodes.add_(potential_next_global_episode_idx >= self.max_episodes)
                if self.reached_max_episodes or \
                   self.reached_max_minutes or \
                   self.reached_goal_mean_reward:
                    self.get_out_signal.add_(1)
                    break
                    
                # else go work on another episode
                global_episode_idx = self.stats['episode'].add_(1).item() - 1

        while rank == 0 and self.stats['n_active_workers'].item() > 1:
            pass

        if rank == 0:
            if self.reached_max_minutes: print(u'--> reached_max_minutes')
            if self.reached_max_episodes: print(u'--> reached_max_episodes')
            if self.reached_goal_mean_reward: print(u'--> reached_goal_mean_reward')

        env.close() 
        del env
        
        self.stats['n_active_workers'].sub_(1)

    def optimize_model(self, logpas, entropies, rewards, values, local_policy_model, local_value_model):
        
        #  get the length of the reward. Remember: rewards includes the bootstrapping value.
        T = len(rewards)
        
        # calculate all discounts up to n+1.
        discounts = np.logspace(0, T, num=T, base=self.gamma, endpoint=False)
        
        # calculate the the n-step predicted return
        returns = np.array([np.sum(discounts[:T-t] * rewards[t:]) for t in range(T)])
        
        # remove the extra elements and format the variables as expected
        discounts = torch.FloatTensor(discounts[:-1]).unsqueeze(1)
        returns = torch.FloatTensor(returns[:-1]).unsqueeze(1)

        logpas = torch.cat(logpas)
        entropies = torch.cat(entropies)
        values = torch.cat(values)

        # calculate the value errors as the predicted return minus the estimated values
        value_error = returns - values
        
        # calculate the loss
        policy_loss = -(discounts * value_error.detach() * logpas).mean()
        entropy_loss = -entropies.mean()
        loss = policy_loss + self.entropy_loss_weight * entropy_loss
        
        # notice we now zero the shared policy optimizer, then calculate the loss
        self.shared_policy_optimizer.zero_grad()
        loss.backward()
        
        # clip the gradient magnitude:
        torch.nn.utils.clip_grad_norm_(local_policy_model.parameters(), self.policy_model_max_grad_norm)
        
        # iterating over all local and shared policy network parameters
        # and copy every gradient from the local to the shared model
        for param, shared_param in zip(local_policy_model.parameters(), self.shared_policy_model.parameters()):
            if shared_param.grad is None:
                shared_param._grad = param.grad
        
        # once the gradients are copied into the shared optimizer, we run an optimization step
        self.shared_policy_optimizer.step()
        
        # we load the shared model into the local model
        local_policy_model.load_state_dict(self.shared_policy_model.state_dict())

        # we do the same thing but with the state-value network
        value_loss = value_error.pow(2).mul(0.5).mean()
        self.shared_value_optimizer.zero_grad()
        value_loss.backward()
        torch.nn.utils.clip_grad_norm_(local_value_model.parameters(), self.value_model_max_grad_norm)
        for param, shared_param in zip(local_value_model.parameters(), self.shared_value_model.parameters()):
            if shared_param.grad is None:
                shared_param._grad = param.grad
        self.shared_value_optimizer.step()
        local_value_model.load_state_dict(self.shared_value_model.state_dict())

    @staticmethod
    def interaction_step(state, env, local_policy_model, local_value_model,
                         logpas, entropies, rewards, values):
        action, is_exploratory, logpa, entropy = local_policy_model.full_pass(state)
        new_state, reward, is_terminal, info = env.step(action)
        is_truncated = 'TimeLimit.truncated' in info and info['TimeLimit.truncated']

        logpas.append(logpa)
        entropies.append(entropy)
        rewards.append(reward)
        values.append(local_value_model(state))

        return new_state, reward, is_terminal, is_truncated, is_exploratory
    
    def train(self, make_env_fn, seed, gamma, 
              max_minutes, max_episodes, goal_mean_100_reward):
        
        self.make_env_fn = make_env_fn
        self.seed = seed
        self.gamma = gamma
        self.max_minutes = max_minutes
        self.max_episodes = max_episodes
        self.goal_mean_100_reward = goal_mean_100_reward

        env = self.make_env_fn(seed=self.seed)
        nS, nA = env.observation_space.shape[0], env.action_space.n
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        random.seed(self.seed)

        self.stats = {}
        self.stats['episode'] = torch.zeros(1, dtype=torch.int).share_memory_()
        self.stats['result'] = torch.zeros([max_episodes, 5]).share_memory_()
        self.stats['evaluation_scores'] = torch.zeros([max_episodes]).share_memory_()
        self.stats['episode_reward'] = torch.zeros([max_episodes]).share_memory_()
        self.stats['episode_timestep'] = torch.zeros([max_episodes], dtype=torch.int).share_memory_()
        self.stats['episode_exploration'] = torch.zeros([max_episodes]).share_memory_()
        self.stats['episode_elapsed'] = torch.zeros([max_episodes]).share_memory_()
        self.stats['n_active_workers'] = torch.zeros(1, dtype=torch.int).share_memory_()

        self.shared_policy_model = self.policy_model_fn(nS, nA).share_memory()
        self.shared_policy_optimizer = self.policy_optimizer_fn(self.shared_policy_model, self.policy_optimizer_lr)
        self.shared_value_model = self.value_model_fn(nS).share_memory()
        self.shared_value_optimizer = self.value_optimizer_fn(self.shared_value_model, self.value_optimizer_lr)

        self.get_out_lock = mp.Lock()
        self.get_out_signal = torch.zeros(1, dtype=torch.int).share_memory_()
        self.reached_max_minutes = torch.zeros(1, dtype=torch.int).share_memory_() 
        self.reached_max_episodes = torch.zeros(1, dtype=torch.int).share_memory_() 
        self.reached_goal_mean_reward  = torch.zeros(1, dtype=torch.int).share_memory_() 
        self.training_start = time.time()
        
        workers = [mp.Process(target=self.work, args=(rank,)) for rank in range(self.n_workers)]  
        [w.start() for w in workers]
        [w.join() for w in workers]
        wallclock_time = time.time() - self.training_start

        final_eval_score, score_std = self.evaluate(self.shared_policy_model, env, n_episodes=100)
        env.close()
        del env

        final_episode = self.stats['episode'].item()
        training_time = self.stats['episode_elapsed'][:final_episode+1].sum().item()

        print('Training complete.')
        print('Final evaluation score {:.2f}\u00B1{:.2f} in {:.2f}s training time,'
              ' {:.2f}s wall-clock time.\n'.format(
                  final_eval_score, score_std, training_time, wallclock_time))

        self.stats['result'] = self.stats['result'].numpy()
        self.stats['result'][final_episode:, ...] = np.nan
        
        return self.stats['result'], final_eval_score, training_time, wallclock_time

    def evaluate(self, eval_policy_model, eval_env, n_episodes=1, greedy=True):
        rs = []
        for _ in range(n_episodes):
            s, d = eval_env.reset(), False
            rs.append(0)
            for _ in count():
                if greedy:
                    a = eval_policy_model.select_greedy_action(s)
                else: 
                    a = eval_policy_model.select_action(s)
                s, r, d, _ = eval_env.step(a)
                rs[-1] += r
                if d: break
        return np.mean(rs), np.std(rs)



Notice that we append the value of the next state, whether terminal or not, to the reward sequence. Before we were using full returns for our advantage estimates:

$\begin{align}
A(S_t,A_t;\phi) = G_t - V(S_t;\phi)
\end{align}$

Now, the reward variable contains all rewards from the partial trajectory and the state-value estimate of that last state. We can also see this as having the partial return (the sequence of rewards) and the predicted remaining return (a single-number estimate) in the same place. We should realize that this is an **n-step return**: the agent go out for n-steps collecting rewards, and then **bootstrap** after that nth state (or before if we land on a terminal state, whichever comes first):

$\begin{align}
A(S_t,A_t;\phi) = R_t + \gamma R_{t+1} + ... \gamma^n T_{t-n} + \gamma^{n+1} V(S_{t+n+1};\phi) - V(S_t;\phi)
\end{align}$

We now use this n-step advantage estimate for updating the action probabilities:

$\begin{align}
L_\pi(\theta)=\frac{1}{N}\sum\limits_{n=0}^{N}{\left[\left(A(S_t,A_t;\phi)\right) \log \pi(A_t|S_t;\theta) + \beta H(\pi(S_t;\theta))\right]}\end{align}$

We also use the n-step return to improve the value function estimate. Notice the **bootstrapping here**. This is what makes the algorithm an **actor-critic method**:

$\begin{align}
L_v(\phi)=\frac{1}{N}\sum\limits_{n=0}^{N}{\left[\left(R_t + \gamma R_{t+1} + ... \gamma^n T_{t-n} + \gamma^{n+1} V(S_{t+n+1};\phi) - V(S_t;\phi)\right)^2\right]} 
\end{align}$


One of the most critical aspects of A3C is that its network updates are asynchronous and lockfree. Having a shared model creates a need for a blocking mechanism to prevent workers from overwriting other updates. Interestingly, A3C uses an update style called a **Hogwild!**, which workers access to shared model with the possibility of overwriting each other's work, which is shown to not only achieve a near-optimal rate of convergence but also outperform alternative schemes that use locking by an order of magnitude: [F. Niu et al. **"HOGWILD!: A Lock-Free Approach to Parallelizing Stochastic Gradient Descent"**, NIPS 2011](https://arxiv.org/abs/1106.5730?context=cs).

We need to create an Adam and RMSprop optimizer that puts internal variables into shared memory:

In [3]:
import torch

class SharedAdam(torch.optim.Adam):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):
        super(SharedAdam, self).__init__(params, lr=lr, betas=betas, eps=eps,  weight_decay=weight_decay, amsgrad=amsgrad)
        
        # We only need to call the share_memory_ method on the variables we need shared across workers
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state['shared_step'] = torch.zeros(1).share_memory_()
                state['exp_avg'] = torch.zeros_like(p.data).share_memory_()
                state['exp_avg_sq'] = torch.zeros_like(p.data).share_memory_()
                if weight_decay: state['weight_decay'] = torch.zeros_like(p.data).share_memory_()
                if amsgrad: state['max_exp_avg_sq'] = torch.zeros_like(p.data).share_memory_()

    # override the step function so that we can manually increment the step variable, which isn’t easily
    # put into shared memory
    def step(self, closure=None):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                self.state[p]['steps'] = self.state[p]['shared_step'].item()
                self.state[p]['shared_step'] += 1
        super().step(closure)

In [4]:
class SharedRMSprop(torch.optim.RMSprop):
    def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):
        super(SharedRMSprop, self).__init__(
            params, lr=lr, alpha=alpha, 
            eps=eps, weight_decay=weight_decay, 
            momentum=momentum, centered=centered)
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state['shared_step'] = torch.zeros(1).share_memory_()
                state['square_avg'] = torch.zeros_like(p.data).share_memory_()
                if weight_decay:
                    state['weight_decay'] = torch.zeros_like(p.data).share_memory_()
                if momentum > 0:
                    state['momentum_buffer'] = torch.zeros_like(p.data).share_memory_()
                if centered:
                    state['grad_avg'] = torch.zeros_like(p.data).share_memory_()

    def step(self, closure=None):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                self.state[p]['steps'] = self.state[p]['shared_step'].item()
                self.state[p]['shared_step'] += 1
        super().step(closure)

We can try the algorithm in the cart-pole environment. We need the two ANN architectures:

In [5]:
import torch.nn as nn
import torch.nn.functional as F

class FCR(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims=(32,32), activation_fc=F.relu):
        super(FCR, self).__init__()
        self.activation_fc = activation_fc

        self.input_layer = nn.Linear(input_dim, hidden_dims[0])
        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_dims)-1):
            hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
            self.hidden_layers.append(hidden_layer)
        self.output_layer = nn.Linear(hidden_dims[-1], output_dim)
        
    def forward(self, state):
        x = state
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float32)
            x = x.unsqueeze(0)
        x = self.activation_fc(self.input_layer(x))
        for hidden_layer in self.hidden_layers:
            x = self.activation_fc(hidden_layer(x))
        return self.output_layer(x)

    def full_pass(self, state):
        logits = self.forward(state)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        logpa = dist.log_prob(action).unsqueeze(-1)
        entropy = dist.entropy().unsqueeze(-1)
        is_exploratory = action != np.argmax(logits.detach().numpy())
        return action.item(), is_exploratory.item(), logpa, entropy

    def select_action(self, state):
        logits = self.forward(state)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        return action.item()

    def select_greedy_action(self, state):
        logits = self.forward(state)
        return np.argmax(logits.detach().numpy())
    

In [6]:
class FCV(nn.Module):
    def __init__(self, input_dim, hidden_dims=(32,32), activation_fc=F.relu):
        super(FCV, self).__init__()
        self.activation_fc = activation_fc

        self.input_layer = nn.Linear(input_dim, hidden_dims[0])
        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_dims)-1):
            hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
            self.hidden_layers.append(hidden_layer)
            
        self.output_layer = nn.Linear(hidden_dims[-1], 1)

    def forward(self, state):
        x = state
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float32)
            x = x.unsqueeze(0)
        x = self.activation_fc(self.input_layer(x))
        for hidden_layer in self.hidden_layers:
            x = self.activation_fc(hidden_layer(x))
        return self.output_layer(x)

In [7]:
import gym

def make_env(seed=None):
    env = gym.make('CartPole-v1')
    if seed is not None: env.seed(seed)
    return env

In [8]:
a3c_results = []
best_a3c_agent, best_a3c_eval_score = None, float('-inf')
   
gamma = 1.00
max_minutes = 20
max_episodes = 10000
goal_mean_100_reward = 475

for seed in (12, 34, 56, 78, 90):
    
    policy_model_fn = lambda nS, nA: FCR(nS, nA, hidden_dims=(128,64))
    policy_model_max_grad_norm = 1
    policy_optimizer_fn = lambda net, lr: SharedAdam(net.parameters(), lr=lr)
    policy_optimizer_lr = 0.0005

    value_model_fn = lambda nS: FCV(nS, hidden_dims=(256,128))
    value_model_max_grad_norm = float('inf')
    value_optimizer_fn = lambda net, lr: SharedRMSprop(net.parameters(), lr=lr)
    value_optimizer_lr = 0.0007

    entropy_loss_weight = 0.001

    max_n_steps = 50
    n_workers = 1
    
    agent = A3C(policy_model_fn, policy_model_max_grad_norm, policy_optimizer_fn, policy_optimizer_lr,
                value_model_fn, value_model_max_grad_norm, value_optimizer_fn, value_optimizer_lr,
                entropy_loss_weight, max_n_steps, n_workers)
 
    result, final_eval_score, training_time, wallclock_time = agent.train(make_env, seed, gamma, max_minutes, max_episodes, goal_mean_100_reward)
    
    a3c_results.append(result)
    if final_eval_score > best_a3c_eval_score:
        best_a3c_eval_score = final_eval_score
        best_a3c_agent = agent
        
a3c_results = np.array(a3c_results)

PicklingError: Can't pickle <function <lambda> at 0x7fda09314310>: attribute lookup <lambda> on __main__ failed

## Generalized advantage estimation (GAE)

A3C uses n-step returns for reducing the variance of the targets. Still, there’s a more robust method that combines multiple n-step bootstrapping targets in a single target, creating even more robust targets than a single n-step: the $\lambda$-target. **Generalized advantage estimation (GAE)** is analogous to the $\lambda$-target in TD($\lambda$), but for advantages. [John Schulman et al. **"High-dimensional Continuous Control Using Generalized Advantage Estimation"**, ICLR 2016](https://arxiv.org/abs/1506.02438) 

GAE is not an agent on its own, but a way of estimating targets for the advantage function that most actor-critic methods can leverage. More specifically, GAE uses an exponentially weighted combination of n-step action-advantage function targets, this can substantially reduce the variance of policy-gradient estimates at the cost of some bias. We can consider N-step advantage estimates:

$\begin{align}
A^1(S_t,A_t;\phi) = R_t + \gamma V(S_{t+1};\phi) - V(S_t;\phi)
\end{align}$

$\begin{align}
A^2(S_t,A_t;\phi) = R_t + \gamma R_{t+1} + \gamma^2 V(S_{t+2};\phi) - V(S_t;\phi)
\end{align}$

$\begin{align}
A^3(S_t,A_t;\phi) = R_t + \gamma R_{t+1} + \gamma^2 R_{t+2} + \gamma^3 V(S_{t+3};\phi) - V(S_t;\phi)
\end{align}$

...

$\begin{align}
A^n(S_t,A_t;\phi) = R_t + \gamma R_{t+1} + ... + \gamma^n R_{t+n} + \gamma^{n+1} V(S_{t+n+1};\phi) - V(S_t;\phi)
\end{align}$

which we can mix to make an estimate analogous to TD($\lambda$), but for advantages:

$\begin{align}
A^{\text{GAE}(\gamma,\lambda)}(S_t,A_t;\phi) = \sum\limits_{l=0}^{\infty}{(\gamma \lambda)^l \delta_{t+l}}
\end{align}$

$\lambda=0$ returns the one-step advantage estimate, and $\lambda=1$ returns the infinite-step advantage
estimate.

We can implement it in Python:

In [66]:
class GAE():
    def __init__(self, policy_model_fn, policy_model_max_grad_norm, policy_optimizer_fn, policy_optimizer_lr,
                 value_model_fn, value_model_max_grad_norm, value_optimizer_fn, value_optimizer_lr,
                 entropy_loss_weight, max_n_steps, n_workers, tau):
        
        # notice that lambda is a reserveved word, so often is referred as tau
        
        self.policy_model_fn = policy_model_fn
        self.policy_model_max_grad_norm = policy_model_max_grad_norm
        self.policy_optimizer_fn = policy_optimizer_fn
        self.policy_optimizer_lr = policy_optimizer_lr

        self.value_model_fn = value_model_fn
        self.value_model_max_grad_norm = value_model_max_grad_norm
        self.value_optimizer_fn = value_optimizer_fn
        self.value_optimizer_lr = value_optimizer_lr

        self.entropy_loss_weight = entropy_loss_weight

        self.max_n_steps = max_n_steps
        self.n_workers = n_workers
        self.tau = tau

    def optimize_model(self, logpas, entropies, rewards, values, 
                       local_policy_model, local_value_model):
        
        # create the discounted returns, the way we did with A3C
        T = len(rewards)
        discounts = np.logspace(0, T, num=T, base=self.gamma, endpoint=False)
        returns = np.array([np.sum(discounts[:T-t] * rewards[t:]) for t in range(T)])

        logpas = torch.cat(logpas)
        entropies = torch.cat(entropies)
        values = torch.cat(values)

        # crete an array with all the state values and an array with the (gamma*lambda)^t        
        np_values = values.view(-1).data.numpy()
        tau_discounts = np.logspace(0, T-1, num=T-1, base=self.gamma*self.tau, endpoint=False)
        
        # creates an array of TD errors: R_t + gamma * value_t+1 - value_t, for t=0 to T
        advs = rewards[:-1] + self.gamma * np_values[1:] - np_values[:-1]  
        
        # create the GAEs, by multiplying the tau discounts times the TD errors
        gaes = np.array([np.sum(tau_discounts[:T-1-t] * advs[t:]) for t in range(T-1)])

        values = values[:-1,...]
        discounts = torch.FloatTensor(discounts[:-1]).unsqueeze(1)
        returns = torch.FloatTensor(returns[:-1]).unsqueeze(1)
        gaes = torch.FloatTensor(gaes).unsqueeze(1)

        # now use the gaes to calculate the policy loss
        # And proceed as before
        policy_loss = -(discounts * gaes.detach() * logpas).mean()
        entropy_loss = -entropies.mean()
        loss = policy_loss + self.entropy_loss_weight * entropy_loss
        self.shared_policy_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(local_policy_model.parameters(), self.policy_model_max_grad_norm)
        for param, shared_param in zip(local_policy_model.parameters(), self.shared_policy_model.parameters()):
            if shared_param.grad is None:
                shared_param._grad = param.grad
        self.shared_policy_optimizer.step()
        local_policy_model.load_state_dict(self.shared_policy_model.state_dict())

        value_error = returns - values
        value_loss = value_error.pow(2).mul(0.5).mean()
        self.shared_value_optimizer.zero_grad()
        value_loss.backward()
        torch.nn.utils.clip_grad_norm_(local_value_model.parameters(), self.value_model_max_grad_norm)
        for param, shared_param in zip(local_value_model.parameters(), self.shared_value_model.parameters()):
            if shared_param.grad is None:
                shared_param._grad = param.grad
        self.shared_value_optimizer.step()
        local_value_model.load_state_dict(self.shared_value_model.state_dict())

    @staticmethod
    def interaction_step(state, env, local_policy_model, local_value_model,
                         logpas, entropies, rewards, values):
        action, is_exploratory, logpa, entropy = local_policy_model.full_pass(state)
        new_state, reward, is_terminal, info = env.step(action)
        is_truncated = 'TimeLimit.truncated' in info and info['TimeLimit.truncated']

        logpas.append(logpa)
        entropies.append(entropy)
        rewards.append(reward)
        values.append(local_value_model(state))

        return new_state, reward, is_terminal, is_truncated, is_exploratory

    def work(self, rank):
        last_debug_time = float('-inf')
        self.stats['n_active_workers'].add_(1)
        
        local_seed = self.seed + rank
        env = self.make_env_fn(seed=local_seed)
        torch.manual_seed(local_seed) 
        np.random.seed(local_seed)
        random.seed(local_seed)

        nS, nA = env.observation_space.shape[0], env.action_space.n
        
        local_policy_model = self.policy_model_fn(nS, nA)
        local_policy_model.load_state_dict(self.shared_policy_model.state_dict())
        local_value_model = self.value_model_fn(nS)
        local_value_model.load_state_dict(self.shared_value_model.state_dict())

        global_episode_idx = self.stats['episode'].add_(1).item() - 1
        while not self.get_out_signal:            
            episode_start = time.time()
            state, is_terminal = env.reset(), False
            
            # collect n_steps rollout
            n_steps_start, total_episode_rewards = 0, 0
            total_episode_steps, total_episode_exploration = 0, 0
            logpas, entropies, rewards, values = [], [], [], []

            for step in count(start=1):
                state, reward, is_terminal, is_truncated, is_exploratory = self.interaction_step(
                    state, env, local_policy_model, local_value_model, 
                    logpas, entropies, rewards, values)

                total_episode_steps += 1
                total_episode_rewards += reward
                total_episode_exploration += int(is_exploratory)
                
                if is_terminal or step - n_steps_start == self.max_n_steps:
                    is_failure = is_terminal and not is_truncated
                    next_value = 0 if is_failure else local_value_model(state).detach().item()
                    rewards.append(next_value)
                    values.append(torch.FloatTensor([[next_value,],]))

                    self.optimize_model(logpas, entropies, rewards, values, 
                                        local_policy_model, local_value_model)
                    logpas, entropies, rewards, values = [], [], [], []
                    n_steps_start = step
                
                if is_terminal:
                    break

            # save global stats
            episode_elapsed = time.time() - episode_start
            evaluation_score, _ = self.evaluate(local_policy_model, env)

            self.stats['episode_elapsed'][global_episode_idx].add_(episode_elapsed)
            self.stats['episode_timestep'][global_episode_idx].add_(total_episode_steps)
            self.stats['episode_reward'][global_episode_idx].add_(total_episode_rewards)
            self.stats['episode_exploration'][global_episode_idx].add_(total_episode_exploration/total_episode_steps)
            self.stats['evaluation_scores'][global_episode_idx].add_(evaluation_score)

            mean_10_reward = self.stats['episode_reward'][:global_episode_idx+1][-10:].mean().item()
            mean_100_reward = self.stats['episode_reward'][:global_episode_idx+1][-100:].mean().item()
            mean_100_eval_score = self.stats['evaluation_scores'][:global_episode_idx+1][-100:].mean().item()
            mean_100_exp_rat = self.stats['episode_exploration'][:global_episode_idx+1][-100:].mean().item()
            std_10_reward = self.stats['episode_reward'][:global_episode_idx+1][-10:].std().item()
            std_100_reward = self.stats['episode_reward'][:global_episode_idx+1][-100:].std().item()
            std_100_eval_score = self.stats['evaluation_scores'][:global_episode_idx+1][-100:].std().item()
            std_100_exp_rat = self.stats['episode_exploration'][:global_episode_idx+1][-100:].std().item()
            if std_10_reward != std_10_reward: std_10_reward = 0            
            if std_100_reward != std_100_reward: std_100_reward = 0
            if std_100_eval_score != std_100_eval_score: std_100_eval_score = 0
            if std_100_exp_rat != std_100_exp_rat: std_100_exp_rat = 0
            global_n_steps = self.stats['episode_timestep'][:global_episode_idx+1].sum().item()
            global_training_elapsed = self.stats['episode_elapsed'][:global_episode_idx+1].sum().item()
            wallclock_elapsed = time.time() - self.training_start
            
            self.stats['result'][global_episode_idx][0].add_(global_n_steps)
            self.stats['result'][global_episode_idx][1].add_(mean_100_reward)
            self.stats['result'][global_episode_idx][2].add_(mean_100_eval_score)
            self.stats['result'][global_episode_idx][3].add_(global_training_elapsed)
            self.stats['result'][global_episode_idx][4].add_(wallclock_elapsed)

            elapsed_str = time.strftime("%H:%M:%S", time.gmtime(time.time() - self.training_start))
            debug_message = 'el {}, ep {:04}, ts {:06}, '
            debug_message += 'ar 10 {:05.1f}\u00B1{:05.1f}, '
            debug_message += '100 {:05.1f}\u00B1{:05.1f}, '
            debug_message += 'ex 100 {:02.1f}\u00B1{:02.1f}, '
            debug_message += 'ev {:05.1f}\u00B1{:05.1f}'
            debug_message = debug_message.format(
                elapsed_str, global_episode_idx, global_n_steps, mean_10_reward, std_10_reward, 
                mean_100_reward, std_100_reward, mean_100_exp_rat, std_100_exp_rat,
                mean_100_eval_score, std_100_eval_score)

            if rank == 0:
                print(debug_message, end='\r', flush=True)
                if time.time() - last_debug_time >= 60:
                    print(ERASE_LINE + debug_message, flush=True)
                    last_debug_time = time.time()

            with self.get_out_lock:
                potential_next_global_episode_idx = self.stats['episode'].item()
                self.reached_goal_mean_reward.add_(mean_100_eval_score >= self.goal_mean_100_reward)
                self.reached_max_minutes.add_(time.time() - self.training_start >= self.max_minutes * 60)
                self.reached_max_episodes.add_(potential_next_global_episode_idx >= self.max_episodes)
                if self.reached_max_episodes or \
                   self.reached_max_minutes or \
                   self.reached_goal_mean_reward:
                    self.get_out_signal.add_(1)
                    break
                # else go work on another episode
                global_episode_idx = self.stats['episode'].add_(1).item() - 1

        while rank == 0 and self.stats['n_active_workers'].item() > 1:
            pass

        if rank == 0:
            print(debug_message)
            if self.reached_max_minutes: print(u'--> reached_max_minutes')
            if self.reached_max_episodes: print(u'--> reached_max_episodes')
            if self.reached_goal_mean_reward: print(u'--> reached_goal_mean_reward \u2713')

        env.close() ; del env
        self.stats['n_active_workers'].sub_(1)


    def train(self, make_env_fn, seed, gamma, max_minutes, max_episodes, goal_mean_100_reward):
        self.make_env_fn = make_env_fn
        self.seed = seed
        self.gamma = gamma
        self.max_minutes = max_minutes
        self.max_episodes = max_episodes
        self.goal_mean_100_reward = goal_mean_100_reward

        env = self.make_env_fn(seed=self.seed)
        nS, nA = env.observation_space.shape[0], env.action_space.n
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        random.seed(self.seed)

        self.stats = {}
        self.stats['episode'] = torch.zeros(1, dtype=torch.int).share_memory_()
        self.stats['result'] = torch.zeros([max_episodes, 5]).share_memory_()
        self.stats['evaluation_scores'] = torch.zeros([max_episodes]).share_memory_()
        self.stats['episode_reward'] = torch.zeros([max_episodes]).share_memory_()
        self.stats['episode_timestep'] = torch.zeros([max_episodes], dtype=torch.int).share_memory_()
        self.stats['episode_exploration'] = torch.zeros([max_episodes]).share_memory_()
        self.stats['episode_elapsed'] = torch.zeros([max_episodes]).share_memory_()
        self.stats['n_active_workers'] = torch.zeros(1, dtype=torch.int).share_memory_()

        self.shared_policy_model = self.policy_model_fn(nS, nA).share_memory()
        self.shared_policy_optimizer = self.policy_optimizer_fn(self.shared_policy_model, self.policy_optimizer_lr)
        self.shared_value_model = self.value_model_fn(nS).share_memory()
        self.shared_value_optimizer = self.value_optimizer_fn(self.shared_value_model, self.value_optimizer_lr)

        self.get_out_lock = mp.Lock()
        self.get_out_signal = torch.zeros(1, dtype=torch.int).share_memory_()
        self.reached_max_minutes = torch.zeros(1, dtype=torch.int).share_memory_() 
        self.reached_max_episodes = torch.zeros(1, dtype=torch.int).share_memory_() 
        self.reached_goal_mean_reward  = torch.zeros(1, dtype=torch.int).share_memory_() 
        self.training_start = time.time()
        workers = [mp.Process(target=self.work, args=(rank,)) for rank in range(self.n_workers)]
        [w.start() for w in workers] ; [w.join() for w in workers]
        wallclock_time = time.time() - self.training_start

        final_eval_score, score_std = self.evaluate(self.shared_policy_model, env, n_episodes=100)
        env.close() ; del env

        final_episode = self.stats['episode'].item()
        training_time = self.stats['episode_elapsed'][:final_episode+1].sum().item()

        print('Training complete.')
        print('Final evaluation score {:.2f}\u00B1{:.2f} in {:.2f}s training time,'
              ' {:.2f}s wall-clock time.\n'.format(
                  final_eval_score, score_std, training_time, wallclock_time))

        self.stats['result'] = self.stats['result'].numpy()
        self.stats['result'][final_episode:, ...] = np.nan
        return self.stats['result'], final_eval_score, training_time, wallclock_time

    def evaluate(self, eval_policy_model, eval_env, n_episodes=1, greedy=True):
        rs = []
        for _ in range(n_episodes):
            s, d = eval_env.reset(), False
            rs.append(0)
            for _ in count():
                if greedy:
                    a = eval_policy_model.select_greedy_action(s)
                else: 
                    a = eval_policy_model.select_action(s)
                s, r, d, _ = eval_env.step(a)
                rs[-1] += r
                if d: break
        return np.mean(rs), np.std(rs)


SyntaxError: invalid syntax (3611165480.py, line 244)

In [67]:
gae_results = []
best_gae_agent, best_gae_eval_score = None, float('-inf')

gamma = 0.99
max_minutes = 10
max_episodes = 10000
goal_mean_100_reward = 475

for seed in (12, 34, 56, 78, 90):
    
    policy_model_fn = lambda nS, nA: FCR(nS, nA, hidden_dims=(128,64))
    policy_model_max_grad_norm = 1
    policy_optimizer_fn = lambda net, lr: SharedAdam(net.parameters(), lr=lr)
    policy_optimizer_lr = 0.0005

    value_model_fn = lambda nS: FCV(nS, hidden_dims=(256,128))
    value_model_max_grad_norm = float('inf')
    value_optimizer_fn = lambda net, lr: SharedRMSprop(net.parameters(), lr=lr)
    value_optimizer_lr = 0.0007

    entropy_loss_weight = 0.001

    max_n_steps = 50
    n_workers = 8
    tau = 0.95

    agent = GAE(policy_model_fn, policy_model_max_grad_norm, policy_optimizer_fn, policy_optimizer_lr,
                value_model_fn, value_model_max_grad_norm, value_optimizer_fn, value_optimizer_lr, 
                entropy_loss_weight, max_n_steps, n_workers, tau)

    result, final_eval_score, training_time, wallclock_time = agent.train(make_env, seed, gamma, 
                                                                          max_minutes, max_episodes, 
                                                                          goal_mean_100_reward)
  
    gae_results.append(result)
    if final_eval_score > best_gae_eval_score:
        best_gae_eval_score = final_eval_score
        best_gae_agent = agent
gae_results = np.array(gae_results)

AttributeError: module 'pathos.multiprocessing' has no attribute 'Lock'

## Advantage actor-critic (A2C)

Advantage actor-critic (A2C) is the synchronous version of A3C. Updating the neural network in a Hogwild!-style can be chaotic, yet introducing a lock mechanism lowers A3C performance considerably. In A2C, we move the workers from the agent
down to the environment. Instead of having multiple actor-learners, we have multiple actors with a single learner.

<img src="./images/synchronous-model.png" width="600">



We can also use a single neural network for both the policy andcthe value function. Sharing a model can be particularly beneficial when learning from images, because feature extraction can be compute-intensive. However, model sharing can be challenging due to the potentially different scales of the policy and value function updates.

In [4]:
import multiprocess as mp

class test_multi():
    def __init__(self, n_workers):
        self.n_workers = n_workers

    def work(self, rank):
        print(rank)
        
    def train(self):
        print('start')
        self.get_out_lock = mp.Lock()
        workers = [mp.Process(target=self.work, args=(rank,)) for rank in range(self.n_workers)]  
        [w.start() for w in workers]
        [w.join() for w in workers]
        print('end')
        return 

In [5]:
test = test_multi(n_workers=4)
test.train()

start
01

2
3
end
