# Sarsa. On-policy TD control

In [1]:
import numpy as np
import gym

In [2]:
env = gym.make('FrozenLake8x8-v0')

  result = entry_point.load(False)


In [3]:
def sarsa(n_episodes, env, gamma, alpha, epsilon):
    
    n_states = env.observation_space.n
    n_actions = env.action_space.n
    
    #initialize action-value function
    Q = np.zeros((n_states,n_actions)) + 0.5
    for episode in range(n_episodes):
        observation = env.reset()
        action = None
        done = False
        #generate episode
        while not done:
            observation_old = observation
            action_old = action
            if np.random.random() < epsilon:
                action = env.action_space.sample()
            else:
                action = Q[observation,:].argmax()
            observation,reward,done,info = env.step(action)
            Q[observation_old,action_old] += alpha *(reward + gamma * Q[observation,action] - Q[observation_old,action_old])

    return Q

In [4]:
gamma = 0.9
alpha = 0.2
epsilon = 0.1
Q = sarsa(20000,env,gamma,alpha,epsilon)

### TEST

In [5]:
def test(n_episodes,Q):
    wins = 0
    for i_episode in range(n_episodes):
        observation = env.reset()
        for t in range(100):
            action = Q[observation,:].argmax()
            observation, reward, done, info = env.step(action)
            if done:
                if reward == 1:
                    wins += 1
                break
    return wins

In [7]:
test(1000,Q)

0

In [8]:
alphas = np.linspace(0.1, 0.9, num=9)
for a in alphas:
    print("alpha:",a)
    Q = sarsa(10000,env,gamma=0.9,alpha=a,epsilon=0.1)
    print("wins",test(100,Q))
    print()

alpha: 0.1
wins 2

alpha: 0.2
wins 0

alpha: 0.30000000000000004
wins 2

alpha: 0.4
wins 1

alpha: 0.5
wins 0

alpha: 0.6
wins 0

alpha: 0.7000000000000001
wins 4

alpha: 0.8
wins 0

alpha: 0.9
wins 1



In [9]:
epsilons = np.linspace(0.1, 0.9, num=9)
for e in epsilons:
    print("epsilon:",e)
    Q = sarsa(10000,env,gamma=0.9,alpha=0.1,epsilon=e)
    print("wins",test(100,Q))
    print()

epsilon: 0.1
wins 0

epsilon: 0.2
wins 0

epsilon: 0.30000000000000004
wins 0

epsilon: 0.4
wins 0

epsilon: 0.5
wins 0

epsilon: 0.6
wins 0

epsilon: 0.7000000000000001
wins 0

epsilon: 0.8
wins 0

epsilon: 0.9
wins 0

