In [6]:
from stable_baselines3.common.env_checker import check_env
from gym_reversi import ReversiEnv

env = ReversiEnv(player_color='black', opponent = "random", board_size=8)
# It will check your custom environment and output additional warnings if needed
check_env(env)


AssertionError: The observation returned by the `reset()` method does not match the data type (cannot cast) of the given observation space Box(0, 1, (4, 8, 8), uint8). Expected: uint8, actual dtype: int32

In [14]:
import os
import time
import gymnasium as gym
import numpy as np
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack
from stable_baselines3 import PPO
from gym_reversi import ReversiEnv



class ReversiModelTrain(object):
    def __init__(self, board_size=8, check_point_timesteps=100000, n_envs=16, model_path=None,
                 opponent_model_path="random", tensorboard_log=None):
        self.board_size = board_size
        self.check_point_timesteps = check_point_timesteps
        self.n_envs = n_envs
        self.model_path = model_path
        self.opponent_model_path = opponent_model_path
        self.tensorboard_log = tensorboard_log

    def reversi_model_train_step(self, check_point_timesteps):
        if self.opponent_model_path != "random":
            opponent_model = PPO.load(self.opponent_model_path)
        else:
            opponent_model = "random"

        env = ReversiEnv(opponent=opponent_model, is_train=True, board_size=self.board_size,
                         is_finished_reward=True, verbose=0)

        vec_env = env
        if self.n_envs > 1:
            # multi-worker training (n_envs=4 => 4 environments)
            vec_env = make_vec_env(ReversiEnvWrapper, n_envs=self.n_envs, seed=None,
                                   env_kwargs={
                                       "opponent": opponent_model,
                                       "is_train": True,
                                       "board_size": self.board_size,
                                       "is_finished_reward": True,
                                       "verbose": 0},
                                )

            # vec_env = make_vectorized_env(ReversiEnvWrapper, dumm=False, n=8)

        try:
            model = PPO.load(self.model_path, env=vec_env)
        except Exception:
            print(f"load model from self.model_path: {self.model_path} error")
            model = PPO('MlpPolicy', vec_env,
                          policy_kwargs=dict(net_arch=[256, 256]),
                          learning_rate=2.5e-4,  # learning_rate=2.5e-4,
                          ent_coef=0.01,
                          n_steps=64, # n_steps=128,
                          n_epochs=4,
                          batch_size=32, # batch_size=256,
                          gamma=0.99,
                          gae_lambda=0.95,
                          clip_range=0.1,
                          vf_coef=0.5,
                          verbose=1,
                          tensorboard_log=self.tensorboard_log)

        t0 = time.time()
        # model.learn(int(2e4))
        model.learn(total_timesteps=check_point_timesteps)
        model.save(self.model_path)
        print(f"train time: {time.time()-t0}")

    def reversi_model_train(self, total_timesteps=1000000):
        n_check_point = int(np.ceil(total_timesteps/self.check_point_timesteps))
        for i in range(n_check_point):
            self.reversi_model_train_step(self.check_point_timesteps)

    def game_play(self, model_path, opponent_model_path="random", player_color='black', max_round=100):

        # opponent_model = "random"
        # opponent_model = PPO.load("models/Reversi_ppo/model4x4_50w")
        # opponent_model = PPO.load("models/Reversi_ppo/model")
        if self.opponent_model_path != "random":
            opponent_model = PPO.load(opponent_model_path)
        else:
            opponent_model = "random"

        env = ReversiEnv(opponent=opponent_model, is_train=False, board_size=self.board_size, player_color=player_color,
                         is_finished_reward=True, verbose=0)

        model = PPO.load(model_path)
        # model = PPO.load("models/Reversi_ppo/model4x4_50w")

        total_round = 0
        total_win = 0
        total_failure = 0
        total_equal = 0

        t0 = time.time()
        obs, info = env.reset()
        while total_round < max_round:
            action, _states = model.predict(obs, deterministic=False)
            obs, rewards, dones, truncated, info = env.step(action)

            #     print(f"---- round:{total_round} --------")
            #     print(f"action: {action}")
            #     env.render("human")

            if dones:
                print(f"---- round:{total_round} --------")
                #         env.render("human")
                obs, info = env.reset()
                total_round += 1
                if rewards > 0:
                    total_win += 1
                elif rewards < 0:
                    total_failure += 1
                else:
                    total_equal += 1

                print(f"total_win:{total_win}, total_failure: {total_failure}, total_equal:{total_equal}\n")

        # print(f"total_win:{total_win}, total_failure: {total_failure}")
        print(f"train time: {time.time() - t0}")


