# Sarsa. On-policy TD control

In [None]:
import numpy as np
import gym

In [None]:
env = gym.make('FrozenLake8x8-v0')

In [None]:
def sarsa(n_episodes, env, gamma, alpha, epsilon):
    
    n_states = env.observation_space.n
    n_actions = env.action_space.n
    
    #initialize action-value function
    Q = np.zeros((n_states,n_actions)) + 0.5
    for episode in range(n_episodes):
        observation = env.reset()
        action = None
        done = False
        #generate episode
        while not done:
            observation_old = observation
            action_old = action
            if np.random.random() < epsilon:
                action = env.action_space.sample()
            else:
                action = Q[observation,:].argmax()
            observation,reward,done,info = env.step(action)
            Q[observation_old,action_old] += alpha *(reward + gamma * Q[observation,action] - Q[observation_old,action_old])

    return Q

In [None]:
gamma = 0.9
alpha = 0.2
epsilon = 0.1
Q = sarsa(20000,env,gamma,alpha,epsilon)

### TEST

In [None]:
def test(n_episodes,Q):
    wins = 0
    for i_episode in range(n_episodes):
        observation = env.reset()
        for t in range(100):
            action = Q[observation,:].argmax()
            observation, reward, done, info = env.step(action)
            if done:
                if reward == 1:
                    wins += 1
                break
    return wins

In [None]:
test(100,Q)

In [None]:
alphas = np.linspace(0.1, 0.9, num=9)
for a in alphas:
    print("alpha:",a)
    Q = sarsa(10000,env,gamma=0.9,alpha=a,epsilon=0.1)
    print("wins",test(100,Q))
    print()

In [None]:
epsilons = np.linspace(0.1, 0.9, num=9)
for e in epsilons:
    print("epsilon:",e)
    Q = sarsa(10000,env,gamma=0.9,alpha=0.1,epsilon=e)
    print("wins",test(100,Q))
    print()