In [42]:
import numpy as np
import math
import matplotlib.pyplot as plt
from scipy.optimize import fsolve

In [43]:
def eqNonLin(vars):
    x = vars

    k = lambda i, j : 5 * i + j
    fns = []
    for i in range(5):
        for j in range(5): 
            L1 = i - 1 < 0
            L2 = i + 1 > 4
            L3 = j - 1 < 0
            L4 = j + 1 > 4

            if i == 0 and j == 1:
                f = x[i,j] - (0.9*x[k(4,1)]+10)
            elif i == 0 and j == 3:
                f = x[i,j] - (0.9*x[k(2,3)]+5)
            else:
                e1 = (x[i,j]-1 if L1 else 0.9*x[k(i-1,j)])
                e2 = (x[i,j]-1 if L2 else 0.9*x[k(i+1,j)])
                e3 = (x[i,j]-1 if L3 else 0.9*x[k(i,j-1)])
                e4 = (x[i,j]-1 if L4 else 0.9*x[k(i,j+1)])
                f = x[i,j] - max([e1,e2,e3,e4])

            fns.append(f)
    return fns

In [44]:
x =  fsolve(eqNonLin, tuple(np.ones(25)))

In [45]:
np.set_printoptions(precision=1)
print(x.reshape((5,5)))

[[22.  24.4 22.  19.4 17.5]
 [19.8 22.  19.8 17.8 16. ]
 [17.8 19.8 17.8 16.  14.4]
 [16.  17.8 16.  14.4 13. ]
 [14.4 16.  14.4 13.  11.7]]


In [46]:
def eqLin(g):
    A = np.zeros(25*25).reshape((25,25))
    b = np.zeros(25)

    k = lambda i, j : 5 * i + j
    for i in range(5):
        for j in range(5): 
            L1 = i - 1 < 0
            L2 = i + 1 > 4
            L3 = j - 1 < 0
            L4 = j + 1 > 4

            A[i,j, i,j] = 1 - g/4*(L1 + L2 + L3 + L4)
            if not L1:
                A[i,j, k(i-1,j)] = -g/4
            if not L2:
                A[i,j, k(i+1,j)] = -g/4
            if not L3:
                A[i,j, k(i,j-1)] = -g/4
            if not L4:
                A[i,j, k(i,j+1)] = -g/4

            b[i,j] = -(L1 + L2 + L3 + L4)/4

    A[k(0,1),:] = 0
    A[k(0,1), k(0,1)] = 1
    A[k(0,1), k(4,1)] = -g
    b[k(0,1)] = 10

    A[k(0,3),:] = 0
    A[k(0,3), k(0,3)] = 1
    A[k(0,3), k(2,3)] = -g
    b[k(0,3)] = 5

    return A, b

In [47]:
A, b = eqLin(0.9)
x = np.linalg.solve(A, b)

In [48]:
np.set_printoptions(precision=1)
print(x.reshape((5,5)))

[[ 3.3  8.8  4.4  5.3  1.5]
 [ 1.5  3.   2.3  1.9  0.5]
 [ 0.1  0.7  0.7  0.4 -0.4]
 [-1.  -0.4 -0.4 -0.6 -1.2]
 [-1.9 -1.3 -1.2 -1.4 -2. ]]


In [61]:
def randomPolicy(state, actionSpace):
    return np.full(actionSpace.shape, 1 / len(actionSpace))

In [59]:
print(randomPolicy(1, np.ones(4)))
print(np.ones(4).shape)
print(len(np.ones(4)))

[0.2 0.2 0.2 0.2]
(4,)
4


In [160]:
def policyEvalIter(policy, statesShape, gamma, eps):
    v = np.ones(statesShape)
    vOld = 0
    res = 1.1*eps
    iter = 0
    k = lambda i, j : 5 * i + j
    while True:
        res = 0
        for i in range(5):
            for j in range(5): 
                vOld = v[i,j]
                L1 = i - 1 < 0
                L2 = i + 1 > 4
                L3 = j - 1 < 0
                L4 = j + 1 > 4
                actions = randomPolicy(k(i,j), np.ones(4))

                if i == 0 and j == 1:
                    v[i,j] = 10 + gamma*(v[4,1])
                elif i == 0 and j == 3:
                    v[i,j] = 5 + gamma*(v[2,3])
                else:
                    update = 0
                    for index, a in np.ndenumerate(actions):
                        if index[0] == 0:
                            update = update + a * (-1 if L1 else 0 + gamma*(v[i,j] if L1 else v[i-1,j]))
                        elif index[0] == 1:
                            update = update + a * (-1 if L2 else 0 + gamma*(v[i,j] if L2 else v[i+1,j]))
                        elif index[0] == 2:
                            update = update + a * (-1 if L3 else 0 + gamma*(v[i,j] if L3 else v[i,j-1]))
                        elif index[0] == 3:
                            update = update + a * (-1 if L4 else 0 + gamma*(v[i,j] if L4 else v[i,j+1]))
                    v[i,j] = update

                res = max([res, abs(vOld - v[i,j])])
        iter = iter + 1
        if res < eps:
            break
    return v, iter

In [161]:
v, iter = policyEvalIter(randomPolicy, (5,5), 0.9, 0.001)

In [162]:
np.set_printoptions(precision=1, suppress=True)
print(v)
print("iter = {}".format(iter))

[[ 1.8  9.5  3.6  5.5  0.8]
 [ 0.9  3.   2.1  1.9  0.3]
 [ 0.   0.9  0.8  0.6 -0.2]
 [-0.4  0.   0.1 -0.1 -0.5]
 [-0.7 -0.5 -0.5 -0.5 -0.7]]
iter = 16