# if __name__ == '__main__':

#     board_size = 8
#     check_point_timesteps = 100000
#     n_envs = 8
#     tensorboard_log = f"models/Reversi_ppo_{board_size}x{board_size}/"
#     if not os.path.isdir(tensorboard_log):
#         os.makedirs(tensorboard_log)
#     model_path = os.path.join(tensorboard_log, "model")
#     opponent_model_path = "random"

#     train_obj = ReversiModelTrain(board_size=board_size,
#                                   check_point_timesteps=check_point_timesteps,
#                                   n_envs=n_envs,
#                                   model_path=model_path,
#                                   opponent_model_path=opponent_model_path,
#                                   tensorboard_log=tensorboard_log)

#     t0 = time.time()
#     total_timesteps = 1000000
#     train_obj.reversi_model_train(total_timesteps)
#     print(f"total train time: {time.time() - t0}")


In [15]:
# import sys
import os
# current_path = os.getcwd()
# sys.path.append(current_path)

from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3 import DQN, DDPG, A2C, PPO,SAC,TD3
import time
from gym_reversi import ReversiEnv

t0=time.time()

# There already exists an environment generator
# that will make and wrap atari environments correctly.
# Here we are also multi-worker training (n_envs=4 => 4 environments)
# vec_env = make_atari_env("PongNoFrameskip-v4", n_envs=4, seed=0)
# env = make_atari_env("BreakoutNoFrameskip-v4", seed=0)
# vec_env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=4, 
# #                          seed=0
#                         )
# # Frame-stacking with 4 frames
# vec_env = VecFrameStack(vec_env, n_stack=4)

board_size=4
total_timesteps=10_0000
PolicyModel = PPO
# PolicyModel = TD3
n_envs = 8

greedy_rate=0
verbose = 0

tensorboard_log = f"models/Reversi_ppo_{board_size}x{board_size}_debug/"
if not os.path.isdir(tensorboard_log):
    os.makedirs(tensorboard_log)
model_path = os.path.join(tensorboard_log, "model")
opponent_model_path="random"
# opponent_model_path=os.path.join(tensorboard_log, "opponent_model")


if opponent_model_path != "random":
    opponent_model = PolicyModel.load(opponent_model_path)
else:
    opponent_model = "random"

env = ReversiEnv(opponent=opponent_model, is_train=True, board_size=board_size,
                 greedy_rate=greedy_rate, verbose=verbose)

vec_env = env
if n_envs > 1:
    # multi-worker training (n_envs=4 => 4 environments)
    vec_env = make_vec_env(ReversiEnv, n_envs=n_envs, seed=None,
                           env_kwargs={
                               "opponent": opponent_model,
                               "is_train": True,
                               "board_size": board_size,
                               "greedy_rate": greedy_rate,
                               "verbose": verbose},
                        )

    # vec_env = make_vectorized_env(ReversiEnvWrapper, dumm=False, n=8)

try:
    model = PolicyModel.load(model_path, env=vec_env)
except Exception:
    print(f"load model from self.model_path: {model_path} error")
    model = PolicyModel('MlpPolicy', vec_env,
                  policy_kwargs=dict(net_arch=[64, 64]),
                  learning_rate=2.5e-4,  # learning_rate=2.5e-4,
                  ent_coef=0.01,
                  n_steps=64, # n_steps=128,
                  n_epochs=4,
                  batch_size=32, # batch_size=256,
                  gamma=0.99,
                  gae_lambda=0.95,
                  clip_range=0.1,
                  vf_coef=0.5,
                  verbose=1,
                  tensorboard_log=tensorboard_log)

