In [741]:
from random import random, randint, sample
from numpy import argmax, zeros, ones
from tabulate import tabulate

In [742]:
def update(state, state_new, action, Q_table, Reward, alpha, gamma):
    
    Q_table[state][action] = Q_table[state][action] + alpha * (Reward[state_new] + gamma * max(Q_table[state_new]) 
                                                               - Q_table[state][action])
    
    return Q_table

In [743]:
def non_zero(input):
    
    output = []
    
    for i in range(len(input)):
        if input[i] != 0:
            output.append(i)
    
    return output

In [744]:
def move(epsilon, state, Q_table, Reward, alpha, gamma, Moves, dim):
    
    state_new = state
    nz = [non_zero(Moves[state]), [Q_table[state][i] for i in non_zero(Moves[state])]]
    
    if random() < epsilon:
        action = sample(nz[0], 1)[0]
    else:
        action = nz[0][argmax(nz[1])]
        
    if action == 0:
        state_new = state - 1
    elif action == 1:
        state_new = state + 1
    elif action == 2:
        state_new = state - dim[1]
    elif action == 3:
        state_new = state + dim[1]

    Q_table = update(state, state_new, action, Q_table, Reward, alpha, gamma)
        
    return state_new, Q_table

In [745]:
def learnEpisode(epsilon, start_state, Q_table, Reward, alpha, gamma, Moves, dim, prt):
    
    state = start_state
    
    if prt:
        for j in range(len(Q_table)):
            print('[', end='')

            if j == state:
                print('X', end='')
            elif Reward[j] < -1:
                print('L', end='')
            elif Reward[j] > -1:
                print('W', end='')
            else:
                print(' ', end='')

            print(']', end='')

            if (j%dim[1]) == (dim[1]-1):
                print('\n', end='')

        print('\n')
    
    response = 0
    
    while response == 0:
        
        state, Q_table = move(epsilon, state, Q_table, Reward, alpha, gamma, Moves, dim)
        
        if prt:
            for j in range(len(Q_table)):
                print('[', end='')

                if j == state:
                    print('X', end='')
                elif Reward[j] < -1:
                    print('L', end='')
                elif Reward[j] > -1:
                    print('W', end='')
                else:
                    print(' ', end='')

                print(']', end='')

                if (j%dim[1]) == (dim[1]-1):
                    print('\n', end='')

            print('\n')
        
        if Reward[state] > -1:
            response = 1
        elif Reward[state] < -1:
            response = -1
        else:
            response = 0
        
    return response, Q_table

In [751]:
def learn(Q_table, Reward, alpha, gamma, Moves, n_epochs, dim, prt):
    
    epsilon = 1
    count = 0
    Q_table = [list(zeros(4, dtype=int)) for i in range(dim[0]*dim[1])]
    Moves = [list(ones(4, dtype=int)) for i in range(dim[0]*dim[1])]
    
    for i in range(len(Moves)):
        if i//dim[1] == 0:
            Moves[i][2] = 0
        if i//dim[1] == dim[0]-1:
            Moves[i][3] = 0
        if i%dim[1] == 0:
            Moves[i][0] = 0
        if i%dim[1] == dim[1]-1:
            Moves[i][1] = 0 
            
    if dim[0]*dim[1] != len(Reward):
        print("Length of Reward table and dim of matrix don't match")
    
    for epoch in range(n_epochs):
        starting_states = list(set(range(len(Reward))) - set(non_zero([i+1 for i in Reward])))
        
        for start_state in starting_states:
            response, Q_table = learnEpisode(epsilon, start_state, Q_table, Reward, alpha, gamma, Moves, dim, prt)
            
            if response == 1:
                count += 1
                
            print('epoch: {}, acc: {:.2f}%'.format(epoch+1, 
                                                   count/((epoch*len(starting_states))+
                                                          starting_states.index(start_state) + 1)*100) + 
                                                   '\n\n==========================================================')
        
        epsilon = epsilon * 0.99
    
    return Q_table

