# Download SimpleDistributedRL

In [None]:
!git clone https://github.com/pocokhc/simple_distributed_rl.git
%mv simple_distributed_rl/srl srl


In [None]:
!python -V

In [None]:
import srl
srl.__version__

# create model.py

In [None]:
%%writefile model.py

import json
import os
import time
from typing import cast

import numpy as np

import srl.envs
import srl.rl
from srl.base.define import EnvAction
from srl.base.env.base import EnvRun
from srl.base.rl.base import ExtendWorker, WorkerRun
from srl.envs import connectx
from srl.rl.functions.model import ImageLayerType
from srl.runner import sequence


class MyConnectXWorker(ExtendWorker):
    def __init__(self, *args):
        super().__init__(*args)

        # rlのconfig
        self.rl_config = cast(srl.rl.dqn.Config, self.rl_worker.worker.config)

        # 探索数
        self.max_depth = 4

    def call_on_reset(self, env: EnvRun, worker_run: WorkerRun) -> None:
        self.is_rl = False
        self.scores = [0] * env.action_space.n
        self.minmax_time = 0
        self.minmax_count = 0

    def call_policy(self, env: EnvRun, worker_run: WorkerRun) -> EnvAction:
        if env.step_num == 0:
            # 先行1ターン目
            self.rl_config.epsilon = 0.5
            action = self.rl_worker.policy(env)
            self.is_rl = True
            return action

        # 2ターン目以降
        self.rl_config.epsilon = 0.1

        # 元の環境を取得
        env_org = cast(connectx.ConnectX, env.get_original_env())

        # MinMaxを実施、環境は壊さないようにcopyで渡す
        self.minmax_count = 0
        t0 = time.time()
        self.scores = self._minmax(env_org.copy())
        self.minmax_time = time.time() - t0

        # 最大スコア
        max_score = np.max(self.scores)
        max_count = np.count_nonzero(self.scores == max_score)

        # 最大数が1個ならそのアクションを実施
        if max_count == 1:
            self.is_rl = False
            return int(np.argmax(self.scores))

        # 最大値以外のアクションを選択しないようにする(invalid_actionsに追加)
        new_invalid_actions = [a for a in range(env.action_space.n) if self.scores[a] != max_score]
        env.add_invalid_actions(new_invalid_actions, self.player_index)

        # rl実施
        action = self.rl_worker.policy(env)
        self.is_rl = True

        return action

    # MinMax
    def _minmax(self, env: connectx.ConnectX, depth: int = 0):
        if depth == self.max_depth:
            return [0] * env.action_space.n

        self.minmax_count += 1

        # 有効なアクションを取得
        invalid_actions = env.get_invalid_actions()
        valid_actions = [a for a in range(env.action_space.n) if a not in invalid_actions]

        # env復元用に今の状態を保存
        env_dat = env.backup()

        if env.player_index == self.player_index:
            # 自分の番
            scores = [-9.0 for _ in range(env.action_space.n)]
            for a in valid_actions:
                # envを復元
                env.restore(env_dat)

                # env stepを実施
                _, r1, r2, done, _ = env.call_step(a)
                if done:
                    # 終了状態なら報酬をスコアにする
                    if self.player_index == 0:
                        scores[a] = r1
                    else:
                        scores[a] = r2
                else:
                    # 次のstepに
                    n_scores = self._minmax(env, depth + 1)
                    scores[a] = np.min(n_scores)  # 相手の番は最小を選択

        else:
            # 相手の番
            scores = [9.0 for _ in range(env.action_space.n)]
            for a in valid_actions:
                env.restore(env_dat)

                _, r1, r2, done, _ = env.call_step(a)
                if done:
                    if self.player_index == 0:
                        scores[a] = r1
                    else:
                        scores[a] = r2
                else:
                    n_scores = self._minmax(env, depth + 1)
                    scores[a] = np.max(n_scores)  # 自分の番は最大を選択

        return scores

    # 可視化用
    def call_render(self, env: EnvRun, worker_run: WorkerRun) -> None:
        print(f"- MinMax count: {self.minmax_count}, {self.minmax_time:.3f}s -")
        print("+---+---+---+---+---+---+---+")
        s = "|"
        for a in range(env.action_space.n):
            s += "{:2d} |".format(int(self.scores[a]))
        print(s)
        print("+---+---+---+---+---+---+---+")
        if self.is_rl:
            self.rl_worker.render(env)


def create_config():
    env_config = srl.envs.Config("ConnectX")

    rl_config = srl.rl.dqn.Config()
    rl_config.processors = [connectx.LayerProcessor()]
    rl_config.extend_worker = MyConnectXWorker

    config = sequence.Config(env_config, rl_config)
    return config


# Train

In [None]:
from model import create_config
from srl.runner import sequence

config = create_config()

# --- set players
config.players = [None, None]  # self play
#config.players = [None, "alphabeta8"]

# model summary
config.model_summary()

# --- train
parameter, memory, history = sequence.train(
    config,
    timeout=60 * 60 * 10,
    enable_file_logger=True,
    enable_validation=False,
)

# save parameter
parameter.save("parameter.dat")


In [None]:
history.plot_info("train", "loss")

In [None]:
import numpy as np

# --- evaluate
for players in [
    [None, None],
    [None, "random"],
    ["random", None],
    [None, "alphabeta7"],
    ["alphabeta7", None],
]:
    config.players = players
    rewards = sequence.evaluate(config, parameter, max_episodes=5)
    print(f"{np.mean(rewards, axis=0)}, {players}")


In [None]:
# vs human
#config.players = [None, "human"]
#sequence.render(config, parameter=parameter)


# Create main.py

In [None]:
%%writefile main.py

import os
from typing import cast

from model import create_config
from srl.envs import connectx

KAGGLE_PATH = "/kaggle_simulations/agent/"
if os.path.isdir(KAGGLE_PATH):
    is_local = False
    path = os.path.join(KAGGLE_PATH, "parameter.dat")
else:
    is_local = True
    path = os.path.join(os.path.dirname(__file__), "parameter.dat")

    
# -------- config   
config = create_config()
config.set_parameter_path(parameter_path=path)


# ------- agent
env = config.make_env()
org_env = cast(connectx.ConnectX, env.get_original_env())
parameter = config.make_parameter()
worker = config.make_worker(parameter)


def my_agent(observation, configuration):
    step = observation.step

    # 1エピソードの最初にresetを呼ぶ必要がある
    # connectx は先行なら step==0、後攻なら step==1 がエピソードの最初
    if step == 0 or step == 1:
        env.direct_reset(observation, configuration)
        worker.on_reset(env, org_env.player_index)

    env.direct_step(observation, configuration)
    action = worker.policy(env)
    return action


if is_local:
    import time
    import numpy as np
    import kaggle_environments
    
    config.model_summary()

    kaggle_env = kaggle_environments.make("connectx", debug=True)
    for players in [
        [my_agent, "random"],
        ["random", my_agent],
        [my_agent, "negamax"],
        ["negamax", my_agent],
    ]:
        # 10episode実行
        rewards = []
        t0 = time.time()
        for _ in range(10):
            steps = kaggle_env.run(players)
            rewards.append([steps[-1][0]["reward"], steps[-1][1]["reward"]])
        
        # 結果
        rewards = np.mean(rewards, axis=0)
        print(f"rewards {rewards}, {time.time() - t0:.3f}s, {players}")


In [None]:
!python main.py

# create submission.tar.gz

In [None]:
!find . | grep -E "(__pycache__|\.pyc|\.pyo$)" | xargs rm -rf
!tar -czvf submission.tar.gz main.py srl model.py parameter.dat

In [None]:
!ls