In [3]:
import gym
import random

random.seed(1234)

streets = gym.make("Taxi-v3").env
streets.render()

+---------+
|R: | : :[35mG[0m|
| : | : : |
| : : : : |
| | : | :[43m [0m|
|Y| : |[34;1mB[0m: |
+---------+



In [5]:
initial_state = streets.encode(2,3,2,0) #taxi at 2,3 pickup at 2 and dropoff at 0

streets.s = initial_state

streets.render()

+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : :[43m [0m: |
| | : | : |
|[34;1mY[0m| : |B: |
+---------+



In [6]:
streets.P[initial_state] #inital reward table

{0: [(1.0, 368, -1, False)],
 1: [(1.0, 168, -1, False)],
 2: [(1.0, 288, -1, False)],
 3: [(1.0, 248, -1, False)],
 4: [(1.0, 268, -10, False)],
 5: [(1.0, 268, -10, False)]}

In [18]:
streets.observation_space.n

500

In [28]:
import numpy as np

q_table = np.zeros([streets.observation_space.n, streets.action_space.n])
learning_rate = 0.1
discount_factor = 0.6
exploration = 0.1
epochs = 10000

for taxi_run in range(epochs):
    state = streets.reset()
    done = False
    
    while not done:
        random_value = random.uniform(0,1)
        if(random_value < exploration):
            action = streets.action_space.sample()
        else:
            action = np.argmax(q_table[state])
        next_state,reward,done,info = streets.step(action)
        
        prev_q = q_table[state,action]
        next_max_q = np.max(q_table[next_state])
        new_q = (1 - learning_rate) * prev_q + learning_rate * (reward + discount_factor * next_max_q)
        q_table[state, action] = new_q
        
        state = next_state

In [29]:
q_table[initial_state]

array([-2.40067802, -2.40032227, -2.38554372, -2.3639511 , -6.35246587,
       -8.65955459])

In [30]:
from IPython.display import clear_output
from time import sleep

for tripnum in range(1, 11):
    state = streets.reset()
   
    done = False
    trip_length = 0
    
    while not done and trip_length < 25:
        action = np.argmax(q_table[state])
        next_state, reward, done, info = streets.step(action)
        clear_output(wait=True)
        print("Trip number " + str(tripnum) + " Step " + str(trip_length))
        print(streets.render(mode='ansi'))
        sleep(.5)
        state = next_state
        trip_length += 1
        
    sleep(2)

Trip number 10 Step 24
+---------+
|R: | : :G|
| : | : : |
| : : : : |
| | : | : |
|[35m[43mY[0m[0m| : |[34;1mB[0m: |
+---------+
  (East)