In [752]:
dim = [4, 4]
Reward = [-1, -1, -1, 5, -1, -10, -1, -1, -1, -1, -10, -1, -10, -1, 10, -1]
gamma = 0.9
alpha = 0.1
epsilon = 1
n_epochs = 1000
prt = False

In [753]:
Q = learn(Q_table, Reward, alpha, gamma, Moves, n_epochs, dim, prt)

epoch: 1, acc: 0.00%

epoch: 1, acc: 0.00%

epoch: 1, acc: 0.00%

epoch: 1, acc: 0.00%

epoch: 1, acc: 0.00%

epoch: 1, acc: 16.67%

epoch: 1, acc: 14.29%

epoch: 1, acc: 12.50%

epoch: 1, acc: 11.11%

epoch: 1, acc: 20.00%

epoch: 1, acc: 18.18%

epoch: 2, acc: 16.67%

epoch: 2, acc: 15.38%

epoch: 2, acc: 21.43%

epoch: 2, acc: 20.00%

epoch: 2, acc: 25.00%

epoch: 2, acc: 29.41%

epoch: 2, acc: 27.78%

epoch: 2, acc: 26.32%

epoch: 2, acc: 30.00%

epoch: 2, acc: 33.33%

epoch: 2, acc: 36.36%

epoch: 3, acc: 39.13%

epoch: 3, acc: 37.50%

epoch: 3, acc: 36.00%

epoch: 3, acc: 34.62%

epoch: 3, acc: 33.33%

epoch: 3, acc: 35.71%

epoch: 3, acc: 34.48%

epoch: 3, acc: 33.33%

epoch: 3, acc: 35.48%

epoch: 3, acc: 34.38%

epoch: 3, acc: 36.36%

epoch: 4, acc: 35.29%

epoch: 4, acc: 34.29%

epoch: 4, acc: 36.11%

epoch: 4, acc: 35.14%

epoch: 4, acc: 34.21%

epoch: 4, acc: 35.90%

epoch: 4, acc: 35.00%

epoch: 4, acc: 34.15%

epoch: 4, acc: 33.33%

epoch: 4, acc: 34.88%

epoch: 4, acc: 3

epoch: 140, acc: 63.00%

epoch: 140, acc: 63.02%

epoch: 140, acc: 62.98%

epoch: 140, acc: 63.00%

epoch: 140, acc: 62.96%

epoch: 140, acc: 62.99%

epoch: 141, acc: 62.95%

epoch: 141, acc: 62.97%

epoch: 141, acc: 62.99%

epoch: 141, acc: 63.02%

epoch: 141, acc: 63.04%

epoch: 141, acc: 63.00%

epoch: 141, acc: 63.03%

epoch: 141, acc: 63.05%

epoch: 141, acc: 63.07%

epoch: 141, acc: 63.03%

epoch: 141, acc: 63.06%

epoch: 142, acc: 63.02%

epoch: 142, acc: 63.04%

epoch: 142, acc: 63.06%

epoch: 142, acc: 63.02%

epoch: 142, acc: 62.98%

epoch: 142, acc: 63.01%

epoch: 142, acc: 63.03%

epoch: 142, acc: 63.05%

epoch: 142, acc: 63.08%

epoch: 142, acc: 63.10%

epoch: 142, acc: 63.12%

epoch: 143, acc: 63.15%

epoch: 143, acc: 63.17%

epoch: 143, acc: 63.19%

epoch: 143, acc: 63.22%

epoch: 143, acc: 63.24%

epoch: 143, acc: 63.20%

epoch: 143, acc: 63.16%

epoch: 143, acc: 63.18%

epoch: 143, acc: 63.21%

epoch: 143, acc: 63.23%

epoch: 143, acc: 63.25%

epoch: 144, acc: 63.28%



epoch: 266, acc: 75.83%

epoch: 266, acc: 75.84%

epoch: 266, acc: 75.85%

epoch: 266, acc: 75.86%

epoch: 266, acc: 75.86%

epoch: 266, acc: 75.87%

epoch: 266, acc: 75.88%

epoch: 266, acc: 75.89%

epoch: 266, acc: 75.90%

epoch: 266, acc: 75.91%

