In [327]:
import numpy as np
import tensorflow as tf

# Create environment

In [328]:
number_of_states = 16
number_of_terminal_states = 2
number_of_non_terminal_states = number_of_states - number_of_terminal_states

In [329]:
max_number_of_actions = 4

In [330]:
number_of_actions_per_non_terminal_state = np.repeat(a = max_number_of_actions, repeats = number_of_non_terminal_states)

In [331]:
number_of_state_action_successor_states = np.repeat(a = 1, repeats = number_of_states * max_number_of_actions)

In [332]:
number_of_state_action_successor_states = np.reshape(a = number_of_state_action_successor_states, newshape = (number_of_states, max_number_of_actions))

In [333]:
max_number_of_state_action_successor_states = np.max(a = number_of_state_action_successor_states)

In [334]:
state_action_successor_state_indices = np.array([1, 0, 14, 4, 2, 1, 0, 5, 2, 2, 1, 6, 4, 14, 3, 7, 5, 0, 3, 8, 6, 1, 4, 9, 6, 2, 5, 10, 8, 3, 7, 11, 9, 4, 7, 12, 10, 5, 8, 13, 10, 6, 9, 15, 12, 7, 11, 11, 13, 8, 11, 12, 15, 9, 12, 13], dtype = np.int64)

In [335]:
state_action_successor_state_transition_probabilities = np.repeat(a = 1.0, repeats = number_of_non_terminal_states * max_number_of_actions * 1)

In [336]:
state_action_successor_state_rewards = np.repeat(a = -1.0, repeats = number_of_non_terminal_states * max_number_of_actions * 1)

In [337]:
state_action_successor_state_indices = np.reshape(a = state_action_successor_state_indices, newshape = (number_of_non_terminal_states, max_number_of_actions, 1))
state_action_successor_state_transition_probabilities = np.reshape(a = state_action_successor_state_transition_probabilities, newshape = (number_of_non_terminal_states, max_number_of_actions, 1))
state_action_successor_state_rewards = np.reshape(a = state_action_successor_state_rewards, newshape = (number_of_non_terminal_states, max_number_of_actions, 1))

# Set hyperparameters

In [338]:
discounting_factor_gamma = 1.0
convergence_threshold = 0.001
maximum_number_of_sweeps = 30
maximum_number_of_policy_evaluations = 20

# Create algorithm

In [339]:
# This function evaluates the value functions given the current policy
def policy_evaluation(number_of_non_terminal_states, state_action_successor_state_indices_tensor, state_action_successor_state_transition_probabilities_tensor_tensor, state_action_successor_state_rewards_tensor, policy_tensor, convergence_threshold, discounting_factor_gamma, maximum_number_of_policy_evaluations, state_value_function_tensor, state_action_value_function_tensor):
    delta = np.finfo(np.float64).max
    number_of_policy_evaluations = 0

    def while_loop_condition(delta, number_of_policy_evaluations, state_value_function_tensor, state_action_value_function_tensor):
        return tf.logical_and(x = tf.greater_equal(x = delta, y = convergence_threshold), y = tf.less(x = number_of_policy_evaluations, y = maximum_number_of_policy_evaluations))
    
    def while_loop_body(delta, number_of_policy_evaluations, state_value_function_tensor, state_action_value_function_tensor):
        def value_non_terminal_state_for_loop(state_index, delta, number_of_policy_evaluations, state_value_function_tensor, state_action_value_function_tensor):
            # Cache state-value function for state state_index
            temp_state_value_function = tf.gather(params = state_value_function_tensor, indices = state_index)
            
            # Gather state action successor state slices
            state_action_successor_state_indices_tensor_slice = tf.gather(params = state_action_successor_state_indices_tensor, indices = state_index)
            state_action_successor_state_transition_probabilities_tensor_slice = tf.gather(params = state_action_successor_state_transition_probabilities_tensor, indices = state_index)
            state_action_successor_state_rewards_tensor_slice = tf.gather(params = state_action_successor_state_rewards_tensor, indices = state_index)
            
            # Update state-action value function based on successor states, transition probabilities, and rewards
            x = state_action_successor_state_transition_probabilities_tensor_slice * (state_action_successor_state_rewards_tensor_slice + discounting_factor_gamma * temp_state_value_function)
            y = state_action_successor_state_transition_probabilities_tensor_slice * (state_action_successor_state_rewards_tensor_slice + discounting_factor_gamma * tf.gather(params = state_value_function_tensor, indices = state_action_successor_state_indices_tensor_slice))

            state_action_value_function_tensor_updated = tf.squeeze(input = tf.where(condition = state_action_successor_state_indices_tensor_slice == state_index, x = x, y = y), axis = 1)
            
            # Update state value function based on current policy
            state_value_function_updated = tf.reduce_sum(input_tensor = tf.gather(params = policy_tensor, indices = state_index) * state_action_value_function_tensor_updated)

            # Update delta for convergence criteria to break while loop and update policy
            delta = tf.reduce_max(input_tensor = (delta, tf.abs(x = temp_state_value_function - state_value_function_updated)))
            
            return state_value_function_updated, state_action_value_function_tensor_updated

        # Replace non-terminal state for loop with map function
        state_value_function_tensor, state_action_value_function_tensor = tf.map_fn(
            fn = lambda x: value_non_terminal_state_for_loop(x, delta, number_of_policy_evaluations, state_value_function_tensor, state_action_value_function_tensor), 
            elems = tf.range(number_of_non_terminal_states), 
            dtype = (tf.float64, tf.float64))
        
        # Concat terminal state values back onto state value function
        state_value_function_tensor = tf.concat(values = [state_value_function_tensor, tf.constant(value = 0.0, shape = [number_of_terminal_states], dtype = tf.float64)], axis = 0)

        number_of_policy_evaluations += 1
        
        return delta, number_of_policy_evaluations, state_value_function_tensor, state_action_value_function_tensor
        
    delta, number_of_value_iterations, state_value_function_tensor, state_action_value_function_tensor = tf.while_loop(cond = while_loop_condition, body = while_loop_body, loop_vars = [delta, number_of_policy_evaluations, state_value_function_tensor, state_action_value_function_tensor])
        
    return state_value_function_tensor, state_action_value_function_tensor

