In [None]:
"""This is a minimal example of using Tianshou with MARL to train agents.

Author: Will (https://github.com/WillDudley)

Python version used: 3.8.10

Requirements:
pettingzoo == 1.22.0
git+https://github.com/thu-ml/tianshou
"""

import os
from typing import Optional, Tuple

import gymnasium as gym
import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.common import Net

from pettingzoo.classic import tictactoe_v3


def _get_agents(
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    env = _get_env()
    observation_space = (
        env.observation_space["observation"]
        if isinstance(env.observation_space, gym.spaces.Dict)
        else env.observation_space
    )
    if agent_learn is None:
        # model
        net = Net(
            state_shape=observation_space.shape or observation_space.n,
            action_shape=env.action_space.shape or env.action_space.n,
            hidden_sizes=[128, 128, 128, 128],
            device="cuda" if torch.cuda.is_available() else "cpu",
        ).to("cuda" if torch.cuda.is_available() else "cpu")
        if optim is None:
            optim = torch.optim.Adam(net.parameters(), lr=1e-4)
        agent_learn = DQNPolicy(
            model=net,
            optim=optim,
            discount_factor=0.9,
            estimation_step=3,
            target_update_freq=320,
        )

    if agent_opponent is None:
        agent_opponent = RandomPolicy()

    agents = [agent_opponent, agent_learn]
    policy = MultiAgentPolicyManager(agents, env)
    return policy, optim, env.agents


def _get_env():
    """This function is needed to provide callables for DummyVectorEnv."""
    return PettingZooEnv(tictactoe_v3.env())


if __name__ == "__main__":
    # ======== Step 1: Environment setup =========
    train_envs = DummyVectorEnv([_get_env for _ in range(10)])
    test_envs = DummyVectorEnv([_get_env for _ in range(10)])

    # seed
    seed = 1
    np.random.seed(seed)
    torch.manual_seed(seed)
    train_envs.seed(seed)
    test_envs.seed(seed)

    # ======== Step 2: Agent setup =========
    policy, optim, agents = _get_agents()

    # ======== Step 3: Collector setup =========
    train_collector = Collector(
        policy,
        train_envs,
        VectorReplayBuffer(20_000, len(train_envs)),
        exploration_noise=True,
    )
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    # policy.set_eps(1)
    train_collector.collect(n_step=64 * 10)  # batch size * training_num

    # ======== Step 4: Callback functions setup =========
    def save_best_fn(policy):
        model_save_path = os.path.join("log", "rps", "dqn", "policy.pth")
        os.makedirs(os.path.join("log", "rps", "dqn"), exist_ok=True)
        torch.save(policy.policies[agents[1]].state_dict(), model_save_path)

    def stop_fn(mean_rewards):
        return mean_rewards >= 0.6

    def train_fn(epoch, env_step):
        policy.policies[agents[1]].set_eps(0.1)

    def test_fn(epoch, env_step):
        policy.policies[agents[1]].set_eps(0.05)

    def reward_metric(rews):
        return rews[:, 1]

    # ======== Step 5: Run the trainer =========
    result = offpolicy_trainer(
        policy=policy,
        train_collector=train_collector,
        test_collector=test_collector,
        max_epoch=50,
        step_per_epoch=1000,
        step_per_collect=50,
        episode_per_test=10,
        batch_size=64,
        train_fn=train_fn,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_best_fn=save_best_fn,
        update_per_step=0.1,
        test_in_train=False,
        reward_metric=reward_metric,
    )

    # return result, policy.policies[agents[1]]
    print(f"\n==========Result==========\n{result}")
    print("\n(the trained policy can be accessed via policy.policies[agents[1]])")

In [None]:
performance_benchmark(env)

In [None]:
import Environment
env = Environment.VehicleJobSchedulingEnvACE()
from pettingzoo.test import performance_benchmark
import cProfile
cProfile.run("performance_benchmark(env)")

In [None]:
%conda install scipy

In [2]:
import Environment
from supersuit import flatten_v0
from tianshou.env import PettingZooEnv
env = Environment.VehicleJobSchedulingEnvACE()

env = PettingZooEnv(env)

env.action_space.shape, env.observation_space

AttributeError: 'Dict' object has no attribute 'low'

In [4]:
from pettingzoo.utils.wrappers import BaseWrapper,BaseParallelWrapper,BaseParallelWraper

ImportError: cannot import name 'BaseParallelWraper' from 'pettingzoo.utils.wrappers' (/home/yuan/ResMan/man/lib/python3.9/site-packages/pettingzoo/utils/wrappers/__init__.py)

In [15]:
from rich import pretty
pretty.install()

In [2]:
env.reset()
env.step(1)
env.last()

KeyError: 'avail_slot'

In [4]:
env = Environment.VehicleJobSchedulingEnvACE()
sp, _ = env.reset()

In [6]:
sp

OrderedDict([('avail_slot',
              array([[20, 37],
                     [20, 37],
                     [20, 37],
                     [20, 37],
                     [20, 37],
                     [20, 37],
                     [20, 37],
                     [20, 37],
                     [20, 37],
                     [20, 37]], dtype=int8)),
             ('request_job',
              OrderedDict([('res_vec', [0, 0]),
                           ('len', 0),
                           ('priority', 0)]))])

ImportError: cannot import name 'AECWrapper' from 'pettingzoo.utils' (/home/yuan/ResMan/man/lib/python3.9/site-packages/pettingzoo/utils/__init__.py)

In [1]:
from environment import Environment

env = Environment.VehicleJobSchedulingEnvACE()
env.reset()
for agent in env.agent_iter(10000):
    env.step(env.action_space(agent).sample())

In [4]:
env.observation['Machine_0']

array([24., 41., 24., 41., 24., 41., 24., 41., 24., 41., 24., 41., 24.,
       41., 24., 41., 24., 41., 24., 41.,  0.,  0.,  1.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.])

IndexError: index 18 is out of bounds for axis 0 with size 10

In [7]:
dic = env.parameters.cluster.machines[1].observe(
)

In [9]:
from gymnasium.spaces.utils import flatten, flatten_space
flatten(env.obs, dic)

array([15., 38., 15., 38., 15., 38., 15., 38., 15., 38., 15., 38., 15.,
       38., 15., 38., 15., 38., 15., 38.,  0.,  0.,  1.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.])

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

I0000 00:00:1695365322.463636 4159322 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [4]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

2.59 ms ± 102 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


2023-09-22 14:50:13.407743: W external/xla/xla/service/gpu/buffer_comparator.cc:1054] INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-lwh-Super-Server-5d095050-4159322-605ed053a8f85, line 10; fatal   : Unsupported .version 7.8; current version is '7.7'
ptxas fatal   : Ptx assembly aborted due to errors

