In [69]:
import numpy as np
import pandas as pd

def prettyprint(table, index_name):
    df = pd.DataFrame(table)
    df.index.name = index_name
    display(df)

### Exercise 4.1
Since the policy is uniformly random, $\gamma=1$ (No discounting). Therefore the bellman equation is 
$$ q(s,a) = \sum_{s'\in \mbox{ajacent states of }s}[ r(s,a) + 0.25 \sum_{a'}q(s',a')]$$

In [84]:
TERMINAL_STATES = set([(0,0), (3,3)])
ACTION_TO_VEC = {"left": (0, -1), "right": (0, 1), "up": (-1, 0), "down": (1, 0)}
ACTIONS = ACTION_TO_VEC.keys()

def get_valid_states():
    states = set([])
    for i in range(4):
        for j in range(4):
            if (i,j) not in TERMINAL_STATES:
                states.add((i,j))
    return states
                
VALID_STATES = get_valid_states()
print(VALID_STATES)    

{(0, 1), (1, 2), (2, 1), (3, 1), (1, 1), (0, 3), (2, 0), (3, 0), (2, 3), (0, 2), (2, 2), (1, 0), (3, 2), (1, 3)}


In [85]:
class Q:
    def __init__(self):
        self.q = np.zeros((4,4,4)) # i, j, a
        self.a_to_pos = {"up": 0, "down": 1, "right": 2, "left": 3}
        
    def get_q(self, s, a):
        if s not in VALID_STATES:
            return 0
        return self.q[s[0], s[1], self.a_to_pos[a]]
    
    def set_q(self, s, a, v):
        if s in VALID_STATES:
            self.q[s[0], s[1], self.a_to_pos[a]] = v
        
    def diff(self, q_prime):
        return np.linalg.norm(self.q - q_prime.q)
    
    def copy(self):
        dup = Q()
        dup.q = np.array(self.q, copy=True)
        return dup
    
class V:
    def __init__(self):
        self.v = np.zeros((4,4)) # i, j
        
    def get_v(self, s):
        if s not in VALID_STATES:
            return 0
        return self.v[s[0], s[1]]
    
    def set_v(self, s, v):
        if s in VALID_STATES:
            self.v[s[0], s[1]] = v
        
    def diff(self, v_prime):
        return np.sum(np.absolute(self.v - v_prime.v))
    
    def copy(self):
        dup = V()
        dup.v = np.array(self.v, copy=True)
        return dup

def transition(s, a):
    # If already at terminal state return 0 (extension of episodic tasks as continuing tasks)
    if s in TERMINAL_STATES:
        return (s, 0)
    
    a_vec = ACTION_TO_VEC[a]
    s_prime = (s[0] + a_vec[0], s[1] + a_vec[1])
    
    # reaching terminal state from non terminal gets -1
    if s_prime in TERMINAL_STATES:
        return (s_prime, -1)
    
    if s_prime not in VALID_STATES:
        return (s, -1)
    
    return (s_prime, -1) 

error_tolerance = 1e-4
error = 1
q = Q()
while error > error_tolerance:
    current_q = q.copy()
    for s in VALID_STATES:
        for a in ACTIONS:
            s_prime, r = transition(s, a)
            new_s_a_val = r
            for a_prime in ACTIONS:
                    new_s_a_val += (0.25 * q.get_q(s_prime, a_prime))
            q.set_q(s, a, new_s_a_val)

    error = q.diff(current_q)
    
prettyprint(np.round(q.q[:,:,0], 1), "p(s, up)")
prettyprint(np.round(q.q[:,:,1], 1), "p(s, down)")
prettyprint(np.round(q.q[:,:,2], 1), "p(s, right)")
prettyprint(np.round(q.q[:,:,3], 1), "p(s, left)")

error = 1
v = V()
while error > error_tolerance:
    current_v = v.copy()
    for s in VALID_STATES:
        new_s_val = 0
        for a in ACTIONS:
            s_prime, r = transition(s, a)
            new_s_val += (0.25 * (r +  v.get_v(s_prime)))
        v.set_v(s, new_s_val)
    error = v.diff(current_v)
    
prettyprint(np.round(v.v, 1), "v(s)")

Unnamed: 0_level_0,0,1,2,3
"p(s, up)",Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,0.0,-15.0,-21.0,-23.0
1,-1.0,-15.0,-21.0,-23.0
2,-15.0,-19.0,-21.0,-21.0
3,-21.0,-21.0,-19.0,0.0


Unnamed: 0_level_0,0,1,2,3
"p(s, down)",Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,0.0,-19.0,-21.0,-21.0
1,-21.0,-21.0,-19.0,-15.0
2,-23.0,-21.0,-15.0,-1.0
3,-23.0,-21.0,-15.0,0.0


Unnamed: 0_level_0,0,1,2,3
"p(s, right)",Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,0.0,-21.0,-23.0,-23.0
1,-19.0,-21.0,-21.0,-21.0
2,-21.0,-19.0,-15.0,-15.0
3,-21.0,-15.0,-1.0,0.0


Unnamed: 0_level_0,0,1,2,3
"p(s, left)",Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,0.0,-1.0,-15.0,-21.0
1,-15.0,-15.0,-19.0,-21.0
2,-21.0,-21.0,-21.0,-19.0
3,-23.0,-23.0,-21.0,0.0


Unnamed: 0_level_0,0,1,2,3
v(s),Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,0.0,-14.0,-20.0,-22.0
1,-14.0,-18.0,-20.0,-20.0
2,-20.0,-20.0,-18.0,-14.0
3,-22.0,-20.0,-14.0,0.0


From the result above, $q_{\pi}(11, down) = -1$ and $q_{\pi}(7, down) = -15$