In [None]:
from IPython.core.debugger import set_trace
import numpy as np
import pprint
import sys
if "../" not in sys.path:
    sys.path.append("../")
from lib.envs.gridworld import GridworldEnv

In [None]:
pp = pprint.PrettyPrinter(indent=2)
env = GridworldEnv()

In [None]:
def policy_evaluation(policy, env, discount_factor=1.0, theta=1e-5):
    """
    Function to evaluate the value function. Policy evalution means the process of calculating the value
    function for a given policy. 

    Arguments:
        policy: [S, A] shaped matrix representing the policy. For each [S, A] entry, it has a probability value.
        env: The environment to evaluate the policy in.
            env.P[s][a] -> is a list of transition tuples (probability, next_state, reward, done)
            env.nS -> number of states in the environment
            env.nA -> number of actions in the environment
        discount_factor: Gamma discount factor. It determines how much importance we give to future rewards.
        theta: We stop evaluation once ou value function change is less that theta for every state.
    
    Returns:   
        V: A vector of length env.nS representing the value function for each state.

    """

    # Start with a random (all 0) value function
    V = np.zeros(env.nS)
    while True:
        delta = 0 #keeps track of the maximum change in the value function over all states.
        #For each state, perform a "full backup"
        for states in range(env.nS):
            v = 0
            for action, action_prob in enumerate(policy[states]):
                for prob, next_state, reward, done in env.P[states][action]:
                    v += action_prob * prob * (reward + discount_factor * V[next_state]) #bellman equation update
            delta = max(delta, np.abs(v - V[states])) # calculate the maximum change in value function.
            V[states] = v
        if delta < theta:
            break
    return V
         

In [None]:
def policy_iteration(env, policy_eval_func, discount_factor): #or policy improvement.
    """
    Policy Improvement Algorithm. Iteratively evaluates and improves a policy until 
    an optimal policy is achieved.
    Arguments:
        env: the environment to evaluate the policy in.
            env.P[s][a] -> is a list of transition tuples (probability, next_state, reward, done)
            env.nS -> number of states in the environment
            env.nA -> number of actions in the environment
        policy_eval_func: Function which evaluates a policy and returns the value function.
        discount_factor: Gamma discout factor.

    Returns:
    Tuple(Policy, Value Function). 
    Policy is the optimal policy.
    Value function is the value function for the optimal policy.    
    """

    def one_step_lookahead(state, V):
        """
        Helper function to calculate the value for all action in a state, using the given state and value.
        Args:
            state: the state to consider.
            V: The value to use as an estimator, vector of length env.nS
        Returns:
            A vector of length env.nA containing the expected value of each action from a given state  
        """

        A = np.zeros(env.nA)
        for a in range(env.nA):
            for prob, next_state, reward, done in env.P[state][a]:
                A[a] += prob * (reward + discount_factor * V[next_state]) #Bellman Equations

        return A
    
    #Start with a random policy
    policy = np.ones([env.nS, env.nA]) / env.nA #Uniform policy

    while True:
        # Evaluate the current policy
        V = policy_eval_func(policy, env, discount_factor)

        policy_stable = True

        for s in range(env.nS):

            #Choose the best action with highest probability
            chosen_a = np.argmax(policy[s])

            #find the best action using one-step lookahed
            action_values = one_step_lookahead(s, V)
            best_a = np.argmax(action_values)

            #Greedily update the policy
            if chosen_a != best_a:
                policy_stable= False
            policy[s] = np.eye(env.nA)[best_a] #For state s, probability of action `best_a` is 1, others 0.

        #If the policy is stable we've found the optimal policy.
        if policy_stable:
            return policy, V


In [None]:
policy, v = policy_iteration(env)
print("Policy Probability Distribution:")
print(policy)
print("")

print("Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):")
print(np.reshape(np.argmax(policy, axis=1), env.shape))
print("")

print("Value Function:")
print(v)
print("")

print("Reshaped Grid Value Function:")
print(v.reshape(env.shape))
print("")