epoch: 267, acc: 75.91%

epoch: 267, acc: 75.92%

epoch: 267, acc: 75.93%

epoch: 267, acc: 75.94%

epoch: 267, acc: 75.95%

epoch: 267, acc: 75.95%

epoch: 267, acc: 75.96%

epoch: 267, acc: 75.97%

epoch: 267, acc: 75.98%

epoch: 267, acc: 75.99%

epoch: 267, acc: 76.00%

epoch: 268, acc: 76.00%

epoch: 268, acc: 76.01%

epoch: 268, acc: 76.02%

epoch: 268, acc: 76.03%

epoch: 268, acc: 76.04%

epoch: 268, acc: 76.04%

epoch: 268, acc: 76.05%

epoch: 268, acc: 76.06%

epoch: 268, acc: 76.07%

epoch: 268, acc: 76.08%

epoch: 268, acc: 76.09%

epoch: 269, acc: 76.09%

epoch: 269, acc: 76.10%

epoch: 269, acc: 76.11%

epoch: 269, acc: 76.12%

epoch: 269, acc: 76.13%

epoch: 269, acc: 76.13%

epoch: 269, acc: 76.14%

epoch: 269, acc: 76.15%



epoch: 392, acc: 82.53%

epoch: 392, acc: 82.53%

epoch: 392, acc: 82.54%

epoch: 392, acc: 82.54%

epoch: 392, acc: 82.54%

epoch: 392, acc: 82.55%

epoch: 392, acc: 82.55%

epoch: 392, acc: 82.56%

epoch: 392, acc: 82.56%

epoch: 393, acc: 82.56%

epoch: 393, acc: 82.57%

epoch: 393, acc: 82.57%

epoch: 393, acc: 82.58%

epoch: 393, acc: 82.58%

epoch: 393, acc: 82.58%

epoch: 393, acc: 82.59%

epoch: 393, acc: 82.59%

epoch: 393, acc: 82.60%

epoch: 393, acc: 82.60%

epoch: 393, acc: 82.60%

epoch: 394, acc: 82.61%

epoch: 394, acc: 82.61%

epoch: 394, acc: 82.62%

epoch: 394, acc: 82.62%

epoch: 394, acc: 82.62%

epoch: 394, acc: 82.63%

epoch: 394, acc: 82.63%

epoch: 394, acc: 82.64%

epoch: 394, acc: 82.64%

epoch: 394, acc: 82.64%

epoch: 394, acc: 82.65%

epoch: 395, acc: 82.65%

epoch: 395, acc: 82.66%

epoch: 395, acc: 82.66%

epoch: 395, acc: 82.66%

epoch: 395, acc: 82.67%

epoch: 395, acc: 82.67%

epoch: 395, acc: 82.68%

epoch: 395, acc: 82.68%

epoch: 395, acc: 82.68%



epoch: 516, acc: 86.59%

epoch: 516, acc: 86.59%

epoch: 516, acc: 86.59%

epoch: 516, acc: 86.59%

epoch: 517, acc: 86.60%

epoch: 517, acc: 86.60%

epoch: 517, acc: 86.60%

epoch: 517, acc: 86.60%

epoch: 517, acc: 86.60%

epoch: 517, acc: 86.61%

epoch: 517, acc: 86.61%

epoch: 517, acc: 86.61%

epoch: 517, acc: 86.61%

epoch: 517, acc: 86.62%

epoch: 517, acc: 86.62%

epoch: 518, acc: 86.62%

epoch: 518, acc: 86.62%

epoch: 518, acc: 86.63%

epoch: 518, acc: 86.63%

epoch: 518, acc: 86.63%

epoch: 518, acc: 86.63%

epoch: 518, acc: 86.64%

epoch: 518, acc: 86.64%

epoch: 518, acc: 86.64%

epoch: 518, acc: 86.64%

epoch: 518, acc: 86.64%

epoch: 519, acc: 86.65%

epoch: 519, acc: 86.65%

epoch: 519, acc: 86.65%

epoch: 519, acc: 86.65%

epoch: 519, acc: 86.66%

