# Why and When Attention Matters in Reinforcement Learning
> "Using RLlib and AttentionNet to master environments with stateless observations, here, stateless CartPole."

- hide: true
- toc: true
- branch: master
- badges: true
- comments: true
- categories: [python, gym, ray, rllib, tensorflow, machine learning, reinforcement learning, sequence, attention]
- image: images/cartpole.jpg

In reinforcement learning (RL), the RL agent typically selects a suitable action based on the last observation.
In many practical environments, the full state can only be observed partially,
such that important information may be missing when just considering the last observation.
This blog post covers options for dealing with missing and only partially observed state,
e.g., considering a *sequence* of last observations and applying *self-attention* to this sequence.


## Example: The CartPole Gym Environment

As an example, consider the popular [OpenAI Gym CartPole environment](https://gym.openai.com/envs/CartPole-v1/).
Here, the task is to move a cart left or right in order to balance a pole on the cart as long as possible.

![OpenAI Gym CartPole-v1 Environment](attention/cartpole.gif "OpenAI Gym CartPole-v1 Environment")

In the normal [`CartPole-v1` environment](https://gym.openai.com/envs/CartPole-v1/), the RL agent observes four scalar values ([defined here](https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py#L26:L32)):
* The cart position, i.e., where the cart currently is.
* The cart velocity, i.e., how fast the cart is currently moving and in which direction (can be positive or negative).
* The pole angle, i.e., how tilted the pole currently is and in which direction.
* The pole angular velocity, i.e., how fast the pole is currently moving and in which direction.

All four observations are important to decide whether the cart should move left or right.

Now, assume the RL agent only has access to an instant snapshot of the cart and the pole (e.g., through a photo/raw pixels)
and can neither observe cart velocity nor pole angular velocity.
In this case, the RL agent does not know whether the pole is currently swinging to one side and cannot properly balance the pole.
How to deal with this problem of missing state (here, cart and pole velocity)?


## Options for Dealing With Missing State

There are different options for dealing with missing state, e.g., missing velocity in the CartPole example:

1. Add the missing state explicitly, e.g., measure and observe velocity. Note that this may require installing extra sensors or may even be infeasible in some scenarios.
2. Ignore the missing state, i.e., just rely on the available, partial observations. Depending on the missing state, this may be problematic and keep the agent from learning.
3. Keep track of a sequence of the last observations. By observing the cart position and pole angle over time, the agent can implicitly derive their velocity. There are different ways to deal with this sequence:
   1. Just use the sequence as is for a standard multi-layer perceptron (MLP)/dense feedforward neural network.
   2. Feed the sequence into a recurrent neural network (RNN), e.g., with long short-term memory (LSTM).
   3. Feed the sequence into a neural network with *self-attention*.

In the following, I explain each option in more detail and illustrate them using simple example code.


### Setup

For the examples, I use a PPO RL agent from Ray RLlib with the CartPole environment, described above.

To install these dependencies, run the following code (tested with Python 3.8 on Windows):

In [None]:
#collapse-output
!pip install ray[rllib]==1.8.0
!pip install tensorflow==2.7.0
!pip install seaborn==0.11.2
!pip install gym==0.21.0
!pip install pyglet==1.5.21



Start up ray, load the default PPO config, and determine the number of training iterations,
which is the same for all options (for comparability).

In [None]:
import ray

# adjust num_cpus and num_gpus to your system
# for some reason, num_cpus=2 gets stuck on my system (when trying to train)
ray.init(num_cpus=3, ignore_reinit_error=True)

# stop conditions based on training iterations (each with 4000 train steps)
stop = {"training_iteration": 1}

### Option 1: Explicitly Add Missing State

Sometimes, it is possible to extend the observations and explicitly add important state that was previously unobserved.
In the CartPole example, the cart and pole velocity can simply be "added" by using the default `CartPole-v1` environment.
Here, the cart velocity and pole velocity are already included in the observations.

Note that in many practical scenarios such "missing" state cannot be added and observed simply.
Instead, it may require installing additional sensors or may even be completely infeasible.

Let's start with the best case, i.e., explicitly including the missing state.

In [None]:
import gym

# the default CartPole env has all 4 observations: position and velocity of both cart and pole
env = gym.make("CartPole-v1")
env.observation_space.shape

In [None]:
#collapse-output

from ray.rllib.agents import ppo

# run PPO on the default CartPole-v1 env
config1 = ppo.DEFAULT_CONFIG.copy()
config1["env"] = "CartPole-v1"

# training takes a while
results1 = ray.tune.run("PPO", config=config1, stop=stop)
print("Option 1: Training finished successfully")

In [None]:
# check results
results1.default_metric = "episode_reward_mean"
results1.default_mode = "max"
# print mean number of time steps the pole was balanced (higher = better)
results1.best_result["episode_reward_mean"]

In [None]:
# plot the last 100 episode rewards
import seaborn as sns

eps_rewards = results1.best_result["hist_stats"]["episode_reward"]
eps = [i for i in range(len(eps_rewards))]
sns.scatterplot(eps, eps_rewards)


## Do Not Observe Velocity

In many practical

In [None]:
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.tune import registry

registry.register_env("StatelessCartPole", lambda _: StatelessCartPole())
config["env"] = "StatelessCartPole"
# train; this takes a while
results_stateless = ray.tune.run("PPO", config=config, stop=stop)
print("Training finished successfully")

In [None]:
# check results
results_stateless.default_metric = "episode_reward_mean"
results_stateless.default_mode = "max"
# print the mean episode reward = episode length --> higher = better
results_stateless.best_result["episode_reward_mean"]

In [None]:
eps_lengths = results_stateless.best_result["hist_stats"]["episode_lengths"]
eps = [i for i in range(len(eps_lengths))]
sns.scatterplot(eps, eps_lengths)


## Stacked Observations with Attention

TODO: continue here
https://github.com/ray-project/ray/blob/master/rllib/examples/attention_net.py

> Tip: Also check out the [RLlib example using AttentionNet](https://github.com/ray-project/ray/blob/master/rllib/examples/attention_net.py).


In [None]:
config["model"] = {
    # Attention net wrapping (for tf) can already use the native keras
    # model versions. For torch, this will have no effect.
    "_use_default_native_models": True,
    "use_attention": True,
    "max_seq_len": 10,
    "attention_num_transformer_units": 1,
    "attention_dim": 32,
    "attention_memory_inference": 10,
    "attention_memory_training": 10,
    "attention_num_heads": 1,
    "attention_head_dim": 32,
    "attention_position_wise_mlp_dim": 32,
}

results_attention = ray.tune.run("PPO", config=config, stop=stop)
print("Training finished successfully")

In [None]:
# check results
results_attention.default_metric = "episode_reward_mean"
results_attention.default_mode = "max"
# print the mean episode reward = episode length --> higher = better
results_attention.best_result["episode_reward_mean"]

In [None]:
eps_lengths = results_attention.best_result["hist_stats"]["episode_lengths"]
eps = [i for i in range(len(eps_lengths))]
sns.scatterplot(eps, eps_lengths)

## Stateless but Stacked Frames

(maybe even with LSTM)
use `framestack = True`