t0 = time.time()
# model.learn(int(2e4))
model.learn(total_timesteps=total_timesteps)
model.save(model_path)
print(f"train time: {time.time()-t0}")

# vec_env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=1)
# vec_env = VecFrameStack(vec_env, n_stack=4)


Logging to models/Reversi_ppo_4x4_debug/PPO_2
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 4.18     |
|    ep_rew_mean     | -0.36    |
| time/              |          |
|    fps             | 501      |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 512      |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 4.11        |
|    ep_rew_mean          | -0.42       |
| time/                   |             |
|    fps                  | 366         |
|    iterations           | 2           |
|    time_elapsed         | 2           |
|    total_timesteps      | 1024        |
| train/                  |             |
|    approx_kl            | 0.004466043 |
|    clip_fraction        | 0.147       |
|    clip_range           | 0.1         |
|    entropy_loss         | -0.772      |
|    explained_variance   

KeyboardInterrupt: 

In [3]:
tensorboard --logdir ./reversi/models/Reversi_ppo/PPO_7/ --port=6016

8

In [20]:
import time
import torch
from stable_baselines3 import PPO
from gym_reversi import ReversiEnv

PolicyModel = PPO

opponent_model = "random"
# opponent_model = PolicyModel.load("models/Reversi_ppo/model")

env = ReversiEnv(opponent="random", board_size=8, player_color='black', is_train=False,
                 greedy_rate=0, verbose=0)

# model = PolicyModel.load("models/Reversi_ppo/model")
# model = PolicyModel.load('models/model_240w')
model = torch.load('models/model_240w.pth')

max_round = 1000

total_round = 0
total_win = 0
total_failure = 0
total_equal = 0


t0=time.time()
obs, info = env.reset()
while total_round < max_round:
    action, _states = model.predict(obs, deterministic=True)
    obs, rewards, dones, truncated, info = env.step(action)

#     print(f"---- round:{total_round} --------")
#     print(f"action: {action}")
#     env.render("human")

    if dones:
        print(f"\n---- round:{total_round} --------")
#         env.render("human")
        obs, info = env.reset()
        total_round += 1
        if rewards > 0:
            total_win+=1
        elif rewards < 0:
            total_failure += 1
        else:
            total_equal += 1

        print(f"total_win:{total_win}, total_failure: {total_failure}, total_equal:{total_equal}\n")

# print(f"total_win:{total_win}, total_failure: {total_failure}")
print(f"train time: {time.time()-t0}")



---- round:0 --------
total_win:0, total_failure: 1, total_equal:0


---- round:1 --------
total_win:1, total_failure: 1, total_equal:0


---- round:2 --------
total_win:1, total_failure: 2, total_equal:0


---- round:3 --------
total_win:2, total_failure: 2, total_equal:0


---- round:4 --------
total_win:3, total_failure: 2, total_equal:0


---- round:5 --------
total_win:4, total_failure: 2, total_equal:0


---- round:6 --------
total_win:5, total_failure: 2, total_equal:0


---- round:7 --------
total_win:6, total_failure: 2, total_equal:0


---- round:8 --------
total_win:6, total_failure: 3, total_equal:0


---- round:9 --------
total_win:7, total_failure: 3, total_equal:0


---- round:10 --------
total_win:7, total_failure: 4, total_equal:0


---- round:11 --------
total_win:8, total_failure: 4, total_equal:0


---- round:12 --------
total_win:9, total_failure: 4, total_equal:0


---- round:13 --------
total_win:9, total_failure: 5, total_equal:0


---- round:14 --------
total_


---- round:116 --------
total_win:60, total_failure: 51, total_equal:6


---- round:117 --------
total_win:60, total_failure: 52, total_equal:6


---- round:118 --------
total_win:60, total_failure: 53, total_equal:6