In [340]:
# This function greedily updates the policy based on the current value function
def policy_improvement(number_of_non_terminal_states, state_action_successor_state_indices_tensor, state_action_successor_state_transition_probabilities_tensor, state_action_successor_state_rewards_tensor, policy_tensor, discounting_factor_gamma, state_value_function_tensor):
    policy_stable = tf.constant(value = True, dtype = tf.bool)
    
    def policy_non_terminal_state_for_loop(state_index, policy_stable, policy_tensor):
        # Gather state action successor state slices
        state_action_successor_state_indices_tensor_slice = tf.gather(params = state_action_successor_state_indices_tensor, indices = state_index)
        state_action_successor_state_transition_probabilities_tensor_slice = tf.gather(params = state_action_successor_state_transition_probabilities_tensor, indices = state_index)
        state_action_successor_state_rewards_tensor_slice = tf.gather(params = state_action_successor_state_rewards_tensor, indices = state_index)
    
        # Cache policy for comparison later
        old_policy = tf.gather(params = policy_tensor, indices = state_index)

        # Update policy greedily from state-value function
        policy_tensor_updated = tf.squeeze(input = state_action_successor_state_transition_probabilities_tensor_slice * (state_action_successor_state_rewards_tensor_slice + discounting_factor_gamma * tf.gather(params = state_value_function_tensor, indices = state_action_successor_state_indices_tensor_slice)), axis = 1)

        # Save max policy value and find the number of actions that have the same max policy value
        max_policy_value = tf.reduce_max(input_tensor = policy_tensor_updated)
        max_policy_count = tf.count_nonzero(input_tensor = tf.equal(x = policy_tensor_updated, y = max_policy_value))
            
        # Apportion policy probability across ties equally for state-action pairs that have the same value and zero otherwise
        x = tf.fill(dims = [max_number_of_actions], value = 1.0 / tf.cast(max_policy_count, dtype = tf.float64))
        y = tf.cast(tf.fill(dims = [max_number_of_actions], value = 0.0), dtype = tf.float64)
                    
        policy_tensor_updated = tf.where(condition = tf.equal(x = policy_tensor_updated, y = max_policy_value), x = x, y = y)
        
        # If policy has changed from old policy
        policy_stable_updated = tf.reduce_all(input_tensor = tf.equal(x = policy_tensor_updated, y = old_policy))
        
        return policy_stable_updated, policy_tensor_updated
    
    # Replace non-terminal state for loop with map function
    policy_stable, policy_tensor = tf.map_fn(
        fn = lambda x: policy_non_terminal_state_for_loop(x, policy_stable, policy_tensor), 
        elems = tf.range(number_of_non_terminal_states), 
        dtype = (tf.bool, tf.float64))

    # Reduce policy stable back to a scalar across all non-terminal states
    policy_stable = tf.reduce_all(input_tensor = policy_stable)

    return policy_stable, policy_tensor

