# Reinforcement Learning as Probabilistic Inference


**Author**: [Lisa Lee](https://leelisa.com/)

This Colab notebook is based on Sergey Levine's tutorial ([Levine 2018](https://arxiv.org/abs/1805.00909)) that formalizes reinforcement learning (RL) as probabilistic inference. We apply this framework to a simple chain environment. We implement the corresponding graphical model for our task, then implement a standard sum-product inference algorithm to infer an optimal policy. We compare the learned soft Q-function $Q(s_t,a_t) = \log \beta_t(s_t,a_t)$ and policy $p(a_t \mid s_t, O_{t:T}) = \frac{\beta_t(s_t, a_t)}{\beta_t(s_t)}$ with that of a classic model-free RL algorithm, by evaluating the policies in the environment.

**Acknowledgements**: Thanks to Maruan Al-Shedivat, Ben Eysenbach, and Emilio Parisotto for providing helpful feedback.

**Note**: Please be careful not to publish solutions online, as this homework problem may be reused for future iterations of the [10-708 Probabilistic Graphical Model](https://sailinglab.github.io/pgm-spring-2019/) course.


## I. Markov Decision Process  <a name="mdp"></a>


### 1.1: Definition <a name="mdp-definition"></a>

Consider a simple chain environment consisting of $N$ states:

```
0 <-> 1 <-> ... <-> N-1
```

Suppose the agent always starts at state 0, and its task is to reach the goal state $N-1$. We formalize this task as a **Markov Decision Process** (MDP) consisting of a tuple $(\mathcal{S}, \mathcal{A}, P_0, \mathcal{T}, r)$, where:

* $\mathcal{S} = \{0,\ldots, N-1\}$ are the agent's possible **states**, indicating the agent's current location.
* $\mathcal{A} = \{\diamond, \leftarrow, \rightarrow\}$ are the agent's possible **actions**. The directional actions $\leftarrow, \rightarrow$ correspond to moving left or right one state, respectively. The action $\diamond$ corresponds to 'stay', i.e., staying in the same state.
* $P_0: \mathcal{S} \rightarrow [0, 1]$ is the agent's **initial state distribution**, where we assume the agent always starts in state 0:
$$
P_0(s) = \begin{cases}
1 & \text{if }s = 0\\
0 & \text{otherwise}
\end{cases}
$$
* $\mathcal{T}(s_{t+1} \mid s_t, a_t)$  is the **transition dynamics** probability that action $a_t$ in state $s_t$ will lead to the next state $s_{t+1}$. Assume there is a "action fail" probability $\epsilon \in [0, 1]$ such that, at each time step $t$, the agent stays in the same cell regardless of the action it takes. Note that the action $\diamond$ has deterministic transition dynamics: $\mathcal{T}(s' \mid s, \diamond)=\mathbb{1}(s' = s)$ for all $s, s' \in \mathcal{S}$.

* $r(s_t,a_t) : \mathcal{S} \times \mathcal{A} \rightarrow \mathbb{R}$ is a **reward** function that provides a supervision signal for taking action $a_t$ in state $s_t$, where
$$
r(s_t,a_t) = \begin{cases}
0 & \text{if }s_t = N-1 \\
0 & \text{if }s_t = N-2\text{ and }a_t = \rightarrow \\
-1 & \text{otherwise}
\end{cases}
$$
In other words, the agent receives 0 reward for taking an action such that its next state $s_{t+1}$ might be the goal, and -1 otherwise.

Assume a finite-horizon MDP, and let $T$ be the episodic time horizon. Each episode in the MDP proceeds as follows:

1. The agent spawns in an initial state according to $s_1 \sim P_0(s)$.
2. At each time step $t$, the agent observes its current cell state $s_t \in \mathcal{S}$, and takes an action $a_t \sim \pi(a \mid s_t)$. The next state is determined by the transition dynamics $s_{t+1} \sim \mathcal{T}(s' \mid s_t, a_t)$. The agent also receives a reward $r_t = r(s_t, a_t)$.
4. The episode terminates once the agent reaches the goal, or if $t \geq T$. (If the episode terminates at time step $t < T$, you can assume that $r_{t'} = 0$, $s_{t'} = s_t$, and $a_{t'} = \diamond$ for $t' \in \{t+1, \ldots, T\}$.)

Here, $\pi(a \mid s):\mathcal{S} \times \mathcal{A} \rightarrow [0, 1]$ is the agent's **policy** which defines a probability distribution over actions $a \in \mathcal{A}$ given the current state $s$. A standard RL policy search problem aims to find an **optimal** policy that maximizes the expected cumulative reward:

$$
\pi^* = \underset{\pi}{\arg\max}  \mathbb{E}_{ \substack{
    s_1 \sim p_0(S) \\
    a_t \sim \pi(A \mid s_t) \\
    s_{t+1} \sim \mathcal{T}(S \mid s_t, a_t)
} }  \left[ \sum_{t=1}^{T} r(s_t, a_t) \right]
$$

where the expectation is taken under the policy's distribution over **trajectories** $\tau = \{s_1, a_t, \ldots, s_T, a_T \}$ given by

$$
p_\pi(\tau) = p_\pi(s_1, a_1, \ldots, s_T, a_T)
= P_0(s_1) \prod_{t=1}^{T} \pi(a_t \mid s_t) \mathcal{T}(s_{t+1} \mid s_t, a_t)
$$

### 1.2: Gym Environment <a name="gym-env"></a>

Below, we provide an implementation of the simple chain environment, following the [OpenAI Gym API](https://gym.openai.com/docs/). It uses the following default environment parameters, which you may assume for the remainder of the tutorial:
* Number of states is $N=|\mathcal{S}| = 5$.
* The "action fail" probability of the transition dynamics $\mathcal{T}(s_{t+1} \mid s_t, a_t)$ is $\epsilon = 0.2$.
* The maximum episode length is 50.

In [0]:
import numpy as np
import gym
from gym import spaces

class ChainEnv(gym.Env):
  def __init__(self, num_states=5, action_fail_prob=0.2, max_episode_length=50):
    self.action_space = spaces.Discrete(3)  # no-action, left, right
    self.state_space = spaces.Discrete(num_states)
    self.goal_state = num_states - 1
    
    self.max_episode_length = max_episode_length

    # self.transition_dynamics[s, a, S] = T(S|s,a).
    self.transition_dynamics = np.zeros(
        (self.state_space.n, self.action_space.n, self.state_space.n))
    for s in range(self.state_space.n):
      # Action 0: no-action
      self.transition_dynamics[s, 0, s] = 1
      
      # Action 1: left
      next_state = max(s - 1, 0)
      self.transition_dynamics[s, 1, next_state] += 1 - action_fail_prob
      self.transition_dynamics[s, 1, s] += action_fail_prob

      # Action 2: right
      next_state = min(s + 1, self.state_space.n - 1)
      self.transition_dynamics[s, 2, next_state] += 1 - action_fail_prob
      self.transition_dynamics[s, 2, s] += action_fail_prob

    # self.reward[s, a] = r(s, a)
    self.reward = - np.ones((self.state_space.n, self.action_space.n))
    self.reward[self.state_space.n - 2, 2] = 0  # r(N-2,->) = 0
    self.reward[self.state_space.n - 1] = 0     # r(N-1, a) = 0 for any action a

    # Reset MDP to the beginning of an episode.
    self.reset()

  def reset(self):
    self.t = 0      # Time step counter
    self.state = 0  # Agent always starts in state 0.
    return self.state

  def _sample_transition(self, state, action):
    next_state_prob = self.transition_dynamics[state, action]
    next_state = np.argmax(np.random.multinomial(1, next_state_prob))
    return next_state

  def step(self, action):
    self.t += 1
    rew = self.reward[self.state, action]
    self.state = self._sample_transition(self.state, action)
    done = self.state == self.goal_state or self.t >= self.max_episode_length
    
    # Add any useful diagnostic info here for debugging.
    info = {}

    return self.state, rew, done, info

### 1.3: Instantiation <a name="env-instantiation"></a>

We instantiate ChainEnv, and print some important instance variables:

* `env.state_space` is the state space $\mathcal{S}$.
* `env.action_space` is the action space $\mathcal{A}$.
* `env.transition_dynamics[s, a, S]` is the transition dynamics probability $\mathcal{T}(S \mid s, a)$ that taking action $a$ in state $s$ results in the next state $S$.
* `env.reward[s, a]` is the  reward matrix $r(s, a)$.


Note that we represent the reward function $r(s_t, a_t)$ as a $|\mathcal{S}| \times |\mathcal{A}|$ matrix, and the transition dynamics function $\mathcal{T}(s_{t+1} \mid s_t, a_t)$ as a $|\mathcal{S}| \times |\mathcal{A}| \times |\mathcal{S}|$ matrix. We can do this since the state and action spaces are finite.

In [0]:
env = ChainEnv()
print('env.state_space.n  = |S| =', env.state_space.n)
print('env.action_space.n = |A| =', env.action_space.n)

# Print the distribution over next states for a given state and action.
s = 0
a = 2
print('env.transition_dynamics[{0}, {1}, S] = T(S|s={0}, a={1}) = {2}'.format(
    s, a, env.transition_dynamics[s, a, :]))

# Print the state transition matrix T(s->S') for a given action.
a = 1
print('env.transition_dynamics[s, {0}, S] = T(S|s, a={0}) = \n{1}'.format(
    a, env.transition_dynamics[:, a, :]))

print('env.reward[s, a] = \n{}'.format(env.reward))

### 1.4: Evaluating a Policy <a name="evaluating-policy"></a>

We represent a policy $\pi: \mathcal{S} \times \mathcal{A} \rightarrow [0, 1]$ as a $\mathcal{S} \times \mathcal{A}$ matrix, where `policy[s, a]`$=\pi(a \mid s)$. Recall that a policy defines a probability distribution over actions, so each row of `policy` should sum to 1:
$$
1 = \sum_{a \in \mathcal{A}} \pi(a \mid s) \qquad \text{for any }s \in \mathcal{S}.
$$

Below, we provide a function which runs a given `policy` for 100 episodes, and prints the mean/std of the cumulative reward $\mathbb{E}\left[ \sum_{t=1}^T r(s_t, a_t) \right]$ and the state visitation frequencies.

In [0]:
from collections import Counter

def evaluate(env, policy, num_episodes=100):
  """
  Evaluates the given policy by running it for given num_episodes.

  Args:
    env (ChainEnv)    : The environment.
    num_episodes (int): Number of episodes to simulate
    policy (np.array) : |S| x |A| matrix where policy[s, a] = pi(a|s).
  """
  print('Running the following policy[s, a] for {} episodes:\n{}'.format(
      num_episodes, policy))

  states_counter = Counter()
  returns = []
  for _ in range(num_episodes):
    state = env.reset()
    states_counter[state] += 1

    done = False
    total_rew = 0
    while not done:
      # Sample action from the policy for the current state.
      action_probs = policy[state]
      sample_onehot = np.random.multinomial(1, action_probs)
      action = np.argmax(sample_onehot)
      
      # Execute action, and observe the next state and reward.
      state, rew, done, info = env.step(action)

      total_rew += rew
      states_counter[state] += 1

    # Store cumulative reward at the end of each episode.
    returns.append(total_rew)

  # Compute statistics for episodic returns (cumulative rewards).
  avg_return = np.mean(returns)
  std_return = np.std(returns)

  # Compute normalized state visitation counts.
  state_visitation = np.zeros(env.state_space.n)
  total_count = sum(states_counter.values())
  for state, count in states_counter.items():
    state_visitation[state] = count / total_count

  print('Average Total Reward: {}, std: {}'.format(avg_return, std_return))
  print('State visitation frequency:', state_visitation)
  print()

#### 1.4.1: Uniform Policy

The code below evaluates a uniformly random policy $\pi(a \mid s) = \frac{1}{|\mathcal{A}|}$ which chooses each action uniformly at random.

In [0]:
uniform_policy = np.full((env.state_space.n, env.action_space.n),
                         1. / env.action_space.n)
evaluate(env, uniform_policy)

### 1.5: Reinforcement Learning <a name="rl"></a>

Below, we provide an implementation of **Q-Learning**, a classic model-free reinforcement learning algorithm that performs a simple value iteration update using the weighted average of the old value and the new information:

$$
Q(s_t,a_t) \leftarrow (1 - \alpha) Q(s_t, a_t) + \alpha (r_t + \gamma \max_a Q(s_{t+1}, a_t))
$$

Here, $\alpha \in (0, 1]$ is the learning rate and $\gamma \in [0, 1]$ is the discount factor, which we assume to be fixed hyperparameters.

Q-Learning uses an **epsilon-greedy strategy**: At each time step, it chooses a random action $a_t \in \mathcal{A}$ with probability $\epsilon$, and the max action $a_t = \arg\max_a Q(s_t, a)$ otherwise.



In [0]:
def train_qlearning(env, num_epochs=10):
  # Hyperparameters
  alpha = 1.0
  gamma = 0.9
  epsilon = 0.1
  
  q_table = np.zeros([env.state_space.n, env.action_space.n])
  
  # Train Q-Learning.
  for i in range(1, num_epochs):
    state = env.reset()
    done = False
    while not done:
      # Take the next action.
      if np.random.uniform(0, 1) < epsilon:
          action = env.action_space.sample()  # Explore action space
      else:
          action = np.argmax(q_table[state])  # Exploit learned values
      next_state, rew, done, info = env.step(action)

      old_value = q_table[state, action]
      next_max = np.max(q_table[next_state])
      new_value = (1 - alpha) * old_value + alpha * (rew + gamma * next_max)
      q_table[state, action] = new_value

      state = next_state
  
  # Compute policy pi(a|s) = argmax_a Q(s, a).
  argmax_actions = np.argmax(q_table, axis=1)
  policy = np.zeros([env.state_space.n, env.action_space.n])
  for state, action in enumerate(argmax_actions):
    policy[state, action] = 1

  return q_table, policy

#### Exercise 1.5.1: Q-Learning
Train Q-learning, then evaluate the learned greedy policy $\pi_\text{greedy}(a \mid s) := \arg\max_a Q(s, a)$ on the environment.

* Describe the policy $\pi_\text{greedy}$ in words -- how does it behave?
* How does $\pi_\text{greedy}$ compare to the uniform policy in Section 1.4.1 in terms of average total reward?

In [0]:
q_table, q_policy = train_qlearning(env)
print('Q(s,a):\n{}\n'.format(q_table))
evaluate(env, q_policy)

## II. Reinforcement Learning as Probabilistic Inference <a name="probabilistic-inference"></a>

In this section,  we formalize reinforcement learning as probabilistic inference, following the tutorial from [Levine 2018](https://arxiv.org/pdf/1805.00909.pdf).

### 2.1: Differences from the Original Tutorial <a name="differences"></a>

In all sections of this Colab tutorial (except Exercise 2.3.1), we do **not** restrict the action prior $p(a_t \mid s_t)$ to be uniform, so the joint optimality-action distribution
$$
\begin{aligned}
p(O_t = 1, a_t \mid s_t)
&= p(O_t = 1 \mid s_t, a_t) p(a_t \mid s_t)\\
&= \exp(r(s_t, a_t)) p(a_t \mid s_t)
\end{aligned}
$$
is not necessarily proportional to $p(O_t = 1 \mid s_t, a_t) := \exp\{ r(s_t, a_t) \} $.  We also define the backward state-action message as
$$
\beta_t(s_t, a_t) := p(O_{t:T}, a_t \mid s_t),
$$
whereas [Levine 2018](https://arxiv.org/abs/1805.00909) defines it as $\beta_t(s_t, a_t) = p(O_{t:T} \mid s_t, a_t)$. This results in slightly different derivations for the backward message update equations and the optimal policy $p(a_t \mid s_t, O_{t:T})$, which you will derive in Section 2.4. Thus, be careful when copying derivations from the original tutorial.

In Section 2.3, you will prove that any non-uniform action prior $p(a_t \mid s_t)$ can be incorporated into the reward function. We will later experiment with the Message-Passing algorithm using different action priors $p(a_t \mid s_t)$ for a fixed reward function $r(s_t, a_t)$, which is equivalent to using a uniform action prior with different reward functions.


### 2.2: Graphical Model for Control <a name="pgm"></a>

We define a graphical model that allows us to embed control into the framework of PGMs.

A task can be defined by a **reward** function $r(s_t, a_t)$. However, a graphical model has no notion of rewards or costs, so we introduce a binary random variable for **optimality** where $O_t = 1$ indicates that time step $t$ is optimal.

![graphical model](https://leelisa.com/blog/assets/2019-02-06-max-entropy-rl/rl-pgm.png)

We choose the distribution over $O_t$ to be

$$
\begin{aligned}
p(O_t = 1 \mid s_t, a_t) &:= \exp(r(s_t, a_t)) \\
p(O_t = 1, a_t \mid s_t)
&= p(O_t = 1 \mid s_t, a_t) p(a_t \mid s_t) \\
&= \exp(r(s_t, a_t)) p(a_t \mid s_t)
\end{aligned}
$$

where $p(a_t \mid s_t)$ is an **action prior**. This leads to a natural posterior distribution over actions when we condition on $O_t = 1$ for all $t \in [T]$:

$$
\begin{aligned}
p(\tau \mid o_{1:T})
\propto p(\tau, o_{1:T})
&= P_0(s_1) \prod_{t=1}^T p(O_t = 1, a_t \mid s_t) \mathcal{T}(s_{t+1} \mid s_t, a_t) \\
&= P_0(s_1) \prod_{t=1}^T \exp(r(s_t, a_t)) p(a_t \mid s_t) \mathcal{T}(s_{t+1} \mid s_t, a_t) \\
&=\underbrace{ \left[ P_0(s_1) \prod_{t=1}^T \mathcal{T}(s_{t+1} \mid s_t, a_t) p(a_t \mid s_t) \right] }_{\text{dynamics}} \exp  \underbrace{ \left( \sum_{t=1}^T r(s_t, a_t)  \right)  }_{\text{total reward}}
\end{aligned}
$$

Observe that, if we know the dynamics, then we can **infer a likely optimal trajectory** using
$$
\tau^* = \underset{\tau}{\arg\max} \; p(\tau, o_{1:T}).
$$

For the remainder of this tutorial, we use $O_{t:T}$ to denote $O_{t:T}=1$ for conciseness.

### 2.3: Reward vs. Action Prior <a name="reward-action-prior"></a>

Note that the joint distribution  $p(\tau, o_{1:T})$ has an extra factor $p(a_t \mid s_t)$ compared to Eq. (4) in [Levine 2018](https://arxiv.org/abs/1805.00909), due to the fact that
$$
\begin{equation}
p(O_t = 1, a_t \mid s_t) = p(O_t = 1 \mid s_t, a_t) p(a_t \mid s_t)
\end{equation}.
$$
This is a typo in the original tutorial, but the math ends up being OK (by changing some equalities '$=$' to proportionalities '$\propto$') if we assume a **uniform action prior** $p(a_t \mid s_t) = \frac{1}{|\mathcal{A}|}$, because then $p(O_t = 1, a_t \mid s_t) \propto p(O_t = 1 \mid s_t, a_t)$.

[Levine 2018](https://arxiv.org/abs/1805.00909) assumes a uniform action prior, and argues that this assumption does not introduce any loss of generality. In the following exercise, you will show that any non-uniform action prior $p(a_t \mid s_t)$ can be incorporated into $p(O_t \mid s_t, a_t) := \exp\{ r_1(s_t, a_t) \}$ via the reward function $r_1(s_t, a_t)$.

#### Exercise 2.3.1: Non-Uniform Action Priors

Let $r(s_t, a_t)$ and $p(a_t \mid s_t)$ be any given reward function and action prior, respectively.  Show that there exists some reward function $r_1(s_t, a_t)$ such that the posterior distribution $p(\tau \mid o_{1:T})$ is equal for the following combinations of reward function and action prior:
1. The reward function $r(s_t, a_t)$ and action prior $p(a_t \mid s_t)$.
2. The reward function $r_1(s_t, a_t)$ and a uniform action prior.

Write down the expression for $r_1(s_t, a_t)$ in terms of $r(s_t, a_t)$ and $p(a_t \mid s_t)$.

### 2.4: Message-Passing Derivations <a name="message-passing-derivations"></a>

In this section, we will derive a standard sum-product inference algorithm to **infer the optimal policy** $p(a_t \mid s_t, O_{t:T})$. Define the following **backward messages** for $t \in \{T, \ldots, 1\}$:
$$
\begin{aligned}
\beta_t(s_t, a_t) &:= p(O_{t:T}, a_t \mid s_t) \\
\beta_t(s_t) &:= p(O_{t:T} \mid s_t)
\end{aligned}
$$

#### Exercise 2.4.1: Derivation of $\beta_t(s_t)$ Update

Show that the backward messages satisfy the following update equation for $\beta_t(s_t)$:

$$
\begin{equation}
\beta_t(s_t)
= \sum_{a_t \in \mathcal{A}} \beta_t(s_t, a_t)
\tag{1}
\end{equation}
$$


#### Exercise 2.4.2: Derivation of $\beta_t(s_t, a_t)$ Update

Show that the backward messages satisfy the following update equation for $\beta_t(s_t, a_t)$:
$$
\begin{equation}
\beta_t(s_t, a_t)
= \sum_{s_{t+1} \in \mathcal{S}} \beta_{t+1}(s_{t+1}) \mathcal{T}(s_{t+1} \mid s_t, a_t) p(O_t, a_t \mid s_t)
\tag{2}
\end{equation}
$$

#### Exercise 2.4.3: Derivation of the Optimal Policy

Show that the optimal policy $p(a_t \mid s_t, O_{t:T})$ satisfies
$$
\begin{equation}
p(a_t \mid s_t, O_{t:T})
= \frac{ \beta_t(s_t, a_t) }{ \beta_t(s_t) }
\tag{3}
\end{equation}
$$


### 2.5: Message-Passing Implementation <a name="message-passing-implementation"></a>

In this section, we provide code for running the message-passing algorithm.


#### 2.5.1: Message-Passing Algorithm

Below, we provide an implementation of the recursive message-passing algorithm for computing the backward messages $\beta_t(a_t, s_t)$ and $\beta_t(s_t)$. The algorithm starts from the last time step $t=T$,
$$
\beta_T(s_T, a_T) = p(O_T, a_T \mid s_T),
$$
and proceeds backwards through time to $t=1$, computing the following updates at each time step:

1. Update $\beta_{t}(s_{t}, a_{t})$ using Eq. (2).

2. Update $\beta_t(s_t)$ using Eq. (1).


Then the algorithm infers the optimal policy $p(a \mid s, O_{1:T})$ 
using Eq. (3).


In [0]:
def log_sum_exp(terms):
    """
    Uses log-sum-exp trick to compute
      log(\sum_i exp(terms[i])) = t* + log(\sum_i exp(terms[i] - t*))
    where t* = max(terms).
    """
    max_term = np.max(terms)
    diff = np.exp(terms - max_term)
    result = max_term + np.log(np.sum(diff))
    return result
  
class MessagePassing():
  def __init__(self, env, action_prior=None):
    """
    Args:
      env: GridworldEnv object
      action_prior: Action prior p(a|s), which is a matrix of shape |S| x |A|.
                    If None, we use a uniform action prior: p(s|a) = 1/|A|.
    """
    self.env = env

    # self.log_action_prior[s, a] = log p(a|s)
    if action_prior is None:
      # Uniform action prior: p(s|a) = 1/|A|.
      self.log_action_prior = np.full(
          (self.env.state_space.n, self.env.action_space.n),
          np.log(1. / self.env.action_space.n))
    else:
      assert (action_prior.shape[0] == self.env.state_space.n and
              action_prior.shape[1] == self.env.action_space.n)
      eps = 1e-10  # to prevent underflow
      self.log_action_prior = (np.log(action_prior + eps) -
                               np.log(1 + eps * self.env.action_space.n))

    # self.log_transition_dynamics[s, a, S] = log T(S|s,a)
    self.log_transition_dynamics = np.log(env.transition_dynamics)
    
    # self.log_optimality_dist[s, a] = log p(O|s,a) = r(s,a)
    self.log_optimality_dist = env.reward

    # self.log_optimality_action_dist[s, a] = log p(O,a|s)
    #                                       = log p(O|s,a) + log p(a|s).
    self.log_optimality_action_dist = (self.log_optimality_dist +
                                       self.log_action_prior)

    # self.log_state_action_message[s, a] = log beta(s, a).
    # It is initialized to the log optimality distribution, log p(O_T,a_T|s_T).
    self.log_state_action_message = (self.log_optimality_dist +
                                     self.log_action_prior)

    # self.state_message[s] = beta(s).
    self.log_state_message = self._compute_state_message_update()

    # self.policy[s, a] = p(a|s,O) = beta(s,a) / beta(s).
    self.policy = self.compute_policy()

  def _compute_state_message_update(self):
    """
    Computes the state-message update in Eq. (1) in log-space:
      log beta(s) = log(\sum_a exp(log beta(s,a))).
    """
    log_state_message = np.zeros(self.env.state_space.n)
    for s in range(self.env.state_space.n):
      log_state_message[s] = log_sum_exp(self.log_state_action_message[s, :])
    return log_state_message

  def _compute_state_action_message_update(self):
    """
    Computes the state-action-message update in Eq. (2) in log-space:
      log beta(s,a) = log(\sum_S exp(log beta(S) + log T(S|s,a) + log p(O,a|s)))
    """
    log_state_action_message = np.zeros(
        (self.env.state_space.n, self.env.action_space.n))
    for s in range(self.env.state_space.n):
      for a in range(self.env.action_space.n):
        terms = (self.log_state_message +
                 self.log_transition_dynamics[s, a] +
                 self.log_optimality_action_dist[s, a])
        log_state_action_message[s, a] = log_sum_exp(terms)
    return log_state_action_message

  def update_messages(self):
    """
    Performs a single step of the backward message-passing algorithm.
    """
    # beta(s,a) = \sum_S beta(S) * p(S|s,a) * p(O,a|s)
    self.log_state_action_message = self._compute_state_action_message_update()
    
    # beta(s) = \sum_a beta(s,a) * p(a|s)
    self.log_state_message = self._compute_state_message_update()

    return self.log_state_action_message, self.log_state_message

  def compute_policy(self):
    """
    Computes policy using the given backward messages:
      p(a|s,O) = beta(s,a) / beta(s).
    """
    log_policy = np.zeros((env.state_space.n, env.action_space.n))
    for s in range(env.state_space.n):
      log_policy[s] = (self.log_state_action_message[s] -
                       self.log_state_message[s])
    policy = np.exp(log_policy)
    return policy

  def get_log_messages(self):
    """
    Returns the soft value function and soft Q-function:
      V(s)   := log beta(s)
      Q(s,a) := log beta(s,a).
    """
    return self.log_state_message, self.log_state_action_message

#### 2.5.2: Script to run Message-Passing

We provide a function that runs message passing for $T$ steps using the given action prior $p(a_t \mid s_t)$. It prints the provided action prior, and the learned soft value function and soft Q-function:
$$
\begin{aligned}
V(s) &:= \log \beta_1(s) \\
Q(s,a) &:= \log \beta_1(s, a)
\end{aligned}
$$

Finally, it evaluates the policy on ChainEnv and prints the results.

In [0]:
def run_message_passing(env, action_prior, T=50):
  mp = MessagePassing(env, action_prior=action_prior)
  for _ in range(T):
    state_action_message, state_messages = mp.update_messages()
  
  # Print the soft value function V(s) and soft Q-function Q(s,a).
  V, Q = mp.get_log_messages()

  # Evaluate the learned policy for num_episodes.
  policy = mp.compute_policy()
  
  print('Action prior p(a|s):\n{}\n'.format(action_prior))
  print('V(s):\n{}\n'.format(V))
  print('Q(s,a):\n{}\n'.format(Q))
  evaluate(env, policy=policy)

### 2.6: Message-Passing Experiments <a name="message-passing-experiments"></a>

In this section, we will experiment running Message-Passing with different action priors $p(a_t \mid s_t)$.



#### Exercise 2.6.1: Uniform Action Prior

Run message-passing algorithm using a **uniform** action prior $p(a_t \mid s_t) = \frac{1}{|\mathcal{A}|}$.

* How does the learned policy compare to the Q-Learning policy $\pi_\text{greedy}$ from Exercise 1.5.1 in terms of behavior and average total reward?

In [0]:
# Run message-passing using a uniform action prior.
uniform_policy = np.full((env.state_space.n, env.action_space.n),
                         1. / env.action_space.n)
run_message_passing(env, action_prior=uniform_policy)

#### Exercise 2.6.2: Soft Action Prior

The following function `construct_policy()` returns a policy $\pi(a \mid s; \phi)$ such that, regardless of the state, it takes the 'right' action with probability $\phi$, and all other actions uniformly at random. That is:
$$
\pi(a \mid s;\phi) = \begin{cases}
\phi & \text{if }a=\rightarrow\\
\frac{1 - \phi}{|\mathcal{A}| - 1}& \text{otherwise}
\end{cases}
$$

Run message-passing algorithm using a "soft" action prior $\pi(a \mid s;\phi)$ for $\phi = 0.5$.

* How does the learned policy compare to the one from Exercise 2.6.1 (using uniform action prior) in terms of behavior and average total reward?
* (True or False) If $\phi > \frac{1}{|\mathcal{A}|}$, then using the action prior $\pi(a \mid s ; \phi)$ instead of a uniform action prior is equivalent to changing the reward function $r(s_t, a_t)$ such that the agent receives relatively greater reward for taking the action $a_t = \rightarrow$ in any state, and less reward otherwise.

In [0]:
def construct_policy(env, phi=0.5):
  """
  Returns a policy such that, regardless of the state, it takes the given action
  with probability phi, and all other actions uniformly at random.
  """
  right_action = 2
  policy = np.zeros((env.state_space.n, env.action_space.n))
  for s in range(env.state_space.n):
    policy[s, right_action] = phi
    for a in range(env.action_space.n):
      if a != right_action:
        policy[s, a] = (1. - phi) / (env.action_space.n - 1)
  return policy

# A "soft" policy that chooses 'right' with probability 0.5.
pi = construct_policy(env, phi=0.5)
run_message_passing(env, action_prior=pi)

#### Exercise 2.6.3: Hard Action Prior

Run message-passing algorithm using a "hard" action prior $\pi(s,a;\phi)$ for $\phi = 1.0$.

* How does the learned policy compare to the Q-Learning policy $\pi_\text{greedy}$ from Exercise 1.5.1 in terms of behavior and average total reward?

In [0]:
# A "hard" policy that chooses 'right' with probability 1.0.
pi = construct_policy(env, phi=1.0)
run_message_passing(env, action_prior=pi)

## III. Concluding Remarks <a name="concluding-remarks">

### 3.1: Q-Learning vs. Message-Passing <a name="q-learning-vs-message-passing">

In this section, we will address how Q-Learning and Message-Passing are different, and when might one be preferred.

#### Exercise 3.1.1: Unknown Transition Dynamics

Suppose we don't know the transition dynamics $\mathcal{T}(s_{t+1} \mid s_t, a_t)$.

1. Can you learn the optimal policy via Q-learning?

2. Can you learn the optimal policy via Message-Passing?


#### Exercise 3.1.2: Equivalence

Is it the case that the optimal message-passing policy can be equivalent to the one discovered by Q-learning? If yes, under which conditions? If no, why not?
