In [40]:
import numpy as np

# Create environment

In [41]:
num_states = 16
num_terminal_states = 2
num_non_terminal_states = num_states - num_terminal_states

In [42]:
max_num_actions = 4

In [43]:
num_actions_per_non_terminal_state = np.repeat(
    a=max_num_actions, repeats=num_non_terminal_states)

In [44]:
num_state_action_successor_states = np.repeat(
    a=1, repeats=num_states * max_num_actions)

In [45]:
num_state_action_successor_states = np.reshape(
    a=num_state_action_successor_states,
    newshape=(num_states, max_num_actions))

In [46]:
sp_idx = np.array(
    object=[1, 0, 14, 4,
            2, 1, 0, 5,
            2, 2, 1, 6,
            4, 14, 3, 7,
            5, 0, 3, 8,
            6, 1, 4, 9,
            6, 2, 5, 10,
            8, 3, 7, 11,
            9, 4, 7, 12,
            10, 5, 8, 13,
            10, 6, 9, 15,
            12, 7, 11, 11,
            13, 8, 11, 12,
            15, 9, 12, 13],
    dtype=np.int64)

In [47]:
p = np.repeat(
    a=1.0, repeats=num_non_terminal_states * max_num_actions * 1)

In [48]:
r = np.repeat(
    a=-1.0, repeats=num_non_terminal_states * max_num_actions * 1)

In [49]:
sp_idx = np.reshape(
    a=sp_idx,
    newshape=(num_non_terminal_states, max_num_actions, 1))
p = np.reshape(
    a=p,
    newshape=(num_non_terminal_states, max_num_actions, 1))
r = np.reshape(
    a=r,
    newshape=(num_non_terminal_states, max_num_actions, 1))

# Create value function and policy arrays

In [50]:
v = np.zeros(shape=num_states, dtype=np.float64)
q = np.zeros(
    shape=(num_non_terminal_states, max_num_actions),
    dtype=np.float64)

In [51]:
policy = np.repeat(
    a=1.0 / max_num_actions,
    repeats=num_non_terminal_states * max_num_actions)

In [52]:
policy = np.reshape(
    a=policy,
    newshape=(num_non_terminal_states, max_num_actions))

# Set hyperparameters

In [53]:
gamma = 1.0
convergence_threshold = 0.001
maximum_num_value_estimations = 20

# Create algorithm

In [54]:
# This function estimates the value functions
def value_estimation(
    num_non_terminal_states,
    sp_idx,
    p,
    r,
    convergence_threshold,
    gamma,
    maximum_num_value_estimations,
    v,
    q):
    """Estimates state-action value function.

    Args:
        num_non_terminal_states: int, number of non terminal states.
        sp_idx: array[int], state indices of new
            state s' of taking action a from state s.
        p: array[float],
            transition probability to go from state s to s' by taking action a.
        r: array[float], reward from new
            state s' from state s by taking action a.
        convergence_threshold: float, minimum maximum change across all value
            function updates.
        gamma: float, 0 <= gamma <= 1, amount to discount
            future r.
        maximum_num_value_estimations: int, max number of iterations.
        v: array[float], keeps track of the estimated
            value of each state V(s).
        q: array[float], keeps track of the estimated
            value of each state-action pair Q(s, a).
    Returns:
        v: array, estimate of state value function V(s).
        q: array, estimate of state-action value
            function Q(s, a).
    """
    delta = np.finfo(np.float64).max
    num_value_estimations = 0

    while (delta >= convergence_threshold and
           num_value_estimations < maximum_num_value_estimations):
        for i in range(0, num_non_terminal_states):
            # Cache state-value function for state i
            temp_v = v[i]

            # Update state-action value function based on successor states,
            # transition probabilities, and r
            q[i, :] = np.squeeze(
                a=np.where(
                    sp_idx[i, :, :] == i,
                    p[i, :, :] * (r[i, :, :] + gamma * temp_v),
                    p[i, :, :] * (r[i, :, :] + gamma * v[sp_idx[i, :, :]])),
                axis=1)

            # Update state-value function
            v[i] = np.max(
                a=q[i, :])

            # Update delta for convergence criteria to break while loop and
            # update policy
            delta = np.max(
                (delta,
                 np.abs(temp_v - v[i])))

        num_value_estimations += 1

    return v, q

