In [2]:
import gymnasium as gym
import highway_env

# Agent
from stable_baselines3 import DQN


import sys
from tqdm.notebook import trange

from stable_baselines3 import DQN
import pprint
from matplotlib import pyplot as plt
import numpy as np
import joblib
from tqdm import trange

2025-01-08 10:42:07.694014: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-08 10:42:07.738564: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-01-08 10:42:08.092372: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
class MyHighwayEnv(gym.Env):
    def __init__(self, vehicleCount=10):
        super(MyHighwayEnv, self).__init__()
        # base setting
        self.vehicleCount = vehicleCount

        # environment setting
        self.config = {
            "observation": {
                "type": "Kinematics",
                "features": ["presence", "x", "y", "vx", "vy"],
                "absolute": True,
                "normalize": True,
                "vehicles_count": vehicleCount,
                "see_behind": True,
            },
            "action": {
                "type": "DiscreteMetaAction",
                "target_speeds": np.linspace(0, 32, 9),
            },
            "duration": 40,
            "vehicles_density": 2,
            "show_trajectories": True,
            "render_agent": True,
        }
        self.env = gym.make("highway-v0", config = self.config)
        self.action_space = self.env.action_space
        self.observation_space = gym.spaces.Box(
            low=-np.inf,high=np.inf,shape=(10,5),dtype=np.float32
        )

    def step(self, action):
        # Step the wrapped environment and capture all returned values
        obs, dqn_reward, done, truncated, info = self.env.step(action)
        ##reward shaping takes place
        Reward = 1 / (1 + np.exp(-dqn_reward))

        return obs, Reward, done, truncated, info

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        return obs  # Make sure to return the observation

In [4]:
from stable_baselines3.common.vec_env import DummyVecEnv
env = MyHighwayEnv()

In [4]:
model = DQN('MlpPolicy', env,
            policy_kwargs=dict(net_arch=[256, 256]),
            learning_rate=5e-4,
            buffer_size=15000,
            learning_starts=200,
            batch_size=32,
            gamma=0.8,
            train_freq=1,
            gradient_steps=1,
            target_update_interval=50,
            exploration_fraction=0.7,
            verbose=1,
            tensorboard_log='highway_dqn/')

model.learn(int(2e4))
model.save('models/dqn_model')

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to highway_dqn/DQN_9
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2.25     |
|    ep_rew_mean      | 1.36     |
|    exploration_rate | 0.999    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 1        |
|    time_elapsed     | 4        |
|    total_timesteps  | 9        |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2.75     |
|    ep_rew_mean      | 1.72     |
|    exploration_rate | 0.999    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 2        |
|    time_elapsed     | 10       |
|    total_timesteps  | 22       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2.83     |
| 

Visulise

In [6]:
import base64
from pathlib import Path

from gymnasium.wrappers import RecordVideo
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display



def record_videos(env, video_folder="videos"):
    wrapped = RecordVideo(
        env, video_folder=video_folder, episode_trigger=lambda e: True
    )

    # Capture intermediate frames
    env.unwrapped.set_record_video_wrapper(wrapped)

    return wrapped


def show_videos(path="videos"):
    html = []
    for mp4 in Path(path).glob("*.mp4"):
        video_b64 = base64.b64encode(mp4.read_bytes())
        html.append(
            """<video alt="{}" autoplay
                      loop controls style="height: 400px;">
                      <source src="data:video/mp4;base64,{}" type="video/mp4" />
                 </video>""".format(
                mp4, video_b64.decode("ascii")
            )
        )
    ipythondisplay.display(ipythondisplay.HTML(data="<br>".join(html)))


In [8]:
from gymnasium.wrappers import RecordVideo
# base setting
vehicleCount = 10

# environment setting
config = {

    "observation": {
        "type": "Kinematics",
        "features": ["presence", "x", "y", "vx", "vy"],
        "absolute": True,
        "normalize": True,
        "vehicles_count": vehicleCount,
        "see_behind": True,
    },
    "action": {
        "type": "DiscreteMetaAction",
        "target_speeds": np.linspace(0, 32, 9),
    },
    "duration": 40,
    "vehicles_density": 2,
    "show_trajectories": True,
    "render_agent": True,
}

modelDqnllm = DQN.load("models/dqn_model")

env = gym.make('highway-v0', render_mode='rgb_array')
env.configure(config)
env = record_videos(env)
for episode in trange(3, desc='Test episodes'):
    (obs, info), done, truncated = env.reset(), False, False
    while not (done or truncated):
        action, _ = modelDqnllm.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(int(action))
env.close()
show_videos()

  logger.warn(
  logger.warn(
Test episodes:   0%|          | 0/3 [00:13<?, ?it/s]


KeyboardInterrupt: 