In [1]:
import os
os.environ['RAY_LOG_LEVEL'] = "DEBUG"

In [2]:
import ray
import torch
from gymnasium.wrappers import TimeLimit
from dynaconf import Dynaconf
from ray.rllib.models import ModelCatalog
from environments.MutilRoadEnv import RouteEnv
from ray.tune.registry import register_env
from ray.rllib.algorithms.apex_dqn import ApexDQN, ApexDQNConfig
from ray.rllib.algorithms.dqn import DQN, DQNConfig
from ray.tune.logger import JsonLogger
from environments.ObsWrapper import FullRGBImgPartialObsWrapper
from model.image_decoder import CustomCNN
from minigrid.wrappers import ImgObsWrapper
import gymnasium as gym

pygame 2.5.2 (SDL 2.28.2, Python 3.9.18)
Hello from the pygame community. https://www.pygame.org/contribute.html




In [3]:
torch.cuda.is_available()

True

In [4]:
# Init Ray
ray.init(
    num_cpus=14, num_gpus=1,
    include_dashboard=True,
    _system_config={"maximum_gcs_destroyed_actor_cached_count": 200},
)


2023-11-22 11:00:39,797	INFO worker.py:1664 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m


0,1
Python version:,3.9.18
Ray version:,2.8.0
Dashboard:,http://127.0.0.1:8265


In [5]:
# Config path
log_path = "./logs/"
checkpoint_path = "./checkpoints"
sub_buffer_size = 16
setting = Dynaconf(envvar_prefix="DYNACONF", settings_files="./drone.yml")

# Set hyper parameters
hyper_parameters = setting.hyper_parameters.to_dict()
hyper_parameters["logger_config"] = {"type": JsonLogger, "logdir": checkpoint_path}
print("log path: %s \ncheck_path: %s" % (log_path, checkpoint_path))

log path: ./logs/ 
check_path: ./checkpoints


In [6]:
hyper_parameters["replay_buffer_config"]

{'type': 'MultiAgentPrioritizedReplayBuffer', 'capacity': 100000}

In [7]:
hyper_parameters["train_batch_size"] = int(hyper_parameters["train_batch_size"] / sub_buffer_size)
hyper_parameters["env_config"] = {
        "size": 20,
        "roads": (5, 7),
        "max_steps": 1000,
        "battery":100,
        "render_mode": "rgb_array",
        "agent_pov": False
    }

In [8]:
# Build env
def env_creator(env_config):
    env = RouteEnv(**env_config)
    env = FullRGBImgPartialObsWrapper(env, tile_size=5)
#     env = TimeLimit(env, max_episode_steps=1000)
    return ImgObsWrapper(env)

register_env("RandomPath", env_creator)

env = env_creator(hyper_parameters["env_config"])
obs, _ = env.reset()
step = env.step(1)
print(env.action_space, env.observation_space)

Discrete(3) Box(0, 255, (100, 100, 3), uint8)


In [9]:
ModelCatalog.register_custom_model("CustomCNN", CustomCNN)

hyper_parameters["model"] = {
        "custom_model": "CustomCNN",
        "no_final_linear": True,
        "fcnet_hiddens": hyper_parameters["hiddens"],
        "custom_model_config": {},
    }

In [10]:
config = DQNConfig().environment("RandomPath").resources(num_gpus=1)

In [11]:
config.update_from_dict(hyper_parameters)

<ray.rllib.algorithms.dqn.dqn.DQNConfig at 0x7f49c8d50dc0>

In [12]:
trainer = config.build()

The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  object_ = constructor(*ctor_args, **ctor_kwargs)


In [13]:
for i in range(100):
    _ = trainer.train()
    print(_["sampler_results"]["episode_reward_mean"])



-9.727974683544302




-10.093999999999996
-9.942299999999996
-8.990799999999997
-8.853599999999998
-8.282999999999998
-8.691299999999995
-8.144199999999994
-7.618099999999995
-7.2602999999999955
-6.4119999999999955
-6.713899999999995
-6.554299999999996
-6.328299999999995
-6.1950999999999965
-6.814399999999998
-6.304899999999998
-6.668699999999998
-6.789999999999997
-8.446599999999997
-9.022699999999997
-8.729699999999996
-10.277699999999983
-8.413999999999994
-8.371299999999996
-8.343099999999987
-7.665299999999995
-9.006499999999994
-8.524399999999988
-7.87459999999999
-8.995799999999992
-7.374699999999991
-8.962699999999986
-10.150499999999987
-7.024299999999995
-6.9961999999999955
-8.551099999999993
-7.687999999999993
-9.154799999999986
-8.421299999999981
-7.203199999999992
-5.311799999999998
-7.073099999999994
-7.274299999999988
-6.169199999999995
-5.286699999999996
-5.593399999999995
-8.896199999999986
-8.81509999999999
-7.4473999999999885
-8.537499999999989
-8.359899999999989
-6.847699999999993
-6.967