In [1]:
import gym
import numpy as np

In [2]:
env = gym.make('gym_gridworld:gridworld-v4', size=3, player_pos=[0,0], goal_pos=[[2,2]]) 

In [4]:
UP = 0
RIGHT = 1
DOWN = 2
LEFT = 3

# Tabular TD(0) Prediction

In [4]:
def tabularTDZero(policy=[0.25,0.25,0.25,0.25], alpha=0.1, gamma=1):
    V = [0 for i in env.observation_space]
    
    for i in range(100000):
        obs0 = env.reset()
        
        done = False
        while not done:
            action = np.random.choice([UP, RIGHT, DOWN, LEFT], p=policy)
            obs1, reward, done, info = env.step(action)
            
            obs0_i = env.observation_space.index(obs0)
            obs1_i = env.observation_space.index(obs1)
            
            V[obs0_i] = V[obs0_i] + alpha * (reward + gamma * V[obs1_i] - V[obs0_i])
            
            obs0 = obs1
        
    return V

In [5]:
a = tabularTDZero()

In [6]:
np.array(a).reshape(3,3)

array([[-26.63526579, -24.01310281, -22.12160586],
       [-24.71451945, -21.64036123, -15.52603753],
       [-22.73708552, -16.42409977,   0.        ]])

In [7]:
env = gym.make('gym_gridworld:gridworld-v4', size=4, player_pos=[0,3], goal_pos=[[3,3], [0,0]]) 

In [8]:
a = tabularTDZero()
np.array(a).reshape(4,4)

array([[  0.        ,  -8.51834845, -17.51567378, -19.7675834 ],
       [-13.35316172, -15.54478301, -17.99881873, -16.48236428],
       [-17.36659171, -18.55989591, -16.90376001,  -9.96660682],
       [-20.70701166, -19.89462409, -14.49348776,   0.        ]])

# SARSA Control

In [43]:
def sarsa(alpha=0.1, epsilon=0.1, gamma=1):
    Q = [[0 for s in env.observation_space] for a in [UP, RIGHT, DOWN, LEFT]]
    
    # Loop for each episode
    for i in range(100000):
        
        # Initialize S
        obs0 = env.reset()
        
        # Choose A from S using policy derived from Q (e-greedy)
        obs0_i = env.observation_space.index(obs0)
        action0 = np.argmax([Q[i][obs0_i] for i, a in enumerate(Q)])
        action0 = np.random.choice([action0, UP, RIGHT, DOWN, LEFT], p=[1-epsilon,epsilon/4,epsilon/4,epsilon/4,epsilon/4])
        
        # Loop for each step of episode
        done = False
        while not done:
            
            # Take action A, observe R, S'
            obs1, reward, done, info = env.step(action0)
            
            # Choose A' from S' using policy derived from Q (e-greedy)
            obs1_i = env.observation_space.index(obs1)
            action1 = np.argmax([Q[i][obs1_i] for i, a in enumerate(Q)])
            action1 = np.random.choice([action1, UP, RIGHT, DOWN, LEFT], p=[1-epsilon,epsilon/4,epsilon/4,epsilon/4,epsilon/4])
            
            # Q(S,A) <- Q(S,A) + alpha * [ R + gamma * Q(S',A') - Q(S,A) ]
            Q[action0][obs0_i] = Q[action0][obs0_i] + alpha * (reward + gamma * Q[action1][obs1_i] - Q[action0][obs0_i])
            
            #print('-------------')
            #print(Q)
            #print(action0, action1)
            #print(obs1, reward, done)
            #print(obs0_i, obs1_i)
            #print([Q[i][obs1_i] for i, a in enumerate(Q)])
            
            # S <- S'; A <- A'
            obs0_i = obs1_i
            action0 = action1
            
        
    return Q

In [44]:
a_sarsa = sarsa()

In [49]:
for i in a_sarsa:
    print(np.array(i).reshape(3,3), '\n')

[[-4.37664781 -3.47920544 -2.29884309]
 [-4.5523513  -3.46657483 -2.37272378]
 [-3.52933475 -2.26217682  0.        ]] 

[[-3.56206115 -2.61889128 -2.40212858]
 [-2.71338079 -1.45783181 -1.03649461]
 [-1.35573237  0.          0.        ]] 

[[-3.19051821 -2.33432524 -1.32731641]
 [-2.11856654 -1.27186723  0.        ]
 [-2.28946262 -1.05874868  0.        ]] 

[[-4.43839352 -4.53100277 -3.71898457]
 [-3.25701855 -3.37312487 -2.32795292]
 [-2.23058735 -2.32599179  0.        ]] 



In [52]:
obs = env.reset()

In [68]:
obs_i = env.observation_space.index(obs)
action = np.argmax([a_sarsa[i][obs_i] for i, a in enumerate(a_sarsa)])

obs, reward, done, info = env.step(action)

In [69]:
env.render()

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 1.]]


In [70]:
env = gym.make('gym_gridworld:gridworld-v4', size=6, player_pos=[0,5], goal_pos=[[0,0], [5,2]]) 

