In [1]:
import gym
import numpy as np
import time
import random
from IPython import display
from prettytable import PrettyTable

In [2]:
def play(env, q_table, render=False):
    state = env.reset()
    total_reward = 0
    steps = 0
    done = False
    while not done:
        action = np.argmax(q_table[state, :])
        next_state, reward, done, info = env.step(action)
        total_reward += reward
        steps += 1
        if render:
            env.render()
            time.sleep(0.2)
            if not done:
                display.clear_output(wait=True)
        state = next_state

    return (total_reward, steps)

In [3]:
def play_multiple_times(env, policy, evaluatate_episode):
    num_of_success = 0
    succeed_steps = []
    succeed_rewards = [] 
    for i in range(evaluatate_episode):
        total_reward, steps = play(env, policy)
        if total_reward > 0:
            succeed_rewards.append(total_reward)
            num_of_success += 1
            succeed_steps.append(steps)

    success = f'{num_of_success}/{evaluatate_episode}'
    mean_succeed_steps = None
    mean_succeed_reward = None
    if len(succeed_steps) > 0:
      mean_succeed_reward = np.mean(succeed_rewards)
      mean_succeed_steps = np.mean(succeed_steps)

    return success, mean_succeed_reward, mean_succeed_steps

In [4]:
def q_learning(env, train_episode, num_steps_per_episode, learning_rate, gamma, max_epsilon, min_epsilon, epsilon_decay_rate):
    def policy(state):
      nonlocal epsilon, env, q_table
      exploration = random.uniform(0,1)
      if exploration < epsilon:
          action = env.action_space.sample()
      else:
          action = np.argmax(q_table[state, :])
      return action

    # q_table = np.zeros((env.observation_space.n, env.action_space.n))
    q_table = np.random.rand(env.observation_space.n, env.action_space.n)
    rewards_all = []
    for episode in range(train_episode):
        state = env.reset()

        reward_episode = 0.0
        done = False
        epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-epsilon_decay_rate*episode)
        for step in range(num_steps_per_episode):
            action = policy(state)
            next_state, reward, done, info = env.step(action)
            q_table[state, action] = q_table[state, action] * (1 - learning_rate) + learning_rate * (reward + gamma * np.max(q_table[next_state,:]))

            reward_episode += reward
            state = next_state

            if done:
                break
        rewards_all.append(reward_episode)
    return q_table, rewards_all

In [5]:
def sarsa(env, train_episode, num_steps_per_episode, learning_rate, gamma, max_epsilon, min_epsilon, epsilon_decay_rate):
    def policy(state):
      nonlocal epsilon, env, q_table
      exploration = random.uniform(0,1)
      if exploration < epsilon:
          action = env.action_space.sample()
      else:
          action = np.argmax(q_table[state, :])
      return action

    # q_table = np.zeros((env.observation_space.n, env.action_space.n))
    q_table = np.random.rand(env.observation_space.n, env.action_space.n)
    rewards_all = []
    for episode in range(train_episode):

        epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-epsilon_decay_rate*episode)
        state0 = env.reset()
        action0 = policy(state0)
        state1, reward1, done, info = env.step(action0)
        reward_episode = reward1

        for step in range(1, num_steps_per_episode):
            if done:
                break
            action1 = policy(state1)
            state2, reward2, done, info = env.step(action1)
            reward_episode += reward2

            q_table[state0, action0] = q_table[state0, action0] * (1 - learning_rate) + learning_rate * (reward1 + gamma * (q_table[state1, action1]))

            state0 = state1
            state1 = state2
            reward1 = reward2
            action0 = action1

        rewards_all.append(reward_episode)
    return q_table, rewards_all

In [6]:
def compare(toy_game, evaluatate_episode, train_episode, num_steps_per_episode, learning_rate, gamma, max_epsilon, min_epsilon, epsilon_decay_rate):
  def run_algo(algo):
    nonlocal env, table, row
    
    start = time.time()
    q_table, rewards_all = algo(env, train_episode, num_steps_per_episode, learning_rate, gamma, max_epsilon, min_epsilon, epsilon_decay_rate)
    end = time.time()
    learning_seconds = end - start
    mean_reward_on_learning = np.mean(rewards_all)

    row += [learning_seconds, mean_reward_on_learning] + list(play_multiple_times(env, q_table, evaluatate_episode))
    for i in range(len(row)):
      if type(row[i]) is float or type(row[i]) is np.float64:
        row[i] = round(row[i], 4)
    table.add_row(row)

  env = gym.make(toy_game)
  table = PrettyTable(['algo', 'learning_seconds', 'mean_reward_on_learning', 'success', 'mean_succeed_reward', 'mean_succeed_steps'])

  row = ['Q-Learning']
  run_algo(q_learning)

  row = ['SARSA']
  run_algo(sarsa)

  print(f'{toy_game}: {env.observation_space.n}x{env.action_space.n}')
  print(table)

