# Tianshou Tutorial
## Tianshou: Basic API Usage
demonstrates a game between two random policy agents in the rock-paper-scissors environment

In [1]:
from tianshou.data import Collector
from tianshou.env import DummyVectorEnv, PettingZooEnv
from tianshou.policy import MultiAgentPolicyManager, RandomPolicy

from pettingzoo.classic import rps_v2

if __name__=="__main__":
    # step 1: load the PettingZoo environment
    env = rps_v2.env(render_mode="human")

    # step 2: wrap the environment for Tianshou interfacing
    env = PettingZooEnv(env)

    # step 3: define policies for each agent
    policies = MultiAgentPolicyManager([RandomPolicy(), RandomPolicy()], env)

    # step 4: convert the env to vector format
    env = DummyVectorEnv([lambda: env])

    # step 5: construct the Collector, which interfaces the policies with the vectorised environment
    collector = Collector(policies, env)

    # step 6: execute the environment with the agents playing for 1 episode, and render a frame every 0.1 seconds
    result = collector.collect(n_episode=1, render=0.1)

2023-10-25 13:35:05.941191: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


: 

## Tianshou: Training Agents
use Tianshou to train a Deep Q-Network (DQN) agent to play vs a random policy agent in the Tic-Tac-Toe environment

In [5]:
import os
from typing import Optional, Tuple

import gymnasium
import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.common import Net

from pettingzoo.classic import tictactoe_v3