In [71]:
a_sarsa = sarsa()

In [72]:
obs = env.reset()

In [83]:
obs_i = env.observation_space.index(obs)
action = np.argmax([a_sarsa[i][obs_i] for i, a in enumerate(a_sarsa)])

obs, reward, done, info = env.step(action)

In [84]:
env.render()

[[1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]


In [86]:
for i in a_sarsa:
    print(np.array(i).reshape(6,6), '\n')

[[ 0.         -1.24693125 -2.41097443 -3.49867662 -4.77243523 -5.85872809]
 [ 0.         -1.59339271 -2.37346611 -3.6574015  -4.56531387 -5.81624263]
 [-1.37175487 -2.19788288 -2.86902851 -3.82142588 -4.37770954 -5.59192594]
 [-1.53791373 -1.68214213 -2.73938591 -3.19930353 -3.40756721 -4.45467371]
 [-1.11044734 -0.95762391 -1.90131037 -1.97831293 -2.74765522 -3.23400147]
 [-0.57649626 -0.1         0.         -1.27022949 -1.71309716 -2.35122328]] 

[[ 0.         -2.38561587 -3.85641088 -4.75724554 -5.92371145 -5.67670893]
 [-2.52549927 -3.34753948 -4.52525529 -5.53741274 -6.63062934 -6.50718162]
 [-2.01028442 -2.39470529 -3.1957332  -3.80658166 -5.10852887 -5.23903127]
 [-1.59402409 -1.65003505 -2.21308574 -2.65770994 -3.85637069 -4.07784783]
 [-1.03764744 -0.87842335 -1.37269851 -2.31463406 -2.69562521 -3.08544257]
 [-0.61257951  0.          0.         -1.4129204  -2.12213072 -2.30126012]] 

[[ 0.         -2.51306115 -3.30156807 -4.39297987 -5.33715649 -6.41396196]
 [-2.45226203 -3.21

# Q-learning Control

In [89]:
def qLearning(alpha=0.1, epsilon=0.1, gamma=1):
    Q = [[0 for s in env.observation_space] for a in [UP, RIGHT, DOWN, LEFT]]
    
    # Loop for each episode
    for i in range(100000):
        
        # Initialize S
        obs = env.reset()
        
        # Loop for each step of episode
        done = False
        while not done:
            
            # Choose A from S using policy derived from Q (e-greedy)
            obs_i = env.observation_space.index(obs)
            action = np.argmax([Q[i][obs_i] for i, a in enumerate(Q)])
            action = np.random.choice([action, UP, RIGHT, DOWN, LEFT], p=[1-epsilon,epsilon/4,epsilon/4,epsilon/4,epsilon/4])
            
            # Take action A, observe R, S'
            obs1, reward, done, info = env.step(action)            
            
            # Q(S,A) <- Q(S,A) + alpha * [ R + gamma * max_a( Q(S',a) ) - Q(S,A) ]
            obs1_i = env.observation_space.index(obs1)
            max_Q = np.max([Q[i][obs1_i] for i, a in enumerate(Q)])
            Q[action][obs_i] = Q[action][obs_i] + alpha * (reward + gamma * max_Q - Q[action][obs_i])
            
            # S <- S'
            obs = obs1
            
        
    return Q

In [90]:
a_qLearning = qLearning()

In [112]:
obs = env.reset()

ValueError: too many values to unpack (expected 4)

In [116]:
obs, a, b, c = env.step(DOWN)

In [142]:
obs_i = env.observation_space.index(obs)
action = np.argmax([a_sarsa[i][obs_i] for i, a in enumerate(a_sarsa)])

obs, reward, done, info = env.step(action)

In [143]:
env.render()

[[0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]]


In [93]:
for i in a_qLearning:
    print(np.array(i).reshape(6,6), '\n')

[[ 0.         -1.         -2.         -3.         -4.         -5.        ]
 [ 0.         -1.         -2.         -3.         -4.         -5.        ]
 [-0.99363731 -1.94984495 -2.20203959 -3.07563027 -4.04225818 -4.79581156]
 [-1.09477031 -1.42739996 -1.96936601 -2.26658759 -3.32351734 -3.84535524]
 [-0.81996824 -0.80141834 -0.73607645 -1.19720184 -2.4710588  -2.88912006]
 [-0.58870647 -0.1         0.         -0.67547431 -1.22269394 -2.16759702]] 

[[ 0.         -2.         -3.         -4.         -5.         -5.        ]
 [-1.53843351 -2.99884857 -3.99995611 -4.99879087 -5.99930991 -5.9996852 ]
 [-1.08448486 -2.18863461 -2.3886971  -3.37150555 -4.01548107 -4.86088101]
 [-1.14746878 -1.43122574 -1.39147487 -2.30098941 -3.11045    -3.8831233 ]
 [-0.86453188 -0.71757046 -1.12651855 -1.36280543 -2.05496167 -2.99563703]
 [-0.468559    0.          0.         -0.70773872 -1.5918233  -1.98059455]] 

[[ 0.         -2.         -3.         -4.         -5.         -6.        ]
 [-0.82763089 -2.91