In [15]:
import gym
import numpy as np
from IPython.display import clear_output

In [82]:
env=gym.make('Taxi-v3')
num_episodes = 10

for episode in range (num_episodes):
    state_0 = env.reset()
    done = False # When true => the agent lost (the end of an episode)
    score = 0
    num_states_in_episode = 0 
    while not done:
        env.render() # To show how the agent interact with the env
        action = env.action_space.sample() # take a random action
        #returns the observations ensued from the agent-env interaction
        state, reward, done, info = env.step(action)
        score += reward
        clear_output(wait=True)
    print(f"Episode: {episode}, Score: {score}")
env.close()

In [83]:
# Creating a Q-table where rows are states and columns are actions
actions = env.action_space.n
# Note that for other environments where states are infint (e.g. state is an array of pixels) the 
# number of states is not avialable rather you can get the shape of the state
states = env.observation_space.n
print(actions, states)

6 500


In [84]:
q_tabel = np.zeros((states, actions))
rewards_episodes = []

In [102]:
# Q-learning parameteres
n_episodes = 100
max_steps_per_episode = np.inf # Till the agent loss
gamma = 0.99
lr_rate = 0.1

# Exploration param (the probability that the agent will take a random
# action instead of greedy action)
exploration_rate = 1 
max_exploration_rate = 1
min_exploration_rate = 0.01
exploration_decay_rate = 0.001



In [103]:
# Q-learning algorithm 
for episode in range(n_episodes):
    state = env.reset() # Initial state
    done = False
    current_reward = 0
    
    while not done:
        exploration_threshold = np.random.uniform(0,1)
        if exploration_threshold > exploration_rate:
            # Take the greedy action
            action = np.argmax(q_tabel[state,:])
        else:
            # Explore
            action = env.action_space.sample()
            
        # Take action
        next_state, reward, done, info = env.step(action)
        # Update
        q_value = float(q_tabel[state, action])
        max_q_prime = np.max(q_tabel[next_state,:])
        TD_error = float(reward + gamma * max_q_prime - q_value)
        q_tabel[state, action] = float(q_value + (lr_rate * TD_error))
        
        state = next_state
        current_reward += reward
    
    # Decay the exploration
    exploration_rate = min_exploration_rate + \
                      (max_exploration_rate - min_exploration_rate)* \
                      np.exp(-exploration_decay_rate*episode)
    rewards_episodes.append(current_reward)

In [104]:
q_tabel

array([[  0.        ,   0.        ,   0.        ,   0.        ,
          0.        ,   0.        ],
       [ -1.68465823,  -1.7806518 ,  -2.03504853,  -2.11303449,
         -1.79806776, -10.93322335],
       [ -2.4772883 ,  -2.30365835,  -2.07019141,  -2.06537803,
         -2.00651134, -11.26356866],
       ...,
       [ -1.20799099,  -1.49439538,  -1.38687734,  -1.67763684,
        -10.2256309 , -10.48214761],
       [ -1.20798182,  -1.5724532 ,  -1.27836364,  -1.19818182,
        -10.42527455, -10.19818182],
       [ -0.98999103,  -1.18058748,  -0.90002397,  -0.78200002,
         -9.98910049,  -9.90004793]])

In [105]:
rewards_episodes

[-632,
 -776,
 -695,
 -821,
 -893,
 -740,
 -749,
 -767,
 -839,
 -749,
 -830,
 -848,
 -785,
 -776,
 -740,
 -830,
 -740,
 -740,
 -803,
 -758,
 -713,
 -722,
 -740,
 -794,
 -812,
 -776,
 -776,
 -776,
 -767,
 -821,
 -812,
 -839,
 -731,
 -749,
 -812,
 -695,
 -749,
 -812,
 -668,
 -704,
 -731,
 -686,
 -776,
 -812,
 -866,
 -713,
 -758,
 -785,
 -484,
 -812,
 -839,
 -821,
 -812,
 -659,
 -812,
 -830,
 -596,
 -740,
 -722,
 -713,
 -749,
 -785,
 -794,
 -776,
 -821,
 -776,
 -956,
 -794,
 -758,
 -749,
 -875,
 -803,
 -749,
 -884,
 -749,
 -722,
 -785,
 -776,
 -686,
 -821,
 -574,
 -776,
 -722,
 -749,
 -776,
 -740,
 -740,
 -776,
 -875,
 -722,
 -641,
 -677,
 -686,
 -731,
 -731,
 -740,
 -713,
 -785,
 -749,
 -695,
 -731,
 -893,
 -722,
 -812,
 -722,
 -776,
 -857,
 -722,
 -767,
 -830,
 -803,
 -695,
 -722,
 -686,
 -938,
 -713,
 -884,
 -785,
 -785,
 -785,
 -749,
 -785,
 -812,
 -722,
 -749,
 -884,
 -406,
 -767,
 -260,
 -767,
 -821,
 -686,
 -641,
 -695,
 -776,
 -686,
 -893,
 -911,
 -785,
 -740,
 -329,
 -830,
 -776,

In [76]:
np.argmax(q_tabel[0])

3