In [341]:
def policy_iteration(number_of_non_terminal_states, state_action_successor_state_indices_tensor, state_action_successor_state_transition_probabilities_tensor, state_action_successor_state_rewards_tensor, policy_tensor, convergence_threshold, discounting_factor_gamma, maximum_number_of_policy_evaluations, state_value_function_tensor, state_action_value_function_tensor, maximum_number_of_sweeps):
    policy_stable = tf.constant(value = False, dtype = tf.bool)
    number_of_sweeps = 0
    
    def while_loop_condition(policy_stable, number_of_sweeps, state_value_function_tensor, state_action_value_function_tensor, policy_tensor):
        return tf.logical_and(x = tf.equal(x = policy_stable, y = tf.constant(value = False, dtype = tf.bool)), y = tf.less(x = number_of_sweeps, y = maximum_number_of_sweeps))
      
    def while_loop_body(policy_stable, number_of_sweeps, state_value_function_tensor, state_action_value_function_tensor, policy_tensor):
        # Policy evaluation
        state_value_function_tensor, state_action_value_function_tensor = policy_evaluation(number_of_non_terminal_states, state_action_successor_state_indices_tensor, state_action_successor_state_transition_probabilities_tensor, state_action_successor_state_rewards_tensor, policy_tensor, convergence_threshold, discounting_factor_gamma, maximum_number_of_policy_evaluations, state_value_function_tensor, state_action_value_function_tensor)

        # Policy improvement
        policy_stable, policy_tensor = policy_improvement(number_of_non_terminal_states, state_action_successor_state_indices_tensor, state_action_successor_state_transition_probabilities_tensor, state_action_successor_state_rewards_tensor, policy_tensor, discounting_factor_gamma, state_value_function_tensor)

        number_of_sweeps += 1
        
        return policy_stable, number_of_sweeps, state_value_function_tensor, state_action_value_function_tensor, policy_tensor
      
    policy_stable, number_of_sweeps, state_value_function_tensor, state_action_value_function_tensor, policy_tensor = tf.while_loop(cond = while_loop_condition, body = while_loop_body, loop_vars = [policy_stable, number_of_sweeps, state_value_function_tensor, state_action_value_function_tensor, policy_tensor])
      
    return state_value_function_tensor, state_action_value_function_tensor, policy_tensor

# Run algorithm

In [342]:
with tf.Session() as sess:
    # Read in environment
    state_action_successor_state_indices_tensor = tf.placeholder(dtype = tf.int64, shape = [number_of_non_terminal_states, max_number_of_actions, max_number_of_state_action_successor_states])
    state_action_successor_state_transition_probabilities_tensor = tf.placeholder(dtype = tf.float64, shape = [number_of_non_terminal_states, max_number_of_actions, max_number_of_state_action_successor_states])
    state_action_successor_state_rewards_tensor = tf.placeholder(dtype = tf.float64, shape = [number_of_non_terminal_states, max_number_of_actions, max_number_of_state_action_successor_states])

    # Create value functions
    state_value_function_tensor = tf.zeros(shape = number_of_states, dtype = tf.float64)
    state_action_value_function_tensor = tf.zeros(shape = [number_of_non_terminal_states, max_number_of_actions], dtype = tf.float64)

    # Create policy
    policy_tensor = tf.tile(input = [tf.constant(value = 1.0 / max_number_of_actions, dtype = tf.float64)], multiples = [number_of_non_terminal_states * max_number_of_actions])
    policy_tensor = tf.reshape(tensor = policy_tensor, shape = [number_of_non_terminal_states, max_number_of_actions])

    # Create algorithm
    algorithm = policy_iteration(number_of_non_terminal_states, state_action_successor_state_indices_tensor, state_action_successor_state_transition_probabilities_tensor, state_action_successor_state_rewards_tensor, policy_tensor, convergence_threshold, discounting_factor_gamma, maximum_number_of_policy_evaluations, state_value_function_tensor, state_action_value_function_tensor, maximum_number_of_sweeps)

    # Run graph
    state_value_function, state_action_value_function, policy = sess.run(fetches = algorithm, feed_dict = {state_action_successor_state_indices_tensor: state_action_successor_state_indices, 
                                                                                                                                state_action_successor_state_transition_probabilities_tensor: state_action_successor_state_transition_probabilities, 
                                                                                                                                state_action_successor_state_rewards_tensor: state_action_successor_state_rewards})
print("\nFinal state value function")
print(state_value_function)
print("\nFinal state-action value function")
print(state_action_value_function)
print("\nFinal policy")
print(policy)


Final state value function
[-1. -2. -3. -1. -2. -3. -2. -2. -3. -2. -1. -3. -2. -1.  0.  0.]

Final state-action value function
[[-3. -2. -1. -3.]
 [-4. -3. -2. -4.]
 [-4. -4. -3. -3.]
 [-3. -1. -2. -3.]
 [-4. -2. -2. -4.]
 [-3. -3. -3. -3.]
 [-3. -4. -4. -2.]
 [-4. -2. -3. -4.]
 [-3. -3. -3. -3.]
 [-2. -4. -4. -2.]
 [-2. -3. -3. -1.]
 [-3. -3. -4. -4.]
 [-2. -4. -4. -3.]
 [-1. -3. -3. -2.]]

Final policy
[[0.   0.   1.   0.  ]
 [0.   0.   1.   0.  ]
 [0.   0.   0.5  0.5 ]
 [0.   1.   0.   0.  ]
 [0.   0.5  0.5  0.  ]
 [0.25 0.25 0.25 0.25]
 [0.   0.   0.   1.  ]
 [0.   1.   0.   0.  ]
 [0.25 0.25 0.25 0.25]
 [0.5  0.   0.   0.5 ]
 [0.   0.   0.   1.  ]
 [0.5  0.5  0.   0.  ]
 [1.   0.   0.   0.  ]
 [1.   0.   0.   0.  ]]