Relying on driver to perform ptx compilation. 
Setting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda  or modifying $PATH can be used to set the location of ptxas
This message will only be logged once.


In [5]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

1.13 ms ± 280 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

211 µs ± 58.1 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [1]:
from environment import Environment

env = Environment.VehicleJobSchedulingEnvACE()
env.reset()


(array([ 72., 144.,  72., 144.,  72., 144.,  72., 144.,  72., 144.,  72.,
        144.,  72., 144.,  72., 144.,  72., 144.,  72., 144.,  72., 144.,
         72., 144.,  72., 144.,  72., 144.,  72., 144.,  72., 144.,  72.,
        144.,  72., 144.,  72., 144.,  72., 144.,   0.,   0.,   1.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.]),
 {'Machine_0': {},
  'Machine_1': {},
  'Machine_2': {},
  'Machine_3': {},
  'Machine_4': {},
  'Machine_5': {},
  'Machine_6': {},
  'Machine_7': {},
  'Machine_8': {},
  'Machine_9': {},
  'Machine_10': {},
  'Machine_11': {}})

In [2]:
env.parameters.cluster.machines[1].observe()

OrderedDict([('avail_slot',
              array([[ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048],
                     [ 192, 2048]], dtype=int16)),
             ('request_res_vec', array([0, 0])),
             ('request_len', 0),
             ('request_priority', 0)])

In [1]:
import cProfile
from pettingzoo.test import performance_benchmark
from environment import Environment
env = Environment.VehicleJobSchedulingEnvACE()
cProfile.run("performance_benchmark(env)")


Starting performance benchmark
8430.592976980437 turns per second
702.5494147483697 cycles per second
Finished performance benchmark
         4115172 function calls (3946068 primitive calls) in 5.002 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    40337    0.039    0.000    0.396    0.000 <__array_function__ internals>:177(all)
     8959    0.007    0.000    0.126    0.000 <__array_function__ internals>:177(argsort)
       36    0.000    0.000    0.000    0.000 <__array_function__ internals>:177(atleast_1d)
       36    0.000    0.000    0.001    0.000 <__array_function__ internals>:177(broadcast_arrays)
    42276    0.037    0.000    0.200    0.000 <__array_function__ internals>:177(concatenate)
      108    0.000    0.000    0.001    0.000 <__array_function__ internals>:177(copyto)
    37130    0.026    0.000    0.154    0.000 <__array_function__ internals>:177(dot)
     8959    0.007    0.000    0.064    0.000 <__arr

In [6]:
import cProfile
from pettingzoo.test import performance_benchmark
from environment import Environment
env = Environment.VehicleJobSchedulingEnvACE()
cProfile.run("performance_benchmark(env)")

Starting performance benchmark
8109.137924775118 turns per second
675.7614937312599 cycles per second
Finished performance benchmark
         3962297 function calls (3799639 primitive calls) in 5.004 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    38819    0.039    0.000    0.388    0.000 <__array_function__ internals>:177(all)
     8617    0.007    0.000    0.123    0.000 <__array_function__ internals>:177(argsort)
       35    0.000    0.000    0.000    0.000 <__array_function__ internals>:177(atleast_1d)
       35    0.000    0.000    0.002    0.000 <__array_function__ internals>:177(broadcast_arrays)
    40668    0.039    0.000    0.204    0.000 <__array_function__ internals>:177(concatenate)
      108    0.000    0.000    0.001    0.000 <__array_function__ internals>:177(copyto)
    35862    0.026    0.000    0.152    0.000 <__array_function__ internals>:177(dot)
     8617    0.007    0.000    0.063    0.000 <__arr

In [None]:
# 调试Allocation Mechanism