In [1]:
import gymnasium as gym
import numpy as np
from tqdm.notebook import tqdm

In [2]:
env = gym.make('Taxi-v3', render_mode='rgb_array')

In [3]:
q_table = np.zeros((env.observation_space.n, env.action_space.n))

In [4]:
def epsilon_greedy(state, info, epsilon):
    if np.random.uniform() < epsilon:
        action = env.action_space.sample(info["action_mask"])
    else:
        action = np.argmax(q_table[state][np.where(info["action_mask"] == 1)[0]])
    return action

In [5]:
episodes = 200000
max_step = 199
learning_rate = 0.99
gamma = 0.95

In [6]:
for episode in tqdm(range(episodes)):
    epsilon = max(1-episode/episodes, 0.05)
    state, info = env.reset()
    for step in range(max_step):
        action = epsilon_greedy(state, info, epsilon)
        next_state, reward, terminate, trunc, info = env.step(action)
        q_table[state][action] = (1 - learning_rate) * q_table[state][action] + learning_rate * (reward + gamma * np.max(q_table[next_state]))

        state = next_state

        if terminate or trunc:
            break

  0%|          | 0/200000 [00:00<?, ?it/s]

In [7]:
episode_rewards = []
for episode in tqdm(range(100)):
    state, info = env.reset()
    total_rewards_ep = 0

    while True:
        action = np.argmax(q_table[state][:])
        next_state, reward, terminate, trunc, info = env.step(action)
        total_rewards_ep += reward

        if terminate or trunc:
            break
        state = next_state
    episode_rewards.append(total_rewards_ep)
mean_reward = np.mean(episode_rewards)
std_reward = np.std(episode_rewards)

print(f'mean_reward: {mean_reward}, std_reward: {std_reward}')

  0%|          | 0/100 [00:00<?, ?it/s]

mean_reward: -99.04, std_reward: 345.6298575065528


In [10]:
import imageio as iio

images = []
state, info = env.reset(seed=np.random.randint(0, 500))
img = env.render()
images.append(img)
while True:
    action = np.argmax(q_table[state][np.where(info["action_mask"] == 1)[0]])
    state, reward, terminate, trunc, info = env.step(action)  # We directly put next_state = state for recording logic
    img = env.render()
    images.append(img)
    if terminate or trunc:
        break
iio.mimsave('./taxi_QL.gif', [img for i, img in enumerate(images)], duration=0.1)