---- round:119 --------
total_win:61, total_failure: 53, total_equal:6


---- round:120 --------
total_win:62, total_failure: 53, total_equal:6


---- round:121 --------
total_win:63, total_failure: 53, total_equal:6


---- round:122 --------
total_win:63, total_failure: 54, total_equal:6


---- round:123 --------
total_win:63, total_failure: 55, total_equal:6


---- round:124 --------
total_win:64, total_failure: 55, total_equal:6


---- round:125 --------
total_win:65, total_failure: 55, total_equal:6


---- round:126 --------
total_win:65, total_failure: 56, total_equal:6


---- round:127 --------
total_win:66, total_failure: 56, total_equal:6


---- round:128 --------
total_win:67, total_failure: 56, total_equal:6


---- round:129 --------
total_win:68, total_failur


---- round:229 --------
total_win:133, total_failure: 89, total_equal:8


---- round:230 --------
total_win:133, total_failure: 90, total_equal:8


---- round:231 --------
total_win:134, total_failure: 90, total_equal:8


---- round:232 --------
total_win:134, total_failure: 91, total_equal:8


---- round:233 --------
total_win:134, total_failure: 92, total_equal:8


---- round:234 --------
total_win:134, total_failure: 93, total_equal:8


---- round:235 --------
total_win:134, total_failure: 94, total_equal:8


---- round:236 --------
total_win:135, total_failure: 94, total_equal:8


---- round:237 --------
total_win:136, total_failure: 94, total_equal:8


---- round:238 --------
total_win:137, total_failure: 94, total_equal:8


---- round:239 --------
total_win:138, total_failure: 94, total_equal:8


---- round:240 --------
total_win:139, total_failure: 94, total_equal:8


---- round:241 --------
total_win:139, total_failure: 95, total_equal:8


---- round:242 --------
total_win:139


---- round:340 --------
total_win:190, total_failure: 138, total_equal:13


---- round:341 --------
total_win:190, total_failure: 139, total_equal:13


---- round:342 --------
total_win:190, total_failure: 140, total_equal:13


---- round:343 --------
total_win:190, total_failure: 141, total_equal:13


---- round:344 --------
total_win:190, total_failure: 142, total_equal:13


---- round:345 --------
total_win:191, total_failure: 142, total_equal:13


---- round:346 --------
total_win:191, total_failure: 143, total_equal:13


---- round:347 --------
total_win:191, total_failure: 144, total_equal:13


---- round:348 --------
total_win:191, total_failure: 145, total_equal:13


---- round:349 --------
total_win:192, total_failure: 145, total_equal:13


---- round:350 --------
total_win:192, total_failure: 146, total_equal:13


---- round:351 --------
total_win:192, total_failure: 147, total_equal:13


---- round:352 --------
total_win:193, total_failure: 147, total_equal:13


---- round:

total_win:253, total_failure: 182, total_equal:14


---- round:449 --------
total_win:253, total_failure: 183, total_equal:14


---- round:450 --------
total_win:253, total_failure: 184, total_equal:14


---- round:451 --------
total_win:253, total_failure: 185, total_equal:14


---- round:452 --------
total_win:254, total_failure: 185, total_equal:14


---- round:453 --------
total_win:254, total_failure: 186, total_equal:14


---- round:454 --------
total_win:254, total_failure: 187, total_equal:14


---- round:455 --------
total_win:255, total_failure: 187, total_equal:14


---- round:456 --------
total_win:256, total_failure: 187, total_equal:14


---- round:457 --------
total_win:257, total_failure: 187, total_equal:14


---- round:458 --------
total_win:258, total_failure: 187, total_equal:14


---- round:459 --------
total_win:259, total_failure: 187, total_equal:14


---- round:460 --------
total_win:259, total_failure: 188, total_equal:14


---- round:461 --------
total_win:25


---- round:557 --------
total_win:314, total_failure: 225, total_equal:19


---- round:558 --------
total_win:315, total_failure: 225, total_equal:19


---- round:559 --------
total_win:315, total_failure: 226, total_equal:19


