# Assignment 8: Policy gradients in Jax

In this assignment, you will implement the REINFORCE algorithm for the inverted pendulum swing-up problem. The pendulum starts in a random position and the goal is to apply torque on the free end to swing it into an upright position, with its center of gravity right above the fixed point. 

Like in the previous assignment, we will use the Jax framework for automatic differentiation. We (again) recommend you to look through the following articles to get started with the Jax framework:

- [Jax Quickstart](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
- [Training a simple NN with Jax](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
- [Jax vs. NumPy](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)

In addition, we use two additional libraries, optax and gymnax, which are used for optimization and for simulating OpenAI gym models. Both packages can be installed via the pip package manager.

## 8.1 Using gymnax (0 points)
the following code sets up a gymnax environment for the inverted pendulum swing up, where we discretize the action space manually. The code below does not have to be changed, but it is recommend to try to understand observation and state representations of the pendulum.

In [None]:
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.nn as jnn
from jax.random import PRNGKey
import optax
import gymnax
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
key = PRNGKey(42)
key, key_reset, key_policy, key_step = jr.split(key, 4)

# Create the Pendulum-v1 environment
env_name = "Pendulum-v1"
env, env_params = gymnax.make(env_name)

# Inspect default environment settings
print(env_params)
ts = jnp.arange(0, env_params.dt * env_params.max_steps_in_episode, env_params.dt)

obs, state = env.reset(key_reset, env_params)
obs, state

In [None]:
action = env.action_space(env_params).sample(key_policy)
action_scalar = env.action_space() 
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)
n_obs, n_state, reward, done

For this assignment, we discretize the action space:

In [None]:
obs, state = env.reset(key_reset)
action = env.action_space().sample(key_policy)
action_list = jnp.array([-1., 0., 1.]) 
num_actions = len(action_list)
n_obs, n_state, reward, done, _ = env.step(key_step, state, action)

env_params

## 8.2 Setting up a basic neural network (3 points)

Create a basic neural network in Jax using jax.nn and jax.random for initialising the weights and biases randomly with a mean of zero and a small standard deviation. 

Since the network represents the policy, it should take a state as input and output a predicted probability distribution over all discrete actions. Below is a suggested setup for the functions via which this can be done.

In [None]:
def initialize_mlp(layer_sizes, key:PRNGKey, scale:float=1e-2):
    """
    Inputs:
        layer_sizes (tuple) Tuple of shapes of the neural network layers. Includes the input shape, hidden layer shape, and output layer shape.
        key (PRNGKey) 
        scale (float) standard deviation of initial weights and biases

    Return: 
        params (List) Tuple of weights and biases - [ (weights_1, biases_1), ..., (weights_n, biases_n) ]
    """
    raise NotImplementedError

def policy(params, x):
    """ Standard MLP that predicts either -1, 0, 1.
    
    Inputs:
        params (PyTree) Parameters of the policy network, represented as PyTree. 
        x (D,) input state, where D is the dimensionality of the state observation.
        """
    raise NotImplementedError

In addition to the policy network, we will need some helper functions to do the following: 

- Sample an action from the policy distribution.
- Calculate the log probability of an action given the policy distribution.
- Update the error terms $\delta$ during training (hint: the tree.map() function from question 1 is useful here).

In [None]:
def get_action(params, x, key:PRNGKey):
    """  
    Sample an action using the action probabilities predicted by the MLP

    Input:
        params (PyTree) Parameters of the policy network, represented as PyTree. 
        x (D,) input state, where D is the dimensionality of the state observation.
        key (PRNGKey)
     
    Return:
        action (M,) of floats: actions generated according to params, where M is the dimensionality of actions we carry out. 
        action_idx (M,) of int: indices of actions generated according to params.
    """
    return 0., 0
 
def get_log_prob(params, x, action_idx):
    """
    Return the log probability of the action executed by the MLP.

    Input:
        params (PyTree) Parameters of the policy network, represented as PyTree. 
        x (D,) input state, where D is the dimensionality of the state observation.
        action_idx (M,) of int: indices of actions generated according to params.
     
    Return:
        log probability
    """
    raise NotImplementedError

@jax.jit
def update_delta(delta, grad_theta):
    """ 
    Update the parameter update delta with the gradient of the policy.

    Input:
        delta (PyTree) current loss term
        grad_theta (PyTree) gradient update of the network parameters.
    
    Return:
        updated_delta (PyTree) 
    """
    updated_delta = None
    raise NotImplementedError
    return updated_delta, None # you can leave the second return (None) for functionality of training loop.

In [None]:
# initialize the policy 
params = None

## Intermezzo: generating rollouts of the pendulum in parallel with Gymnax.
The following code generates batched rollouts of the environment in parallel. Note that it requires the 'get_action' function that should have been implemented above.

In [None]:
def rollout(params, env_params, rng_input:PRNGKey, steps_in_episode:int):
    """Rollout a jitted gymnax episode with lax.scan."""
    # Reset the environment
    rng_reset, rng_episode = jr.split(rng_input)
    obs, state = env.reset(rng_reset, env_params)


    def policy_step(state_input, tmp):
        """lax.scan compatible step transition in jax env."""
        obs, state, rng = state_input
        rng, rng_action, rng_step = jr.split(rng, 3)
        action, action_idx = get_action(params, obs, rng_action)
        next_obs, next_state, reward, done, _ = env.step(
          rng_step, state, action, env_params
        )
        carry = [next_obs, next_state, rng]
        return carry, [obs, state, action, action_idx, reward, next_obs, done]

    # Scan over episode step loop
    _, scan_out = jax.lax.scan(
      policy_step,
      [obs, state, rng_episode],
      (),
      length=steps_in_episode, 
    )
    return scan_out

# Jit-Compiled Episode Rollout
jit_rollout = jax.jit(rollout, static_argnums=3)

In [None]:
def visualize_trajectory(params, key):
    obs, state, action, action_idx, reward, next_obs, done = rollout(params, env_params, rng_input=key, steps_in_episode=env_params.max_steps_in_episode)


    fig, ax = plt.subplots(5,1,figsize=(8,8))
    # first three plots for the system states
    ax[0].set_title('System states over time')

    for d in range(env.obs_shape[0]):
        ax[d].plot(ts, obs[:,d], color='C0', label=f'State {d}')
    ax[0].set_title(r'$\cos(\theta)$')
    ax[1].set_title(r'$\sin(\theta)$')
    ax[2].set_title(r'$\dot{\theta}$')
    
    ax[3].plot(ts, action, color='C1', label=f'Actions')
    # ax[3].set_ylim((env.action_space().low, env.action_space().high))
    ax[3].set_title('u(t)')
    ax[4].plot(ts, reward, color='C2', label='Rewards')
    ax[4].set_title('r(t)')

    plt.tight_layout()
    plt.show()

In [None]:
key, subkey = jr.split(key)
print('Caption 1: Single rollout using the policy network without training')
visualize_trajectory(params,key = subkey)

## 8.3 REINFORCE without baseline (4 points)

Implement the REINFORCE algorithm from the lecture notes (algorithm 1) to compute the error terms to update the parameters. Make use of get_log_prob and update_delta functions implemented earlier. 

For iterating over the time steps, we recommend using the jax.lax.scan function. To parallelize the gradients across batches, it is recommended to use jax.vmap function. Note that either of them are possible using for loops too, however using these methods will drastically speed up the code.

To get you started, the training loop is already defined below. As such, you will have to implemenet the REINFORCE loss and choose the learning parameters. Keep in mind that optax.optimizer assumes a minimization problem when computing the gradients.

In [None]:
def loss_REINFORCE(params, obs, action_idx, reward, baseline, gamma:float=0.99):   
    """
    Compute the error term delta using the REINFORCE algorithm

    Inputs:
        params (PyTree) Current parameters of the network
        obs (Array) Batch of observations
        action_idx (Array) Batch of action indices
        reward (Array) Batch of rewards
        baseline (Array) Baseline over time points - not required for current question 8.3.

    Return:
        delta (PyTree) Error terms of the parameters
        Gt (Array) Batched discounted rewards over time
    """

    def trajectory_gradients(reward, obs, action_idx, baseline, delta): 
        G_init = 0

        def step(carry, variables):
            G, delta = carry
            r, obs, action_idx, baseline = variables
            
            """
            YOUR CODE HERE
            """

            carry = G, delta
            return carry, G

        #Iterate backwards in time
        variables = (reward[::-1], 
                     obs[::-1], 
                     action_idx[::-1], 
                     baseline[::-1])
        
        """ WRITE YOUR SCAN FUNCTION HERE THAT CALLS step()"""
        (_, delta), Gt = jax.lax.scan(None)
        return delta, Gt

    # create a parallizable function and initialize the error terms delta.
    """ VMAP THE trajectory_gradients() OVER THE BATCH SIZE"""
    parallel_trajectory_gradients = jax.vmap(None)

    # compute the delta gradients in parallel and sum them up.
    delta = jax.tree.map(lambda t: jnp.zeros(t.shape), params)
    deltas, Gs = parallel_trajectory_gradients(reward, obs, action_idx, baseline, delta)    
    delta, _ = jax.lax.scan(update_delta, delta, deltas)

    return delta, jnp.array(Gs)

# Jit the function for computational efficiency. Note: for printing inside the function, do not jit this by commenting the below line.
loss_REINFORCE = jax.jit(loss_REINFORCE)

In [None]:
# set training parameters
num_iters = None
steps_in_episode = env_params.max_steps_in_episode
lr = None
gamma = None
n_batches = None

optim = optax.adam(learning_rate=lr)
state = optim.init(params)

In [None]:
# Mini-batch random keys to scan over.
key, subkey = jr.split(key) 
iter_keys = jr.split(subkey, num_iters)


# Optimisation step.
def step(carry, key):
    params, opt_state, env_params = carry
    
    # forward pass
    keys = jr.split(key, n_batches)
    parallel_rollout = jax.vmap(rollout, in_axes=(None,None,0,None))
    obs, _, action, action_idx, reward, next_obs, done = parallel_rollout(params, 
                                                         env_params, 
                                                         keys,
                                                         steps_in_episode)
    empty_baseline = jnp.zeros((reward.shape[-1])) 

    # compute gradients and update model
    delta, _ = loss_REINFORCE(params, obs, action_idx, reward, empty_baseline, gamma)
    updates, opt_state = optim.update(delta, opt_state, params)
    new_params = optax.apply_updates(params, updates)

    carry = new_params, opt_state, env_params
    return carry, jnp.mean(jnp.sum(reward,axis=-1))

# Optimisation loop.
(params, _, _), history = jax.lax.scan(step, (params, state, env_params), (iter_keys))

In [None]:
plt.plot(history, label='loss')
plt.show()

In [None]:
visualize_trajectory(params, subkey)
key, subkey = jr.split(key)

## 8.4 Create (time-dependent) baseline (3 points)

Implement a baseline $b(t) = E[ G_t] $ following Algorithm 2 in the lecture notes. Rather than minimizing $| G_n - b(x_n)|^2$, you can instead take the average discounted reward as a constant baseline, or compute the time-dependent baseline by averaging over the cumulative discounted rewards over time.

Hint: this should be possible by only altering the training loop function, without altering the REINFORCE loss function. 

Comment on the effect of adding the baseline, and explain how this impacts the training of the network.