# Project : beat flappy bird

You may be familiar with the game [flappy bird](https://flappybird.io/). It is very simple: a bird moves at constant speed on the x axis and, to direct him, you can either push it up or let it fall at each step. The goal of the game is to go as far as possible.

The goal for this project is as follow: design and train an agent which does the best possible score at flappy bird !

## Members

- Souhaiel BEN SALEM
- Charbel ABI HANA
- Adrian GARNIER ARTIÃ‘ANO
- Israfel SALAZAR REYES

In [None]:
#@title Installations  { form-width: "30%" }

# This is just for the purpose of this colab. Please do not share a ssh
# private key in real life, it is a really unsafe practice.
GITHUB_PRIVATE_KEY = """-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
QyNTUxOQAAACD5ow+qHLZLVosHfeGcGeJKQgwUlPYgoFliCEsshiFhXwAAALCn99V2p/fV
dgAAAAtzc2gtZWQyNTUxOQAAACD5ow+qHLZLVosHfeGcGeJKQgwUlPYgoFliCEsshiFhXw
AAAECJ+OOLQqiwINexx26mmQt6FL5xXYHRf9Jv2UzahlW0avmjD6octktWiwd94ZwZ4kpC
DBSU9iCgWWIISyyGIWFfAAAAKm1yaXZpZXJlQG1yaXZpZXJlLW1hY2Jvb2twcm8ucm9hbS
5pbnRlcm5hbAECAw==
-----END OPENSSH PRIVATE KEY-----
"""

# Create the directory if it doesn't exist.
! mkdir -p /root/.ssh
# Write the key
with open("/root/.ssh/id_ed25519", "w") as f:
  f.write(GITHUB_PRIVATE_KEY)
# Add github.com to our known hosts
! ssh-keyscan -t ed25519 github.com >> ~/.ssh/known_hosts
# Restrict the key permissions, or else SSH will complain.
! chmod go-rwx /root/.ssh/id_ed25519

# Clone and install the RL Games repository
! if [ -d "rl_games" ]; then echo "rl_games directory exists."; else git clone git@github.com:Molugan/rl_games.git; fi
! cd rl_games ; git pull;  pip install .

# Other dependencies
# If you just want to play your environment and does not intend to use either
# jax or haiku you can comment this part.
!pip install dm-acme[jax]
!pip install dm-acme[tf]
!pip install dm-haiku
!pip install chex
!pip install optax
!pip install jax[cuda]

## The environment

We will use the Flappy Bird environment defined in the deep_rl package. Let's have a closer look at it.


In [None]:
from deep_rl.environments.flappy_bird import FlappyBird

env = FlappyBird(
        gravity=0.05,
        force_push=0.1,
        vx=0.05,
        prob_new_bar=1,
        invictus_mode=False,
        max_height_bar=0.5,
    )

print(env.help)

For example let's interact with it a little bit.

In [None]:
rows, cols = env.min_res
print(f"We should use at least {rows} rows and {cols} when rendering the environment")

obs_reset = env.reset()
print("First observation when reseting the environment:")
print(obs_reset)
print()

print("Now, let's perform a few steps\n")

print("Step 1: we let the bird fall")
obs, reward, done = env.step(0)
print(f"Observation: {obs}")
print(f"Reward: {reward}")
print(f"Game over: {done}")
print()

print("Step 2: we push the bird up")
obs, reward, done = env.step(1)
print(f"Observation: {obs}")
print(f"Reward: {reward}")
print(f"Game over: {done}")
print()

print("Step 3: we push the bird up again")
obs, reward, done = env.step(1)
print(f"Observation: {obs}")
print(f"Reward: {reward}")
print(f"Game over: {done}")
print()

print("Step 4: we push the bird up again")
obs, reward, done = env.step(1)
print(f"Observation: {obs}")
print(f"Reward: {reward}")
print(f"Game over: {done}")
print()

To simplify typing a bit, the deep_rl package implements a new type `FlappyObs` which corresponds to a state of the flappy bird environment.

In [None]:
from typing import List, Tuple

BarObs = Tuple[float, float, float, bool]
BirdObs = Tuple[float, float, float]
FlappyObs = Tuple[BirdObs, List[BarObs]]

## Baseline

We provide you with a simple baseline: the `StableAgent` which does nothing more than keeping the bird stable.

In [None]:
from deep_rl.environments.flappy_bird import FlappyObs


class StableAgent:
    """An agent which just keeps the bird stable."""

    def __init__(self, target_y: float = 0.5):
        self._target_y = target_y

    def sample_action(
        self,
        observation: FlappyObs,
        evaluation: bool,
    ) -> int:
        _, y_bird, v_y_bird = observation[0]

        if y_bird <= self._target_y and v_y_bird <= 0:
            return 1
        else:
            return 0

Let's see how a single runs works in practice with this agent.

In [None]:
from IPython.display import clear_output
from deep_rl.terminal_renderer import BashRenderer
from deep_rl.episode_runner import run_episode
from deep_rl.project_values import PROJECT_FLAPPY_BIRD_ENV

# We are going to render the environment !
ROWS = 30
COLS = 60
# Because ipython sucks, I have not found a cleaner option to add
# the refresher function
renderer = BashRenderer(ROWS,
                        COLS,
                        clear_fn = lambda: clear_output(wait=True))

# Flappy bird environment
env = PROJECT_FLAPPY_BIRD_ENV

# Our agent
agent = StableAgent()

# We run a single episode, with rendering, over a maximum of 100 steps
run_episode(env,
            agent,
            max_steps=100,
            renderer = renderer,
            time_between_frame=0.1)

Without rendering now, let's see the average reward we can get over 100 episodes with this agent.

In [None]:
from deep_rl.project_values import PROJECT_FLAPPY_BIRD_ENV
from deep_rl.episode_runner import run_episode

# Flappy bird environment
env = PROJECT_FLAPPY_BIRD_ENV

# Our agent
agent = StableAgent()

N_EPISODES = 100

reward = 0
for _ in range(N_EPISODES):
    reward += run_episode(env, agent, max_steps=1000, renderer=None)

reward /= N_EPISODES

print(f"Average reward over {N_EPISODES} episodes: {reward}")

An now, we need to do much better.

## Let's get to work !

We will attempt to design and train an agent that performs the best possible score on Flappy bird. Here are the constraints:
- if you chose a Deep learning algorithm, you must use jax and Haiku. Pytorch is not allowed for this project.
- our agent should converge in less than an hour. 
- our agent must maximize the reward obtained over 100 episodes with a maximal number of 1000 steps per episode.

On top of that, we will  plot and analyse the relevant curves showing the evolution of our training loop.

### Agent's API

Our agent will implement a method, `sample_action`, which takes two arguments as input, the observed state and wether or not it is in evaluation mode, and pick the action to perform. Appart from that, we can add any other method we want to our model.

In this project, the main method that proved to solve this environment is the DQN (Deeq Q Networks) methods. We use the DQN implementation already developed in our deep rl class and we add and edit it accordingly. The main addditions to the base DQN model implemented in class are:
- Scheduler for the epsilon policy where $\epsilon$ decays over time (with the number of iterations (epochs): $$
\epsilon =  \begin{cases} \epsilon_{start} - \frac{episode_{i}}{\epsilon_{decayLastFrame}} \: \text{if} \: episode_{i} \leq \epsilon_{decayLastFrame} \\ \epsilon_{final} \end{cases}$$ 
- Scheduler for the target network parameter updates. We set a simple update at each $sync_{target}$ episodes. 
- We add a pre-processing stage on the observations generated by the environment in order to provide a rich representation of the state.

In addition to the base DQN, we train the following:
- Double DQN
- Dueling DQN
- Double Dueling DQN
We obtain the best resutls on the **Double Dueling DQN** (check below for training and evaluation).

In [None]:
import dataclasses
import time

import jax
import chex
import optax

import numpy as onp
import jax.numpy as jnp
import haiku as hk
import matplotlib.pyplot as plt
import seaborn as sns

from dataclasses import dataclass, is_dataclass
from typing import List, Tuple
from operator import itemgetter

from deep_rl.project_values import PROJECT_FLAPPY_BIRD_ENV
from IPython.display import clear_output
from deep_rl.terminal_renderer import BashRenderer
from deep_rl.episode_runner import run_episode


sns.set_style('darkgrid')

In [None]:
@chex.dataclass
class Transition:
    state_t: chex.Array
    action_t: chex.Array
    reward_t: chex.Array
    done_t: chex.Array
    state_tp1: chex.Array

@dataclass
class EpisodeTrainingStatus:
    episode_number: int
    reward: float
    training_time: float

@chex.dataclass
class LearnerState:
    online_params: hk.Params
    target_params: hk.Params
    opt_state: optax.OptState

@dataclass
class Epsilon:
    epsilon_decay_last_frame: int
    epsilon_start: float
    epsilon_final: float

In [None]:
class ReplayBuffer:
    """Fixed-size buffer to store transition tuples."""

    def __init__(self, buffer_capacity: int):
        """Initialize a ReplayBuffer object.
        Args:
            buffer_capacity (int): maximal number of tuples to store at once
        """
        self._memory = list()
        self._maxlen = buffer_capacity

    @property
    def size(self) -> int:
        # Return the current number of elements in the buffer.
        return len(self._memory)

    def add(
        self,
        state_t: chex.Array,
        action_t: chex.Array,
        reward_t: chex.Array,
        done_t: chex.Array,
        state_tp1: chex.Array,
    ) -> None:
        """Add a new transition to memory."""

        if self.size > self._maxlen:
            self._memory = self._memory[1:]

        self._memory.append(
            Transition(
                state_t=state_t,
                action_t=action_t,
                reward_t=reward_t,
                done_t=done_t,
                state_tp1=state_tp1,
            )
        )

    def sample(self) -> Transition:
        """Randomly sample a transition from memory."""
        assert self._memory, "replay buffer is unfilled"
        index = onp.random.randint(self.size)
        return self._memory[index]

class BatchedReplayBuffer(ReplayBuffer):
    def sample_batch(self, batch_size) -> Transition:
        """Randomly sample a batch of experiences from memory."""
        assert (
            len(self._memory) >= batch_size
        ), "Insuficient number of transitions in replay buffer"

        samples = [self.sample() for i in range(batch_size)]
        kwargs = dict()
        for attr in ["state_t", "action_t", "reward_t", "done_t", "state_tp1"]:
            kwargs[attr] = onp.array([getattr(s, attr) for s in samples])
        return Transition(**kwargs)

### Base DQN
The basic DQN model is mainly explained in the report of our project. Essentially what we require is a function approximator for the $Q$-values which is our neural network, a replay buffer with uniform sampling and a target network which is updated according to our scheduler mentioned previously and finally, a scheduler for an epsilon-greedy strategy for sampling actions from our agent. The loss is expressed as: 
$$
\mathcal{L} = (Q(s, a) - y)^2, \: \text{where,} \: \\
y = \begin{cases} r, \: \text{if episode has ended} \\ r + \gamma\max_{a'\in A}{\hat{Q}(s', a')} \end{cases}
$$. 

We use a simple convolutional neural network with 2 `Conv2D` layers to extract the features from the input and one `MLP` (fully connected) layer to output the $Q$-values for each action in the action space. 

In [None]:
def dqn_flappy_network(x: chex.Array, n_actions: int):
    x = x[..., None]
    model = hk.Sequential(
    [
        hk.Conv2D(32, kernel_shape=[2, 2], stride=2, padding="VALID"),
        jax.nn.relu,
        hk.Conv2D(64, kernel_shape=[2, 2], stride=2, padding="VALID"),
        jax.nn.relu,
        hk.Flatten(),
        hk.nets.MLP([64, n_actions])
    ])
    
    return model(x)

#### Input Observation Transformation

We transform our observation from our environment to obtain the shape $(1+FOV, 4)$ where $FOV$ is a hyperparameter which essentially represents the number of bars visible by our agent. We add $1$ to the $FOV$ because we have one row for the agent state:
$$
[x, y, v_x, v_y]
$$
In addition, each bar has the following features:
$$
[presence, h_{distance}, v_{distance}, top]
$$
Where:
- $presence$: `bool` $(1/0)$, denotes if a bar is present or "seen" by the agent.
- $h_{distance}$: `float` $[-0.5, 0.5]$, denotes the vertical distance measured relatively from the agent. 
- $v_{distance}$: `float` $[0, 0.5]$, denotes the horizontal distance measured relatively from the agent.
- $top$: `bool` $(1/0)$, denotes wether the seen bar is on the top or bottom position.

The input state is always padded to obtain the mentioned shape to ensure a constant input to the network.

In [None]:
class DQN:
    def __init__(
        self,
        env: FlappyBird,
        net: hk.Module,
        epsilon: Epsilon,
        gamma: float,
        learning_rate: float,
        buffer_capacity: int,
        min_buffer_capacity: int,
        batch_size: int,
        target_ema: float,
        seed: int = 0,
        fov:int = 4,
        sync_target: int = 100
    ) -> None:
        """Initializes the DQN agent.

        Args:
          env: input maze environment.
          gamma: discount factor
          eps: probability to perform a random exploration when picking a new action.
          learning_rate: learning rate of the online network
          buffer_capacity: capacity of the replay buffer
          min_buffer_capacity: min buffer size before picking batches from the
            replay buffer to update the online network
          batch_size: batch size when updating the online network
          target_ema: weight when updating the target network.
          seed: seed of the random generator.
        """
        self._env = env
        self._learning_rate = learning_rate
        self._gamma = gamma
        self._batch_size = batch_size
        self._target_ema = target_ema
        self._Na = env.N_ACTIONS
        self.fov = fov
        self.net = net
        self.epsilon = epsilon
        
        # Define the neural network for this agent
        self._init, self._apply = hk.without_apply_rng(hk.transform(self._hk_qfunction))
        # Jit the forward pass of the neural network for better performances
        self.apply = jax.jit(self._apply)

        # Also jit the update functiom
        #self._update_fn = jax.jit(self._update_fn)
        # Initialize the network's parameters
        self._rng = jax.random.PRNGKey(seed)
        self._rng, init_rng = jax.random.split(self._rng)
        self._learner_state = self._init_state(init_rng)

        # Initialize the replay buffer
        self._min_buffer_capacity = min_buffer_capacity
        self._buffer = BatchedReplayBuffer(buffer_capacity)

        # Build a variable to store the last state observed by the agent
        self._state = None
        
        # Keep number of episodes stored for timed target updates and epsilon updates
        self.episode = 0        
        self.sync_target = sync_target
        
    def _optimizer(self) -> optax.GradientTransformation:
        return optax.adam(learning_rate=self._learning_rate)

    def _hk_qfunction(self, state: chex.Array) -> chex.Array:
        s = state
        return self.net(s, self._Na)

    def _init_state(self, rng: chex.PRNGKey) -> LearnerState:
        """Initialize the online parameters, the target parameters and the
        optimizer's state."""
        dummy_step = pre_process_obs(self._env.reset(), self.fov)[None]

        online_params = self._init(rng, dummy_step)
        target_params = online_params
        opt_state = self._optimizer().init(online_params)

        return LearnerState(
            online_params=online_params,
            target_params=target_params,
            opt_state=opt_state,
        )

    def _update_fn(
        self,
        state: LearnerState,
        batch: Transition,
    ) -> Tuple[chex.Array, LearnerState]:
        """Get the next learner state given the current batch of transitions.

        Args:
          state: learner state before update.
          batch: batch of experiences (st, at, rt, done_t, stp1)
        Returns:
          loss, learner state after update
        """
        # Compute gradients
        loss, gradients = jax.value_and_grad(self.loss_fn)(
            state.online_params,
            state.target_params,
            batch.state_t,
            batch.action_t,
            batch.reward_t,
            batch.done_t,
            batch.state_tp1,
        )

        # Apply gradients
        updates, new_opt_state = self._optimizer().update(gradients, state.opt_state)
        new_online_params = optax.apply_updates(state.online_params, updates)

        # Update target network params as:
        # target_params <- ema * target_params + (1 - ema) * online_params
        
        if  self.episode % self.sync_target == 0:
            new_target_params = jax.tree_map(
                lambda x, y: x + (1 - self._target_ema) * (y - x),
                state.target_params,
                new_online_params,
            )
        else:
            new_target_params = state.target_params
            
        return loss, LearnerState(
            online_params=new_online_params,
            target_params=new_target_params,
            opt_state=new_opt_state,
        )

    def loss_fn(
        self,
        online_params: hk.Params,
        target_params: hk.Params,
        state_t: chex.Array,
        action_t: chex.Array,
        reward_t: chex.Array,
        done_t: chex.Array,
        state_tp1: chex.Array,
    ) -> chex.Array:
        """Computes the Q-learning loss

        Args:
          online_params: parameters of the online network
          target_params: parameters of the target network
          state_t: batch of observations at time t
          action_t: batch of actions performed at time t
          reward_t: batch of rewards obtained at time t
          done_t: batch of end of episode status at time t
          state_tp1: batch of states at time t+1
        Returns:
          The Q-learning loss.
        """
        # Step one: compute the target Q-value for state t+1
        q_tp1 = self._apply(target_params, state_tp1)

        # We do not want to consider the Q-value of states that are done !
        # For theses states, q(t+1) = 0
        q_tp1 = (1.0 - done_t[..., None]) * q_tp1

        # Now deduce the value of the target cumulative reward
        y_t = reward_t + self._gamma * jnp.max(q_tp1, axis=1)  # Shape B

        # Compute the online Q-value for state t
        q_t = self._apply(online_params, state_t)  # Shape B , Na

        # Ok, but we only want the Q value for the actions that have actually
        # been played
        q_at = jax.vmap(lambda idx, q: q[idx])(action_t, q_t)

        # Compute the square error
        error = (q_at - y_t) ** 2

        # Deduce the loss
        return jnp.mean(error)

    def sample_action(
        self,
        state,
        evaluation: bool
    ) -> int:
        """Picks the next action using an epsilon greedy policy.

        Args:
          state: observed state.
          eval: if True the agent is acting in evaluation mode (which means it only
            acts according to the best policy it knows.)
        """
        # Fill in this function to act using an epsilon-greedy policy.

        epsilon = max(self.epsilon.epsilon_final, self.epsilon.epsilon_start -
                      self.episode / self.epsilon.epsilon_decay_last_frame)
        
        if not evaluation and onp.random.uniform() < epsilon:
            return onp.random.randint(self._Na)
        else:
            if isinstance(state, tuple):
                state = pre_process_obs(state, self.fov)
                return onp.argmax(
                self._apply(self._learner_state.online_params, state[None]))

    def observe(
        self,
        action_t: chex.Array,
        reward_t: chex.Array,
        done_t: chex.Array,
        state_tp1: chex.Array,
    ) -> None:
        
        if isinstance(state_tp1, tuple):
            state_tp1 = pre_process_obs(state_tp1, self.fov)
        self._buffer.add(self._state, action_t, reward_t, done_t, state_tp1)
        self._state = state_tp1

        # We update the agent if and only if we have enought state stored in
        # memory.
        if self._buffer.size >= self._min_buffer_capacity:
            batch = self._buffer.sample_batch(self._batch_size)
            loss, self._learner_state = self._update_fn(self._learner_state, batch)
            return loss
        return 0.0
    
    def first_observe(self, state: chex.Array) -> None:
        self._state = pre_process_obs(state, self.fov)
        self.episode += 1
        

### Double DQN

In the Double DQN extension, the researchers from Deepmind in [Deep Reinforcement Learning with Double Q-learning](https://arxiv.org/pdf/1509.06461.pdf) demonstrated that the base DQN tends to overestimate the values for $Q$ which is harmful for the training and the overall performance sometimes resulting in a suboptimal policy. The cause of this tends to come from the max operation in the Bellman equation:
$$
Q(s_t, a_t) = r_t + \gamma\max_{a}{Q'(s_{t+1}, a)}
$$
Where $Q'(s_{t+1}, a)$ were the $Q$ values calculated by the target network. The authors of the paper proposed choosing actions for the next state using the trained network but taking values of $Q$ from the target netowrk. The new expression for the target $Q$-values will be:
$$
Q(s_t, a_t) =  r_t + \gamma\max_{a}{Q'(s_{t+1}, arg\max_{a}Q(s_{t+1}, a)})
$$

The changes are done in the loss function method of the base DQN which is shown below.

In [None]:
class DoubleDQN(DQN):
    def __init__(self,
        env: FlappyBird,
        net: hk.Module,
        epsilon: Epsilon,
        gamma: float,
        learning_rate: float,
        buffer_capacity: int,
        min_buffer_capacity: int,
        batch_size: int,
        target_ema: float,
        seed: int = 0,
        fov:int = 4,
        sync_target: int = 100):
        
        super(DoubleDQN, self).__init__(
        env,
        net,
        epsilon,
        gamma,
        learning_rate,
        buffer_capacity,
        min_buffer_capacity,
        batch_size,
        target_ema,
        seed,
        fov,
        sync_target)
    
    def loss_fn(
        self,
        online_params: hk.Params,
        target_params: hk.Params,
        state_t: chex.Array,
        action_t: chex.Array,
        reward_t: chex.Array,
        done_t: chex.Array,
        state_tp1: chex.Array,
    ) -> chex.Array:

        a_tp1 = jnp.argmax(jax.lax.stop_gradient(self._apply(online_params, state_tp1)), axis=-1)
        q_tp1 = self._apply(target_params, state_tp1)[jnp.arange(a_tp1.shape[0]), a_tp1]

        # We do not want to consider the Q-value of states that are done !
        # For theses states, q(t+1) = 0
        q_tp1 = (1.0 - done_t[..., None]) * q_tp1

        # Now deduce the value of the target cumulative reward
        y_t = reward_t + self._gamma * q_tp1  # Shape B
        
        
        
        # Compute the online Q-value for state t
        q_t = self._apply(online_params, state_t)  # Shape B , Na

        # Ok, but we only want the Q value for the actions that have actually
        # been played
        q_at = jax.vmap(lambda idx, q: q[idx])(action_t, q_t)

        # Compute the square error
        error = (q_at - y_t) ** 2

        # Deduce the loss
        return jnp.mean(error)  

### Dueling DQN

For the dueling DQN extension, we define an architecture that separates the value of a state and the advantage of the state action pair where the advantage is $A(s, a)$ and the value is $V(s)$ and the networks estimates per usual $Q(s, a) = V(s) + A(s, a)$. Convolutional features from the input are processed in two parallel ways; one that calculates the $V(s)$ prediction and the other $A(s, a)$ after which these values are added. More specifically, to ensure correct and stable learning, we need to calculate $Q(s, a) = V(s) + A(s, a) - \frac{1}{N}\sum_{k}{A(s, k)}$. All of the other parts of the DQN training process are kept the same. [Source](https://arxiv.org/pdf/1511.06581.pdf).

In [None]:
def dueling_dqn_flappy_network(x: chex.Array, n_actions: int):
    x = x[..., None]
    conv = hk.Sequential(
    [
        hk.Conv2D(32, kernel_shape=[2, 2], stride=2, padding="VALID"),
        jax.nn.relu,
        hk.Conv2D(64, kernel_shape=[2, 2], stride=2, padding="VALID"),
        jax.nn.relu,
        hk.Flatten()
        
    ])
    
    fc_adv = hk.Sequential(
    [
        #hk.nets.MLP([1*1*64, 128]),
        #jax.nn.relu,
        hk.nets.MLP([1*1*64, n_actions]),
        
    ])
    
    fc_val = hk.Sequential(
    [
        #hk.nets.MLP([1*1*64, 128]),
        #jax.nn.relu,
        hk.nets.MLP([1*1*64, 1]),
    ])
    
    def adv_val(x: chex.Array):
        conv_out = conv(x)
        return fc_adv(conv_out), fc_val(conv_out)
    
    adv, val = adv_val(x)
    return val + (adv - adv.mean())

### Double Dueling DQN

The double dueling DQN extension works by using the dueling neural network approach coupled with the double loss. This is supposed to achieve the best of both worlds and is setup in the configuration in the training section below.

## Environment

You must use the following flappy bird environment from the deep_rl package.


In [None]:
from deep_rl.project_values import PROJECT_FLAPPY_BIRD_ENV

### Training loop

You can use the following training loop to train your agent. Do not hesitate to play with the different parameters or even modify the code if you think you have a better option.

In [None]:
MAX_TIME_TRAINING = 3600 * 2

def run_episode_no_rendering(
    env,
    agent,
    evaluation: bool,
    max_steps: int,
) -> float:
    """Runs a single episode.

    Args:
    env: environment to consider.
    agent: agent to run.
    evaluation: if False, will train the agent.
    max_steps: number of steps after wich the evaluation should be stoppped
      no matter what.
    Returns:
    The total reward accumulated over the episode.
    """

    observation = env.reset()
    agent.first_observe(state=observation)
    tot_reward = 0

    for i in range(max_steps):
        action = agent.sample_action(observation, evaluation)
        observation, reward, end_game = env.step(action)

        if not evaluation:
            agent.observe(action, reward, end_game, observation)
        tot_reward += reward

        if end_game:
            break

    return tot_reward


def train_agent(
    env,
    agent,
    num_episodes: int,
    num_eval_episodes: int,
    eval_every_N: int,
    max_steps_episode: int,
    max_time_training: float = MAX_TIME_TRAINING,
) -> List[EpisodeTrainingStatus]:
    """Train your agent on the given environment.

    Args:
      env: environment to consider.
      agent: agent to train.
      num_episodes: number of episode to run for training.
      eval_every_N: frequency at which the agent is evaluated.
      max_steps_episode: maximal number of step per episode.
      max_time_training: maximal duration of the training loop (in seconds).
    Returns:
      The total reward accumulated over the episode.
    """

    all_status = []
    print(f"Episode number:\t| Average reward on {num_eval_episodes} eval episodes")
    print("------------------------------------------------------")

    start_time = time.time()

    for episode in range(num_episodes):
        run_episode_no_rendering(
            env, agent, evaluation=False, max_steps=max_steps_episode
        )
        if episode % eval_every_N == 0:
            reward = 0
            d_time = time.time() - start_time
            for _ in range(num_eval_episodes):
                reward += run_episode(
                    env, agent, evaluation=True, max_steps=max_steps_episode
                )
            reward /= num_eval_episodes
            print(f"\t{episode}\t|\t{reward}")
            all_status.append(
                EpisodeTrainingStatus(
                    episode_number=episode, reward=reward, training_time=d_time
                )
            )

            if d_time > max_time_training:
                break

    return all_status

In [None]:
def pre_process_obs(obs: FlappyObs, FOV:int, sort=True):
    """
    Function that pre-processes the observation:
    """
    def sort_bars(bars: List[BarObs]):
        return sorted(bars, key=itemgetter(0))
    
    def build_state(obs: List[FlappyObs]):
        """
        Function that converts default observation into custom state
        """
        agent, bars = obs

        n_features = 4
        new_bars = jnp.zeros(shape=(FOV, n_features))
        
        if len(bars):
            if len(bars) > FOV:
                bars = bars[:FOV]
                
            for i, bar in enumerate(bars):
                x_min, x_max, height, top = bar
                h_bar = x_min - agent[0]
                if top:
                    v_bar = (1 - height) - agent[1]
                else:
                    v_bar = height - agent[1]
                new_bar = jnp.array([1., h_bar, v_bar, float(int(top))])
                new_bars = new_bars.at[i].set(new_bar)
                
        new_agent = jnp.array([[agent[0], agent[1], 0.05, agent[2]]])

        return jnp.concatenate([new_agent, new_bars])

    bird, bars = obs
    bird_x, bird_y, bird_vy = bird
    
    new_bars = []
        
    if len(bars):
        for bar in bars:
            x_min, x_max, h, top = bar
            # remove bars which are behind the agent
            if x_max >= bird_x:
                new_bars.append(bar)
        if sort:
            # sort bars
            new_bars = sort_bars(new_bars)
        
    new_obs: FlappyObs = (bird, new_bars)
    
    state = build_state(new_obs)
    return state

### Trainings

In this section, we only show the best configs obtained during our hyperparameters sweeps. We mention in the project report the sweeps of parameters and their general effects on the training and overall performance. We include in the project files the results obtained which are shown in report as ".npy" files where each file has a shape of $(3, 100)$ where on the first dimension we have:
- index 0: array of rewards
- index 1: array of episodes
- index 2: array of timesteps

We include a section in the code to load the numpy files and visualize the results or just directly load the calculated ones at runtime.

In [None]:
NUM_EPISODES = 1200 # can be 1500 for ~ 2hrs, 1000 finishes under under 2hrs ~1hr 30 minutes per model. 
NUM_EVAL_EPISODES = 50
EVAL_EVERY_N = 50
MAX_STEPS_EPISODE = 100
MAX_TIME_TRAINING = 3600 * 2

In [None]:
epsilon = Epsilon(epsilon_decay_last_frame= NUM_EPISODES, epsilon_start=0.99, epsilon_final=0.01)

#### DQN - Training

In [None]:
agent_1 = DQN(
    env,
    net=dqn_flappy_network,
    epsilon=epsilon,
    gamma=0.8,
    learning_rate=1e-4,
    buffer_capacity=1000,
    min_buffer_capacity=32,
    batch_size=32,
    target_ema=0.9,
    sync_target=1,
    fov=4
)
all_status_dqn_1 = train_agent(
    env=env,
    agent=agent_1,
    num_episodes=NUM_EPISODES,
    num_eval_episodes=NUM_EVAL_EPISODES,
    eval_every_N=EVAL_EVERY_N,
    max_steps_episode=MAX_STEPS_EPISODE,
    max_time_training=MAX_TIME_TRAINING,
)

#### Dueling DQN - Training

In [None]:
agent_dueling_1 = DQN(
    env,
    net=dueling_dqn_flappy_network,
    epsilon=epsilon,
    gamma=0.8,
    learning_rate=3e-4,
    buffer_capacity=1000,
    min_buffer_capacity=32,
    batch_size=32,
    target_ema=0.9,
    sync_target=1,
    fov=4
)
all_status_dueling_1 = train_agent(
    env=env,
    agent=agent_dueling_1,
    num_episodes=NUM_EPISODES,
    num_eval_episodes=NUM_EVAL_EPISODES,
    eval_every_N=EVAL_EVERY_N,
    max_steps_episode=MAX_STEPS_EPISODE,
    max_time_training=MAX_TIME_TRAINING,
)

#### Double DQN - Training

In [None]:
agent_double_1 = DoubleDQN(
    env,
    net=dqn_flappy_network,
    epsilon=epsilon,
    gamma=0.8,
    learning_rate=3e-4,
    buffer_capacity=1000,
    min_buffer_capacity=32,
    batch_size=32,
    target_ema=0.9,
    sync_target=1,
    fov=4
)
all_status_double_1 = train_agent(
    env=env,
    agent=agent_double_1,
    num_episodes=NUM_EPISODES,
    num_eval_episodes=NUM_EVAL_EPISODES,
    eval_every_N=EVAL_EVERY_N,
    max_steps_episode=MAX_STEPS_EPISODE,
    max_time_training=MAX_TIME_TRAINING,
)

#### Double Dueling DQN Training - ! Best Results !

In [None]:
agent_double_dueling_1 = DoubleDQN(
    env,
    net=dueling_dqn_flappy_network,
    epsilon=epsilon,
    gamma=0.8,
    learning_rate=3e-4,
    buffer_capacity=1000,
    min_buffer_capacity=32,
    batch_size=32,
    target_ema=0.9,
    sync_target=1,
    fov=4
)
all_status_double_dueling = train_agent(
    env=env,
    agent=agent_double_dueling_1,
    num_episodes=NUM_EPISODES,
    num_eval_episodes=NUM_EVAL_EPISODES,
    eval_every_N=EVAL_EVERY_N,
    max_steps_episode=MAX_STEPS_EPISODE,
    max_time_training=MAX_TIME_TRAINING,
)

### Visualizations + Results
You can use the following code to visualize a single run made by your agent. This can help you for debugging.

In this section, we show the visuals used in the paper. We show the methods used for the generated results and we also provide with the original numpy arrays used to generate the plots. We can either load the arrays from memory or import them from local machine (or drive).

In [None]:
def get_arr_from_status(status: List[EpisodeTrainingStatus]):
    r_arr = onp.zeros((1, len(status)))
    ep_arr = onp.zeros((1, len(status)))
    t_arr = onp.zeros((1, len(status)))

    for i, ep_status in enumerate(status):
        r_arr[0, i] = ep_status.reward
        ep_arr[0, i] = ep_status.episode_number
        t_arr[0, i] = ep_status.training_time

    status_arr = onp.concatenate([r_arr, ep_arr, t_arr])
    return status_arr

In [None]:
# uncomment to plot and visualize current runs metrics - else load saved numpy files (original data)
#dqn_arr = get_arr_from_status(all_status_dqn_1)
#dueling_arr = get_arr_from_status(all_status_dueling_1)
#double_arr = get_arr_from_status(all_status_double_1)
#dueling_double_arr = get_arr_from_status(all_status_double_dueling)

In [None]:
#onp.save("all_status_dqn_1.npy", dqn_arr)
#onp.save("all_status_dueling_1.npy", dueling_arr)
#onp.save("all_status_double_1.npy", double_arr)
#onp.save("all_status_dueling_double_1.npy", dueling_double_arr)

In [None]:
def get_rolling_vals(x: onp.array, window_width:int=6) -> Tuple[onp.array]:
    
    ma_vec = onp.convolve(x, onp.ones((window_width,))/window_width, mode='valid')
    
    nrows = x.size - window_width + 1
    n = x.strides[0]
    
    x2D = onp.lib.stride_tricks.as_strided(x, shape=(nrows, window_width), strides=(n, n))
    mstd_vec = onp.std(x2D, axis=1)
    
    high = ma_vec + mstd_vec
    low = ma_vec - mstd_vec
    
    return ma_vec, high, low

In [None]:
def plot_results(arrays: List[onp.array], labels: List[str], rolling_window:int = 6, show_std:bool = True):
    
    fig = plt.figure(figsize=(11, 7))
    ax1 = fig.add_subplot(111)

    for i, array in enumerate(arrays):
        arr_ma, arr_hi, arr_lo = get_rolling_vals(array, window_width=rolling_window)

        ax1.plot(arr_ma, label=labels[i])
        if show_std:
            ax1.fill_between(onp.arange(arr_ma.shape[0]), arr_hi, arr_lo, alpha=0.1)

    
    ticks = onp.arange(0, array.shape[0]+EVAL_EVERY_N, EVAL_EVERY_N)
    ax1.set_xticks(ticks)
    ax1.set_xticklabels([str(10*x) for x in ticks])
    ax1.set_xlabel("Episodes")
    ax1.set_ylabel("Reward")
    ax1.set_title("Moving Average Reward of DQN + Extensions Over Nb of Episodes")
    ax1.legend()
    plt.show()

In [None]:
# load numpy arrays containing metrics from original runs
dqn_arr = onp.load("all_status_dqn_1.npy")
dueling_arr = onp.load("all_status_dueling_1.npy") 
double_arr = onp.load("all_status_double_1.npy")
dueling_double_arr = onp.load("all_status_dueling_double_1.npy")

In [None]:
arrays = [arr[0] for arr in [dqn_arr, dueling_arr, double_arr, dueling_double_arr]]
labels = ["DQN", "DuelingDQN", "DoubleDQN", "DuelingDoubleDQN"]

plot_results(arrays, labels, rolling_window=10, show_std=False)

In [None]:
def evaluate_model(env, agent, n_episodes, max_steps):
    rewards = onp.zeros(shape=(n_episodes,))
    for i in range(n_episodes):
        rewards[i] = run_episode(env, agent, max_steps=max_steps, renderer=None)

    means = onp.mean(rewards)
    stds = onp.std(rewards)
    
    print(
        f"Average reward over {n_episodes} episodes: {means:.3f} with standard deviation: {stds:.3f}"
    )
    return rewards

In [None]:
print("Model: Base DQN")
_ = evaluate_model(env, agent_1, 100, 1000)

print("Model: Dueling DQN")
_ = evaluate_model(env, agent_dueling_1, 100, 1000)

print("Model: Double DQN")
_ = evaluate_model(env, agent_double_1, 100, 1000)

In [None]:
print("Model: Dueling Double DQN")
_ = evaluate_model(env, agent_double_dueling_1, 100, 1000)

In [None]:
print("Model: Stable Agent")
stable_rewards = evaluate_model(env, StableAgent(), 100, 1000)

### Best Model Evaluation

In [None]:
# We are going to render the environment !
ROWS = 30
COLS = 60
renderer = BashRenderer(ROWS,
                        COLS,
                        clear_fn= lambda: clear_output(wait=True))


# We run a single episode, with rendering, over a maximum of 1000 steps
run_episode(PROJECT_FLAPPY_BIRD_ENV,
            agent_double_dueling_1,
            max_steps= 1000,
            renderer= renderer,
            time_between_frame= 0.1, evaluation=True)

In [None]:
arrays = [dqn_arr[0], dueling_double_arr[0], stable_rewards]
labels = ["DQN", "DuelingDoubleDQN", "StableAgent"]

plot_results(arrays, labels, rolling_window=10, show_std=True)

From the results shown, we were able to train a Double Dueling DQN model in under 2 hours and obtain an average reward of $25$ as a results solving our environment (confirmed by the epic winner winner chicken dinner hidden message). We train for 1000 steps with evaluation on every 10 epochs. There still is a somewhat large standard deviation of the reward even though when evaluating we sample actions according to the greedy policy. This can come from the fact that the environment generation is sometimes way too hard and it might hint that the network still needs some training time. 
In addition, we haven't explored more extensions to the DQN method which include **Prioritized Replay Buffer**, **Noisy DQN** and **Rainbow DQN** which is a combination of most of the extensions shown and mentioned. 