def _get_agents(
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    env = _get_env()
    observation_space = (
        env.observation_space["observation"]
        if isinstance(env.observation_space, gymnasium.spaces.Dict)
        else env.observation_space
    )
    if agent_learn is None:
        # model
        net = Net(
            state_shape=observation_space.shape or observation_space.n,
            action_shape=env.action_space.shape or env.action_space.n,
            hidden_sizes=[128, 128, 128, 128],
            device="cuda" if torch.cuda.is_available() else "cpu"
        ).to("cuda" if torch.cuda.is_available() else "cpu")
        if optim is None:
            optim = torch.optim.Adam(net.parameters(), lr=1e-4)
        agent_learn = DQNPolicy(
            model=net,
            optim=optim,
            discount_factor=0.9,
            estimation_step=3,
            target_update_freq=320
        )
    if agent_opponent is None:
        agent_opponent = RandomPolicy()
    
    agents = [agent_opponent, agent_learn]
    policy = MultiAgentPolicyManager(agents, env)
    return policy, optim, env.agents

def _get_env():
    """This functions is needed to provide callables for DummyVectorEnv"""
    return PettingZooEnv(tictactoe_v3.env())

if __name__=="__main__":
    # step 1: environment setup
    train_envs = DummyVectorEnv([_get_env for _ in range(10)])
    test_envs = DummyVectorEnv([_get_env for _ in range(10)])

    # seed
    seed = 1
    np.random.seed(seed)
    torch.manual_seed(seed)
    train_envs.seed(seed)
    test_envs.seed(seed)

    # step 2: agent setup
    policy, optim, agents = _get_agents()

    # step 3: collector setup
    train_collector = Collector(
        policy,
        train_envs,
        VectorReplayBuffer(20_000, len(train_envs)),
        exploration_noise=True
    )
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    # policy.set_eps(1)
    train_collector.collect(n_step=64 * 10)  # batch size * training_num

# step 4: callback functions setup
def save_best_fn(policy):
    model_save_path = os.path.join("log", "ttt", "dqn", "policy.pth")
    os.makedirs(os.path.join("log", "ttt", "dqn"), exist_ok=True)
    torch.save(policy.policies[agents[1]].state_dict(), model_save_path)

def stop_fn(mean_rewards):
    return mean_rewards >= 0.6

def train_fn(epoch, env_step):
    policy.policies[agents[1]].set_eps(0.1)

def test_fn(epoch, env_step):
    policy.policies[agents[1]].set_eps(0.05)

def reward_metric(rews):
    return rews[:, 1]

# step 5: run the trainer
result = offpolicy_trainer(
    policy=policy,
    train_collector=train_collector,
    test_collector=test_collector,
    max_epoch=50,
    step_per_epoch=1000,
    step_per_collect=50,
    episode_per_test=10,
    batch_size=64,
    train_fn=train_fn,
    test_fn=test_fn,
    stop_fn=stop_fn,
    save_best_fn=save_best_fn,
    update_per_step=0.1,
    test_in_train=False,
    reward_metric=reward_metric
)

# return result, policy.policies[agents[1]]
print(f"\n==========Result==========\n{result}")
print("\n(the trained policy can be accessed via policy.policies[agents[1]])")

Epoch #1: 1001it [00:01, 863.78it/s, env_step=1000, len=7, n/ep=5, n/st=50, player_2/loss=0.272, rew=-0.20]                          

Epoch #1: test_reward: 0.400000 ± 0.916515, best_reward: 0.400000 ± 0.916515 in #1



Epoch #2: 1001it [00:01, 989.95it/s, env_step=2000, len=7, n/ep=8, n/st=50, player_2/loss=0.266, rew=-0.38]                           

Epoch #2: test_reward: 0.500000 ± 0.806226, best_reward: 0.500000 ± 0.806226 in #2



Epoch #3: 1001it [00:01, 992.77it/s, env_step=3000, len=6, n/ep=6, n/st=50, player_2/loss=0.260, rew=0.33]                           


Epoch #3: test_reward: -0.400000 ± 0.916515, best_reward: 0.500000 ± 0.806226 in #2


Epoch #4: 1001it [00:01, 988.36it/s, env_step=4000, len=6, n/ep=7, n/st=50, player_2/loss=0.256, rew=0.43]                            


Epoch #4: test_reward: 0.300000 ± 0.900000, best_reward: 0.500000 ± 0.806226 in #2


Epoch #5: 1001it [00:01, 867.87it/s, env_step=5000, len=7, n/ep=6, n/st=50, player_2/loss=0.251, rew=-0.67]                          

Epoch #5: test_reward: 0.200000 ± 0.979796, best_reward: 0.500000 ± 0.806226 in #2



Epoch #6: 1001it [00:01, 981.98it/s, env_step=6000, len=6, n/ep=4, n/st=50, player_2/loss=0.222, rew=-0.50]                          

Epoch #6: test_reward: -0.200000 ± 0.979796, best_reward: 0.500000 ± 0.806226 in #2



Epoch #7: 1001it [00:01, 991.46it/s, env_step=7000, len=6, n/ep=9, n/st=50, player_2/loss=0.241, rew=0.22]                           

Epoch #7: test_reward: 0.000000 ± 1.000000, best_reward: 0.500000 ± 0.806226 in #2



Epoch #8: 1001it [00:01, 985.12it/s, env_step=8000, len=7, n/ep=8, n/st=50, player_2/loss=0.254, rew=0.00]                           


Epoch #8: test_reward: 0.000000 ± 1.000000, best_reward: 0.500000 ± 0.806226 in #2


Epoch #9: 1001it [00:01, 659.99it/s, env_step=9000, len=7, n/ep=6, n/st=50, player_2/loss=0.247, rew=0.33]                           


Epoch #9: test_reward: 0.100000 ± 0.943398, best_reward: 0.500000 ± 0.806226 in #2


Epoch #10: 1001it [00:01, 876.82it/s, env_step=10000, len=6, n/ep=7, n/st=50, player_2/loss=0.232, rew=-0.14]                          


Epoch #10: test_reward: -0.200000 ± 0.979796, best_reward: 0.500000 ± 0.806226 in #2


Epoch #11: 1001it [00:01, 711.62it/s, env_step=11000, len=6, n/ep=9, n/st=50, player_2/loss=0.252, rew=0.11]                          


Epoch #11: test_reward: 0.400000 ± 0.916515, best_reward: 0.500000 ± 0.806226 in #2


Epoch #12: 1001it [00:01, 780.30it/s, env_step=12000, len=6, n/ep=7, n/st=50, player_2/loss=0.246, rew=0.43]                          


Epoch #12: test_reward: 0.300000 ± 0.900000, best_reward: 0.500000 ± 0.806226 in #2


Epoch #13: 1001it [00:01, 935.91it/s, env_step=13000, len=7, n/ep=4, n/st=50, player_2/loss=0.235, rew=-0.25]                          

Epoch #13: test_reward: 0.400000 ± 0.916515, best_reward: 0.500000 ± 0.806226 in #2



Epoch #14: 1001it [00:01, 888.91it/s, env_step=14000, len=7, n/ep=4, n/st=50, player_2/loss=0.252, rew=0.50]                          


Epoch #14: test_reward: -0.100000 ± 0.943398, best_reward: 0.500000 ± 0.806226 in #2


Epoch #15: 1001it [00:01, 816.05it/s, env_step=15000, len=6, n/ep=7, n/st=50, player_2/loss=0.239, rew=0.43]                           


Epoch #15: test_reward: 0.400000 ± 0.916515, best_reward: 0.500000 ± 0.806226 in #2


Epoch #16: 1001it [00:01, 823.32it/s, env_step=16000, len=6, n/ep=8, n/st=50, player_2/loss=0.240, rew=0.50]                          


Epoch #16: test_reward: 0.600000 ± 0.800000, best_reward: 0.600000 ± 0.800000 in #16

{'duration': '19.09s', 'train_time/model': '11.04s', 'test_step': 1207, 'test_episode': 170, 'test_time': '0.55s', 'test_speed': '2198.96 step/s', 'best_reward': 0.6, 'best_result': '0.60 ± 0.80', 'train_step': 16000, 'train_episode': 2259, 'train_time/collector': '7.50s', 'train_speed': '863.03 step/s'}

(the trained policy can be accessed via policy.policies[agents[1]])


## Tianshou: CLI and Logging
extends the code from Training Agents to add CLI (using argparse) and logging (using Tianshou's Logger)

In [None]:
import argparse
import os
from copy import deepcopy
from typing import Optional, Tuple

import gymnasium
import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from torch.utils.tensorboard import SummaryWriter

from pettingzoo.classic import tictactoe_v3


def get_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=1626)
    parser.add_argument("--eps-test", type=float, default=0.05)