epoch: 519, acc: 86.66%

epoch: 519, acc: 86.66%

epoch: 519, acc: 86.66%

epoch: 519, acc: 86.67%

epoch: 519, acc: 86.67%

epoch: 519, acc: 86.67%

epoch: 520, acc: 86.67%

epoch: 520, acc: 86.67%

epoch: 520, acc: 86.68%



epoch: 610, acc: 88.61%

epoch: 610, acc: 88.61%

epoch: 610, acc: 88.61%

epoch: 610, acc: 88.61%

epoch: 611, acc: 88.62%

epoch: 611, acc: 88.62%

epoch: 611, acc: 88.62%

epoch: 611, acc: 88.62%

epoch: 611, acc: 88.62%

epoch: 611, acc: 88.62%

epoch: 611, acc: 88.63%

epoch: 611, acc: 88.63%

epoch: 611, acc: 88.63%

epoch: 611, acc: 88.63%

epoch: 611, acc: 88.63%

epoch: 612, acc: 88.63%

epoch: 612, acc: 88.64%

epoch: 612, acc: 88.64%

epoch: 612, acc: 88.64%

epoch: 612, acc: 88.64%

epoch: 612, acc: 88.64%

epoch: 612, acc: 88.64%

epoch: 612, acc: 88.65%

epoch: 612, acc: 88.65%

epoch: 612, acc: 88.65%

epoch: 612, acc: 88.65%

epoch: 613, acc: 88.65%

epoch: 613, acc: 88.65%

epoch: 613, acc: 88.66%

epoch: 613, acc: 88.66%

epoch: 613, acc: 88.66%

epoch: 613, acc: 88.66%

epoch: 613, acc: 88.66%

epoch: 613, acc: 88.66%

epoch: 613, acc: 88.67%

epoch: 613, acc: 88.67%

epoch: 613, acc: 88.67%

epoch: 614, acc: 88.67%

epoch: 614, acc: 88.67%

epoch: 614, acc: 88.67%



epoch: 727, acc: 90.40%

epoch: 727, acc: 90.40%

epoch: 727, acc: 90.40%

epoch: 727, acc: 90.41%

epoch: 727, acc: 90.41%

epoch: 727, acc: 90.41%

epoch: 727, acc: 90.41%

epoch: 728, acc: 90.41%

epoch: 728, acc: 90.41%

epoch: 728, acc: 90.41%

epoch: 728, acc: 90.41%

epoch: 728, acc: 90.41%

epoch: 728, acc: 90.42%

epoch: 728, acc: 90.42%

epoch: 728, acc: 90.42%

epoch: 728, acc: 90.42%

epoch: 728, acc: 90.42%

epoch: 728, acc: 90.42%

epoch: 729, acc: 90.42%

epoch: 729, acc: 90.42%

epoch: 729, acc: 90.43%

epoch: 729, acc: 90.43%

epoch: 729, acc: 90.43%

epoch: 729, acc: 90.43%

epoch: 729, acc: 90.43%

epoch: 729, acc: 90.43%

epoch: 729, acc: 90.43%

epoch: 729, acc: 90.43%

epoch: 729, acc: 90.44%

epoch: 730, acc: 90.44%

epoch: 730, acc: 90.44%

epoch: 730, acc: 90.44%

epoch: 730, acc: 90.44%

epoch: 730, acc: 90.44%

epoch: 730, acc: 90.44%

epoch: 730, acc: 90.44%

epoch: 730, acc: 90.44%

epoch: 730, acc: 90.45%

epoch: 730, acc: 90.45%

epoch: 730, acc: 90.45%



epoch: 864, acc: 91.92%

epoch: 864, acc: 91.93%

epoch: 864, acc: 91.93%

epoch: 864, acc: 91.93%

epoch: 864, acc: 91.93%

epoch: 864, acc: 91.93%

epoch: 864, acc: 91.93%

epoch: 865, acc: 91.93%

epoch: 865, acc: 91.93%

epoch: 865, acc: 91.93%

epoch: 865, acc: 91.93%

epoch: 865, acc: 91.93%

