In [1]:
from os import path
import time

import gymnasium as gym
from minigrid.wrappers import ImgObsWrapper, FullyObsWrapper
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.logger import configure

from custom_envs import MultiarmedBanditsEnv
from sb3_contrib import ArDQN
# from stable_baselines3 import DQN
from sb3_contrib.common.satisficing.evaluation import evaluate_policy as ar_evaluate_policy
from sb3_contrib.dqn import DQN
from utils import open_tensorboard

OPEN_TENSORBOARD = True

pygame 2.4.0 (SDL 2.26.4, Python 3.10.6)
Hello from the pygame community. https://www.pygame.org/contribute.html


# Setup

In [2]:
LEARNING_STEPS = 300000
env_id = 'MiniGrid-Empty-5x5-v0'
partial_env = ImgObsWrapper(gym.make(env_id, max_episode_steps=100, render_mode='rgb_array'))


def make_env(render_mode='rgb_array', **kwargs):
    return ImgObsWrapper(FullyObsWrapper(gym.make(env_id, max_episode_steps=100, render_mode=render_mode, **kwargs)))


env = make_env()
# env.reset()[0].shape, partial_env_env.reset()[0].shape

In [24]:
LEARNING_STEPS = 100000
env_id = 'MultiarmedBandits-1-5-99'


def make_env(**kwargs):
    return MultiarmedBanditsEnv([1, 5, 99], [0, 0, 0], **kwargs)

# Setup Logs

Log description can be found here: https://stable-baselines3.readthedocs.io/en/master/common/logger.html

In [25]:
tmp_path = path.join("./logs/tests", time.strftime("%Y%m%d-%H%M%S"))


# set up logger
def tb_logger(exp):
    return configure(path.join(tmp_path, exp), ["tensorboard"])


tb_window = None
if OPEN_TENSORBOARD:
    tb_window = open_tensorboard(tmp_path)

Started Tensorboard Server
Started Browser


# Training

## DQN

In [26]:
env = make_env()
model = DQN('MlpPolicy', env, learning_starts=0)
# Set new logger
dqn_path = path.join(env_id, "DQN")
model.set_logger(tb_logger(dqn_path))

## ArDQN

In [36]:
ar_env = make_env()
initial_aspiration = 100
ar_model = ArDQN('MlpPolicy', ar_env, learning_starts=0, policy_kwargs=dict(initial_aspiration=initial_aspiration))
ar_path = path.join(env_id, "AR_DQN", str(initial_aspiration))
ar_model.set_logger(tb_logger(ar_path))

## Run

In [37]:
ar_model.learn(LEARNING_STEPS)

<sb3_contrib.ar_dqn.ar_dqn.ArDQN at 0x7fd4e944ccd0>

In [29]:
model.learn(LEARNING_STEPS)

<sb3_contrib.dqn.dqn.DQN at 0x7fd4e944cb50>

# Evaluation

In [30]:
h_env = make_env(render_mode='human')
evaluate_policy(model, h_env, n_eval_episodes=10, render=True)



(297.0, 0.0)

In [31]:
h_env = make_env(render_mode='human')
ar_evaluate_policy(ar_model, h_env, n_eval_episodes=100, render=True)



(3.0, 0.0)

In [34]:
ar_model.save(path.join(tmp_path,ar_path, "models", "_".join([str(LEARNING_STEPS)])))



In [35]:
path.join(tmp_path,ar_path, "models", str(LEARNING_STEPS) + "_steps")

'./logs/tests/MultiarmedBandits-1-5-99/AR_DQN/15/models/100000'