---- round:560 --------
total_win:316, total_failure: 226, total_equal:19


---- round:561 --------
total_win:317, total_failure: 226, total_equal:19


---- round:562 --------
total_win:318, total_failure: 226, total_equal:19


---- round:563 --------
total_win:318, total_failure: 227, total_equal:19


---- round:564 --------
total_win:319, total_failure: 227, total_equal:19


---- round:565 --------
total_win:319, total_failure: 227, total_equal:20


---- round:566 --------
total_win:319, total_failure: 228, total_equal:20


---- round:567 --------
total_win:319, total_failure: 229, total_equal:20


---- round:568 --------
total_win:320, total_failure: 229, total_equal:20


---- round:569 --------
total_win:320, total_failure: 229, total_equal:21


---- round:


---- round:667 --------
total_win:367, total_failure: 277, total_equal:24


---- round:668 --------
total_win:367, total_failure: 278, total_equal:24


---- round:669 --------
total_win:368, total_failure: 278, total_equal:24


---- round:670 --------
total_win:368, total_failure: 279, total_equal:24


---- round:671 --------
total_win:369, total_failure: 279, total_equal:24


---- round:672 --------
total_win:370, total_failure: 279, total_equal:24


---- round:673 --------
total_win:370, total_failure: 280, total_equal:24


---- round:674 --------
total_win:370, total_failure: 281, total_equal:24


---- round:675 --------
total_win:370, total_failure: 282, total_equal:24


---- round:676 --------
total_win:371, total_failure: 282, total_equal:24


---- round:677 --------
total_win:371, total_failure: 283, total_equal:24


---- round:678 --------
total_win:371, total_failure: 284, total_equal:24


---- round:679 --------
total_win:372, total_failure: 284, total_equal:24


---- round:


---- round:776 --------
total_win:418, total_failure: 328, total_equal:31


---- round:777 --------
total_win:419, total_failure: 328, total_equal:31


---- round:778 --------
total_win:419, total_failure: 328, total_equal:32


---- round:779 --------
total_win:419, total_failure: 328, total_equal:33


---- round:780 --------
total_win:420, total_failure: 328, total_equal:33


---- round:781 --------
total_win:420, total_failure: 328, total_equal:34


---- round:782 --------
total_win:421, total_failure: 328, total_equal:34


---- round:783 --------
total_win:422, total_failure: 328, total_equal:34


---- round:784 --------
total_win:423, total_failure: 328, total_equal:34


---- round:785 --------
total_win:423, total_failure: 329, total_equal:34


---- round:786 --------
total_win:423, total_failure: 330, total_equal:34


---- round:787 --------
total_win:423, total_failure: 331, total_equal:34


---- round:788 --------
total_win:423, total_failure: 332, total_equal:34


---- round:


---- round:885 --------
total_win:473, total_failure: 370, total_equal:43


---- round:886 --------
total_win:473, total_failure: 371, total_equal:43


---- round:887 --------
total_win:474, total_failure: 371, total_equal:43


---- round:888 --------
total_win:475, total_failure: 371, total_equal:43


---- round:889 --------
total_win:475, total_failure: 372, total_equal:43


---- round:890 --------
total_win:475, total_failure: 373, total_equal:43


---- round:891 --------
total_win:475, total_failure: 374, total_equal:43


---- round:892 --------
total_win:476, total_failure: 374, total_equal:43


---- round:893 --------
total_win:476, total_failure: 375, total_equal:43


---- round:894 --------
total_win:477, total_failure: 375, total_equal:43


---- round:895 --------
total_win:478, total_failure: 375, total_equal:43


---- round:896 --------
total_win:478, total_failure: 376, total_equal:43


---- round:897 --------
total_win:478, total_failure: 377, total_equal:43


---- round:


---- round:994 --------
total_win:522, total_failure: 428, total_equal:45


---- round:995 --------
total_win:523, total_failure: 428, total_equal:45


---- round:996 --------
total_win:524, total_failure: 428, total_equal:45


