The problem that this code answers is described in Sutton and Barto pg 92. Essentially, the goal is to construct a Value matrix obtained through policy iteration on a 4 by 4 grid state space (positions (1,1) and (4,4) are absorbing states). The policy iteration function is given by: $V(s) = \sum_a \pi(s,a) \sum_{s'} P_{ss'}^a [R_{ss'}^a + \gamma V^{\pi}(s')]$. It is important to note that there is a drawback: if policy evaluation is done iteratively then it can be concluded that exact convergence to the value matrix occurs in the limit. The question that then presents itself is: do we need to wait that long? And this is where value iteration comes in (ASIDE: value iteration is computationally more difficult since it deals with multiple nonlinear equations and multiple unknowns, whereas policy iteration deals with multiple LINEAR equation and multiple unknowns. This can all be moot point since the equations are solved using iterative solutions and not matrix algebra).

In [65]:
#THIS FUNCTION RETURNS THE DELTA (THE MAX VALUE CHANGE BETWEEN ITERATION TO ITERATION)
    # THE ITERATION NUMBER
    # AND THE VALUE MATRIX FROM EACH ITERATION
import numpy as np
gamma = 1

sideL = 4
nGrids = sideL**2

#An array to hold the values of the state-value function
V = np.zeros((sideL,sideL))


#Some parameters for convergence
MAX_N_ITERS = 1000
iterCnt = 0
CONV_TOL = 0.000001
delta = 0.00001

#let transition probability be uniform across gridspace
pol_pi = 0.25

#while(iterCnt <=20):
while(delta > CONV_TOL and iterCnt<=MAX_N_ITERS):
    #while delta > CONV_TOL and itercnt<=MAX_N_ITERS:
    delta = 0
    for ii in range(0,sideL):
        for jj in range(0,sideL):
            if (ii == 0 and jj == 0) or (ii == sideL-1 and jj == sideL-1 ):
                continue
        
            v = V[ii,jj]
            v_tmp = 0
    
            #action: UP
            if ii == 0: # state is on the top row, action does not change position
                v_tmp = v_tmp + pol_pi*(-1 + gamma*V[ii,jj])
            elif ii == 1 and jj == 0:
                v_tmp = v_tmp + pol_pi*(0 + gamma*V[ii-1,jj])
            else:
                v_tmp = v_tmp + pol_pi*(-1 + gamma*V[ii-1,jj])
        
            #action: DOWN
            if ii == sideL - 1: #state is on the bottom row, action does not change position
                v_tmp = v_tmp + pol_pi*(-1 + gamma*V[ii,jj])
            elif ii == sideL - 2 and jj == sideL -1:
                v_tmp = v_tmp + pol_pi*(0 + gamma*V[ii+1,jj])
            else:
                v_tmp = v_tmp + pol_pi*(-1 + gamma*V[ii+1,jj])
            
            #action: RIGHT
            if jj == sideL -1: #state is in the rightmost column, action does not change position
                v_tmp = v_tmp + pol_pi*(-1 + gamma*V[ii,jj])
            elif ii == sideL - 1 and jj == sideL - 2:
                v_tmp = v_tmp + pol_pi*(0 + gamma*V[ii,jj+1])
            else:
                v_tmp = v_tmp + pol_pi*(-1 + gamma*V[ii,jj+1])
        
            #action: LEFT
            if jj == 0: #state is in the leftmost column, action does not change position
                v_tmp = v_tmp +pol_pi*(-1 + gamma*V[ii,jj])
            elif jj ==1 and ii == 0:
                v_tmp = v_tmp + pol_pi*(0 + gamma*V[ii,jj-1])
            else:
                v_tmp = v_tmp + pol_pi*(-1 + gamma*V[ii,jj-1])
        
            V[ii,jj] = v_tmp
            delta = max(delta,abs(v-V[ii,jj]))
    print(delta,(V*10).round(1)/10,iterCnt)
    iterCnt+=1


1.8203125 [[ 0.   -0.75 -1.19 -1.3 ]
 [-0.75 -1.38 -1.64 -1.73]
 [-1.19 -1.64 -1.82 -1.64]
 [-1.3  -1.73 -1.64  0.  ]] 0
1.57275390625 [[ 0.   -1.58 -2.43 -2.69]
 [-1.58 -2.61 -3.15 -3.3 ]
 [-2.43 -3.15 -3.39 -2.83]
 [-2.69 -3.3  -2.83  0.  ]] 1
1.39788818359 [[ 0.   -2.4  -3.67 -4.09]
 [-2.4  -3.78 -4.53 -4.69]
 [-3.67 -4.53 -4.68 -3.8 ]
 [-4.09 -4.69 -3.8   0.  ]] 2
1.34772109985 [[ 0.   -3.21 -4.87 -5.43]
 [-3.21 -4.87 -5.78 -5.93]
 [-4.87 -5.78 -5.79 -4.63]
 [-5.43 -5.93 -4.63  0.  ]] 3
1.26940321922 [[ 0.   -3.99 -6.02 -6.7 ]
 [-3.99 -5.88 -6.91 -7.04]
 [-6.02 -6.91 -6.77 -5.36]
 [-6.7  -7.04 -5.36  0.  ]] 4
1.1805472523 [[ 0.   -4.72 -7.09 -7.88]
 [-4.72 -6.81 -7.93 -8.05]
 [-7.09 -7.93 -7.64 -6.01]
 [-7.88 -8.05 -6.01  0.  ]] 5
1.09041144606 [[ 0.   -5.41 -8.08 -8.97]
 [-5.41 -7.67 -8.86 -8.98]
 [-8.08 -8.86 -8.44 -6.61]
 [-8.97 -8.98 -6.61  0.  ]] 6
1.00343448954 [[ 0.   -6.04 -8.99 -9.98]
 [-6.04 -8.45 -9.71 -9.82]
 [-8.99 -9.71 -9.16 -7.15]
 [-9.98 -9.82 -7.15  0.  ]] 7
0.921

In [None]:
#Reference: http://waxworksmath.com/Authors/N_Z/Sutton/Code/Chapter_4/iter_poly_gw_inplace.mttp://waxworksmath.com/Authors/N_Z/Sutton/Code/Chapter_4/gam_Script.m