In [25]:
import gym
import random
import numpy as np
import time
env = gym.make("Taxi-v2")

# Q-learning impementation

action_size = env.action_space.n
print("Action size: ", action_size)

state_size = env.observation_space.n
print("State size: ", state_size)

qtable = np.zeros((state_size, action_size)) # initialize qtable with zeros
print(qtable)

Action size:  6
State size:  500
[[0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 ...
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]


In [26]:
# Setting up the parameters

total_episodes = 50000 # number of training episodes
total_test_episodes = 50 # number of testing episodes 
max_steps = 99 # max number of steps per episode

learning_rate = 0.7 # how fast an agent learns
gamma = 0.8 # discounting of reward value over time (fast rewards are better than long-term)

epsilon = 1.0 # exploration x exloitation 
max_epsilon = 1.0 # we only explore in the beginning
min_epsilon = 0.01 # min exlporation
decay_rate = 0.01 # how fast it decays

In [27]:
# Training 

for episode in range(total_episodes):
    state = env.reset()
    step = 0 
    done = False
    
    for step in range(max_steps):
        exp_exp_tradeoff = random.uniform(0, 1) # random number generator (form 0 to 1)
        
        if exp_exp_tradeoff > epsilon : 
            action = np.argmax(qtable[state, :]) # exploit
        else:
            action = env.action_space.sample() # explore (random action)
        
        new_state, reward, done, info = env.step(action) # write new state, reward etc.
        
        # update the qtable
        qtable[state, action] += learning_rate * (reward + gamma * np.max(qtable[new_state, :]) - qtable[state, action]) 
        
        state = new_state # change to current stage
        
        if done == True:
            break
            
    episode += 1 # episode counter
    
    epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-decay_rate*episode) # decrease epsilon

In [28]:
print(qtable)

[[  0.           0.           0.           0.           0.
    0.        ]
 [ -1.64490869  -0.81180781  -1.65784128  -0.80748545   0.24288
   -9.80593883]
 [  0.15701449   1.51867199   0.2428614    1.51786764   3.192
   -7.47760657]
 ...
 [ -2.59131785  -2.39910391  -2.80003111   1.55359937 -11.16210995
  -11.30704019]
 [ -3.71336291  -2.31564846  -3.77619868  -3.59331089 -12.21750464
  -11.01451344]
 [  7.18804633  -0.973       -1.6002      15.          -9.6096
   -0.70140757]]


In [29]:
# Testing
rewards = []
env.reset()

for episode in range(total_test_episodes):
    state = env.reset()
    step = 0
    done = False
    total_rewards = 0
    print("##################")
    print("Episode ", episode)
    
    for step in range(max_steps):
        if episode % 10 == 0: # render every 10th episode
            env.render()
        action = np.argmax(qtable[state, :]) # select an action according to q-table
        new_state, reward, done, info = env.step(action)
        total_rewards += reward
        if done:
            rewards.append(total_rewards)
            print("Rewards:", total_rewards)
            break
        state = new_state

env.close()

print("##########################")
print("Avarage reward over time: " + str(sum(rewards)/total_test_episodes))

##################
Episode  0
+---------+
|R: | : :[34;1m[43mG[0m[0m|
| : : : : |
| : : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+

+---------+
|R: | : :[42mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (Pickup)
+---------+
|R: | : :G|
| : : : :[42m_[0m|
| : : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (South)
+---------+
|R: | : :G|
| : : :[42m_[0m: |
| : : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (West)
+---------+
|R: | : :G|
| : : : : |
| : : :[42m_[0m: |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (South)
+---------+
|R: | : :G|
| : : : : |
| : :[42m_[0m: : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (West)
+---------+
|R: | : :G|
| : : : : |
| :[42m_[0m: : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (West)
+---------+
|R: | : :G|
| : : : : |
|[42m_[0m: : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (West)
+---------+
|R: | : :G|
| : : : : |
| : : : : |
|[42m_[0m| : | : |
|[35m