In [168]:
import numpy as np
import tensorflow as tf

# Create environment

In [41]:
num_states = 16
num_terminal_states = 2
num_non_terminal_states = num_states - num_terminal_states

In [42]:
max_num_actions = 4

In [43]:
num_actions_per_non_terminal_state = np.repeat(
    a=max_num_actions, repeats=num_non_terminal_states)

In [44]:
num_state_action_successor_states = np.repeat(
    a=1, repeats=num_states * max_num_actions)

In [45]:
num_state_action_successor_states = np.reshape(
    a=num_state_action_successor_states,
    newshape=(num_states, max_num_actions))

In [46]:
sp_idx = np.array(
    object=[1, 0, 14, 4,
            2, 1, 0, 5,
            2, 2, 1, 6,
            4, 14, 3, 7,
            5, 0, 3, 8,
            6, 1, 4, 9,
            6, 2, 5, 10,
            8, 3, 7, 11,
            9, 4, 7, 12,
            10, 5, 8, 13,
            10, 6, 9, 15,
            12, 7, 11, 11,
            13, 8, 11, 12,
            15, 9, 12, 13],
    dtype=np.int64)

In [47]:
p = np.repeat(
    a=1.0, repeats=num_non_terminal_states * max_num_actions * 1)

In [48]:
r = np.repeat(
    a=-1.0, repeats=num_non_terminal_states * max_num_actions * 1)

In [49]:
sp_idx = np.reshape(
    a=sp_idx,
    newshape=(num_non_terminal_states, max_num_actions, 1))
p = np.reshape(
    a=p,
    newshape=(num_non_terminal_states, max_num_actions, 1))
r = np.reshape(
    a=r,
    newshape=(num_non_terminal_states, max_num_actions, 1))

# Set hyperparameters

In [179]:
gamma = 1.0
convergence_threshold = 0.001
maximum_num_value_estimations = 20

# Create algorithm

In [180]:
# This function estimates the value functions
def value_estimation(
    num_non_terminal_states,
    sp_idx_tensor,
    p,
    r_tensor,
    convergence_threshold,
    gamma,
    maximum_num_value_estimations,
    v_tensor,
    q_tensor):
    delta = np.finfo(np.float64).max
    num_value_estimations = 0

    def while_loop_condition(
        delta,
        num_value_estimations,
        v_tensor,
        q_tensor):
        return tf.logical_and(
            x=tf.greater_equal(x=delta, y=convergence_threshold),
            y=tf.less(
                x=num_value_estimations,
                y=maximum_num_value_estimations))

    def while_loop_body(
        delta,
        num_value_estimations,
        v_tensor,
        q_tensor):
        def value_non_terminal_state_for_loop(
            state_index,
            delta,
            num_value_estimations,
            v_tensor,
            q_tensor):
            # Cache state-value function for state state_index
            temp_v = tf.gather(
                params=v_tensor, indices=state_index)

            # Gather state action successor state slices
            sp_idx_tensor_slice = tf.gather(
                params=sp_idx_tensor,
                indices=state_index)
            p_tensor_slice = tf.gather(
                params=p_tensor,
                indices=state_index)
            r_tensor_slice = tf.gather(
                params=r_tensor,
                indices=
                state_index)

            # Update state-action value function based on
            # successor states, transition probabilities, and rewards
            x = p_tensor_slice * (r_tensor_slice + gamma * temp_v)
            y = p_tensor_slice * (r_tensor_slice + gamma * tf.gather(
                params=v_tensor, indices=sp_idx_tensor_slice))

            q_tensor_updated = tf.squeeze(
                input=tf.where(
                    condition=sp_idx_tensor_slice == state_index,
                    x=x, y=y),
                axis = 1)

            # Update state-value function
            v_tensor_updated = tf.reduce_max(
                input_tensor=q_tensor_updated)

            # Update delta for convergence criteria to
            # break while loop and update policy
            delta = tf.reduce_max(
                input_tensor=(delta,
                              tf.abs(x=temp_v - v_tensor_updated)))

            return (v_tensor_updated,
                    q_tensor_updated)

        # Replace non-terminal state for loop with map function
        (v_tensor,
         q_tensor) = tf.map_fn(
            fn=lambda x: value_non_terminal_state_for_loop(
                x,
                delta,
                num_value_estimations,
                v_tensor,
                q_tensor), 
            elems=tf.range(num_non_terminal_states), 
            dtype=(tf.float64, tf.float64))

        # Concat terminal state values back onto state value function
        v_tensor = tf.concat(
            values=[v_tensor,
                    tf.constant(
                        value=0.0,
                        shape=[num_terminal_states],
                        dtype=tf.float64)],
            axis=0)

        num_value_estimations += 1

        return (delta,
                num_value_estimations,
                v_tensor,
                q_tensor)

    (delta,
     num_value_estimations,
     v_tensor,
     q_tensor) = tf.while_loop(
        cond=while_loop_condition,
        body=while_loop_body,
        loop_vars=[delta,
                   num_value_estimations,
                   v_tensor,
                   q_tensor])

    return v_tensor, q_tensor

