In [None]:

import os


import gymnasium as gym
import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.common import Net

from pettingzoo.classic import tictactoe_v3

from pettingzoo.utils.wrappers import BaseWrapper
import Environment

def _get_agents(
    optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    env = _get_env()
    observation_space = (
        env.observation_space["observation"]
        if isinstance(env.observation_space, gym.spaces.Dict)
        else env.observation_space
    )
    agents = []
    for i in range(env.num_agents):
        # model
        net = Net(
            state_shape=observation_space.shape or observation_space.n,
            action_shape=env.action_space.shape or env.action_space.n,
            hidden_sizes=[128, 128, 128, 128],
            device="cuda" if torch.cuda.is_available() else "cpu",
        ).to("cuda" if torch.cuda.is_available() else "cpu")
        if optim is None:
            optim = torch.optim.Adam(net.parameters(), lr=1e-4)
        agent_learn = DQNPolicy(
            model=net,
            optim=optim,
            discount_factor=0.9,
            estimation_step=3,
            target_update_freq=320,
        )
        agents.append(agent_learn)
        

    policy = MultiAgentPolicyManager(agents, env)
    return policy, optim, env.agents


def _get_env():
    """This function is needed to provide callables for DummyVectorEnv."""
    # return PettingZooEnv(tictactoe_v3.env())
    return get_env()

def get_env():
    env = Environment.VehicleJobSchedulingEnvACE()
    env = BaseWrapper(env)
    env = PettingZooEnv(env)
    return env


In [None]:
# ======== Step 1: Environment setup =========
train_envs = DummyVectorEnv([_get_env for _ in range(10)])
test_envs = DummyVectorEnv([_get_env for _ in range(10)])
# seed
seed = 1
np.random.seed(seed)
torch.manual_seed(seed)
train_envs.seed(seed)
test_envs.seed(seed)


In [None]:
# ======== Step 2: Agent setup =========
policy, optim, agents = _get_agents()


In [None]:
# ======== Step 3: Collector setup =========
train_collector = Collector(
    policy,
    train_envs,
    VectorReplayBuffer(20_000, len(train_envs)),
    exploration_noise=True,
)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=64 * 1000)  # batch size * training_num


In [None]:
# ======== Step 4: Callback functions setup =========
def save_best_fn(policy):
    model_save_path = os.path.join("log", "rps", "dqn", "policy.pth")
    os.makedirs(os.path.join("log", "rps", "dqn"), exist_ok=True)
    torch.save(policy.policies[agents[1]].state_dict(), model_save_path)
def reward_metric(rews):
    return rews


In [None]:
# ======== Step 5: Run the trainer =========
result = offpolicy_trainer(
    policy=policy,
    train_collector=train_collector,
    test_collector=test_collector,
    max_epoch=5000,
    step_per_epoch=10000,
    step_per_collect=500,
    episode_per_test=10,
    batch_size=64,

    save_best_fn=save_best_fn,
    update_per_step=0.1,
    test_in_train=true,
    reward_metric=lambda rews:rews,
)
# return result, policy.policies[agents[1]]
print(f"\n==========Result==========\n{result}")

In [None]:
from rich.pretty import pprint

In [None]:
pprint("DQNPolicy:")
rslt = test_collector.collect(n_episode=10)
pprint(rslt)

In [None]:
# 1. Load environment
env = Environment.VehicleJobSchedulingEnvACE()
# 2. Wrap the environment with vectorized wrapper
env = BaseWrapper(env)
agents = env.num_agents
env = PettingZooEnv(env)
policies = MultiAgentPolicyManager([RandomPolicy() for _ in range(10)], env)
# env = SubprocVectorEnv([lambda: env for _ in range(5)])
env = DummyVectorEnv([lambda: env for _ in range(1)])
# action_space = env.action_spaces[env.agents[0]]
# 3. Create policy
# Runtime Environment
# 4. Create collector
collector = Collector(policies, env)
# 5. Execute one episode
result = collector.collect(n_episode=10)
pprint("RandomPolicy:")
pprint(result)