In [1]:
from collections import namedtuple
from itertools import count
import time
import random
import math
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


In [17]:
class Game:
    def __init__(self, device):
        self.device = device
        self.reset()

    def reset(self):
        self.cells = torch.zeros((20, 20), device=self.device, dtype=torch.long)
        self.x_turn = random.choice([True, False])

    def get_state(self, player_is_x=True):
        cells = self.cells.flatten()
        player = 2 * player_is_x - 1
        return (torch.cat((cells == player, cells == -player, cells == 0)).float().unsqueeze(0).to(self.device))

    def perform_action(self, action):
        row = action // 20
        column = action % 20
        if self.cells[row][column] != 0: # клетка уже занята
            return -10, False
        self.cells[row][column] = 2 * self.x_turn - 1
        if (result := self.eval()) is not None:
            return [-1, -5][result], True
        self.x_turn = not self.x_turn
        return 0, False

    def create(self):
        for row in self.cells:
            print("|".join([[" ", "X", "O"][cell] for cell in row]))

    def eval(self):
        player = 2 * self.x_turn - 1

        def check_winner(cells, player):
            for i in range(20):
                for j in range(16):
                    if all(cells[i, j + k] == player for k in range(5)) or \
                       all(cells[j + k, i] == player for k in range(5)):
                        return player
            for i in range(16):
                for j in range(16):
                    if all(cells[i + k, j + k] == player for k in range(5)) or \
                       all(cells[i + k, j + 4 - k] == player for k in range(5)):
                        return player
            return None

        if (winner := check_winner(self.cells, player)) is not None:
            return winner
        if 0 not in self.cells:
            return 0
        return None



class DNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1200, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.output = nn.Linear(256, 400)


    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.output(x)
        return x


Experience = namedtuple("Experience", ("state", "action", "next_state", "reward"))


class Buffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.index = 0

    def push(self, *args):
        args = tuple(arg.to(DEVICE) for arg in args)
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.index] = Experience(*args)
        self.index = (self.index + 1) % self.capacity

    def sample(self, batch_size):
        if len(self.memory) < batch_size:
            return None
        return [Experience(*tuple(e.to(DEVICE) for e in experience))
                 for experience in random.sample(self.memory, batch_size)]


def select_action(state, policy_net):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1.0 * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    return torch.tensor([[random.randrange(400)]], device=DEVICE, dtype=torch.long)


def optimize_model(memory, policy_net, target_net, optimizer, losses):
    if experiences := memory.sample(BATCH_SIZE):
        batch = Experience(*zip(*experiences))

        state_batch = torch.cat(batch.state).to(DEVICE)
        action_batch = torch.cat(batch.action).to(DEVICE)
        next_states_batch = torch.cat(batch.next_state).to(DEVICE)
        reward_batch = torch.cat(batch.reward).to(DEVICE)
        illegal_mask = torch.tensor(
            tuple(map(lambda e: e.state[0][18 + e.action[0]].item() != 1, experiences)),
            device=DEVICE,
            dtype=torch.bool,)
        loss_mask = torch.tensor(
            tuple(map(lambda r: r.item() == -5, batch.reward)),
            device=DEVICE,
            dtype=torch.bool,)
        win_mask = torch.tensor(
            tuple(map(lambda r: r.item() == 10, batch.reward)),
            device=DEVICE,
            dtype=torch.bool,)
        draw_mask = torch.tensor(
            tuple(map(lambda r: r.item() == -1, batch.reward)),
            device=DEVICE,
            dtype=torch.bool,
)
        state_action_values = policy_net(state_batch).gather(1, action_batch)
        next_state_values = target_net(next_states_batch).max(1)[0].detach()
        expected_state_action_values = next_state_values * GAMMA + reward_batch
        expected_state_action_values[illegal_mask] = -10.0
        expected_state_action_values[loss_mask] = -5.0
        expected_state_action_values[win_mask] = 10.0
        expected_state_action_values[draw_mask] = -1.0

        loss = F.mse_loss(state_action_values, expected_state_action_values.unsqueeze(1))

        optimizer.zero_grad()
        loss.backward()
        for param in policy_net.parameters():
            param.grad.data.clamp_(-1, 1)
        optimizer.step()
        losses.append(loss.item())


