In [1]:
import gym
import numpy as np

In [2]:
# Initialize Gym environment
env = gym.make('Taxi-v3', render_mode="ansi")

# Initialize Q-table with zeros
Q = np.zeros([env.observation_space.n, env.action_space.n])

In [3]:
def eps_greedy(Q, s, eps=0.1):
    '''
    Epsilon greedy policy
    '''
    if np.random.uniform(0,1) < eps:
        # Choose a random action
        return np.random.randint(Q.shape[1])
    else:
        # Choose the action of a greedy policy
        return greedy(Q, s)


def greedy(Q, s):
    '''
    Greedy policy

    return the index corresponding to the maximum action-state value
    '''
    return np.argmax(Q[s])

In [4]:
def SARSA(env, lr=0.01, num_episodes=10000, eps=0.3, gamma=0.95, eps_decay=0.00005):
    nA = env.action_space.n
    nS = env.observation_space.n

    # Initialize the Q matrix
    # Q: matrix nS*nA where each row represent a state and each colums represent a different action
    Q = np.zeros((nS, nA))
    games_reward = []
    test_rewards = []

    for ep in range(num_episodes):
        state = env.reset()[0]
        terminated = False
        tot_rew = 0

        # decay the epsilon value until it reaches the threshold of 0.01
        if eps > 0.01:
            eps -= eps_decay


        action = eps_greedy(Q, state, eps) 

        # loop the main body until the environment stops
        while not terminated:
            next_state, rew, terminated, truncated, info = env.step(action) # Take one step in the environment

            # choose the next action (needed for the SARSA update)
            next_action = eps_greedy(Q, next_state, eps) 
            # SARSA update
            Q[state][action] = Q[state][action] + lr*(rew + gamma*Q[next_state][next_action] - Q[state][action])

            state = next_state
            action = next_action
            tot_rew += rew
            if terminated:
                games_reward.append(tot_rew)

        # Test the policy every 300 episodes and print the results
        # if (ep % 300) == 0 and ep != 0:
        #     test_rew = run_episodes(env, Q, 1000)
        #     print("Episode:{:5d}  Eps:{:2.4f}  Rew:{:2.4f}".format(ep, eps, test_rew))
        #     test_rewards.append(test_rew)

    return Q

def run_episodes(env, Q, num_episodes=100, to_print=False):
    '''
    Run some episodes to test the policy
    '''
    tot_rew = []
    state = env.reset()[0]

    for _ in range(num_episodes):
        terminated = False
        game_rew = 0

        while not terminated:
            # select a greedy action
            next_state, rew, terminated, truncated, info = env.step(greedy(Q, state))

            state = next_state
            game_rew += rew 
            if terminated:
                state = env.reset()[0]
                tot_rew.append(game_rew)

    if to_print:
        print('Mean score: %.3f of %i games!'%(np.mean(tot_rew), num_episodes))

    return np.mean(tot_rew)

In [5]:
# training variables
NUM_EPISODES = 10000
MAX_STEPS = 99 # per episode

def try_taxi_driver(env, qtable):
    # watch trained agent
    state = env.reset()[0]
    terminated = False
    rewards = 0

    for s in range(MAX_STEPS):

        print(f"TRAINED AGENT")
        print("Step {}".format(s+1))

        action = np.argmax(qtable[state,:])
        new_state, reward, terminated, truncated, info = env.step(action)
        rewards += reward
        env.render()
        print(f"score: {rewards}")
        state = new_state
        
        if terminated == True:
            break

In [6]:
Q_sarsa = SARSA(env, lr=.9, num_episodes=100000, eps=1, gamma=0.95, eps_decay=0.005)

print(Q_sarsa)

try_taxi_driver(env, Q_sarsa)

[[  0.           0.           0.           0.           0.
    0.        ]
 [-29.67522368 -27.13496408 -28.29427766 -27.55104588   5.20997482
  -27.65547914]
 [-10.71539161  -9.87022054  -9.76878159 -11.53625229  10.9512375
  -16.52583135]
 ...
 [-34.9627083  -22.7248547  -19.99995891 -11.2042169  -24.2953971
  -28.49266105]
 [  0.83553944   6.53580586 -31.05400356 -24.6680864  -27.71080997
  -30.56548682]
 [-41.39647517   9.27673089 -17.42654333  18.           3.34129388
    2.69415071]]
TRAINED AGENT
Step 1
score: -1
TRAINED AGENT
Step 2
score: -2
TRAINED AGENT
Step 3
score: -3
TRAINED AGENT
Step 4
score: -4
TRAINED AGENT
Step 5
score: -5
TRAINED AGENT
Step 6
score: -6
TRAINED AGENT
Step 7
score: -7
TRAINED AGENT
Step 8
score: -8
TRAINED AGENT
Step 9
score: -9
TRAINED AGENT
Step 10
score: -10
TRAINED AGENT
Step 11
score: -11
TRAINED AGENT
Step 12
score: -12
TRAINED AGENT
Step 13
score: -13
TRAINED AGENT
Step 14
score: -14
TRAINED AGENT
Step 15
score: -15
TRAINED AGENT
Step 16
score: 

In [21]:

env = gym.make('Taxi-v3', render_mode="human")

try_taxi_driver(env, Q_sarsa)

env.close()


TRAINED AGENT
Step 1
score: -1
TRAINED AGENT
Step 2
score: -2
TRAINED AGENT
Step 3
score: -3
TRAINED AGENT
Step 4
score: -4
TRAINED AGENT
Step 5
score: -5
TRAINED AGENT
Step 6
score: -6
TRAINED AGENT
Step 7
score: -7
TRAINED AGENT
Step 8
score: -8
TRAINED AGENT
Step 9
score: -9
TRAINED AGENT
Step 10
score: -10
TRAINED AGENT
Step 11
score: -11
TRAINED AGENT
Step 12
score: -12
TRAINED AGENT
Step 13
score: -13
TRAINED AGENT
Step 14
score: 7
