In [None]:
import sys
import os

dir_path = os.path.abspath(os.path.join('rlcard/'))
if dir_path not in sys.path:
    sys.path.append(dir_path)

In [None]:
# importされたmoduleの確認
import rlcard
print(rlcard.__file__)

## Training DQN agent

In [None]:
from rlcard.agents.dqn_agent import DQNAgent
from rlcard.agents.random_agent import  RandomAgent

In [None]:
env = rlcard.make(
    "uno",
    config={
        'allow_step_back': True,
        'seed': 1234
    }
)

In [None]:
# 環境の確認
print("Number of actions:", env.num_actions)
print("Number of players:", env.num_players)
print("Shape of state:", env.state_shape)
print("Shape of action:", env.action_shape)

In [None]:
agent_dqn = DQNAgent(
    num_actions=env.num_actions,
    state_shape=env.state_shape[0],
    mlp_layers=[64, 64, 64, 64, 64, 64],
    save_path="experiments/dqn_result/",
    replay_memory_size=100000,
    replay_memory_init_size=10000,
    update_target_estimator_every=2000,
    save_every=10000,
)

In [None]:
env.set_agents([agent_dqn for _ in range(env.num_players)])

In [None]:
eval_env = rlcard.make(
    "uno",
    config={
        'seed': 1234,
    }
)
eval_env.set_agents([
    agent_dqn,
    RandomAgent(num_actions=env.num_actions),
])

In [None]:
from rlcard.utils import (
    set_seed,
    tournament,
    reorganize,
    Logger,
    plot_curve
)

In [None]:
set_seed(1234)

In [None]:
with Logger("experiments/dqn_result/") as logger:
    for episode in range(10000):
        tranjectories, payoffs = env.run(is_training=True)
        
        tranjectories = reorganize(tranjectories, payoffs)

        for ts in tranjectories[0]:
            agent_dqn.feed(ts)

        if episode % 100 == 0:
            logger.log_performance(
                env.timestep,
                tournament(
                    eval_env,
                    1000,
                )[0]
            )
    
    csv_path, fig_path = logger.csv_path, logger.fig_path

In [None]:
plot_curve(csv_path, fig_path, "DQN", "DQN vs Random")

## Training PPO agent

In [None]:
from rlcard.agents.random_agent import RandomAgent
from rlcard.agents.ppo_agent import PPOAgent

In [None]:
env = rlcard.make(
    "uno",
    config={
        'allow_step_back': True,
        'seed': 1234
    }
)

In [None]:
agent_ppo = PPOAgent(
    state_shape=(4,4,15),
    action_shape=(61,),
    num_actions=env.num_actions,
)

In [None]:
env.set_agents([agent_ppo for _ in range(env.num_players)])

In [None]:
eval_env = rlcard.make(
    "uno",
    config={
        'seed': 1234,
    }
)
eval_env.set_agents([
    agent_ppo,
    RandomAgent(num_actions=env.num_actions),
])

In [None]:
from rlcard.utils import (
    set_seed,
    tournament,
    reorganize,
    Logger,
    plot_curve
)

In [None]:
set_seed(1234)

In [None]:
with Logger("experiments/ppo_result/") as logger:
    for episode in range(10000):
        tranjectories, payoffs = env.run(is_training=True)
        
        tranjectories = reorganize(tranjectories, payoffs)

        for ts in tranjectories[0]:
            agent_ppo.feed(ts)

        if episode % 100 == 0:
            logger.log_performance(
                env.timestep,
                tournament(
                    eval_env,
                    1000,
                )[0]
            )
    
    csv_path, fig_path = logger.csv_path, logger.fig_path

In [None]:
plot_curve(csv_path, fig_path, "PPO", "PPO vs Random")