In [181]:
# This function greedily selects the policy based on the current value function
def greedy_policy_selection(
    sp_idx_tensor,
    p_tensor,
    r_tensor,
    policy_tensor,
    gamma,
    v_tensor):
    def policy_non_terminal_state_for_loop(state_index, policy_tensor):
        # Gather state action successor state slices
        sp_idx_tensor_slice = tf.gather(
            params=sp_idx_tensor,
            indices=state_index)
        p_tensor_slice = tf.gather(
            params=p_tensor,
            indices=state_index)
        r_tensor_slice = tf.gather(
            params=r_tensor,
            indices=state_index)

        # Update policy greedily from state-value function
        policy_tensor_updated = tf.squeeze(
            input=p_tensor_slice * (r_tensor_slice + gamma * tf.gather(
                params=v_tensor, indices=sp_idx_tensor_slice)),
            axis=1)

        # Save max policy value and find the number of actions that have the
        # same max policy value
        max_policy_value = tf.reduce_max(input_tensor=policy_tensor_updated)
        max_policy_count = tf.count_nonzero(
            input_tensor=tf.equal(x=policy_tensor_updated, y=max_policy_value))

        # Apportion policy probability across ties equally for state-action
        # pairs that have the same value and zero otherwise
        x = tf.fill(
            dims=[max_num_actions],
            value=1.0 / tf.cast(max_policy_count, dtype=tf.float64))
        y = tf.cast(
            tf.fill(
                dims=[max_num_actions], value=0.0),
            dtype=tf.float64)

        policy_tensor_updated = tf.where(
            condition=tf.equal(
                x=policy_tensor_updated, y=max_policy_value), x=x, y=y)

        return policy_tensor_updated

    # Replace non-terminal state for loop with map function
    policy_tensor = tf.map_fn(
        fn=lambda x: policy_non_terminal_state_for_loop(x, policy_tensor),
        elems=tf.range(num_non_terminal_states),
        dtype=tf.float64)

    return policy_tensor

In [182]:
def value_iteration(
    num_non_terminal_states,
    sp_idx_tensor,
    p_tensor,
    r_tensor,
    policy_tensor,
    convergence_threshold,
    gamma,
    maximum_num_value_iterations,
    v_tensor,
    q_tensor):
    # Value estimation
    v_tensor, q_tensor = value_estimation(
        num_non_terminal_states,
        sp_idx_tensor,
        p_tensor,
        r,
        convergence_threshold,
        gamma,
        maximum_num_value_iterations,
        v_tensor,
        q_tensor)

    # Greedy policy selection
    policy_tensor = greedy_policy_selection(
        sp_idx_tensor,
        p_tensor,
        r_tensor,
        policy_tensor,
        gamma,
        v_tensor)

    return (v_tensor,
            q_tensor,
            policy_tensor)

# Run algorithm

In [183]:
with tf.Session() as sess:
    # Read in environment
    sp_idx_tensor = tf.placeholder(
        dtype=tf.int64,
        shape=[num_non_terminal_states,
               max_num_actions,
               max_num_state_action_successor_states])
    p_tensor = tf.placeholder(
        dtype=tf.float64,
        shape=[num_non_terminal_states,
               max_num_actions,
               max_num_state_action_successor_states])
    r_tensor = tf.placeholder(
        dtype=tf.float64,
        shape=[num_non_terminal_states,
               max_num_actions,
               max_num_state_action_successor_states])

    # Create value functions
    v_tensor = tf.zeros(
        shape=num_states, dtype=tf.float64)
    q_tensor = tf.zeros(
        shape=[num_non_terminal_states, max_num_actions],
        dtype = tf.float64)

    # Create policy
    policy_tensor = tf.tile(
        input=[tf.constant(
            value = 1.0 / max_num_actions, dtype = tf.float64)],
        multiples=[num_non_terminal_states * max_num_actions])
    policy_tensor = tf.reshape(
        tensor=policy_tensor,
        shape=[num_non_terminal_states, max_num_actions])

    # Create algorithm
    algorithm = value_iteration(
        num_non_terminal_states,
        sp_idx_tensor,
        p_tensor,
        r_tensor,
        policy_tensor,
        convergence_threshold,
        gamma,
        maximum_num_value_estimations,
        v_tensor,
        q_tensor)

    # Run graph
    (v,
     q,
     policy) = sess.run(
        fetches=algorithm,
        feed_dict={
            sp_idx_tensor: sp_idx, 
            p_tensor: p, 
            r_tensor: r
        }
    )

print("\nFinal state value function")
print(v)
print("\nFinal state-action value function")
print(q)
print("\nFinal policy")
print(policy)


Final state value function
[-1. -2. -3. -1. -2. -3. -2. -2. -3. -2. -1. -3. -2. -1.  0.  0.]

Final state-action value function
[[-3. -2. -1. -3.]
 [-4. -3. -2. -4.]
 [-4. -4. -3. -3.]
 [-3. -1. -2. -3.]
 [-4. -2. -2. -4.]
 [-3. -3. -3. -3.]
 [-3. -4. -4. -2.]
 [-4. -2. -3. -4.]
 [-3. -3. -3. -3.]
 [-2. -4. -4. -2.]
 [-2. -3. -3. -1.]
 [-3. -3. -4. -4.]
 [-2. -4. -4. -3.]
 [-1. -3. -3. -2.]]

Final policy
[[0.   0.   1.   0.  ]
 [0.   0.   1.   0.  ]
 [0.   0.   0.5  0.5 ]
 [0.   1.   0.   0.  ]
 [0.   0.5  0.5  0.  ]
 [0.25 0.25 0.25 0.25]
 [0.   0.   0.   1.  ]
 [0.   1.   0.   0.  ]
 [0.25 0.25 0.25 0.25]
 [0.5  0.   0.   0.5 ]
 [0.   0.   0.   1.  ]
 [0.5  0.5  0.   0.  ]
 [1.   0.   0.   0.  ]
 [1.   0.   0.   0.  ]]
