In [1]:
import gymnasium as gym
from minigrid.wrappers import ImgObsWrapper, FullyObsWrapper
# from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.logger import configure

from sb3_contrib import ArDQN
from sb3_contrib.dqn import DQN
from utils import open_tensorboard
import multiprocessing
from os import path

pygame 2.4.0 (SDL 2.26.4, Python 3.10.6)
Hello from the pygame community. https://www.pygame.org/contribute.html


# Setup the environment

In [2]:
env_id = 'MiniGrid-Empty-5x5-v0'
env = ImgObsWrapper(gym.make(env_id, max_episode_steps=100, render_mode='rgb_array'))
full_env = ImgObsWrapper(FullyObsWrapper(gym.make(env_id, max_episode_steps=100, render_mode='rgb_array')))

In [3]:
env.reset()[0].shape, full_env.reset()[0].shape

((7, 7, 3), (5, 5, 3))

# Setup Tensorboard

Log information can be found here: https://stable-baselines3.readthedocs.io/en/master/common/logger.html

In [4]:
tmp_path = "/tmp/sb3_log/"
tb_window = open_tensorboard(tmp_path)

# set up logger
def tb_logger(exp):
    return configure(path.join(tmp_path, exp), ["tensorboard"])

Started Tensorboard Server
Started Browser


# Training

## DQN

In [5]:
model = DQN('MlpPolicy', env, learning_starts=0)
# Set new logger
model.set_logger(tb_logger("DQN"))
dqn_job = multiprocessing.Process(target=lambda: model.learn(100000))

## ArDQN

In [6]:
env = ImgObsWrapper(gym.make(env_id, max_episode_steps=100, render_mode='rgb_array'))
ar_model = ArDQN('MlpPolicy', env, learning_starts=0, policy_kwargs=dict(initial_aspiration=0.5))
ar_model.set_logger(tb_logger("AR_DQN/0.5"))

## Run

In [7]:
model.learn(300000)

<sb3_contrib.dqn.dqn.DQN at 0x7fd2948a5e70>

In [8]:
ar_model.learn(300000)

<sb3_contrib.ar_dqn.ar_dqn.ArDQN at 0x7fd2948a5840>

# Evaluation

In [15]:
h_env = ImgObsWrapper(gym.make(env_id, render_mode='human'))
evaluate_policy(model, h_env, n_eval_episodes=10, render=True)



(0.9549999833106995, 0.0)

In [21]:
evaluate_policy(ar_model, env, n_eval_episodes=10, render=False)

(0.0, 0.0)