In [None]:
import numpy as np
import gym


def eps_greedy(Q, s, eps=0.1):
    '''
    Epsilon greedy policy
    '''
    if np.random.uniform(0,1) < eps:
        # Choose a random action
        return np.random.randint(Q.shape[1])
    else:
        # Choose the action of a greedy policy
        return greedy(Q, s)


def greedy(Q, s):
    '''
    Greedy policy
    return the index corresponding to the maximum action-state value
    '''
    return np.argmax(Q[s])


def run_episodes(env, Q, num_episodes=100, to_print=False):
    '''
    Run some episodes to test the policy
    '''
    tot_rew = []
    state = env.reset()

    for _ in range(num_episodes):
        done = False
        game_rew = 0

        while not done:
            # select a greedy action
            next_state, rew, done, _ = env.step(greedy(Q, state))

            state = next_state
            game_rew += rew
            if done:
                state = env.reset()
                tot_rew.append(game_rew)

    if to_print:
        print('Mean score: %.3f of %i games!'%(np.mean(tot_rew), num_episodes))

    return np.mean(tot_rew)

def Q_learning(env, lr=0.01, num_episodes=10000, eps=0.3, gamma=0.95, eps_decay=0.00005):
    nA = env.action_space.n
    nS = env.observation_space.n

    # Initialize the Q matrix
    # Q: matrix nS*nA where each row represent a state and each colums represent a different action
    Q = np.zeros((nS, nA))
    games_reward = []
    test_rewards = []

    for ep in range(num_episodes):
        state = env.reset()
        done = False
        tot_rew = 0

        # decay the epsilon value until it reaches the threshold of 0.01
        if eps > 0.01:
            eps -= eps_decay

        # loop the main body until the environment stops
        while not done:
            # select an action following the eps-greedy policy
            action = eps_greedy(Q, state, eps)

            next_state, rew, done, _ = env.step(action) # Take one step in the environment

            # Q-learning update the state-action value (get the max Q value for the next state)
            Q[state][action] = Q[state][action] + lr*(rew + gamma*np.max(Q[next_state]) - Q[state][action])

            state = next_state
            tot_rew += rew
            if done:
                games_reward.append(tot_rew)

        # Test the policy every 300 episodes and print the results
        if (ep % 300) == 0:
            test_rew = run_episodes(env, Q, 1000)
            print("Episode:{:5d}  Eps:{:2.4f}  Rew:{:2.4f}".format(ep, eps, test_rew))
            test_rewards.append(test_rew)

    return Q


def SARSA(env, lr=0.01, num_episodes=10000, eps=0.3, gamma=0.95, eps_decay=0.00005):
    nA = env.action_space.n
    nS = env.observation_space.n

    # Initialize the Q matrix
    # Q: matrix nS*nA where each row represent a state and each colums represent a different action
    Q = np.zeros((nS, nA))
    games_reward = []
    test_rewards = []

    for ep in range(num_episodes):
        state = env.reset()
        done = False
        tot_rew = 0

        # decay the epsilon value until it reaches the threshold of 0.01
        if eps > 0.01:
            eps -= eps_decay


        action = eps_greedy(Q, state, eps)

        # loop the main body until the environment stops
        while not done:
            next_state, rew, done, _ = env.step(action) # Take one step in the environment

            # choose the next action (needed for the SARSA update)
            next_action = eps_greedy(Q, next_state, eps)
            # SARSA update
            Q[state][action] = Q[state][action] + lr*(rew + gamma*Q[next_state][next_action] - Q[state][action])

            state = next_state
            action = next_action
            tot_rew += rew
            if done:
                games_reward.append(tot_rew)

        # Test the policy every 300 episodes and print the results
        if (ep % 300) == 0:
            test_rew = run_episodes(env, Q, 1000)
            print("Episode:{:5d}  Eps:{:2.4f}  Rew:{:2.4f}".format(ep, eps, test_rew))
            test_rewards.append(test_rew)

    return Q


if __name__ == '__main__':
    env = gym.make('Taxi-v3')

    Q_qlearning = Q_learning(env, lr=.1, num_episodes=5000, eps=0.4, gamma=0.95, eps_decay=0.001)

    Q_sarsa = SARSA(env, lr=.1, num_episodes=5000, eps=0.4, gamma=0.95, eps_decay=0.001)


  deprecation(
  deprecation(


Episode:    0  Eps:0.3990  Rew:-243.1280
Episode:  300  Eps:0.0990  Rew:-236.8790
Episode:  600  Eps:0.0100  Rew:-249.4620
Episode:  900  Eps:0.0100  Rew:-119.0630
Episode: 1200  Eps:0.0100  Rew:-95.6520
Episode: 1500  Eps:0.0100  Rew:-52.5110
Episode: 1800  Eps:0.0100  Rew:-40.9430
Episode: 2100  Eps:0.0100  Rew:-15.0160
Episode: 2400  Eps:0.0100  Rew:-6.4810
Episode: 2700  Eps:0.0100  Rew:-2.4440
Episode: 3000  Eps:0.0100  Rew:-0.6410
Episode: 3300  Eps:0.0100  Rew:3.6450
Episode: 3600  Eps:0.0100  Rew:7.3960
Episode: 3900  Eps:0.0100  Rew:7.3760
Episode: 4200  Eps:0.0100  Rew:7.0940
Episode: 4500  Eps:0.0100  Rew:7.8330
Episode: 4800  Eps:0.0100  Rew:7.9810
Episode:    0  Eps:0.3990  Rew:-250.3550
Episode:  300  Eps:0.0990  Rew:-205.2650
Episode:  600  Eps:0.0100  Rew:-194.0410
Episode:  900  Eps:0.0100  Rew:-207.7630
Episode: 1200  Eps:0.0100  Rew:-123.3540
Episode: 1500  Eps:0.0100  Rew:-56.5330
Episode: 1800  Eps:0.0100  Rew:-36.6760
Episode: 2100  Eps:0.0100  Rew:-16.6490
Episod