# Notebook of Markov Decision Process (MDP)
This notebook was developed to explain a discrete MDP defined by a tuple $(\mathcal{S}, \mathcal{A}, p, \gamma)$, where $\mathcal{S}$ is a finite set of states, $\mathcal{A}$ is a finite set of actions, $p$ represents the joint probability for the next state and reward given the current state and action, and $\gamma$ is a discount factor. 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from gym import Env, spaces, utils
from ipywidgets import interact, fixed
%matplotlib inline

# Two-Stage Task
Two-stage MDP task used by 
- J. Gläscher, N. Daw, P. Dayan, and J.P. O’Doherty. (2010). [States versus Rewards: Dissociable Neural Prediction Error Signals Underlying Model-Based and Model-Free Reinforcement Learning](https://doi.org/10.1016/j.neuron.2010.04.016). Neuron, vol. 66, no. 4: 585–95.


In [None]:
class TwoStageMDPEnv(Env):
    def __init__(self):
        self.nA = 2     # number of actions
        self.nS = 6     # number of states
        self.observation_space = spaces.Discrete(self.nS, )
        self.action_space = spaces.Discrete(self.nA, )
        
        # Set the state transitions and rewards        
        self.P = {s: {a: [] for a in range(self.nA)} for s in range(self.nS)}
        self.P[0][0] = [(0.7, 1, 0, False), (0.3, 2, 0, False)]
        self.P[0][1] = [(0.7, 3, 0, False), (0.3, 4, 0, False)]
        self.P[1][0] = [(0.7, 5, 10, True), (0.3, 5, 0, True)]
        self.P[1][1] = [(0.7, 5, 0, True), (0.3, 5, 10, True)]
        self.P[2][0] = [(0.7, 5, 0, True), (0.3, 5, 10, True)]
        self.P[2][1] = [(0.7, 5, 0, True), (0.3, 5, 25, True)]
        self.P[3][0] = [(0.7, 5, 25, True), (0.3, 5, 0, True)]
        self.P[3][1] = [(0.7, 5, 10, True), (0.3, 5, 0, True)]
        self.P[4][0] = [(0.7, 5, 0, True), (0.3, 5, 10, True)]
        self.P[4][1] = [(0.7, 5, 0, True), (0.3, 5, 25, True)]
        self.P[5][0] = [(1.0, 5, 0, True)]
        self.P[5][1] = [(1.0, 5, 0, True)]
        
        self.start_state_index = 0
        self.initial_state_distrib = np.zeros(self.nS)
        self.initial_state_distrib[self.start_state_index] = 1.0
    
    def reset(self):
        pass
    
    def step(self, action):
        pass
    
    def render(self, mode='console', close=False):
        pass

# Compute the state-values and action-values
When $s$ is a terminal state, (that is, the second stage), the action-values and the state-values are respectively computed by
\begin{align}
  Q^\pi (s, a) &= \mathbb{E}_{r \sim p(r \mid s, a)} [r], \\
  V^\pi (s) &= \mathbb{E}_{a \sim \pi(a \mid s)} [ Q^\pi (s, a) ],
\end{align}
where $\pi$ is a policy.
When $s = s_1$, (that is, the first stage), they are given by
\begin{align}
  Q^\pi (s, a) &= \mathbb{E}_{(s', r) \sim p(r \mid s, a)} [r + \gamma V^\pi(s')], \\
  V^\pi (s) &= \mathbb{E}_{a \sim \pi(a \mid s)} [ Q^\pi (s, a) ].
\end{align}
The following function computes the state- and action-values of the two-stage MDP task. 

In [None]:
def compute_value(env, p0=0.5, p1=0.5, p2=0.5, p3=0.5, p4=0.5):
    """Compute the state- and action-values of the two-stage MDP task.
    
    Parameters
    ----------
    env : gym.Env
        Two-stage MDP environment.
    p0, ..., p4 : float
        probability to select a0 at state s0, ..., s4.
    
    Return
    ------
    V : ndarray
        state-value function (1D array).
    Q : ndarray
        action-value function (2D array).
    """
    gamma = 1.0  # discount factor
    
    # Initialize a stochastic policy (nS x nA array).
    policy = np.empty((env.nS, env.nA), dtype=np.float64)
    policy[0:5, 0] = np.array([p0, p1, p2, p3, p4])
    policy[0:5, 1] = 1.0 - policy[0:5, 0]
    
    Q = np.zeros([env.nS, env.nA])
    V = np.zeros([env.nS, 1])

    # Compute the state- and action-values at the second stage
    # (terminal state)
    # See the lecture slide.
    for state in range(0, env.nS):        
        for action in range(0, env.nA):
            for prob, next_state, reward, done in env.P[state][action]:
                Q[state, action] += prob*reward

            V[state] += (policy[state, action] * Q[state, action])

    # Compute the state- and action-values at the first stage
    # See the lecture slide.
    for action in range(0, env.nA):
        for prob, next_state, reward, done in env.P[0][action]:
            Q[0, action] += prob*(reward + gamma*V[next_state])

        V[0] += (policy[0, action] * Q[0, action])

    return V, Q

# Results
Here, p0 is the probability to select the action $a_0$ at the state $s_0$. Then, the probability to select $a_1$ at $s_0$ is given by $1 -$ p0. Similarly, p1 is the probability to select $a_0$ at $s_1$.
The first array shows the state-value function, $V^\pi (s)$. The second 2-D array shows the action-value function, $Q^\pi(s, a)$. 

In [None]:
env = TwoStageMDPEnv()
interact(compute_value, env=fixed(env), p0=(0, 1, 0.1), p1=(0, 1, 0.1), p2=(0, 1, 0.1), p3=(0, 1, 0.1), p4=(0, 1, 0.1));

# Find a policy
Next, please find a policy that maximizes the state-value function $V^\pi(s)$ for all states. For comparison, three value functions are also visualized:
- Uniform policy: select $a_0$ with probability 0.5. 
- Always left: simply select $a_0$ with probability 1.0.
- Optimal policy: select an optimal action. We will study how to find an optimal policy in the next lecture.
- User-specified policy: 

In [None]:
def plotV(p0=0.5, p1=0.5, p2=0.5, p3=0.5, p4=0.5):
    """Plot state-value functions.
    
    Parameters
    ----------
    p0, ..., p4 : float
        probability to select a0 at state s0, ..., s4.
    
    """    
    # Uniform policy
    Vuniform, _ = compute_value(env, 0.5, 0.5, 0.5, 0.5, 0.5)
    # Always left
    Vleft, _ = compute_value(env, 1.0, 1.0, 1.0, 1.0, 1.0)
    # Optimal policy
    Voptimal, _ = compute_value(env, 0.0, 1.0, 0.0, 1.0, 0.0)
    # User-specified policy
    Vuser, _ = compute_value(env, p0, p1, p2, p3, p4)
    
    fig = plt.figure(figsize=(8, 5))
    axarr = fig.subplots(1, 1)
    axarr.plot(Vuniform, label='uniform')
    axarr.plot(Vleft, label='always left')
    axarr.plot(Voptimal, label='optimal')
    axarr.plot(Vuser, label='user')

    axarr.set_xlabel('state')
    axarr.set_ylabel('V(s)')
    axarr.legend()

interact(plotV, p0=(0, 1, 0.1), p1=(0, 1, 0.1), p2=(0, 1, 0.1), p3=(0, 1, 0.1), p4=(0, 1, 0.1));