In [None]:
import os
import tensorflow as tf
import gym
from keras.models import Sequential
from keras.layers import Dense, Embedding, Reshape
from keras.optimizers import Adam
from rl.agents.dqn import DQNAgent
from rl.policy import EpsGreedyQPolicy
from rl.memory import SequentialMemory

In [None]:
# Define the environment
env = gym.make('Taxi-v3')

In [None]:
action_size = env.action_space.n

model = Sequential()
model.add(Embedding(500, 10, input_length=1))
model.add(Reshape((10,)))
model.add(Dense(50, activation='relu'))
model.add(Dense(50, activation='relu'))
model.add(Dense(50, activation='relu'))
model.add(Dense(action_size, activation='linear'))
print(model.summary())

In [None]:
memory = SequentialMemory(limit=50000, window_length=1)
policy = EpsGreedyQPolicy()
dqn = DQNAgent(model=model,
               nb_actions=action_size,
               memory=memory,
               nb_steps_warmup=500,
               target_model_update=8000, # 1e-2
               policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])
dqn.fit(env, nb_steps=1000000,
        visualize=False,
        verbose=1,
        nb_max_episode_steps=99,
        log_interval=100000)

In [None]:
import time
import numpy as np
from IPython.display import clear_output

In [None]:
def _get_action_for_state(state):
    predicted = model.predict_on_batch(tf.expand_dims(state, axis=0))
    action = np.argmax(predicted[0])
    return action

In [None]:
sleep = 0.2
max_steps = 20

try:
    actions_str = ["South", "North", "East", "West", "Pickup", "Dropoff"]

    iteration = 0
    state = env.reset()  # reset environment to a new, random state
    env.render()
    print(f"Iter: {iteration} - Action: *** - Reward ***")
    time.sleep(sleep)
    done = False

    while not done:
        action = _get_action_for_state(state)
        iteration += 1
        state, reward, done, info = env.step(action)
        clear_output(wait=True)
        env.render()
        print(f"Iter: {iteration} - Action: {action}({actions_str[action]}) - Reward {reward}")
        time.sleep(sleep)
        if iteration == max_steps:
            print("cannot converge :(")
            break
except KeyboardInterrupt:
    pass