In [1]:
import os
import sys

import gymnasium as gym
import numpy as np
import ray
import torch
from gymnasium import spaces
from ray import tune
from ray.rllib.algorithms.dqn import DQN, DQNConfig
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.exploration.epsilon_greedy import EpsilonGreedy
from ray.tune.registry import register_env
from torch import nn

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import core
print(sys.path)

sys.modules['core'] = core
from core.algorithms.amd.amd import AMD, AMDConfig
from core.algorithms.amd.wrappers import \
    MultiAgentEnvFromPettingZooParallel as P2M
from core.environments.wolfpack import wolfpack_env_creator

pygame 2.3.0 (SDL 2.24.2, Python 3.10.4)
Hello from the pygame community. https://www.pygame.org/contribute.html
['/home/quanta/.conda/envs/forl-proj/lib/python3.10/site-packages/ray/thirdparty_files', '/home/quanta/Projects/FoRL-project/test_code', '/home/quanta/.conda/envs/forl-proj/lib/python310.zip', '/home/quanta/.conda/envs/forl-proj/lib/python3.10', '/home/quanta/.conda/envs/forl-proj/lib/python3.10/lib-dynload', '', '/home/quanta/.conda/envs/forl-proj/lib/python3.10/site-packages', '/home/quanta/Projects/FoRL-project']


In [2]:
class SimpleMLPModelV2(TorchModelV2, nn.Module):

    def __init__(self, obs_space: gym.Space, act_space: gym.Space, num_outputs, *args, **kwargs):
        TorchModelV2.__init__(self, obs_space, act_space, num_outputs, *args, **kwargs)
        nn.Module.__init__(self)

        self.flattened_obs_space = spaces.flatten_space(obs_space)
        self.obs_space = obs_space
        self.action_space = act_space

        self.model = nn.Sequential(
            nn.Flatten(start_dim=1, end_dim=-1),
            nn.Linear(self.flattened_obs_space.shape[0], 128),
            nn.ReLU(),
            nn.Linear(128, 32),
            nn.ReLU(),
        )

        self.policy_fn = nn.Linear(32, num_outputs)
        self.value_fn = nn.Linear(32, 1)

    def forward(self, input_dict, state, seq_lens):
        model_out = self.model(input_dict["obs"].to(torch.float32) / 255)
        self._value_out = self.value_fn(model_out)
        return self.policy_fn(model_out), state

    def value_function(self):
        return self._value_out.flatten()

In [3]:
ray.init()

2023-05-21 12:41:09,054	INFO worker.py:1553 -- Started a local Ray instance.


0,1
Python version:,3.10.4
Ray version:,2.3.1


In [4]:
# register env and model
env_name = 'wolfpack'
register_env(env_name, lambda config: P2M(wolfpack_env_creator(config)))
ModelCatalog.register_custom_model("SimpleMLPModelV2", SimpleMLPModelV2)

In [5]:
# config for dqn
config = DQNConfig().multi_agent(
    policies=['predator', 'prey'],
    policy_mapping_fn=(lambda agent_id, *args, **kwargs: {
        'wolf_1': 'predator',
        'wolf_2': 'predator',
        'prey': 'prey',
    }[agent_id]),
).environment(
    env=env_name,
    env_config={
        'r_lone': 1.0,
        'r_team': 5.0,
        'r_prey': 0.0,
        'coop_radius': 4,
        'max_cycles': 1000,
    },
    clip_actions=True,
).rollouts(
    num_rollout_workers=4,
    rollout_fragment_length=128,
).training(
    model={
        "custom_model": "SimpleMLPModelV2",
        "custom_model_config": {},
    },
    train_batch_size=1000,
    lr=2e-5,
    gamma=0.99,
    v_min=0.0,
    v_max=10.0,
    # double_q=True,
).debugging(log_level="ERROR").framework(framework="torch").resources(
    num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")),
    num_cpus_per_worker=3,
)
explore_config = {
    "type": EpsilonGreedy,
    "initial_epsilon": 1.0,
    "final_epsilon": 0.1,
    "epsilon_timesteps": 100000,
}
config.explore = True,
config.exploration_config = explore_config

In [6]:
# load checkpoint init weight
algo = config.build().load_checkpoint('/home/quanta/ray_results/wolfpack/dqn/DQN_wolfpack_5e5a5_00000_0_2023-05-21_12-22-20/checkpoint_000010')

# worker = algo.workers.local_worker()
# policy_map = worker.policy_map()
# for policy_id in policy_map.keys():
#     print(policy_id)

2023-05-21 12:41:10,439	INFO algorithm.py:506 -- Current log_level is ERROR. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
[2m[36m(RolloutWorker pid=46213)[0m 2023-05-21 12:41:13,810	ERROR serialization.py:371 -- No module named 'core'
[2m[36m(RolloutWorker pid=46213)[0m Traceback (most recent call last):
[2m[36m(RolloutWorker pid=46213)[0m   File "/home/quanta/.conda/envs/forl-proj/lib/python3.10/site-packages/ray/_private/serialization.py", line 369, in deserialize_objects
[2m[36m(RolloutWorker pid=46213)[0m     obj = self._deserialize_object(data, metadata, object_ref)
[2m[36m(RolloutWorker pid=46213)[0m   File "/home/quanta/.conda/envs/forl-proj/lib/python3.10/site-packages/ray/_private/serialization.py", line 252, in _deserialize_object
[2m[36m(RolloutWorker pid=46213)[0m     return self._deserialize_msgpack_data(data, metadata_fields)
[2m[36m(RolloutWorker pid=46213)[0m   File "/home/quanta/.conda/envs/forl-proj/lib/pyth

RaySystemError: System error: No module named 'core'
traceback: Traceback (most recent call last):
  File "/home/quanta/.conda/envs/forl-proj/lib/python3.10/site-packages/ray/_private/serialization.py", line 369, in deserialize_objects
    obj = self._deserialize_object(data, metadata, object_ref)
  File "/home/quanta/.conda/envs/forl-proj/lib/python3.10/site-packages/ray/_private/serialization.py", line 252, in _deserialize_object
    return self._deserialize_msgpack_data(data, metadata_fields)
  File "/home/quanta/.conda/envs/forl-proj/lib/python3.10/site-packages/ray/_private/serialization.py", line 207, in _deserialize_msgpack_data
    python_objects = self._deserialize_pickle5_data(pickle5_data)
  File "/home/quanta/.conda/envs/forl-proj/lib/python3.10/site-packages/ray/_private/serialization.py", line 197, in _deserialize_pickle5_data
    obj = pickle.loads(in_band)
ModuleNotFoundError: No module named 'core'
