# Implementing Advantage-Actor Critic (A2C) - 2 pts

In this notebook you will implement Advantage Actor Critic algorithm that trains on a batch of Atari 2600 environments running in parallel. 

Firstly, we will use environment wrappers implemented in file `atari_wrappers.py`. These wrappers preprocess observations (resize, grayscal, take max between frames, skip frames, stack them together, prepares for PyTorch and normalizes to [0, 1]) and rewards. Some of the wrappers help to reset the environment and pass `done` flag equal to `True` when agent dies.
File `env_batch.py` includes implementation of `ParallelEnvBatch` class that allows to run multiple environments in parallel. To create an environment we can use `nature_dqn_env` function.

In [1]:
# !pip install gym[accept-rom-license]==0.22.0
# !pip install ale-py==0.8.1

In [1]:
import numpy as np
from atari_wrappers import nature_dqn_env

nenvs = 8    # change this if you have more than 8 CPU ;)

env = nature_dqn_env("SpaceInvadersNoFrameskip-v4", nenvs=nenvs, seed=0)

n_actions = env.action_space.spaces[0].n
obs = env.reset()
assert obs.shape == (nenvs, 4, 84, 84)
assert obs.dtype == np.float32


A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


Next, we will need to implement a model that predicts logits of policy distribution and critic value. Use shared backbone. You may use same architecture as in DQN task with one modification: instead of having a single output layer, it must have two output layers taking as input the output of the last hidden layer (one for actor, one for critic). 

Still it may be very helpful to make more changes:
* use orthogonal initialization with gain $\sqrt{2}$ and initialize biases with zeros;
* use more filters (e.g. 32-64-64 instead of 16-32-64);
* use two-layer heads for actor and critic or add a linear layer into backbone;

**Danger:** do not divide on 255, input is already normalized to [0, 1] in our wrappers!

In [2]:
import torch
from torch import nn


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.net(x)

In [3]:
def conv2d_size_out(size, kernel_size, stride):
    """
    common use case:
    cur_layer_img_w = conv2d_size_out(cur_layer_img_w, kernel_size, stride)
    cur_layer_img_h = conv2d_size_out(cur_layer_img_h, kernel_size, stride)
    to understand the shape for dense layer's input
    """
    return (size - (kernel_size - 1) - 1) // stride  + 1


