In [1]:
import numpy as np

In [2]:
nS = 10*10
nA = 4 
N = np.zeros((nS, nA, nS))
Rho = np.zeros((nS, nA))
U = np.zeros((nS))

In [3]:
# load csv files one line at a time
fpath = './data/small.csv'
with open(fpath, 'r') as f:
    for line in f:
        # skip header
        if line.startswith('s,a,r,sp'):
            continue
        s, a, r, sp = line.split(',')
        N[int(s)-1, int(a)-1, int(sp)-1] += 1
        Rho[int(s)-1, int(a)-1] += int(r)

In [4]:
# Estimate of transition probabilities T(s'|s,a) ~ N(s'|s,a) / N(s,a)
T = N / N.sum(axis=2, keepdims=True)

# Estimate of reward function R(s,a) ~ Rho(s,a) / N(s,a)
R = Rho / N.sum(axis=2, keepdims=False)

In [39]:
# Calculate action-value function Q(s,a) (vectorized)
Q = np.zeros((nS, nA))
gamma = 0.9
for i in range(100):
    Q = R + gamma * T.dot(U)
    U = Q.max(axis=1)


In [9]:
# Calculate action-value function Q(s,a) (non-vectorized)
Q = np.zeros((nS, nA))
gamma = 0.9
for i in range(100):
    for s in range(nS):
        for a in range(nA):
            Q[s, a] = R[s, a] + gamma * T[s, a].dot(U)
    U = Q.max(axis=1)


In [10]:
# output the policy based on Q
fpath = './data/small.policy'
policy = np.argmax(Q, axis=1)
with open(fpath, 'w') as f:
    for s in range(nS):
        f.write(str(policy[s]+1) + '\n')

In [58]:
for i in range(10):
    print(policy[i*10:(i+1)*10])

for i in range(10):
    # print the value function with 1 decimal place
    print(['%.1f' % v for v in U[i*10:(i+1)*10]])
    # print(U[i*10:(i+1)*10])

[1 1 1 1 2 2 2 2 2 0]
[1 1 1 1 1 2 2 2 2 0]
[1 1 1 1 1 1 2 2 0 0]
[1 1 1 1 1 1 3 0 0 0]
[1 1 1 1 1 1 3 3 0 0]
[1 1 1 1 3 3 3 3 3 0]
[1 1 1 3 1 3 3 3 0 3]
[1 2 2 3 3 3 3 3 3 3]
[1 1 0 3 3 3 3 3 3 3]
[3 3 3 3 3 3 3 3 3 3]
['6.3', '7.1', '8.5', '10.0', '11.6', '13.4', '15.7', '13.7', '12.1', '10.9']
['6.8', '8.0', '9.6', '11.5', '13.6', '16.3', '19.3', '16.4', '14.1', '12.0']
['7.3', '8.8', '10.6', '13.2', '16.2', '20.0', '24.8', '20.4', '16.7', '13.6']
['7.8', '9.4', '11.9', '15.0', '18.7', '24.1', '32.2', '24.6', '19.4', '15.3']
['7.6', '9.1', '11.0', '13.5', '16.7', '20.2', '24.5', '20.4', '16.6', '13.6']
['6.9', '8.0', '9.5', '11.5', '13.9', '16.5', '18.5', '16.3', '13.7', '11.8']
['6.2', '7.0', '8.1', '9.6', '11.2', '13.5', '14.5', '13.2', '11.2', '9.7']
['6.8', '7.9', '7.3', '8.0', '9.2', '10.8', '11.2', '10.8', '9.1', '8.0']
['7.9', '10.2', '8.1', '7.0', '7.7', '8.6', '9.1', '8.8', '7.6', '6.7']
['7.1', '7.8', '6.8', '6.2', '6.6', '7.2', '7.5', '7.3', '6.5', '5.9']
