# Table of contents
1. [Input features](#inputs)
   1. [Augmentations](#augmentations)
2. [Rewards](#rewards)
3. [Network architecture](#architecture)
4. [PPO algorithm explained](#ppo)
5. [Final touches](#final_touches)
6. [Training](#training)
7. [Save and show](#save_and_show)

In [None]:
!pip install kaggle_environments==1.7.10

In [None]:
# basic imports
import numpy as np
import itertools
import matplotlib.pyplot as plt
import seaborn as sns
from random import shuffle
from copy import deepcopy
from tqdm.notebook import tqdm

# torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.distributions import Categorical
from torch.optim import Adam

# hungry-geese imports
from kaggle_environments import make, evaluate

env = make("hungry_geese", debug=False)

## Input features <a name="inputs"></a>

Input state can be generated in very different ways. Here we are going to use the following procedure:

Inputs have 12 channels, each one is a 7 x 11 (rows x columns) matrix.

* First channel consists of ones in cells where agent is located and zeros elsewhere.
* Next three channels are the same but with opponents coordinates (players randomly ordered).
* In the fourth channel there are ones in cells corresponding to all agents heads.
* In the fifth channel - the same, but with agents bodies.
* In the sixth channel - the same, but with agents tails.
* In the seventh channel - the same, but with agents heads previous locations.
* In the eight channel - ones in food cells.
* The last 3 channels each one represent only one number, broadcasted to the whole matrix:
  * (t % 40) / 40
  * t / 200
  * 0 if (t + 1) is divisible by 40 and 1 otherwise,
  
  where t is observation step.
  
Each channel is rolled for agents head to be located in the center of the matrix.

In [None]:
def get_position_from_index(index, columns):
    row = index // columns
    col = index % columns
    return row, col

def get_index_from_position(row, col, columns):
    return row * columns + col

def find_new_head_position(head_row, head_col, action, rows, columns):
    if action == 0: # north
        new_row, new_col = (head_row + rows - 1) % rows, head_col
    elif action == 1: # east
        new_row, new_col = head_row, (head_col + 1) % columns
    elif action == 2: # south
        new_row, new_col = (head_row + 1) % rows, head_col
    else: # west
        new_row, new_col = head_row, (head_col + columns - 1) % columns
    return new_row, new_col

def shift_head(head_id, action, rows=7, columns=11):
    head_row, head_col = get_position_from_index(head_id, columns)
    new_row, new_col = find_new_head_position(head_row, head_col, action, rows, columns)
    new_head_id = get_index_from_position(new_row, new_col, columns)
    return new_head_id

def get_previous_head(ids, last_action, rows, columns):
    if len(ids) > 1:
        return ids[1]
    return shift_head(ids[0], (last_action + 2) % 4, rows, columns)

def ids2locations(ids, prev_head, step, rows, columns):
    state = np.zeros((4, rows * columns))
    if len(ids) == 0:
        return state
    state[0, ids[0]] = 1 # goose head
    if len(ids) > 1:
        state[1, ids[1:-1]] = 1 # goose body
        state[2, ids[-1]] = 1 # goose tail
    if step != 0:
        state[3, prev_head] = 1 # goose head one step before
    return state

def get_features(observation, config, prev_heads):
    rows, columns = config['rows'], config['columns']
    geese = observation['geese']
    index = observation['index']
    step = observation['step']
    # convert indices to locations
    locations = np.zeros((len(geese), 4, rows * columns))
    for i, g in enumerate(geese):
        locations[i] = ids2locations(g, prev_heads[i], step, rows, columns)
    if index != 0: # swap rows for player locations to be in first channel
        locations[[0, index]] = locations[[index, 0]]
    # put locations into features
    features = np.zeros((12, rows * columns))
    for k in range(4):
        features[k] = np.sum(locations[k][:3], 0)
        features[k + 4] = np.sum(locations[:, k], 0)
    features[-4, observation['food']] = 1 # food channel
    features[-3, :] = (step % config['hunger_rate']) / config['hunger_rate'] # hunger danger channel
    features[-2, :] = step / config['episodeSteps'] # timesteps channel
    features[-1, :] = float((step + 1) % config['hunger_rate'] == 0) # hunger milestone indicator
    features = torch.Tensor(features).reshape(-1, rows, columns)
    # roll
    head_id = geese[index][0]
    head_row = head_id // columns
    head_col = head_id % columns
    features = torch.roll(features, ((rows // 2) - head_row, (columns // 2) - head_col), dims=(-2, -1))
    return features

An example of input state:

In [None]:
def plot_features(features):
    fig, axs = plt.subplots(3, 4, figsize=(20, 10))
    for i in range(3):
        for j in range(4):
            sns.heatmap(features[i * 4 + j], ax=axs[i, j], cmap='Blues',
                        vmin=0, vmax=1, linewidth=2, linecolor='black', cbar=False)

def get_example_features():
    observation = {}
    observation['step'] = 104
    observation['index'] = 0
    observation['geese'] = [[46, 47, 36, 37, 48, 59, 58, 69],
                            [5, 71, 72, 6, 7, 73, 62, 61, 50, 51, 52, 63, 64, 53, 54],
                            [12, 11, 21, 20, 19, 8, 74, 75, 76, 65, 55, 56, 67, 1],
                            [23, 22, 32, 31, 30, 29, 28, 17, 16, 27, 26, 15, 14, 13, 24]]
    observation['food'] = [45, 66]
    prev_heads = [47, 71, 11, 22]
    return get_features(observation, env.configuration, prev_heads)

features = get_example_features()
plot_features(features.cpu().detach().numpy())

### Augmentations <a name="rewards"></a>

Each state will go through the following augmentations:

* Random horizontal flip
* Random vertical flip
* Random opponents indices shuffle

In [None]:
def augment(batch):
    # random horizontal flip
    flip_mask = np.random.rand(len(batch['states'])) < 0.5
    batch['states'][flip_mask] = batch['states'][flip_mask].flip(-1)
    batch['actions'][flip_mask] = torch.where(batch['actions'][flip_mask] > 0, 4 - batch['actions'][flip_mask], 0) # 1 -> 3, 3 -> 1

    # random vertical flip (and also diagonal)
    flip_mask = np.random.rand(len(batch['states'])) < 0.5
    batch['states'][flip_mask] = batch['states'][flip_mask].flip(-2)
    batch['actions'][flip_mask] = torch.where(batch['actions'][flip_mask] < 3, 2 - batch['actions'][flip_mask], 3) # 0 -> 2, 2 -> 0

    # shuffle opponents channels
    permuted_axs = list(itertools.permutations([0, 1, 2]))
    permutations = [torch.tensor(permuted_axs[i]) for i in np.random.randint(6, size=len(batch['states']))]
    for i, p in enumerate(permutations):
        shuffled_channels = torch.zeros(3, batch['states'].shape[2], batch['states'].shape[3])
        shuffled_channels[p] = batch['states'][i, 1:4]
        batch['states'][i, 1:4] = shuffled_channels
    return batch

## Rewards <a name="rewards"></a>

We will use two types of rewards. First reward $R_1$ is given to agent in the end of the game: -1 for getting the 4th place, -0.75 for the 3rd, -0.25 for 2nd and +1 for reaching the 1st place. Second reward $R_2$ is related to food: agent gets +0.1 for eating and -1 if it died from hunger.

In [None]:
def get_rank(obs, prev_obs):
    geese = obs['geese']
    index = obs['index']
    player_len = len(geese[index])
    survivors = [i for i in range(len(geese)) if len(geese[i]) > 0]
    if index in survivors: # if our player survived in the end, its rank is given by its length in the last state
        return sum(len(x) >= player_len for x in geese) # 1 is the best, 4 is the worst
    # if our player is dead, consider lengths in penultimate state
    geese = prev_obs['geese']
    index = prev_obs['index']
    player_len = len(geese[index])
    rank_among_lost = sum(len(x) >= player_len for i, x in enumerate(geese) if i not in survivors)
    return rank_among_lost + len(survivors)
    
def get_rewards(env_reward, obs, prev_obs, done):
    geese = prev_obs['geese']
    index = prev_obs['index']
    step  = prev_obs['step']
    if done:
        rank = get_rank(obs, prev_obs)
        r1 = (1, -0.25, -0.75, -1)[rank - 1]
        died_from_hunger = ((step + 1) % 40 == 0) and (len(geese[index]) == 1)
        r2 = -1 if died_from_hunger else 0 # int(rank == 1) # huge penalty for dying from hunger and huge award for the win
    else:
        if step == 0:
            env_reward -= 1 # somehow initial step is a special case
        r1 = 0
        r2 = max(0.1 * (env_reward - 1), 0) # food reward
    return (r1, r2)

## Network architecture <a name="architecture"></a>

We are going to train a neural network, which is structured in the following way:


                                -------> Actor ---> Logits (ùúã)
                              /
    State (ùë†) ---> Encoder --- --------> Critic-1 ---> Value-1 (ùëâ-1)
                              \
                                -------> Critic-2 ---> Value-2 (ùëâ-2)


Each critic head predicts state-value $V_\theta(s)$ (estimate of discounted return from this point onwards), where $\theta$ stands for neural net parameters. Actor updates policy parameters for $\pi_\theta$, in the direction suggested by Critics.

I won't go into details here, let the code speak for itself. The architecture of encoder is a relatively simple ResNet with Squeeze-Excitation blocks and Swish activation functions instead of commonly used ReLU.

In [None]:
class SEBlock(nn.Module):
    """Squeeze-Excitation Block"""
    
    def __init__(self, dim, reduction_ratio=4):
        super(SEBlock, self).__init__()
        self.f1 = nn.Linear(dim, dim // reduction_ratio)
        self.f2 = nn.Linear(dim // reduction_ratio, dim)

    def forward(self, x):
        y = x.mean(axis=(-1, -2))
        y = F.silu(self.f1(y))
        y = torch.sigmoid(self.f2(y))
        return x * y.unsqueeze(-1).unsqueeze(-1)

class BasicBlock(nn.Module):
    """Basic Residual Block"""
    
    def __init__(self, dim, downscale=False):
        super(BasicBlock, self).__init__()
        if downscale:
            self.conv1 = nn.Conv2d(dim,     2 * dim, 3, stride=2, padding=1)
            self.conv2 = nn.Conv2d(2 * dim, 2 * dim, 3, stride=1, padding=1)
            self.bnorm1 = nn.BatchNorm2d(dim)
            self.bnorm2 = nn.BatchNorm2d(2 * dim)
            self.proj = nn.Conv2d(dim, 2 * dim, 1, stride=2)
            self.se = SEBlock(2 * dim)
        else:
            self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
            self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
            self.bnorm1 = nn.BatchNorm2d(dim)
            self.bnorm2 = nn.BatchNorm2d(dim)
            self.proj = nn.Identity()
            self.se = SEBlock(dim)

    def forward(self, x):
        y = self.conv1(self.bnorm1(F.silu(x)))
        z = self.conv2(self.bnorm2(F.silu(y)))
        return self.se(z) + self.proj(x)
    
class BottleneckBlock(nn.Module):
    """Bottleneck Residual Block"""
    
    def __init__(self, dim, downscale=False):
        super(BottleneckBlock, self).__init__()
        if downscale:
            self.conv1 = nn.Conv2d(dim,     dim,     1)
            self.conv2 = nn.Conv2d(dim,     2 * dim, 3, stride=2, padding=1)
            self.conv3 = nn.Conv2d(2 * dim, 2 * dim, 1)
            self.bnorm1 = nn.BatchNorm2d(dim)
            self.bnorm2 = nn.BatchNorm2d(dim)
            self.bnorm3 = nn.BatchNorm2d(2 * dim)
            self.proj = nn.Conv2d(dim, 2 * dim, 1, stride=2)
            self.se = SEBlock(2 * dim)
        else:
            self.conv1 = nn.Conv2d(dim, dim, 1)
            self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
            self.conv3 = nn.Conv2d(dim, dim, 1)
            self.bnorm1 = nn.BatchNorm2d(dim)
            self.bnorm2 = nn.BatchNorm2d(dim)
            self.bnorm3 = nn.BatchNorm2d(dim)
            self.proj = nn.Identity()
            self.se = SEBlock(dim)
        
    def forward(self, x):
        y = self.conv1(self.bnorm1(F.silu(x)))
        z = self.conv2(self.bnorm2(F.silu(y)))
        w = self.conv3(self.bnorm3(F.silu(z)))
        return self.se(w) + self.proj(x)

class ResLayers(nn.Module):
    """Sequential Residual Layers"""
    
    def __init__(self, block, dim, depth):
        super(ResLayers, self).__init__()
        self.blocks = nn.ModuleList(
            [block(dim, downscale=False) for _ in range(depth - 1)] +
            [block(dim, downscale=True)]
            )
    
    def forward(self, x):
        for b in self.blocks:
            x = b(x)
        return x
    
class Encoder(nn.Module):
    """Res-Net Encoder"""
    
    def __init__(self, dim_in, depths):
        super(Encoder, self).__init__()
        self.gate = nn.Conv2d(12, dim_in, 1, padding=(3, 5), padding_mode='circular')
        self.layers = nn.ModuleList([
            ResLayers(BasicBlock,          dim_in, depths[0]),
            ResLayers(BasicBlock,      2 * dim_in, depths[1]),
            ResLayers(BottleneckBlock, 4 * dim_in, depths[2])
        ])

    def forward(self, x):
        z = self.gate(x)
        for l in self.layers:
            z = l(z)
        return z

class Actor(nn.Module):
    """Actor Head"""
    
    def __init__(self, dim_in, head_dim):
        super(Actor, self).__init__()
        self.compr = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, 3, padding=1),
            nn.SiLU(inplace=True),
            nn.BatchNorm2d(dim_in),
            nn.Conv2d(dim_in, head_dim, 1),
            nn.SiLU(inplace=True)
        )
        self.fc = nn.Linear(head_dim, 4)
        
    def forward(self, state):
        p = self.compr(state)
        p = p.mean(axis=(-1, -2))
        p = self.fc(p)
        return F.log_softmax(p, dim=1)
    
class Critic(nn.Module):
    """Critic Head"""
    
    def __init__(self, dim_in, head_dim):
        super(Critic, self).__init__()
        self.compr = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, 3, padding=1),
            nn.SiLU(inplace=True),
            nn.BatchNorm2d(dim_in),
            nn.Conv2d(dim_in, head_dim, 1),
            nn.SiLU(inplace=True)
        )
        self.fc = nn.utils.weight_norm(nn.Linear(head_dim, 1))
        
    def forward(self, state):
        v = self.compr(state)
        v = v.mean(axis=(-1, -2))
        v = self.fc(v)
        return torch.tanh(v)

class GNet(nn.Module):
    """G-Net"""
    
    def __init__(self):
        super(GNet, self).__init__()
        # init hyperparameters
        dim_in = 32
        head_dim = 16
        depths = (2, 2, 2)
        # init modules
        self.encoder = Encoder(dim_in, depths)
        self.actor = Actor(8 * dim_in, head_dim)
        self.critic1 = Critic(8 * dim_in, head_dim)
        self.critic2 = Critic(8 * dim_in, head_dim)
    
    def forward(self, state):
        latent = self.encoder(state)
        logp = self.actor(latent)
        v1 = self.critic1(latent)
        v2 = self.critic2(latent)
        return logp, (v1, v2)

## PPO algorithm explained <a name="ppo"></a>

Firstly, we will define policy gradient loss:

$$\mathcal{L}^{\text{PG}}(\theta) = \mathbb{E}[\log \pi_\theta(a|s) \hat{A}_\theta(s,a) ],$$

where first term $\log \pi_\theta(a|s)$ are log-probabilities from the output of policy network (actor head), and the second one is an estimate of `advantage function`, the relative value of selected action $a$. The value of $\hat{A}_\theta(s,a)$ is equal to `return` (or `discounted reward`) minus `baseline estimate`. Return at given time $t$ is calculated as follows:

$$ V_{\text{target}}(t) = R_{t+1} + \gamma R_{t+2} + \gamma^2 R_{t+3} + \dots = \sum_{k=0}^\infty \gamma^k R_{t+k+1},$$

where $R^i_t$ is a reward at timestep $t$. Baseline estimate is the output of value network $V_\theta(s)$. Therefore,

$$ \hat{A}_\theta(t) = V_{\text{target}}(t) - V_\theta(s_t). $$

There also exists a generalized version of advantage estimation, that we are going to use:

$$\hat{A}_\theta(t) = \delta_t + (\gamma \lambda) \delta_{t+1} + \dots = \sum_{k=0}^\infty (\gamma \lambda)^k \delta_{t+k+1},$$
$$ \text{where } \delta_t = R_t + \gamma V_\theta(s_{t+1}) - V_\theta(s_t),$$

which is reduced to previous equation when $\lambda = 1$.

Now, when $\hat{A}_\theta$ is positive, meaning that the action agent took resulted in a better than average return, we will increase probabilities of selecting it again in the future. On the other hand, if an advantage was negative, we will reduce the likelihood of selected actions.

In [None]:
def inv_discount_cumsum(array, discount_factor):
    res = [array[-1]]
    for x in torch.flip(array, dims=[0])[1:]:
        res.append(discount_factor * res[-1] + x)
    return torch.flip(torch.stack(res), dims=[0])

def get_advantages_and_returns(rewards, values, gamma, lam):
    # lists -> tensors
    rewards = torch.tensor(rewards, dtype=torch.float)
    values = torch.tensor(values + [0.])
    # calculate deltas, A and R
    deltas = rewards + gamma * values[1:] - values[:-1]
    advs = inv_discount_cumsum(deltas, gamma * lam).cpu().detach().tolist()
    rets = inv_discount_cumsum(rewards, gamma).cpu().detach().tolist()
    return advs, rets

However, as PPO-paper quotes:

`While it is appealing to perform multiple steps of optimization on this loss using the same trajectory, doing so is not well-justified, and empirically it often leads to destructively large policy updates.`

In other words, we have to impose the constraint which won't allow our new policy to move too far away from an old one. Let‚Äôs denote the probability ratio between old and new policies as

$$r(\theta) = \frac{\pi_{\theta}(a|s)}{\pi_{\theta_{\text{old}}}(a|s)}. $$

Then, take a look at our new `surrogate` objective function:

$$\mathcal{L}^{\text{CPI}}(\theta) = \mathbb{E}[r(\theta) \hat{A}_\theta(s,a)].$$

It can be derived that maximimizing $\mathcal{L}^{\text{CPI}}(\theta)$ is identical to vanilla policy gradient method, but I'll bravely skip the proof. Now, we would like to insert the aforementioned constraint into this loss function. The main objective which PPO-parer proposes is the following.

$$J^{\text{CLIP}}(\theta) = \mathbb{E}[\min (r(\theta) \hat{A}_{\theta_{\text{old}}}(s,a), \text{clip}(r(\theta), 1-\epsilon, 1+\epsilon)\hat{A}_{\theta_{\text{old}}}(s,a)],$$

where $\epsilon$ is a `clip ratio` hyperparameter. The first term inside $min$ function, $r(\theta) \hat{A}_{\theta_{\text{old}}}(s,a)$ is a normal policy gradient objective. And the second one is its clipped version, which doesn't allow us to destroy our current policy based on a single estimate, because the value of $\hat{A}_{\theta_{\text{old}}}(s,a)$ is noisy (as it is based on an output of our network).

When applying PPO on the network architecture with shared parameters for both policy and value functions, in addition to the clipped reward, the objective function is augmented with an error term on the value estimation and an entropy term to encourage sufficient exploration. Final loss then becomes:

$$\mathcal{L}(\theta) = \mathbb{E}[-J(\theta) + c  (V_\theta(s) - V_{\text{target}})^2 - c_{\text{ent}}  H(s, \pi_{\theta}(\cdot))], $$

where $c$ and $c_{\text{ent}}$ are both hyperparameter constants.

In [None]:
def compute_losses(net, data, c1, c2, c_ent, clip_ratio=0.2):
    # move data to GPU
    states = data['states'].cuda()
    actions = data['actions'].cuda()
    logp_old = data['log-p'].cuda()
    returns = [data[f'ret-{i}'].float().cuda() for i in range(1, 3)]
    advs  = data['adv-1'].float().cuda()
    advs += data['adv-2'].float().cuda()
    # get network outputs
    logp_dist, (values_1, values_2) = net(states)
    logp = torch.stack([lp[a] for lp, a in zip(logp_dist, actions)])
    # compute actor loss
    ratio = torch.exp(logp - logp_old)
    clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * advs
    actor_loss = -(torch.min(ratio * advs, clip_adv)).mean()
    # critic losses
    critic_loss_1 = ((values_1.squeeze() - returns[0]) ** 2).mean()
    critic_loss_2 = ((values_2.squeeze() - returns[1]) ** 2).mean()
    # entropy loss
    entropy = Categorical(probs=torch.exp(logp_dist)).entropy()
    entropy[entropy != entropy] = torch.tensor(0.).cuda() # remove NaNs if any
    entropy_loss = -entropy.mean()
    return {'actor': actor_loss,
            'critic': (c1 * critic_loss_1, 
                       c2 * critic_loss_2),
            'entropy': c_ent * entropy_loss}

## Final touches <a name="final_touches"></a>

Define G-Net dataset for training and Reinforcement Learning Agent which takes network as input and uses it to make actions.

In [None]:
class GeeseDataset(Dataset):
    """G-Net Dataset"""

    def __init__(self, buffers):
        self.buffers = buffers
        
    def __len__(self):
        return len(self.buffers['states'])

    def __getitem__(self, idx):
        return {key: self.buffers[key][idx] for key in self.buffers}

In [None]:
class RLAgent:
    def __init__(self, net, stochastic):
        self.prev_heads = [-1, -1, -1, -1]
        self.net = net
        self.stochastic = stochastic

    def raw_outputs(self, state):
        with torch.no_grad():
            logits, (v1, v2) = self.net(state.cuda().unsqueeze(0))
            logits = logits.squeeze(0)
            v1 = v1.squeeze(0)
            v2 = v2.squeeze(0)
            if self.stochastic:
                # get probabilities
                probs = torch.exp(logits)
                # convert 2 numpy
                probs = probs.cpu().detach().numpy()
                action = np.random.choice(range(4), p=probs) 
            else:
                action = np.argmax(logits.cpu().detach().numpy())
            return action, logits[action], (v1, v2)

    def __call__(self, observation, configuration):
        if observation['step'] == 0:
            self.prev_heads = [-1, -1, -1, -1]
        state = get_features(observation, configuration, self.prev_heads)
        action, _, _ = self.raw_outputs(state)
        self.prev_heads = [goose[0] if len(goose) > 0 else -1 for goose in observation['geese']]
        return ['NORTH', 'EAST', 'SOUTH', 'WEST'][action]

## Training <a name="training"></a>

Now we will make a couple of functions which make our rl-agent play versus itself and collect the trajectories to be used in training. After performing training updates we run rl-agent against 3 greedy agents and save its performance.

In [None]:
def rollout(player, env, players, buffers, gammas, lambdas):
    rewards = {str(i + 1): [] for i in range(2)}
    values  = {str(i + 1): [] for i in range(2)}
    # shuffle players indices
    shuffle(players)
    trainer = env.train(players)
    observation = trainer.reset()
    prev_obs = observation
    done = False
    prev_heads = [None for _ in range(4)]
    # start rollout
    while not done:
        # cache previous state
        for i, g in enumerate(observation['geese']):
            if len(g) > 0:
                prev_heads[i] = prev_obs['geese'][i][0]
        prev_obs = observation
        # transform observation to state
        state = get_features(observation, env.configuration, prev_heads)
        # make a move
        action, logp, v = player.raw_outputs(state)
        # observe
        observation, reward, done, _ = trainer.step(['NORTH', 'EAST', 'SOUTH', 'WEST'][action])
        # data -> buffers
        buffers['states'].append(state)
        buffers['actions'].append(action)
        buffers['log-p'].append(logp.cpu().detach())
        # save rewards and values
        r = get_rewards(reward, observation, prev_obs, done)
        for i in range(2):
            rewards[str(i+1)].append(r[i])
            values[str(i+1)].append(v[i])
    # save advantages and returns
    for key in ['1', '2']:
        advs, rets = get_advantages_and_returns(rewards[key], values[key], gammas[key], lambdas[key])
        # add them to buffer
        buffers['adv-' + key] += advs
        buffers['ret-' + key] += rets

In [None]:
def runner(net, env, samples_threshold, gammas, lambdas, progress_bar=False):
    data_buffers = {'states': [], 'actions': [], 'log-p': [],
                    'adv-1': [], 'ret-1': [],
                    'adv-2': [], 'ret-2': []}
    samples_collected = 0
    if progress_bar:
        samples_bar = tqdm(total=samples_threshold, desc='Collecting Samples', leave=False)
    player = RLAgent(net, stochastic=True)
    opponents = [RLAgent(net, stochastic=False) for _ in range(3)]
    while True:
        rollout(player, env, players=[None] + opponents, buffers=data_buffers,
                gammas=gammas, lambdas=lambdas)
        if progress_bar:
            # update progress bar
            samples_bar.update(len(data_buffers['states']) - samples_collected)
        samples_collected = len(data_buffers['states'])
        if samples_collected >= samples_threshold:
            if progress_bar:
                samples_bar.close()
            return data_buffers

In [None]:
def train(net, optimizer,
          n_episodes=25,
          batch_size=256,
          samples_threshold=10000,
          n_ppo_epochs=25):
    losses_hist = {'clip': [], 'value-1': [], 'value-2': [], 'ent': [], 'lr': []}
    win_rates = {'Score': [], 'Rank': []}
    gammas  = {'1': 0.8, '2': 0.8}
    lambdas = {'1': 0.7, '2': 0.7}
    print('-Start Training')
    for episode in tqdm(range(n_episodes), desc='Episode', leave=False):
        # update statistics
        net.eval()
        player = RLAgent(net, stochastic=False)
        scores = evaluate("hungry_geese", [player] + ['greedy'] * 3, num_episodes=30)
        win_rates['Score'].append(np.mean([r[0] for r in scores]))
        win_rates['Rank'].append(np.mean([sum(r[0] <= r_ for r_ in r if r_ is not None) for r in scores]))
        # collect data
        buffers = runner(net, env, samples_threshold, gammas=gammas, lambdas=lambdas)
        # perform training
        net.train()
        dataset = GeeseDataset(buffers)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
        for epoch in range(n_ppo_epochs):
            for batch in dataloader:
                losses = compute_losses(net, augment(batch),
                                        c1=1, c2=1, c_ent=0.01)
                loss = losses['actor']
                losses_hist['clip'].append(losses['actor'].item())
                for i in range(2):
                    loss += losses['critic'][i]
                    losses_hist[f'value-{i+1}'].append(losses['critic'][i].item())
                loss += losses['entropy']
                losses_hist['ent'].append(losses['entropy'].item())
                loss.backward()
                nn.utils.clip_grad_norm_(net.parameters(), max_norm=1)
                optimizer.step()
                optimizer.zero_grad()
    return win_rates

In [None]:
net = GNet().cuda()
optimizer = Adam(net.parameters(), lr=1e-5)

win_rates = train(net, optimizer)

## Save and show <a name="training"></a>

In [None]:
!mkdir checkpoint
torch.save(net.state_dict(), 'checkpoint/g.net')

In [None]:
t = np.arange(len(win_rates['Score']))
fig, ax1 = plt.subplots()

color = 'tab:red'
ax1.set_xlabel('Timesteps')
ax1.set_ylabel('Score', color=color)
ax1.plot(t, win_rates['Score'], color=color)
ax1.tick_params(axis='y', labelcolor=color)
ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

color = 'tab:blue'
ax2.set_ylabel('Rank', color=color)  # we already handled the x-label with ax1
ax2.plot(t, win_rates['Rank'], color=color)
ax2.tick_params(axis='y', labelcolor=color)
fig.tight_layout()  # otherwise the right y-label is slightly clipped
plt.title('Performance of G-net vs 3 Greedy Agents')
plt.show()

In [None]:
env.reset()
net.eval()
env.run([RLAgent(net, stochastic=False),
         RLAgent(net, stochastic=False),
         RLAgent(net, stochastic=False),
         RLAgent(net, stochastic=False)])
env.render(mode="ipython", width=500, height=400)