In [4]:
import numpy as np

In [8]:
a = np.array([[0.75,-0.5,-0],[0,0.75,-0.5],[-0.5,0,0.75]])
b = np.array([[12],[-3],[6]])
np.linalg.inv(a).dot(b)

array([[24.],
       [12.],
       [24.]])

In [3]:
cov = []
for i in range(8):
    cov_temp = np.random.normal(5,0.3,size=64)
    cov.append(np.diag(cov_temp))

cov[0].shape

(64, 64)

In [2]:
A, M, factor = 4, 81, 0.9925

prob_a1_sparse = np.loadtxt('prob_a1.txt')
prob_a2_sparse = np.loadtxt('prob_a2.txt')
prob_a3_sparse = np.loadtxt('prob_a3.txt')
prob_a4_sparse = np.loadtxt('prob_a4.txt')
rewards = np.loadtxt('rewards.txt')

In [3]:
def reconstruct(sparse):
    res = np.zeros(M ** 2)
    index = ((sparse[:,0] - 1) * M + sparse[:,1]).astype(int)
    res[index-1] = sparse[:,2]
    return res.reshape(M,M)

In [4]:
prob_a1 = reconstruct(prob_a1_sparse)
prob_a2 = reconstruct(prob_a2_sparse)
prob_a3 = reconstruct(prob_a3_sparse)
prob_a4 = reconstruct(prob_a4_sparse)
prob = np.array([prob_a1,prob_a2,prob_a3,prob_a4])

In [5]:
pi = np.random.randint(4,size = 81)

def value_function(P_pi,R):
    return (np.linalg.inv(np.eye(M) - factor * P_pi)).dot(R)

def choose(prob):
    res = np.zeros((M,M))
    for i in range(M):
        res[i,:] += prob[pi[i]][i,:]
    return res

def greedy(V_pi):
    temp = (prob.dot(V_pi[:,np.newaxis])).reshape(4,81).T
    return np.argmax(temp,axis = 1)


In [6]:
pi_old = np.zeros(81)
iteration_p = 0
while (np.array_equal(pi_old,pi) == False):
    iteration_p += 1
    pi_old = pi
    P_pi = choose(prob)
    V_pi = value_function(P_pi,rewards)
    pi = greedy(V_pi)
    if iteration_p >= 100:
        break

In [7]:
target = [3,11,12,15,16,17,20,22,23,24,26,29,30,31,34,35
          ,39,43,48,52,53,56,57,58,59,60,61,62,66,70,71]
target = np.array(target) - 1

policy_iteration = pi[target]
print(policy_iteration)

[2 2 1 3 3 2 2 3 3 0 2 3 3 0 2 1 0 2 0 2 2 3 3 3 3 3 2 2 0 2 1]


In [8]:
def action(policy):
    res = []
    for n in policy:
        if n == 0:
            res.append('left')
        if n == 1:
            res.append('up')
        if n == 2:
            res.append('right')
        else:
            res.append('down')
    
    return res

policy_order = action(policy_iteration)
for i in range(len(target)):
    print(target[i] + 1, policy_order[i])

3 right
11 right
12 up
15 down
16 down
17 down
20 right
22 right
23 down
24 down
26 left
29 down
30 right
31 down
34 down
35 left
39 down
43 right
48 up
52 down
53 left
56 down
57 right
58 left
59 down
60 right
61 right
62 down
66 down
70 down
71 down


In [9]:
P_best = choose(prob)
V_best = value_function(P_best,rewards)

for i in range(len(target)):
    print(target[i] + 1, V_best[target[i]])

3 100.70098072748911
11 102.37526440102091
12 101.52364514898127
15 109.48993453646305
16 110.40903296181362
17 111.3358466339684
20 103.23462341601049
22 106.77826755022934
23 107.67462642880356
24 108.57848711681841
26 112.27044031794428
29 104.10121204279733
30 104.97507555494721
31 105.88853590955101
34 114.16322950263663
35 113.21287932200798
39 103.78140737394392
43 115.1215572691303
48 90.98537960093466
52 116.087929588253
53 122.02491241481366
56 81.39949278128714
57 93.67165583314662
58 95.17285726464925
59 108.3426193434063
60 109.58365071834504
61 123.64307020769661
62 123.18223909953842
66 81.39949278128717
70 125.24978943555789
71 124.2073856333965


Value iteration:

In [10]:
V_0 = np.zeros(M) - 1
V = np.zeros(M)

In [11]:
iteration_v = 0
while (np.linalg.norm(V_0-V) >= 0.001):
    iteration_v += 1
    V_old = V
    temp = (prob.dot(V_old[:,np.newaxis])).reshape(4,81).T
    V = np.max(temp,axis = 1) * factor + rewards
    if iteration_v >= 1000:
        break

In [12]:
pi = np.argmax(temp,axis = 1)
policy_order_b = action(pi[target])

for i in range(len(target)):
    print(target[i] + 1, policy_order_b[i])

3 right
11 right
12 up
15 down
16 down
17 down
20 right
22 right
23 down
24 down
26 left
29 down
30 right
31 down
34 down
35 left
39 down
43 right
48 up
52 down
53 left
56 down
57 right
58 left
59 down
60 right
61 right
62 down
66 down
70 down
71 down


In [13]:
for i in range(len(target)):
    print(target[i] + 1, V[target[i]])

3 100.6369274095141
11 102.3112110830459
12 101.45959183100628
15 109.42587593135482
16 110.3449743567053
17 111.27178802886007
20 103.17057009803548
22 106.71420894898309
23 107.61056782379963
24 108.51442851171294
26 112.20638171283598
29 104.03715872482236
30 104.91102223697223
31 105.82447744733823
34 114.0991708975283
35 113.14882071689965
39 103.71754439276856
43 115.05749866402195
48 90.92855908134533
52 116.02387098314469
53 121.95711509767986
56 81.352631874468
57 93.6183868085393
58 95.11933784164746
59 108.28232511529859
60 109.52317338739665
61 123.57544378730873
62 123.11434073613364
66 81.352631874468
70 125.18188071890775
71 124.13947718920015
