In [2]:
from IPython.core.debugger import set_trace
import numpy as np
import pprint
import sys
if "../" not in sys.path:
    sys.path.append("../")
from lib.envs.gridworld import GridworldEnv

In [None]:
pp = pprint.PrettyPrinter(indent=2)
env = GridworldEnv()

In [3]:
def policy_evaluation(policy, env, discount_factor=1.0, theta=1e-5):
    """
    Function to evaluate the value function. Policy evalution means the process of calculating the value
    function for a given policy. 

    Arguments:
        policy: [S, A] shaped matrix representing the policy. For each [S, A] entry, it has a probability value.
        env: The environment to evaluate the policy in.
            env.P[s][a] -> is a list of transition tuples (probability, next_state, reward, done)
            env.nS -> number of states in the environment
            env.nA -> number of actions in the environment
        discount_factor: Gamma discount factor. It determines how much importance we give to future rewards.
        theta: We stop evaluation once ou value function change is less that theta for every state.
    
    Returns:   
        V: A vector of length env.nS representing the value function for each state.

    """

    # Start with a random (all 0) value function
    V = np.zeros(env.nS)
    while True:
        delta = 0 #keeps track of the maximum change in the value function over all states.
        #For each state, perform a "full backup"
        for states in range(env.nS):
            v = 0
            for action, action_prob in enumerate(policy[states]):
                for prob, next_state, reward, done in env.P[states][action]:
                    v += action_prob * prob * (reward + discount_factor * V[next_state]) #bellman equation update
            delta = max(delta, np.abs(v - V[states])) # calculate the maximum change in value function.
            V[states] = v
        if delta < theta:
            break
    return V
         

In [None]:
random_policy = np.ones([env.nS, env.nA]) / env.nA
V = policy_evaluation(random_policy, env)