In [1]:
import time
from os import path

import gymnasium as gym
import numpy as np
from minigrid.wrappers import ImgObsWrapper, FullyObsWrapper
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.logger import configure
from stable_baselines3.common.monitor import Monitor

from custom_envs import MultiarmedBanditsEnv
from sb3_contrib import ArDQN
# from stable_baselines3 import DQN
from sb3_contrib.common.satisficing.evaluation import evaluate_policy as ar_evaluate_policy
from sb3_contrib.dqn import DQN
from utils import open_tensorboard

OPEN_TENSORBOARD = True

pygame 2.4.0 (SDL 2.26.4, Python 3.10.6)
Hello from the pygame community. https://www.pygame.org/contribute.html


# Setup

In [2]:
LEARNING_STEPS = 300000
env_id = 'MiniGrid-Empty-5x5-v0'
partial_env = ImgObsWrapper(gym.make(env_id, max_episode_steps=100, render_mode='rgb_array'))


def make_env(render_mode='rgb_array', **kwargs):
    return ImgObsWrapper(FullyObsWrapper(gym.make(env_id, max_episode_steps=100, render_mode=render_mode, **kwargs)))


env = make_env()
# env.reset()[0].shape, partial_env_env.reset()[0].shape

In [3]:
LEARNING_STEPS = 30000
env_id = 'MultiarmedBandits-1-5-99'


def make_env(**kwargs):
    return MultiarmedBanditsEnv([1, 5, 99], [0, 0, 0], 1, **kwargs)

# Setup Logs

Log description can be found here: https://stable-baselines3.readthedocs.io/en/master/common/logger.html

In [4]:
tmp_path = path.join("./logs/tests", time.strftime("%Y%m%d-%H%M%S"))


# set up logger
def tb_logger(exp):
    return configure(path.join(tmp_path, exp), ["tensorboard"])


tb_window = None
if OPEN_TENSORBOARD:
    tb_window = open_tensorboard(tmp_path)

Started Tensorboard Server
Started Browser


# Training

## DQN

In [5]:
env = make_env()
model = DQN('MlpPolicy', env, learning_starts=0)
# Set new logger
dqn_path = path.join(env_id, "DQN")
model.set_logger(tb_logger(dqn_path))

  return torch._C._cuda_getDeviceCount() > 0


## ArDQN

In [6]:
for a in np.linspace(1, 100, num=11):
    ar_env = make_env()
    initial_aspiration = a
    ar_model = ArDQN('MlpPolicy', ar_env, learning_starts=0, policy_kwargs=dict(initial_aspiration=initial_aspiration))
    ar_path = path.join(env_id, "AR_DQN", str(initial_aspiration))
    ar_model.set_logger(tb_logger(ar_path))
    ar_model.learn(LEARNING_STEPS)
    ar_model.save(path.join(tmp_path, ar_path, "models", "_".join([str(LEARNING_STEPS)])))

ALSA lib conf.c:4120:(snd_config_update_r) Cannot access file /usr/share/alsa/alsa.conf
ALSA lib seq.c:935:(snd_seq_open_noupdate) Unknown SEQ default


## Run

<sb3_contrib.ar_dqn.ar_dqn.ArDQN at 0x7fbea2ea8310>

In [8]:
model.learn(LEARNING_STEPS)

<sb3_contrib.dqn.dqn.DQN at 0x7fbdd901b940>

# Evaluation

In [8]:
h_env = Monitor(make_env(render_mode='human'))

In [17]:
evaluate_policy(model, h_env, n_eval_episodes=10, render=True)

(99.0, 0.0)

In [13]:
ar_model = ar_model.load("logs/tests/20230626-184832/MultiarmedBandits-1-5-99/AR_DQN/50/models/100000.zip")

In [19]:
for a in np.linspace(1, 100, num=11):
    ar_path = path.join(env_id, "AR_DQN", str(a))
    ar_model = ar_model.load(path.join(tmp_path, ar_path, "models", "_".join([str(LEARNING_STEPS)])))
    print(f"aspiration : {a}, res : {ar_evaluate_policy(ar_model, h_env, n_eval_episodes=1000, render=True)}")

aspiration : 1.0, res : (1.0, 0.0)
aspiration : 10.9, res : (11.392, 23.664114942249583)
aspiration : 20.8, res : (18.912, 33.37939867642914)
aspiration : 30.700000000000003, res : (32.542, 42.78301340485496)
aspiration : 40.6, res : (42.13, 45.951965137521604)
aspiration : 50.5, res : (51.624, 46.998495975935235)
aspiration : 60.400000000000006, res : (60.836, 46.16194432646874)
aspiration : 70.3, res : (71.834, 42.61000403661093)
aspiration : 80.2, res : (83.114, 35.22668028639656)
aspiration : 90.10000000000001, res : (96.744, 14.38660710522116)
aspiration : 100.0, res : (99.0, 0.0)




In [22]:
model.save(path.join(tmp_path, dqn_path, "models", "_".join([str(LEARNING_STEPS)])))