In [8]:
# Hyperparameters
gamma = 0.99
learning_rate = 0.1
max_epsilon = 1.0
min_epsilon = 0.01
epsilon_decay_rate = 0.005

num_steps_per_episode = 100
train_episode = 20000
evaluatate_episode = 1000
# train_episode = 1000
# evaluatate_episode = 100

toy_games = ["FrozenLake-v0", "FrozenLake8x8-v0", "Taxi-v3"]

for toy_game in toy_games:
  compare(toy_game, evaluatate_episode, train_episode, num_steps_per_episode, learning_rate, gamma, max_epsilon, min_epsilon, epsilon_decay_rate)
  print()

FrozenLake-v0: 16x4
+------------+------------------+-------------------------+----------+---------------------+--------------------+
|    algo    | learning_seconds | mean_reward_on_learning | success  | mean_succeed_reward | mean_succeed_steps |
+------------+------------------+-------------------------+----------+---------------------+--------------------+
| Q-Learning |     16.5599      |          0.5096         | 607/1000 |         1.0         |      33.8764       |
|   SARSA    |      2.1185      |          0.0231         | 30/1000  |         1.0         |      11.6667       |
+------------+------------------+-------------------------+----------+---------------------+--------------------+

FrozenLake8x8-v0: 64x4
+------------+------------------+-------------------------+---------+---------------------+--------------------+
|    algo    | learning_seconds | mean_reward_on_learning | success | mean_succeed_reward | mean_succeed_steps |
+------------+------------------+-------------

In [None]:
# Nhận xét
  # SARSA học nhanh hơn Q-Learning
  # Phần thưởng khi học của SARSA ít hơn Q-Learning
  # Ở các ván thành công, SARSA cho số bước ít hơn và phần thưởng nhiều hơn Q-Learning
  # Khả năng hoàn thành ván chơi thành công cửa SARSA ít hơn Q-Learning

In [None]:
# FrozenLake-v0: 16x4
# +------------+------------------+-------------------------+----------+---------------------+--------------------+
# |    algo    | learning_seconds | mean_reward_on_learning | success  | mean_succeed_reward | mean_succeed_steps |
# +------------+------------------+-------------------------+----------+---------------------+--------------------+
# | Q-Learning |     16.5599      |          0.5096         | 607/1000 |         1.0         |      33.8764       |
# |   SARSA    |      2.1185      |          0.0231         | 30/1000  |         1.0         |      11.6667       |
# +------------+------------------+-------------------------+----------+---------------------+--------------------+

# FrozenLake8x8-v0: 64x4
# +------------+------------------+-------------------------+---------+---------------------+--------------------+
# |    algo    | learning_seconds | mean_reward_on_learning | success | mean_succeed_reward | mean_succeed_steps |
# +------------+------------------+-------------------------+---------+---------------------+--------------------+
# | Q-Learning |      8.0653      |          0.0001         |  0/1000 |         None        |        None        |
# |   SARSA    |      5.1707      |          0.0008         |  0/1000 |         None        |        None        |
# +------------+------------------+-------------------------+---------+---------------------+--------------------+

# Taxi-v3: 500x6
# +------------+------------------+-------------------------+-----------+---------------------+--------------------+
# |    algo    | learning_seconds | mean_reward_on_learning |  success  | mean_succeed_reward | mean_succeed_steps |
# +------------+------------------+-------------------------+-----------+---------------------+--------------------+
# | Q-Learning |      9.3346      |          0.0884         | 1000/1000 |        7.782        |       13.218       |
# |   SARSA    |      7.1327      |         -5.0828         |  967/1000 |        8.0238       |      12.9762       |
# +------------+------------------+-------------------------+-----------+---------------------+--------------------+