<a href="https://colab.research.google.com/github/zhe0/present/blob/main/gomuku_RL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -q stable_baselines3
%pip install -q shimmy>=0.2.1
import gym
from gym import spaces
import numpy as np
from stable_baselines3 import PPO

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m182.1/182.1 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.9/953.9 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
class GomokuEnv(gym.Env):
    def __init__(self, size=15, win=5):
        super(GomokuEnv, self).__init__()
        # 定義狀態和動作空間
        self.size = size
        self.win = win
        self.board = np.zeros((self.size, self.size), dtype=int)
        self.action_space = spaces.Discrete(self.size * self.size)
        self.observation_space = spaces.Box(low=0, high=2, shape=(self.size, self.size), dtype=int)
        self.current_player = None
    def human_move(self, x, y):
        if self.is_valid_action((x, y)):
            action = x * self.size + y
            return self.step(action)
        else:
            raise ValueError("Invalid move! Please try again.")

    def model_move(self, model):
        valid_actions = self.get_valid_actions()
        best_action = None
        best_value = -float('inf')

        # Assuming the model has a predict method that takes the board state and returns a value for each action
        for action in valid_actions:
            x, y = action
            action_index = x * self.size + y
            value = model.predict(self.board.reshape(1, -1), action_index)
            if value > best_value:
                best_value = value
                best_action = action
        return best_action
    def step(self, action):
        # 執行動作並返回新的狀態、獎勵和遊戲是否結束
        x, y = action // self.size, action % self.size
        self.board[x][y] = self.current_player  # 玩家的棋子
        done = self.check_done((x, y))
        self.current_player = 3 - self.current_player
        # if not done:
        #     self.play2()  # 對手的棋子
        #     done = self.check_done()
        return self.board, -1 if done else 0, done, {}
        #return new_state, reward, done, info

    def get_valid_actions(self):
        # 返回所有有效的動作
        return [(i, j) for i in range(self.size) for j in range(self.size) if self.board[i][j] == 0]

    def is_valid_action(self, action):
        # 檢查一個動作是否有效
        x, y = action
        return 0 <= x < self.size and 0 <= y < self.size and self.board[x][y] == 0
    def get_state(self):
        # 返回當前的狀態
        return self.board
    def get_current_player(self):
        # 返回當前的玩家
        return self.current_player
    def reset(self):
        # 重置遊戲環境並返回初始狀態
        self.board = np.zeros((self.size, self.size), dtype=int)
        self.current_player = 1
        return self.get_state()

    def render(self):
        # 顯示遊戲環境的當前狀態
        print(self.board)

    def play2(self):
        while True:
            x, y = np.random.randint(0, self.size, 2)
            if self.board[x][y] == 0:
                self.board[x][y] = 2
                break

    def check_done(self, last_move):
        directions = [(1, 0), (0, 1), (1, 1), (1, -1)]
        for dx, dy in directions:
            count = 1
            for _ in range(1, self.win):
                nx, ny = last_move[0] + dx * _, last_move[1] + dy * _
                if nx < 0 or nx >= self.size or ny < 0 or ny >= self.size or self.board[nx][ny] != self.current_player:
                    break
                count += 1
            for _ in range(1, self.win):
                nx, ny = last_move[0] - dx * _, last_move[1] - dy * _
                if nx < 0 or nx >= self.size or ny < 0 or ny >= self.size or self.board[nx][ny] != self.current_player:
                    break
                count += 1
            if count >= self.win:
                # print('done')
                return True
        return False

In [None]:
# 建立模型
model = PPO("MlpPolicy", GomokuEnv(), verbose=0)

# 訓練模型
model.learn(total_timesteps=100_000_000)
env = GomokuEnv(15)



In [None]:
# 測試模型
obs = env.reset()
for i in range(1000):
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

In [None]:

# Assuming you have the GomokuEnv and model as defined previously
obs = env.reset()
done = False
current_player = env.get_current_player()

while not done:
    if current_player == 1:  # Assuming 1 is the model's turn
        action, _states = model.predict(obs)
        obs, rewards, done, info = env.step(action)
        print("Model made a move")
        env.render()  # Show the board after the model's move
    else:  # Human player's turn
        valid = False
        while not valid:
            try:
                # Replace this with your method of getting human input, e.g., from a UI
                x = int(input("Enter your move's X coordinate: "))
                y = int(input("Enter your move's Y coordinate: "))
                obs, rewards, done, info = env.human_move(x, y)
                valid = True
                print("You made a move")
                env.render()  # Show the board after the human's move
            except ValueError as e:
                print(e)

    if done:
        if rewards == -1:
            print("Game Over. The winner is the model.")
        else:
            print("Game Over. It's a draw or the human wins.")
        break
    current_player = env.get_current_player()

In [None]:
%pip install arxiv

Collecting arxiv
  Downloading arxiv-2.1.0-py3-none-any.whl (11 kB)
Collecting feedparser==6.0.10 (from arxiv)
  Downloading feedparser-6.0.10-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.1/81.1 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
Collecting sgmllib3k (from feedparser==6.0.10->arxiv)
  Downloading sgmllib3k-1.0.0.tar.gz (5.8 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: sgmllib3k
  Building wheel for sgmllib3k (setup.py) ... [?25l[?25hdone
  Created wheel for sgmllib3k: filename=sgmllib3k-1.0.0-py3-none-any.whl size=6049 sha256=270e419a53167dae92b2d96b7aada3031e38d654e097f94a4ece6590c63a9cb8
  Stored in directory: /root/.cache/pip/wheels/f0/69/93/a47e9d621be168e9e33c7ce60524393c0b92ae83cf6c6e89c5
Successfully built sgmllib3k
Installing collected packages: sgmllib3k, feedparser, arxiv
Successfully installed arxiv-2.1.0 feedparser-6.0.10 sgmllib3k-1.0.0


In [None]:
import arxiv

# Construct the default API client.
client = arxiv.Client()

# Search for the 10 most recent articles matching the keyword "quantum."
search = arxiv.Search(
  query = "AI",
  max_results = 10,
#   sort_by = arxiv.SortCriterion.SubmittedDate
)

results = client.results(search)


In [None]:
# for r in client.results(search):
#     break
r.title,'\n',r.summary,'\n',r.ab

('Supporting AI/ML Security Workers through an Adversarial Techniques, Tools, and Common Knowledge (AI/ML ATT&CK) Framework',
 '\n',
 'This paper focuses on supporting AI/ML Security Workers -- professionals\ninvolved in the development and deployment of secure AI-enabled software\nsystems. It presents AI/ML Adversarial Techniques, Tools, and Common Knowledge\n(AI/ML ATT&CK) framework to enable AI/ML Security Workers intuitively to\nexplore offensive and defensive tactics.')

In [None]:

# `results` is a generator; you can iterate over its elements one by one...
for r in client.results(search):
  print(r.title)
# ...or exhaust it into a list. Careful: this is slow for large results sets.
all_results = list(results)
print([r.title for r in all_results])

# For advanced query syntax documentation, see the arXiv API User Manual:
# https://arxiv.org/help/api/user-manual#query_details
search = arxiv.Search(query = "au:del_maestro AND ti:checkerboard")
first_result = next(client.results(search))
print(first_result)

# Search for the paper with ID "1605.08386v1"
search_by_id = arxiv.Search(id_list=["1605.08386v1"])
# Reuse client to fetch the paper, then print its title.
first_result = next(client.results(search))
print(first_result.title)