epoch: 865, acc: 91.93%

epoch: 865, acc: 91.94%

epoch: 865, acc: 91.94%

epoch: 865, acc: 91.94%

epoch: 865, acc: 91.94%

epoch: 865, acc: 91.94%

epoch: 866, acc: 91.94%

epoch: 866, acc: 91.94%

epoch: 866, acc: 91.94%

epoch: 866, acc: 91.94%

epoch: 866, acc: 91.94%

epoch: 866, acc: 91.94%

epoch: 866, acc: 91.94%

epoch: 866, acc: 91.95%

epoch: 866, acc: 91.95%

epoch: 866, acc: 91.95%

epoch: 866, acc: 91.95%

epoch: 867, acc: 91.95%

epoch: 867, acc: 91.95%

epoch: 867, acc: 91.95%

epoch: 867, acc: 91.95%

epoch: 867, acc: 91.95%

epoch: 867, acc: 91.95%

epoch: 867, acc: 91.95%

epoch: 867, acc: 91.96%

epoch: 867, acc: 91.96%

epoch: 867, acc: 91.96%

epoch: 867, acc: 91.96%



epoch: 986, acc: 92.92%

epoch: 986, acc: 92.92%

epoch: 986, acc: 92.92%

epoch: 986, acc: 92.92%

epoch: 986, acc: 92.93%

epoch: 986, acc: 92.93%

epoch: 986, acc: 92.93%

epoch: 986, acc: 92.93%

epoch: 986, acc: 92.93%

epoch: 986, acc: 92.93%

epoch: 987, acc: 92.93%

epoch: 987, acc: 92.93%

epoch: 987, acc: 92.93%

epoch: 987, acc: 92.93%

epoch: 987, acc: 92.93%

epoch: 987, acc: 92.93%

epoch: 987, acc: 92.93%

epoch: 987, acc: 92.93%

epoch: 987, acc: 92.93%

epoch: 987, acc: 92.93%

epoch: 987, acc: 92.94%

epoch: 988, acc: 92.94%

epoch: 988, acc: 92.94%

epoch: 988, acc: 92.94%

epoch: 988, acc: 92.94%

epoch: 988, acc: 92.94%

epoch: 988, acc: 92.94%

epoch: 988, acc: 92.94%

epoch: 988, acc: 92.94%

epoch: 988, acc: 92.94%

epoch: 988, acc: 92.94%

epoch: 988, acc: 92.94%

epoch: 989, acc: 92.94%

epoch: 989, acc: 92.94%

epoch: 989, acc: 92.94%

epoch: 989, acc: 92.95%

epoch: 989, acc: 92.95%

epoch: 989, acc: 92.95%

epoch: 989, acc: 92.95%

epoch: 989, acc: 92.95%



In [740]:
print(tabulate([[i] + Q[i] for i in range(len(Q))], headers =['Left', 'Right', 'Up', 'Down'], tablefmt='orgtbl'))

|    |     Left |    Right |       Up |     Down |
|----+----------+----------+----------+----------|
|  0 |  0       |  2.14987 |  0       |  3.122   |
|  1 |  1.6772  |  3.5     |  0       | -9.99803 |
|  2 |  2.14599 |  5       |  0       |  3.1018  |
|  3 |  0       |  0       |  0       |  0       |
|  4 |  0       | -9.99955 |  1.78761 |  4.58    |
|  5 |  0       |  0       |  0       |  0       |
|  6 | -9.96619 |  4.58    |  3.46635 | -9.86697 |
|  7 |  3.10763 |  0       |  4.99999 |  6.2     |
|  8 |  0       |  6.2     |  3.12059 | -9.99906 |
|  9 |  4.57303 | -9.9996  | -9.99884 |  8       |
| 10 |  0       |  0       |  0       |  0       |
| 11 | -9.99976 |  0       |  4.57256 |  8       |
| 12 |  0       |  0       |  0       |  0       |
| 13 | -9.99967 | 10       |  6.19936 |  0       |
| 14 |  0       |  0       |  0       |  0       |
| 15 | 10       |  0       |  6.2     |  0       |
