In [1]:
from stable_baselines3 import DQN, PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.env_checker import check_env
from env import DangerousDaveEnv
import time, os
from custom_cnn import policy_kwargs
import torch
from stable_baselines3.common.vec_env import SubprocVecEnv
import numpy as np

In [2]:
# Setting up the device
device = "mps" if torch.backends.mps.is_available() else "cpu"
device = torch.device(device)

# Manual assignment of arguments (replace with your desired values or use ipywidgets for interactivity)
train = True  # equivalent to --train in argparse
evaluate = True  # equivalent to --evaluate in argparse
model_name = "dqn_test_2"  # manually specify or generate a name
env_rep_type = 'text'  # 'text' or 'image'
model_type = 'DQN'  # 'DQN', 'RND', or 'PPO'
retrain = False  # equivalent to --retrain in argparse

# Your existing logic below
checkpoint_timestamp = int(time.time())
if not model_name:
    model_name = "checkpoints/dqn_ddave_{}".format(checkpoint_timestamp)

tensorboard_log = f"tensorboard_log/{model_name}"
tensorboard_log_run_name = '0'
print(model_name,tensorboard_log)
# Create the DangerousDaveEnv environment
random_respawn=True
env = DangerousDaveEnv(render_mode="human", env_rep_type=env_rep_type,random_respawn=random_respawn)
obs,info = env.reset()

total_timesteps=100000

dqn_test_2 tensorboard_log/dqn_test_2
(1, 11, 19)
Box(0, 255, (1, 11, 19), uint8)




In [3]:
if model_type == 'DQN':
    if train:
        # Define and train the DQN agent
        if retrain:
            model = DQN.load("checkpoints/{}".format(model_name),tensorboard_log=tensorboard_log)
            model.set_env(env)
        else:
            model = DQN("CnnPolicy", env, verbose=1, batch_size=256, policy_kwargs=policy_kwargs,
                        learning_starts=1000, exploration_fraction=0.5, exploration_final_eps=0.01, device=device,
                        target_update_interval=5000, buffer_size=100000,tensorboard_log=tensorboard_log)

        model.learn(total_timesteps=total_timesteps, progress_bar=True,tb_log_name=tensorboard_log_run_name,log_interval=1)
        
        # Save the trained model if desired
        model.save("checkpoints/{}".format(model_name))

    if evaluate:
        # Evaluate the trained model
        model = DQN.load("checkpoints/{}".format(model_name), env=env,tensorboard_log=tensorboard_log)

elif model_type == 'PPO':
    if train:
        # Define and train the PPO agent
        if retrain:
            model = PPO.load("checkpoints/{}".format(model_name), env=env,tensorboard_log=tensorboard_log)
        else:
            model = PPO("CnnPolicy", env, verbose=1, batch_size=256, policy_kwargs=policy_kwargs, device=device,
                        tensorboard_log=tensorboard_log,ent_coef=0.01,vf_coef=1)

        model.learn(total_timesteps=total_timesteps, progress_bar=True,tb_log_name=tensorboard_log_run_name,log_interval=1)
        
        # Save the trained model if desired
        model.save("checkpoints/{}".format(model_name))

    if evaluate:
        # Evaluate the trained model
        model = PPO.load("checkpoints/{}".format(model_name), env=env,tensorboard_log=tensorboard_log)

if evaluate:
    eps_reward = []
    for i in range(5):
        env = DangerousDaveEnv(render_mode="human", env_rep_type=env_rep_type,random_respawn=False)
        obs, info = env.reset()
        terminated = False
        truncated = False
        reward = 0
        while not (terminated or truncated):
            action, _ = model.predict(obs, deterministic=True)
            obs, rewards, terminated, truncated, info = env.step(action)
            reward += rewards
        eps_reward.append(reward)
    print(f'{np.mean(eps_reward)} eval reward mean')

            

Using mps device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to tensorboard_log/dqn_test_2/0_12


Output()

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.5e+03  |
|    ep_rew_mean      | -1.5e+03 |
|    exploration_rate | 0.97     |
| time/               |          |
|    episodes         | 1        |
|    fps              | 66       |
|    time_elapsed     | 22       |
|    total_timesteps  | 1500     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.69e-07 |
|    n_updates        | 124      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.5e+03  |
|    ep_rew_mean      | -1.5e+03 |
|    exploration_rate | 0.941    |
| time/               |          |
|    episodes         | 2        |
|    fps              | 37       |
|    time_elapsed     | 79       |
|    total_timesteps  | 3000     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.02e-12 |
|    n_updates        | 499      |
----------------------------------


KeyboardInterrupt: 

In [None]:
np.mean(eps_reward)

In [None]:
eps_reward

In [None]:
logfile = tensorboard_log+'/'+tensorboard_log_run_name +'_5/'
logfile

In [None]:
!tensorboard --logdir {logfile}