---- round:997 --------
total_win:525, total_failure: 428, total_equal:45


---- round:998 --------
total_win:525, total_failure: 429, total_equal:45


---- round:999 --------
total_win:525, total_failure: 430, total_equal:45

train time: 96.84397673606873


In [4]:
!python reversi_model_train.py --board_size 8 --total_timesteps 100000 --cp_timesteps 20000 --n_envs 8 --opponent_model_path random --start_index 0



load model from self.model_path: models/Reversi_ppo_8x8/model error
Using cuda device
Logging to models/Reversi_ppo_8x8/PPO_1
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1.03     |
|    ep_rew_mean     | -1       |
| time/              |          |
|    fps             | 107      |
|    iterations      | 1        |
|    time_elapsed    | 4        |
|    total_timesteps | 512      |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 1.05         |
|    ep_rew_mean          | -1           |
| time/                   |              |
|    fps                  | 149          |
|    iterations           | 2            |
|    time_elapsed         | 6            |
|    total_timesteps      | 1024         |
| train/                  |              |
|    approx_kl            | 0.0073447656 |
|    clip_fraction        | 0.333        |
|    clip_range   

Traceback (most recent call last):
  File "reversi_model_train.py", line 225, in <module>
    run_train(sys.argv[1:])
  File "reversi_model_train.py", line 220, in run_train
    print(f"end time: {time2str(time.time())}")
NameError: name 'time2str' is not defined


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1.32        |
|    ep_rew_mean          | -1          |
| time/                   |             |
|    fps                  | 234         |
|    iterations           | 9           |
|    time_elapsed         | 19          |
|    total_timesteps      | 4608        |
| train/                  |             |
|    approx_kl            | 0.014122702 |
|    clip_fraction        | 0.53        |
|    clip_range           | 0.1         |
|    entropy_loss         | -3.8        |
|    explained_variance   | -27.1       |
|    learning_rate        | 0.00025     |
|    loss                 | -0.0721     |
|    n_updates            | 32          |
|    policy_gradient_loss | -0.027      |
|    value_loss           | 0.000412    |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1.39  

| train/                  |             |
|    approx_kl            | 0.016008945 |
|    clip_fraction        | 0.403       |
|    clip_range           | 0.1         |
|    entropy_loss         | -2.24       |
|    explained_variance   | -1.93       |
|    learning_rate        | 0.00025     |
|    loss                 | -0.0448     |
|    n_updates            | 144         |
|    policy_gradient_loss | -0.0205     |
|    value_loss           | 0.000332    |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 2.56        |
|    ep_rew_mean          | -1          |
| time/                   |             |
|    fps                  | 277         |
|    iterations           | 38          |
|    time_elapsed         | 70          |
|    total_timesteps      | 19456       |
| train/                  |             |
|    approx_kl            | 0.011021307 |
|    clip_fraction        | 0.329 

|    learning_rate        | 0.00025     |
|    loss                 | -0.0489     |
|    n_updates            | 468         |
|    policy_gradient_loss | -0.0131     |
|    value_loss           | 0.00075     |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 4.38        |
|    ep_rew_mean          | -1          |
| time/                   |             |
|    fps                  | 279         |
|    iterations           | 39          |
|    time_elapsed         | 71          |
|    total_timesteps      | 19968       |
| train/                  |             |
|    approx_kl            | 0.020007072 |
|    clip_fraction        | 0.277       |
|    clip_range           | 0.1         |
|    entropy_loss         | -1.01       |
|    explained_variance   | -0.572      |
|    learning_rate        | 0.00025     |
|    loss                 | -0.0403     |
|    n_updates            | 472   

|    explained_variance   | -0.278      |
|    learning_rate        | 0.00025     |
|    loss                 | -0.0276     |
|    n_updates            | 612         |
|    policy_gradient_loss | -0.00273    |
|    value_loss           | 0.000403    |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 5.01        |
|    ep_rew_mean          | -1          |
| time/                   |             |
|    fps                  | 282         |
|    iterations           | 35          |
|    time_elapsed         | 63          |
|    total_timesteps      | 17920       |
| train/                  |             |
|    approx_kl            | 0.032047465 |
|    clip_fraction        | 0.268       |
|    clip_range           | 0.1         |
|    entropy_loss         | -0.869      |
|    explained_variance   | 0.0643      |
|    learning_rate        | 0.00025     |
|    loss                 | -0.053