class Flatten(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return x.view(x.size(0), -1)

In [4]:
class AgentNetwork(nn.Module):
    def __init__(self, n_actions):
        super().__init__()

        self.n_actions = n_actions

        # Define your network body here. Please make sure agent is fully contained here
        hparams = [
            dict(in_channels=4, out_channels=32, kernel_size=3, stride=4),
            dict(in_channels=32, out_channels=64, kernel_size=3, stride=2),
            dict(in_channels=64, out_channels=64, kernel_size=3, stride=1),
        ]
        self.conv = nn.Sequential(*[ConvBlock(**kwargs) for kwargs in hparams])
        
        width = 84
        height = 84
        for kwargs in hparams:
            width = conv2d_size_out(width, kwargs['kernel_size'], kwargs['stride'])
            height = conv2d_size_out(height, kwargs['kernel_size'], kwargs['stride'])
        
        dense_in_features = width * height * hparams[-1]['out_channels']
        
        self.actor = nn.Sequential(
            Flatten(),
            nn.Linear(in_features=dense_in_features, out_features=512),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=512, out_features=n_actions)
        )

        self.critic = nn.Sequential(
            Flatten(),
            nn.Linear(in_features=dense_in_features, out_features=512),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=512, out_features=1, bias=False)
        )
        
        self.apply(self.init)
        

    def forward(self, state_t):
        """
        takes agent's observation (tensor), returns advantage and logits (tensor)
        :param state_t: a batch of 4-frame buffers, shape = [batch_size, 4, h, w]
        """

        x = self.conv(state_t)
        values = self.critic(x).squeeze(dim=-1)
        logits = self.actor(x)

        return values, logits
    
    def init(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.orthogonal_(m.weight, np.sqrt(2))
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.zeros_(m.bias)

You will also need to define and use a policy that wraps the model. While the model computes logits for all actions, the policy will sample actions and also compute their log probabilities.  `policy.act` should return a **dictionary** of all the arrays that are needed to interact with an environment and train the model.

**Important**: "actions" will be sent to environment, they must be numpy array or list, not PyTorch tensor.

Note: you can add more keys, e.g. it can be convenient to compute entropy right here.

In [5]:
import torch.nn.functional as F


class Policy:
    def __init__(self, model: AgentNetwork):
        self.model = model

    def act(self, inputs):
        '''
        input:
            inputs - numpy array, (batch_size x channels x width x height)
        output: dict containing keys ['actions', 'logits', 'log_probs', 'values']:
            'actions' - selected actions, numpy, (batch_size)
            'logits' - actions logits, tensor, (batch_size x num_actions)
            'log_probs' - log probs of selected actions, tensor, (batch_size)
            'values' - critic estimations, tensor, (batch_size)
        '''
        device = next(self.model.parameters()).device
        inputs = torch.from_numpy(inputs).to(device)
        values, logits = self.model(inputs)

        logprobs = F.log_softmax(logits, dim=1)
        probs = F.softmax(logits, dim=1)
        entropy = probs.mul(logprobs).neg().sum(dim=1)

        actions = torch.multinomial(probs, num_samples=1).squeeze(-1)
        ids = torch.arange(len(values), device=values.device)
        logprobs = logprobs[ids, actions]

        return {
            "actions": actions.detach().cpu().numpy(),
            "logits": logits,
            "log_probs": logprobs,
            "values": values,
            "entropy": entropy
        }

In [6]:
a = torch.randn(8, 4, 84, 84).numpy()
agent = AgentNetwork(n_actions=n_actions)
policy = Policy(agent)
policy.act(a)

{'actions': array([5, 3, 5, 4, 4, 0, 0, 4]),
 'logits': tensor([[-0.1037, -0.0054,  0.0194, -0.0670,  0.0138, -0.1633],
         [-0.1906, -0.1316,  0.0053,  0.0645,  0.2488, -0.0135],
         [-0.0799, -0.2727, -0.0153, -0.0436,  0.0280, -0.0983],
         [-0.1188, -0.1467, -0.0109, -0.0309,  0.3081, -0.1928],
         [-0.3331,  0.0728,  0.2201, -0.0444, -0.1000, -0.1762],
         [-0.1401,  0.0501, -0.1286, -0.0133,  0.2314, -0.1620],
         [ 0.0098, -0.0481, -0.1802, -0.0659,  0.0308, -0.0569],
         [ 0.0236, -0.1886, -0.0806, -0.1430,  0.2154,  0.0183]],
        grad_fn=<AddmmBackward0>),
 'log_probs': tensor([-1.9062, -1.7345, -1.8141, -1.4661, -1.8471, -1.9147, -1.7324, -1.5596],
        grad_fn=<IndexBackward0>),
 'values': tensor([-0.1297, -0.1377, -0.0250, -0.0449,  0.1353, -0.0191,  0.1572, -0.1123],
        grad_fn=<SqueezeBackward1>),
 'entropy': tensor([1.7896, 1.7814, 1.7875, 1.7764, 1.7762, 1.7816, 1.7896, 1.7825],
        grad_fn=<SumBackward1>)}

Next we will pass the environment and policy to a runner that collects rollouts from the environment. 
The class is already implemented for you.

In [7]:
from runners import EnvRunner

This runner interacts with the environment for a given number of steps and returns a dictionary containing
keys 

* 'observations' 
* 'rewards' 
* 'dones'
* 'actions'
* all other keys that you defined in `Policy`

under each of these keys there is a python `list` of interactions with the environment of specified length $T$ &mdash; the size of partial trajectory, or rollout length. Let's have a look at how it works.

In [8]:
model = AgentNetwork(n_actions=n_actions)
policy = Policy(model)
runner = EnvRunner(env, policy, nsteps=5)

In [9]:
# generates new rollout
trajectory = runner.get_next()

In [10]:
# what is inside
print(trajectory.keys())

dict_keys(['actions', 'logits', 'log_probs', 'values', 'entropy', 'observations', 'rewards', 'dones'])


In [11]:
# Sanity checks
assert 'logits' in trajectory, "Not found: policy didn't provide logits"
assert 'log_probs' in trajectory, "Not found: policy didn't provide log_probs of selected actions"
assert 'values' in trajectory, "Not found: policy didn't provide critic estimations"
assert trajectory['logits'][0].shape == (nenvs, n_actions), "logits wrong shape"
assert trajectory['log_probs'][0].shape == (nenvs,), "log_probs wrong shape"
assert trajectory['values'][0].shape == (nenvs,), "values wrong shape"

for key in trajectory.keys():
    assert len(trajectory[key]) == 5, \
    f"something went wrong: 5 steps should have been done, got trajectory of length {len(trajectory[key])} for '{key}'"

In [12]:
trajectory['values']

[tensor([-0.1643, -0.1643, -0.1643, -0.1643, -0.1643, -0.1643, -0.1643, -0.1643],
        grad_fn=<SqueezeBackward1>),
 tensor([-0.1643, -0.1643, -0.1643, -0.1643, -0.1643, -0.1643, -0.1643, -0.1643],
        grad_fn=<SqueezeBackward1>),
 tensor([-0.1643, -0.1643, -0.1643, -0.1643, -0.1643, -0.1643, -0.1643, -0.1643],
        grad_fn=<SqueezeBackward1>),
 tensor([-0.1588, -0.1588, -0.1588, -0.1588, -0.1588, -0.1588, -0.1588, -0.1588],
        grad_fn=<SqueezeBackward1>),
 tensor([-0.1680, -0.1680, -0.1680, -0.1680, -0.1680, -0.1680, -0.1680, -0.1680],
        grad_fn=<SqueezeBackward1>)]

Now let's work with this trajectory a bit. To train the critic you will need to compute the value targets. It will also be used as an estimation of $Q$ for actor training.

You should use all available rewards for value targets, so the formula for the value targets is simple:

$$
\hat v(s_t) = \sum_{t'=0}^{T - 1}\gamma^{t'}r_{t+t'} + \gamma^T \hat{v}(s_{t+T}),
$$

where $s_{t + T}$ is the latest observation of the environment.

Any callable could be passed to `EnvRunner` to be applied to each partial trajectory after it is collected. 
Thus, we can implement and use `ComputeValueTargets` callable. 

**Do not forget** to use `trajectory['dones']` flags to check if you need to add the value targets at the next step when 
computing value targets for the current step.

**Bonus (+0.5 pts):** implement [Generalized Advantage Estimation (GAE)](https://arxiv.org/pdf/1506.02438.pdf) instead; use $\lambda \approx 0.95$ or even closer to 1 in experiment. 

In [13]:
trajectory['rewards']

[array([0., 0., 0., 0., 0., 0., 0., 0.]),
 array([0., 0., 0., 0., 0., 0., 0., 0.]),
 array([0., 0., 0., 0., 0., 0., 0., 0.]),
 array([0., 0., 0., 0., 0., 0., 0., 0.]),
 array([0., 0., 0., 0., 0., 0., 0., 0.])]

In [14]:
trajectory['dones']

[array([False, False, False, False, False, False, False, False]),
 array([False, False, False, False, False, False, False, False]),
 array([False, False, False, False, False, False, False, False]),
 array([False, False, False, False, False, False, False, False]),
 array([False, False, False, False, False, False, False, False])]

In [15]:
class ComputeValueTargets:
    def __init__(self, policy, gamma=0.99):
        self.policy = policy
        self.gamma = gamma

    def __call__(self, trajectory, latest_observation):
        '''
        This method should modify trajectory inplace by adding 
        an item with key 'value_targets' to it
        
        input:
            trajectory - dict from runner
            latest_observation - last state, numpy, (num_envs x channels x width x height)
        '''
        
        rewards = trajectory['rewards'] # (env_steps, num_envs)
        dones = trajectory['dones'] # (env_steps, num_envs)
        env_steps = len(rewards)

        value_estimate = policy.act(latest_observation)['values'].detach().cpu().numpy() # (num_envs,)

        value_targets = []
        for t in range(env_steps-1, -1, -1):
            value_estimate = self.gamma * value_estimate * (1 - dones[t]) + rewards[t]
            value_targets.append(value_estimate)

        trajectory['value_targets'] = value_targets[::-1]

After computing value targets we will transform lists of interactions into tensors
with the first dimension `batch_size` which is equal to `T * nenvs`.

You need to make sure that after this transformation `"log_probs"`, `"value_targets"`, `"values"` are 1-dimensional PyTorch tensors.

In [16]:
for k, v in trajectory.items():
    print(k, v[0].shape)

actions (8,)
logits torch.Size([8, 6])
log_probs torch.Size([8])
values torch.Size([8])
entropy torch.Size([8])
observations (8, 4, 84, 84)
rewards (8,)
dones (8,)


In [17]:
class MergeTimeBatch:
    def __init__(self, device):
        self.device = device

    """Merges first two axes typically representing time and env batch."""
    def __call__(self, trajectory, latest_observation):
        for k, v in trajectory.items():
            if isinstance(v[0], np.ndarray):
                tensor = torch.from_numpy(np.stack(v, axis=0).astype(np.float32)).to(self.device)
                tensor = tensor.view(-1, *tensor.shape[2:])
            else:
                tensor = torch.concatenate(v, dim=0)
            trajectory[k] = tensor

Let's do more sanity checks!

In [18]:
runner = EnvRunner(
    env,
    policy,
    nsteps=5,
    transforms=[
        ComputeValueTargets(policy),
        MergeTimeBatch(device='cuda')
    ]
)

trajectory = runner.get_next()

In [19]:
# More sanity checks
assert 'value_targets' in trajectory, "Value targets not found"
assert trajectory['log_probs'].shape == (5 * nenvs,)
assert trajectory['value_targets'].shape == (5 * nenvs,)
assert trajectory['values'].shape == (5 * nenvs,)

assert trajectory['log_probs'].requires_grad, "Gradients are not available for actor head!"
assert trajectory['values'].requires_grad, "Gradients are not available for critic head!"

In [20]:
for k, v in trajectory.items():
    print(k, v.shape)

actions torch.Size([40])
logits torch.Size([40, 6])
log_probs torch.Size([40])
values torch.Size([40])
entropy torch.Size([40])
observations torch.Size([40, 4, 84, 84])
rewards torch.Size([40])
dones torch.Size([40])
value_targets torch.Size([40])


Now is the time to implement the advantage actor critic algorithm itself. You can look into [Mnih et al. 2016](https://arxiv.org/abs/1602.01783) paper, and lectures ([part 1](https://www.youtube.com/watch?v=Ds1trXd6pos&list=PLkFD6_40KJIwhWJpGazJ9VSj9CFMkb79A&index=5), [part 2](https://www.youtube.com/watch?v=EKqxumCuAAY&list=PLkFD6_40KJIwhWJpGazJ9VSj9CFMkb79A&index=6)) by Sergey Levine.

In [21]:
from torch.nn.utils import clip_grad_norm_

class A2C:
    def __init__(self, policy: Policy, optimizer: torch.optim.Optimizer, value_loss_coef=0.25, entropy_coef=0.01, max_grad_norm=0.5):
        self.policy = policy
        self.optimizer = optimizer
        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef
        self.max_grad_norm = max_grad_norm
    
    def loss(self, trajectory, write):
        advantage = trajectory['value_targets'] - trajectory['values']
        policy_loss = trajectory['log_probs'].mul(advantage.detach()).mean().neg()
        entropy_loss = trajectory['entropy'].mean().neg()
        critic_loss = F.mse_loss(trajectory['values'], trajectory['value_targets'].detach(), reduction='mean')

        write(
            'losses',
            {
                'policy loss': policy_loss,
                'critic loss': critic_loss,
                'entropy loss': entropy_loss
            }
        )

        write('critic/advantage', advantage.mean())
        write(
            'critic/values',
            {
                'value predictions': trajectory['values'].mean(),
                'value targets': trajectory['value_targets'].mean(),
            }
        )

        write('Episodes/mode_action', torch.mode(trajectory['actions'])[0])
        write('Episodes/actions', trajectory['actions'])

        return policy_loss + self.value_loss_coef * critic_loss + self.entropy_coef * entropy_loss

    def step(self, runner: EnvRunner):
        trajectory = runner.get_next()
        loss = self.loss(trajectory, runner.write)
        self.optimizer.zero_grad()
        loss.backward()
        grad_norm = clip_grad_norm_(self.policy.model.parameters(), self.max_grad_norm)
        self.optimizer.step()
        runner.write('gradient norm', grad_norm)

Now you can train your model. For optimization we suggest you use RMSProp with learning rate 7e-4 (you can also linearly decay it to 0), smoothing constant (alpha in PyTorch) equal to 0.99 and epsilon equal to 1e-5.

We recommend to train for at least 10 million environment steps across all batched environments (takes ~3 hours on a single GTX1080 with 8 CPU). It should be possible to achieve *average raw reward over last 100 episodes* (the average is taken over 100 last episodes in each environment in the batch) of about 600. **Your goal is to reach 500**.

Notes:
* if your reward is stuck at ~200 for more than 2M steps then probably there is a bug
* if your gradient norm is >10 something probably went wrong
* make sure your `entropy loss` is negative, your `critic loss` is positive
* make sure you didn't forget `.detach` in losses where it's needed
* `actor loss` should oscillate around zero or near it; do not expect loss to decrease in RL ;)
* you can experiment with `nsteps` ("rollout length"); standard rollout length is 5 or 10. Note that this parameter influences how many algorithm iterations is required to train on 10M steps (or 40M frames --- we used frameskip in preprocessing).

In [22]:
model = AgentNetwork(n_actions).cuda()
policy = Policy(model)
runner = EnvRunner(
    env,
    policy,
    nsteps=10,
    transforms=[
        ComputeValueTargets(policy),
        MergeTimeBatch(device='cuda')
    ]
)

optimizer = torch.optim.RMSprop(model.parameters(), lr=7e-4, alpha=0.99, eps=1e-5)

a2c = A2C(policy, optimizer)

In [23]:
model.load_state_dict(torch.load('A2C-2.pth'))
model.eval()

AgentNetwork(
  (conv): Sequential(
    (0): ConvBlock(
      (net): Sequential(
        (0): Conv2d(4, 32, kernel_size=(3, 3), stride=(4, 4))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
    )
    (1): ConvBlock(
      (net): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
    )
    (2): ConvBlock(
      (net): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
    )
  )
  (actor): Sequential(
    (0): Flatten()
    (1): Linear(in_features=4096, out_features=512, bias=True)
    (2): ReLU(inplace=True)
    (3): Linear(in_features=512, out_features=6, bias=True)
  )
  (critic): Sequential(

In [24]:
# from tqdm.autonotebook import trange


# n_steps = int(1e7 / nenvs / 10)
# for i_step in trange(n_steps):
#     a2c.step(runner)

In [25]:
# # save your model just in case 
# torch.save(model.state_dict(), "A2C-2.pth")
# torch.save(optimizer.state_dict(), "A2C-2-optim.pth")

In [26]:
env.close()

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


## Evaluation

In [27]:
env = nature_dqn_env("SpaceInvadersNoFrameskip-v4", nenvs=None, 
                     clip_reward=False, summaries=False, episodic_life=False)

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


In [27]:
def evaluate(env, policy, n_games=1, t_max=10000):
    '''
    Plays n_games and returns rewards
    '''
    rewards = []
    
    for _ in range(n_games):
        s = env.reset()
        
        R = 0
        for _ in range(t_max):
            action = policy.act(np.array([s]))["actions"][0]
            
            s, r, done, _ = env.step(action)
            
            R += r
            if done:
                break

        rewards.append(R)
    return np.array(rewards)

In [31]:
# evaluation will take some time!
sessions = evaluate(env, policy, n_games=30)
score = sessions.mean()
print(f"Your score: {score}")

assert score >= 500, "Needs more training?"
print("Well done!")

Your score: 648.0
Well done!


In [29]:
env.close()

## Record

In [28]:
env_monitor = nature_dqn_env("SpaceInvadersNoFrameskip-v4", nenvs=None, monitor=True,
                             clip_reward=False, summaries=False, episodic_life=False)

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


In [31]:
# record sessions
sessions = evaluate(env_monitor, policy, n_games=3)

In [32]:
# rewards for recorded games
sessions

array([ 740.,  570., 1320.])

In [33]:
env_monitor.close()