In [1]:
"""This is a minimal example of using Tianshou with MARL to train agents.

Author: Will (https://github.com/WillDudley)

Python version used: 3.8.10

Requirements:
pettingzoo == 1.22.0
git+https://github.com/thu-ml/tianshou
"""

import os
from typing import Optional, Tuple

import gymnasium as gym
import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.common import Net

from pettingzoo.classic import tictactoe_v3


def _get_agents(
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    env = _get_env()
    observation_space = (
        env.observation_space["observation"]
        if isinstance(env.observation_space, gym.spaces.Dict)
        else env.observation_space
    )
    if agent_learn is None:
        # model
        net = Net(
            state_shape=observation_space.shape or observation_space.n,
            action_shape=env.action_space.shape or env.action_space.n,
            hidden_sizes=[128, 128, 128, 128],
            device="cuda" if torch.cuda.is_available() else "cpu",
        ).to("cuda" if torch.cuda.is_available() else "cpu")
        if optim is None:
            optim = torch.optim.Adam(net.parameters(), lr=1e-4)
        agent_learn = DQNPolicy(
            model=net,
            optim=optim,
            discount_factor=0.9,
            estimation_step=3,
            target_update_freq=320,
        )

    if agent_opponent is None:
        agent_opponent = RandomPolicy()

    agents = [agent_opponent, agent_learn]
    policy = MultiAgentPolicyManager(agents, env)
    return policy, optim, env.agents


def _get_env():
    """This function is needed to provide callables for DummyVectorEnv."""
    return PettingZooEnv(tictactoe_v3.env())


if __name__ == "__main__":
    # ======== Step 1: Environment setup =========
    train_envs = DummyVectorEnv([_get_env for _ in range(10)])
    test_envs = DummyVectorEnv([_get_env for _ in range(10)])

    # seed
    seed = 1
    np.random.seed(seed)
    torch.manual_seed(seed)
    train_envs.seed(seed)
    test_envs.seed(seed)

    # ======== Step 2: Agent setup =========
    policy, optim, agents = _get_agents()

    # ======== Step 3: Collector setup =========
    train_collector = Collector(
        policy,
        train_envs,
        VectorReplayBuffer(20_000, len(train_envs)),
        exploration_noise=True,
    )
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    # policy.set_eps(1)
    train_collector.collect(n_step=64 * 10)  # batch size * training_num

    # ======== Step 4: Callback functions setup =========
    def save_best_fn(policy):
        model_save_path = os.path.join("log", "rps", "dqn", "policy.pth")
        os.makedirs(os.path.join("log", "rps", "dqn"), exist_ok=True)
        torch.save(policy.policies[agents[1]].state_dict(), model_save_path)

    def stop_fn(mean_rewards):
        return mean_rewards >= 0.6

    def train_fn(epoch, env_step):
        policy.policies[agents[1]].set_eps(0.1)

    def test_fn(epoch, env_step):
        policy.policies[agents[1]].set_eps(0.05)

    def reward_metric(rews):
        return rews[:, 1]

    # ======== Step 5: Run the trainer =========
    result = offpolicy_trainer(
        policy=policy,
        train_collector=train_collector,
        test_collector=test_collector,
        max_epoch=50,
        step_per_epoch=1000,
        step_per_collect=50,
        episode_per_test=10,
        batch_size=64,
        train_fn=train_fn,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_best_fn=save_best_fn,
        update_per_step=0.1,
        test_in_train=False,
        reward_metric=reward_metric,
    )

    # return result, policy.policies[agents[1]]
    print(f"\n==========Result==========\n{result}")
    print("\n(the trained policy can be accessed via policy.policies[agents[1]])")

Epoch #1: 1001it [00:02, 414.50it/s, env_step=1000, len=7, n/ep=5, n/st=50, player_2/loss=0.272, rew=-0.20]                          


Epoch #1: test_reward: 0.400000 ± 0.916515, best_reward: 0.400000 ± 0.916515 in #1


Epoch #2: 1001it [00:01, 692.10it/s, env_step=2000, len=7, n/ep=8, n/st=50, player_2/loss=0.266, rew=-0.38]                          


Epoch #2: test_reward: 0.500000 ± 0.806226, best_reward: 0.500000 ± 0.806226 in #2


Epoch #3: 1001it [00:01, 722.63it/s, env_step=3000, len=6, n/ep=6, n/st=50, player_2/loss=0.260, rew=0.33]                           

Epoch #3: test_reward: -0.400000 ± 0.916515, best_reward: 0.500000 ± 0.806226 in #2



Epoch #4: 1001it [00:01, 717.43it/s, env_step=4000, len=6, n/ep=7, n/st=50, player_2/loss=0.256, rew=0.43]                           


Epoch #4: test_reward: 0.300000 ± 0.900000, best_reward: 0.500000 ± 0.806226 in #2


Epoch #5: 1001it [00:01, 691.82it/s, env_step=5000, len=7, n/ep=6, n/st=50, player_2/loss=0.251, rew=-0.67]                          

Epoch #5: test_reward: 0.200000 ± 0.979796, best_reward: 0.500000 ± 0.806226 in #2



Epoch #6: 1001it [00:01, 661.43it/s, env_step=6000, len=7, n/ep=7, n/st=50, player_2/loss=0.221, rew=-1.00]                          


Epoch #6: test_reward: -0.200000 ± 0.979796, best_reward: 0.500000 ± 0.806226 in #2


Epoch #7: 1001it [00:01, 700.13it/s, env_step=7000, len=6, n/ep=7, n/st=50, player_2/loss=0.246, rew=-0.14]                          


Epoch #7: test_reward: 0.400000 ± 0.916515, best_reward: 0.500000 ± 0.806226 in #2


Epoch #8: 1001it [00:01, 682.13it/s, env_step=8000, len=7, n/ep=6, n/st=50, player_2/loss=0.245, rew=-0.67]                          


Epoch #8: test_reward: 0.100000 ± 0.943398, best_reward: 0.500000 ± 0.806226 in #2


Epoch #9: 1001it [00:01, 736.34it/s, env_step=9000, len=7, n/ep=5, n/st=50, player_2/loss=0.244, rew=-0.20]                          


Epoch #9: test_reward: 0.400000 ± 0.916515, best_reward: 0.500000 ± 0.806226 in #2


Epoch #10: 1001it [00:01, 758.81it/s, env_step=10000, len=7, n/ep=10, n/st=50, player_2/loss=0.241, rew=0.60]                          


Epoch #10: test_reward: 0.200000 ± 0.979796, best_reward: 0.500000 ± 0.806226 in #2


Epoch #11: 1001it [00:01, 712.07it/s, env_step=11000, len=6, n/ep=6, n/st=50, player_2/loss=0.254, rew=0.33]                           


Epoch #11: test_reward: 0.200000 ± 0.979796, best_reward: 0.500000 ± 0.806226 in #2


Epoch #12: 1001it [00:01, 740.25it/s, env_step=12000, len=6, n/ep=9, n/st=50, player_2/loss=0.238, rew=0.78]                          


Epoch #12: test_reward: -0.100000 ± 0.943398, best_reward: 0.500000 ± 0.806226 in #2


Epoch #13: 1001it [00:01, 746.50it/s, env_step=13000, len=7, n/ep=7, n/st=50, player_2/loss=0.243, rew=0.14]                           


Epoch #13: test_reward: 0.800000 ± 0.600000, best_reward: 0.800000 ± 0.600000 in #13

{'duration': '19.76s', 'train_time/model': '13.96s', 'test_step': 995, 'test_episode': 140, 'test_time': '0.37s', 'test_speed': '2660.05 step/s', 'best_reward': 0.8, 'best_result': '0.80 ± 0.60', 'train_step': 13000, 'train_episode': 1835, 'train_time/collector': '5.42s', 'train_speed': '670.65 step/s'}

(the trained policy can be accessed via policy.policies[agents[1]])