In [55]:
# This function greedily selects the policy based on the current value function
def greedy_policy_selection(
    sp_idx,
    p,
    r,
    policy,
    gamma,
    v):
    """Estimates state-action value function.

    Args:
        sp_idx: array[int], state indices of new
            state s' of taking action a from state s.
        p: array[float],
            transition probability to go from state s to s' by taking action a.
        r: array[float], reward from new
            state s' from state s by taking action a.
        policy: array[float], learned stochastic policy of which action a to
            take in state s.
        gamma: float, 0 <= gamma <= 1, amount to discount
            future r.
        v: array[float], keeps track of the estimated
            value of each state V(s).
    Returns:
        array, learned stochastic policy of which action a to take in state s.
    """
    for i in range(0, num_non_terminal_states):
        # Update policy greedily from state-value function
        policy[i, :] = np.squeeze(
            a=p[i, :, :] * (r[i, :, :] + gamma * v[sp_idx[i, :, :]]),
            axis=1)

        # Save max policy value and find the number of actions that have the
        # same max policy value
        max_policy_value = np.max(a=policy[i, :])
        max_policy_count = np.count_nonzero(
            a=policy[i, :] == max_policy_value)

        # Apportion policy probability across ties equally for state-action
        # pairs that have the same value and zero otherwise
        policy[i, :] = np.where(
            policy[i, :] == max_policy_value, 1.0 / max_policy_count, 0.0)

    return policy

In [56]:
def value_iteration(
    num_non_terminal_states,
    sp_idx,
    p,
    r,
    policy,
    convergence_threshold,
    gamma,
    maximum_num_value_iterations,
    v,
    q):
    """Performs value iteration to learn optimal policy.

    Args:
        num_non_terminal_states: int, number of non terminal states.
        sp_idx: array[int], state indices of new
            state s' of taking action a from state s.
        p: array[float],
            transition probability to go from state s to s' by taking action a.
        r: array[float], reward from new
            state s' from state s by taking action a.
        policy: array[float], learned stochastic policy of which action a to
            take in state s.
        convergence_threshold: float, minimum maximum change across all value
            function updates.
        gamma: float, 0 <= gamma <= 1, amount to discount
            future r.
        maximum_num_value_estimations: int, max number of iterations.
        v: array[float], keeps track of the estimated
            value of each state V(s).
        q: array[float], keeps track of the estimated
            value of each state-action pair Q(s, a).
    Returns:
        v: array, estimate of state value function V(s).
        q: array, estimate of state-action value
            function Q(s, a).
        policy: array, learned stochastic policy of which action a to take in
            state s.
    """
    # Value estimation
    v, q = value_estimation(
        num_non_terminal_states,
        sp_idx,
        p,
        r,
        convergence_threshold,
        gamma,
        maximum_num_value_iterations,
        v,
        q)

    # Greedy policy selection
    policy = greedy_policy_selection(
        sp_idx,
        p,
        r,
        policy,
        gamma,
        v)

    return v, q, policy

# Run algorithm

In [57]:
print("\nInitial state value function")
print(v)

print("\nInitial state-action value function")
print(q)

print("\nInitial policy")
print(policy)

# Run value iteration
v, q, policy = value_iteration(
    num_non_terminal_states,
    sp_idx,
    p,
    r,
    policy,
    convergence_threshold,
    gamma,
    maximum_num_value_iterations,
    v,
    q)

# Print final results
print("\nFinal state value function")
print(v)
print("\nFinal state-action value function")
print(q)
print("\nFinal policy")
print(policy)


Initial state value function
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

Initial state-action value function
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]

Initial policy
[[0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]]

Final state value function
[-1. -2. -3. -1. -2. -3. -2. -2. -3. -2. -1. -3. -2. -1.  0.  0.]

Final state-action value function
[[-3. -2. -1. -3.]
 [-4. -3. -2. -4.]
 [-4. -4. -3. -3.]
 [-3. -1. -2. -3.]
 [-4. -2. -2. -4.]
 [-3. -3. -3. -3.]
 [-3. -4. -4. -2.]
 [-4. -2. -3. -4.]
 [-3. -3. -3. -3.]
 [-2. -4. -4. -2.]
 [-2. -3. -3. 