def train():
    env = Game(DEVICE)
    episode = 0
    steps_done = 0
    policy_net = DNetwork().to(DEVICE)
    optimizer = optim.SGD(policy_net.parameters(), lr=0.01)
    memory = Buffer(200)

    if os.path.isfile(SAVE_FILE):
        checkpoint = torch.load(SAVE_FILE)
        episode = checkpoint["episode"]
        steps_done = checkpoint["steps_done"]
        memory = checkpoint["memory"]
        policy_net.load_state_dict(checkpoint["policy_net"])
        optimizer.load_state_dict(checkpoint["optimizer"])

    target_net = DNetwork().to(DEVICE)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    results = []
    loss = []

    start_time = time.thread_time()

    for i in count(episode):
        env.reset()
        state_x = env.get_state()
        state_o = env.get_state(False)
        action_x = None
        action_o = None
        done = False
        for t in count():
            if env.x_turn:
                #  игра за Х
                action_x = select_action(state_x, policy_net)
                reward, done = env.perform_action(action_x.item())
                if reward != -10:
                    if done:
                        results.append(0 if reward == -1 else 1)
                        reward_x = torch.tensor([10 if reward == -5 else -1], device=DEVICE)
                        memory.push(state_x, action_x, env.get_state(), reward_x)
                        optimize_model(memory, policy_net, target_net, optimizer, loss)
                    next_state_o = env.get_state(False)
                    if action_o is not None:  # если нолик уже ходил, сохраняем
                        reward_o = torch.tensor([reward], device=DEVICE)
                        memory.push(state_o, action_o, next_state_o, reward_o)
                        optimize_model(memory, policy_net, target_net, optimizer, loss)
                    state_o = next_state_o
                else: # запрещенка
                    reward_x = torch.tensor([reward], device=DEVICE)
                    memory.push(state_x, action_x, state_x, reward_x)
                    optimize_model(memory, policy_net, target_net, optimizer, loss)
            else:
                # Игра за О
                action_o = select_action(state_o, policy_net)
                reward, done = env.perform_action(action_o.item())
                if reward != -10:
                    if done:
                        results.append(0 if reward == -1 else -1)
                        reward_o = torch.tensor([10 if reward == -5 else -1], device=DEVICE)
                        memory.push(state_o, action_o, env.get_state(False), reward_o)
                        optimize_model(memory, policy_net, target_net, optimizer, loss)
                    next_state_x = env.get_state()
                    if action_x is not None: # если крестик уже ходил, сохраняем
                        reward_x = torch.tensor([reward], device=DEVICE)
                        memory.push(state_x, action_x, next_state_x, reward_x)
                        optimize_model(memory, policy_net, target_net, optimizer, loss)
                    state_x = next_state_x
                else: # запрещенка
                    reward_o = torch.tensor([reward], device=DEVICE)
                    memory.push(state_o, action_o, state_o, reward_o)
                    optimize_model(memory, policy_net, target_net, optimizer, loss)

            if done:
                break
        if len(results) == SAVE_FREQUENCY:
            print(i + 1, "игр сыграно")
            print(f"{time.thread_time() - start_time} секунд")
            print("X выиграли ", str(results.count(1)).zfill(3), "раз")
            print("O выиграли ", str(results.count(-1)).zfill(3), "раз")
            print("Ничья ", str(results.count(0)).zfill(3), "раз")
            print(
                "Mean loss в последних",
                SAVE_FREQUENCY,
                "играх:",
                round(torch.tensor(loss, device=DEVICE).mean().item(), 4),
            )
            print()
            checkpoint = {
                "episode": i + 1,
                "steps_done": steps_done,
                "memory": memory,
                "policy_net": policy_net.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            torch.save(checkpoint, SAVE_FILE)
            results = []
            loss = []
            start_time = time.thread_time()
        if i % TARGET_UPDATE == 0:
            target_net.load_state_dict(policy_net.state_dict())


In [18]:
DEVICE = 'cuda'
BATCH_SIZE = 8
GAMMA = 0.8
EPS_START = 0.9
EPS_END = 0.2
EPS_DECAY = 100
TARGET_UPDATE = 10
SAVE_FREQUENCY = 2
SAVE_FILE = "weights.pt"

train()

1004 игр сыграно
33.515625 секунд
X выиграли  001 раз
O выиграли  001 раз
Ничья  000 раз
Mean loss в последних 2 играх: 7.2277

1006 игр сыграно
27.171875 секунд
X выиграли  001 раз
O выиграли  001 раз
Ничья  000 раз
Mean loss в последних 2 играх: 7.1671

1008 игр сыграно
25.921875 секунд
X выиграли  001 раз
O выиграли  001 раз
Ничья  000 раз
Mean loss в последних 2 играх: 7.1739

1010 игр сыграно
21.875 секунд
X выиграли  000 раз
O выиграли  002 раз
Ничья  000 раз
Mean loss в последних 2 играх: 6.8761

1012 игр сыграно
20.375 секунд
X выиграли  002 раз
O выиграли  000 раз
Ничья  000 раз
Mean loss в последних 2 играх: 6.6141



KeyboardInterrupt: 

In [19]:
def load_checkpoint(policy_net, optimizer, memory):
    if os.path.isfile(SAVE_FILE):
        checkpoint = torch.load(SAVE_FILE)
        episode = checkpoint["episode"]
        steps_done = checkpoint["steps_done"]
        memory = checkpoint["memory"]
        policy_net.load_state_dict(checkpoint["policy_net"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        return episode, steps_done, memory
    return 0, 0, memory

policy_net = DNetwork().to(DEVICE)
optimizer = optim.SGD(policy_net.parameters(), lr=0.01)
memory = Buffer(200)
episode, steps_done, memory = load_checkpoint(policy_net, optimizer, memory)

game = Game(DEVICE)

while True:
    state = game.get_state()
    action = select_action(state, policy_net)
    reward, done = game.perform_action(action.item())
    game.create()
    print("-------------------------------------------------------")

    # Проверяем, закончена ли игра
    if done:
        break


 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | |X| | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
-------------------------------------------------------
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | | | | | | | | | 
 | | | | | | | | | | | |

In [21]:
class GameBoard:
    def __init__(self):
        self.board = [[' ' for _ in range(20)] for _ in range(20)]
        self.x_turn = True
        self.device = torch.device("cpu")
        self.policy_net = DNetwork().to(self.device)
        self.policy_net.load_state_dict(torch.load("weights.pt")["policy_net"])
        self.policy_net.eval()

    def display_board(self):
        for row in self.board:
            print("|".join(row))
            print("-" * 41)

    def conquer_cell(self, row, col):
        if self.board[row][col] == ' ':
            self.board[row][col] = 'X' if self.x_turn else 'O'
            self.x_turn = not self.x_turn
            return True
        return False

    def play_o(self):
        with torch.no_grad():
            output = self.policy_net(self.get_state().float()).to(self.device)[0]
        for action in output.argsort(descending=True):
            row = action // 20
            col = action % 20
            if self.board[row][col] == ' ':
                self.board[row][col] = 'O'
                self.x_turn = not self.x_turn
                break
        return

    def get_state(self):
        cells = []
        for row in self.board:
            for cell in row:
                cells.append({' ': 0, 'X': 1, 'O': -1}[cell])
        player = -1
        cells = torch.tensor(cells, device=self.device)
        return torch.cat((cells == player, cells == -player, cells == 0)).float().unsqueeze(0)

    def check_game_status(self):
        for i in range(20):
            if 'XXXXX' in ''.join(self.board[i]):
                return 1
            if 'XXXXX' in ''.join([row[i] for row in self.board]):
                return 1
        for i in range(16):
            if 'XXXXX' in ''.join([self.board[j][i+j] for j in range(5)]):
                return 1
            if 'XXXXX' in ''.join([self.board[j][i+4-j] for j in range(5)]):
                return 1
        # чек на ничью
        if " " not in [cell for row in self.board for cell in row]:
            return 0
        return None

def main():
    game = GameBoard()
    while True:
        game.display_board()
        if game.check_game_status() is not None:
            print("Ничья!" if game.check_game_status() == 0 else f"{'X' if game.check_game_status() == 1 else 'O'} has won!")
            break
        if game.x_turn:
            row = int(input("Введите строку (0-19): "))
            col = int(input("Введите столбец (0-19): "))
            if not game.conquer_cell(row, col):
                print("Неправильно, попробуйте снова.")
        else:
            game.play_o()

if __name__ == "__main__":
    main()


 | | | | | | | | | | | | | | | | | | | 
-----------------------------------------
 | | | | | | | | | | | | | | | | | | | 
-----------------------------------------
 | | | | | | | | | | | | | | | | | | | 
-----------------------------------------
 | | | | | | | | | | | | | | | | | | | 
-----------------------------------------
 | | | | | | | | | | | | | | | | | | | 
-----------------------------------------
 | | | | | | | | | | | | | | | | | | | 
-----------------------------------------
 | | | | | | | | | | | | | | | | | | | 
-----------------------------------------
 | | | | | | | | | | | | | | | | | | | 
-----------------------------------------
 | | | | | | | | | | | | | | | | | | | 
-----------------------------------------
 | | | | | | | | | | | | | | | | | | | 
-----------------------------------------
 | | | | | | | | | | | | | | | | | | | 
-----------------------------------------
 | | | | | | | | | | | | | | | | | | | 
-----------------------------------------
 | | | | | | | |