In [1]:
import numpy as np
import pandas as pd

def prettyprint(table, index_name):
    df = pd.DataFrame(table)
    df.index.name = index_name
    display(df)

### Exercise 4.2

In [23]:
TERMINAL_STATES = set([(0,0), (3,3)])
INVALID_STATES = set([(4, 0), (4, 2), (4, 3)])
VALID_STATES = get_valid_states()
ACTION_TO_VEC = {"left": (0, -1), "right": (0, 1), "up": (-1, 0), "down": (1, 0)}
def get_valid_states():
    states = set([])
    for i in range(5):
        for j in range(4):
            if ((i,j) not in TERMINAL_STATES) and ((i,j) not in INVALID_STATES):
                states.add((i,j))
    return states
                
print(VALID_STATES)    

{(0, 1), (1, 2), (2, 1), (4, 1), (3, 1), (1, 1), (0, 3), (2, 0), (3, 0), (2, 3), (0, 2), (2, 2), (1, 0), (3, 2), (1, 3)}


In [27]:
class V:
    def __init__(self):
        self.v = np.zeros((5,4))
        
    def get_v(self, s):
        if s not in VALID_STATES:
            return 0
        return self.v[s[0], s[1]]
    
    def set_v(self, s, v):
        if s in VALID_STATES:
            self.v[s[0], s[1]] = v
        
    def diff(self, v_prime):
        return np.sum(np.absolute(self.v - v_prime.v))
    
    def copy(self):
        dup = V()
        dup.v = np.array(self.v, copy=True)
        return dup    

In [28]:
def transition(s, a):
    # If already at terminal state return 0 (extension of episodic tasks as continuing tasks)
    if s in TERMINAL_STATES:
        return (s, 0)
    
    a_vec = ACTION_TO_VEC[a]
    s_prime = (s[0] + a_vec[0], s[1] + a_vec[1])
    
    # reaching terminal state from non terminal gets -1
    if s_prime in TERMINAL_STATES:
        return (s_prime, -1)
    
    if s_prime not in VALID_STATES:
        return (s, -1)
    
    return (s_prime, -1) 

In [29]:
error_tolerance = 1e-4
error = 1
v = V()
while error > error_tolerance:
    current_v = v.copy()
    for s in VALID_STATES:
        new_s_val = 0
        for a in ["up", "down", "left", "right"]:
            s_prime, r = transition(s, a)
            new_s_val += (0.25 * (r +  v.get_v(s_prime)))
        v.set_v(s, new_s_val)
    error = v.diff(current_v)
    
prettyprint(np.round(v.v, 1), "v(s)")

Unnamed: 0_level_0,0,1,2,3
v(s),Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,0.0,-14.7,-20.9,-22.9
1,-15.0,-19.2,-21.1,-20.9
2,-21.8,-21.9,-19.3,-14.7
3,-24.6,-23.4,-15.6,0.0
4,0.0,-27.4,0.0,0.0


As we can see, adding a dead end state 15 causes ripple effect on all states basic on the distance to the added state. The surrounding state values get decreased as the states further from it gets slight increase in value. Based on the table, we have $v_{\pi}(15) = -27.4$

### Exercise 4.3

$\begin{align}
q_{\pi}(s,a) &= E_{\pi}[G_t|s,a] \\
&= E_{\pi}[R_{t+1}+\gamma G_{t+1}|s,a] \\
&= E_{\pi}[R_{t+1}|s,a] +  \gamma  E_{\pi}[G_{t+1}|s,a] \\
&= E_{\pi}[R_{t+1}|s,a] +  \gamma \sum_{s',a'} E_{\pi}[G_{t+1}|s',a']p(s',a'|s,a) \\
&= \sum_{s',r}rp(s',r|s,a) +  \gamma \sum_{s',a'} q_{\pi}(s',a')p(a'|s',s,a)\sum_r p(s',r|s,a) \\
&= \sum_{s',r}p(s',r|s,a) \left[r +  \gamma \sum_{a'} q_{\pi}(s',a')\pi(a'|s')\right] \\
\end{align}$

Then we have $$
q_{t+1}(s,a) = \sum_{s',r}p(s',r|s,a) \left[r +  \gamma \sum_{a'} q_{t}(s',a')\pi(a'|s')\right] $$