## Q-Learning 
----

Solve easy maze by Q-Learning

In [1]:
import numpy as np
import matplotlib.pyplot as plt

#convert params into probabilities
def theta_to_pi(theta):
    m, n = theta.shape
    pi = np.zeros((m, n))
    for i in range(m):
        pi[i, :] = theta[i, :] / np.nansum(theta[i, :])
        
    return np.nan_to_num(pi)

def get_action(s, Q, eps, pi_0):
    direction = ['up', 'right', 'down', 'left']
    
    if np.random.rand() < eps:
        next_direction = np.random.choice(direction, p=pi_0[s, :])
    else:
        next_direction = direction[np.nanargmax(Q[s, :])]
        
    if next_direction == 'up':
        return 0
    elif next_direction == 'right':
        return 1
    elif next_direction == 'down':
        return 2
    elif next_direction == 'left':
        return 3
    
def get_s_next(s, a):
    direction = ['up', 'right', 'down', 'left']
    next_direction = direction[a]
    
    if next_direction == 'up':
        return s - 3
    elif next_direction == 'right':
        return s + 1
    elif next_direction == 'down':
        return s + 3
    elif next_direction == 'left':
        return s - 1

### How to update Q-value

Accurate Q-value satisfies following Bellman equation:

$$
    Q(s_t, a_t) = R_{t+1} + \gamma Q(s_{t+1}, a_{t+1})
$$

To update Q-value, feedback TD-error for every one step like this:

$$
    Q(s_t, a_t) = Q(s_t, a_t) + \eta * (R_{t+1} + \gamma \max_{a}Q(s_{t+1}, a) - Q(s_t, a_t))
$$

cf. update function in Sarsa method:

$$
    Q(s_t, a_t) = Q(s_t, a_t) + \eta * (R_{t+1} + \gamma Q(s_{t+1}, a_{t+1}) - Q(s_t, a_t))
$$

In [2]:
def update_Q(s, a, r, s_next, a_next, Q, eta, gamma):
    if s_next is 8:
        Q[s, a] = Q[s, a] + eta * (r - Q[s, a])
    else:
        Q[s, a] = Q[s, a] + eta * (r + gamma * np.nanmax(Q[s_next, :]) - Q[s, a])
    
    return Q

In [3]:
def goal_maze_ret_s_a_Q(Q, eps, eta, gamma, pi):
    s = 0 #start
    a = a_next = get_action(s, Q, eps, pi)
    s_a_history = [[0, np.nan]]
    
    while True:
        a = a_next
        s_a_history[-1][1] = a
        s_next = get_s_next(s, a)
        s_a_history.append([s_next, np.nan])
        if s_next is 8:
            r = 1
            a_next = np.nan
        else:
            r = 0
            a_next = get_action(s_next, Q, eps, pi)
        
        Q = update_Q(s, a, r, s_next, a_next, Q, eta, gamma)
        
        if s_next is 8:
            break
        else:
            s = s_next
        
    
    return [s_a_history, Q]

In [4]:
#initial params
theta = np.array([[np.nan, 1, 1, np.nan],  #s0
                   [np.nan, 1, np.nan, 1], #s1
                   [np.nan, np.nan, 1, 1], #s2
                   [1, 1, 1, np.nan],
                   [np.nan, np.nan, 1, 1],
                   [1, np.nan, np.nan, np.nan],
                   [1, np.nan, np.nan, np.nan],
                   [1, 1, np.nan, np.nan], #s7
                   ])
a, b = theta.shape
Q = 0.1 * np.random.rand(a, b) * theta
pi_0 = theta_to_pi(theta)

eta = 0.1
gamma = 0.9
eps = 0.5
v = np.nanmax(Q, axis=1)
is_continue = True
episode = 1

V = [np.nanmax(Q, axis = 1)]

for episode in range(100):
    eps /= 2
    s_a_history, Q = goal_maze_ret_s_a_Q(Q, eps, eta, gamma, pi_0)
    V.append(np.nanmax(Q, axis=1))

In [12]:
from matplotlib import animation
from IPython.display import HTML
import matplotlib.cm as cm

def draw_maze():
    fig = plt.figure(figsize=(5,5))
    ax = plt.gca()

    #draw walls
    plt.plot([1,1], [0,1], color='r', linewidth=2)
    plt.plot([1,2], [2,2], color='r', linewidth=2)
    plt.plot([2,2], [2,1], color='r', linewidth=2)
    plt.plot([2,3], [1,1], color='r', linewidth=2)

    #write status
    plt.text(0.5, 2.5, 'S0', size=14, ha='center')
    plt.text(1.5, 2.5, 'S1', size=14, ha='center')
    plt.text(2.5, 2.5, 'S2', size=14, ha='center')
    plt.text(0.5, 1.5, 'S3', size=14, ha='center')
    plt.text(1.5, 1.5, 'S4', size=14, ha='center')
    plt.text(2.5, 1.5, 'S5', size=14, ha='center')
    plt.text(0.5, 0.5, 'S6', size=14, ha='center')
    plt.text(1.5, 0.5, 'S7', size=14, ha='center')
    plt.text(2.5, 0.5, 'S8', size=14, ha='center')
    plt.text(0.5, 2.3, 'START', size=14, ha='center')
    plt.text(2.5, 0.3, 'GOAL', size=14, ha='center')

    ax.set_xlim(0,3)
    ax.set_ylim(0,3)

    plt.tick_params(axis='both', which='both', bottom=False, top=False,
                    labelbottom=False, right=False, left=False, labelleft=False)
    line, = ax.plot([0.5], [2.5], marker='o', color='g', markersize=60)
    plt.close(fig)
    return fig, ax, line

In [16]:
# visualize the chage of state value
# ATTENTION: take about 30sec

fig, ax, line = draw_maze()

def animate2(i):
    line, = ax.plot([0.5], [2.5], marker='s', color=cm.jet(V[i][0]), markersize=85)
    line, = ax.plot([1.5], [2.5], marker='s', color=cm.jet(V[i][1]), markersize=85)
    line, = ax.plot([2.5], [2.5], marker='s', color=cm.jet(V[i][2]), markersize=85)
    line, = ax.plot([0.5], [1.5], marker='s', color=cm.jet(V[i][3]), markersize=85)
    line, = ax.plot([1.5], [1.5], marker='s', color=cm.jet(V[i][4]), markersize=85)
    line, = ax.plot([2.5], [1.5], marker='s', color=cm.jet(V[i][5]), markersize=85)
    line, = ax.plot([0.5], [0.5], marker='s', color=cm.jet(V[i][6]), markersize=85)
    line, = ax.plot([1.5], [0.5], marker='s', color=cm.jet(V[i][7]), markersize=85)
    line, = ax.plot([2.5], [0.5], marker='s', color=cm.jet(1.0), markersize=85)
    return line,

anim = animation.FuncAnimation(fig, animate2, init_func=init, frames=len(V), interval=50, repeat=False)
HTML(anim.to_jshtml())