In [3]:
def get_possible_actions(board, player_color):
    actions = []
    d = board.shape[-1]
    opponent_color = 1 - player_color
    for pos_x in range(d):
        for pos_y in range(d):
            if board[0, pos_x, pos_y] or board[1, pos_x, pos_y]:
                continue
            for dx in [-1, 0, 1]:
                for dy in [-1, 0, 1]:
                    if dx == 0 and dy == 0:
                        continue
                    nx = pos_x + dx
                    ny = pos_y + dy
                    n = 0
                    if nx not in range(d) or ny not in range(d):
                        continue
                    while board[opponent_color, nx, ny] == 1:
                        tmp_nx = nx + dx
                        tmp_ny = ny + dy
                        if tmp_nx not in range(d) or tmp_ny not in range(d):
                            break
                        n += 1
                        nx += dx
                        ny += dy
                    if n > 0 and board[player_color, nx, ny] == 1:
                        action = pos_x * d + pos_y
                        if action not in actions:
                            actions.append(action)
    return actions

def set_possible_actions_place(board, possible_actions, channel_index=2):
    board[channel_index, :, :] = 0
    # possible_actions = ReversiEnv.get_possible_actions(board, player_color)
    possible_actions_coords = [ReversiEnv.action_to_coordinate(board, _action) for _action in possible_actions]
    for pos_x, pos_y in possible_actions_coords:
        board[channel_index, pos_x, pos_y] = 1
    return board

def get_test_observation(board_size=4, player_color=0):
    # init board setting
    N_CHANNELS = 4
    # channels： 0: 黑棋位置， 1: 白棋位置， 2: 当前可合法落子位置，3：player 颜色
    observation = np.zeros((N_CHANNELS, board_size, board_size), dtype=int)

    observation[3, :, :] = player_color

    centerL = int(board_size / 2 - 1)
    centerR = int(board_size / 2)
    # self.observation[2, :, :] = 1
    # self.observation[2, (centerL) : (centerR + 1), (centerL) : (centerR + 1)] = 0
    observation[0, centerR, centerL] = 1
    observation[0, centerL, centerR] = 1
    observation[1, centerL, centerL] = 1
    observation[1, centerR, centerR] = 1
    possible_actions = get_possible_actions(observation, player_color)

    # 设置主玩家合法位置
    set_possible_actions_place(observation, possible_actions)

    return observation


def action_to_coordinate(board, action):
    return action // board.shape[-1], action % board.shape[-1]


In [12]:
import torch

def sb3_model_to_pth_model(PolicyModel, model_path):
    ppo_model = PolicyModel.load(model_path)
    ## 保存pth模型
    torch.save(ppo_model.policy, model_path + '.pth')

In [13]:
PolicyModel = PPO
model_path = "models/model_240w"

sb3_model_to_pth_model(PolicyModel, model_path)

In [5]:
import torch

ppo_model = PPO.load("models/Reversi_ppo_4x4/model_100w")

## 保存模型
torch.save(ppo_model.policy, 'models/Reversi_ppo_4x4/model_100w.pth')

## 读取模型
pth_model = torch.load('models/Reversi_ppo_4x4/model_100w.pth')

observation = get_test_observation(board_size=4, player_color=0)
observation

action, _states = pth_model.predict(observation, deterministic=True)
action_to_coordinate(observation, action)

possible_actions = get_possible_actions(observation, player_color=0)
possible_actions


C=array([[5.43019184, 6.40250697, 6.1004084 , 5.13867582],
       [6.29431658, 5.80315648, 5.16903573, 4.87083518],
       [4.42687125, 4.45198253, 5.73728282, 6.59846583]])


0