In [13]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [14]:
width, height = 3, 3
state_size = width * height
nh1, nh2 = 64, 64

In [15]:
class ValueNetwork(nn.Module):
    
    def __init__(self):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, nh1)
        self.fc2 = nn.Linear(nh1, nh2)
        self.fc3 = nn.Linear(nh2, 1)
    
    def forward(self, x):
        y = F.relu(self.fc1(x))
        y = F.relu(self.fc2(y))
        y = F.tanh(self.fc3(y))
        return y

value_network = ValueNetwork()
print(value_network)

ValueNetwork(
  (fc1): Linear(in_features=9, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=1, bias=True)
)


In [16]:
optimizer = optim.Adam(value_network.parameters(), lr = 0.01)
criterion = nn.MSELoss()

In [17]:
np1, np2 = 64, 64

In [18]:
class PolicyNetwork(nn.Module):
    
    def __init__(self):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, np1)
        self.fc2 = nn.Linear(np1, np2)
        self.fc3 = nn.Linear(np2, state_size)
    
    def forward(self, x):
        y = F.relu(self.fc1(x))
        y = F.relu(self.fc2(y))
        y = F.softmax(self.fc3(y))
        return y
    
policy_network = PolicyNetwork()
print(policy_network)

PolicyNetwork(
  (fc1): Linear(in_features=9, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=9, bias=True)
)


In [19]:
optimizer = optim.Adam(policy_network.parameters(), lr = 0.01)
criterion = nn.CrossEntropyLoss()


In [20]:
class TicTacToeState:
    def __init__(self, board=None, player=1):
        if board is None:
            board = [0] * 9
        self.board = board
        self.player = player

    def get_possible_moves(self):
        return [i for i, v in enumerate(self.board) if v == 0]

    def play_move(self, move):
        self.board[move] = self.player
        self.player = -self.player

    def play_random_move(self):
        possible_moves = self.get_possible_moves()
        move = random.choice(possible_moves)
        self.play_move(move)

    def is_game_over(self):
        return self.get_result() is not None

    def get_result(self):
        for i in range(3):
            row = i * 3
            if self.board[row] != 0 and self.board[row] == self.board[row + 1] == self.board[row + 2]:
                return self.board[row]
            if self.board[i] != 0 and self.board[i] == self.board[i + 3] == self.board[i + 6]:
                return self.board[i]

        if self.board[0] != 0 and self.board[0] == self.board[4] == self.board[8]:
            return self.board[0]
        if self.board[2] != 0 and self.board[2] == self.board[4] == self.board[6]:
            return self.board[2]

        if 0 not in self.board:
            return 0

        return None

    def copy(self):
        return TicTacToeState(self.board[:], self.player)

In [21]:
import random

class Node:
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.visits = 0
        self.wins = 0

    def add_child(self, child_state):
        child = Node(child_state, self)
        self.children.append(child)
        return child

    def update(self, result):
        self.visits += 1
        self.wins += result

def select_child(node):
    total_visits = sum(child.visits for child in node.children)
    log_visits = np.log(1 + total_visits)

    def uct(node):
        return node.wins / (1 + node.visits) + np.sqrt(log_visits / (node.visits + 1))

    return max(node.children, key=uct)

def expand(node):
    state = node.state.copy()
    move = random.choice(state.get_possible_moves())
    #print("move", move)
    child_state = state.copy()
    child_state.play_move(move)
    child_node = node.add_child(child_state)
    return child_node
    
def simulate(node):
    state = node.state.copy()
    while not state.is_game_over():
        state.play_random_move()
    return state.get_result()

def backpropagate(node, result):
    while node is not None:
        node.update(result)
        node = node.parent

def mcts(root, iterations):
    for i in range(iterations):
        node = root
        while len(node.children) != 0 and not node.state.is_game_over():
            node = select_child(node)
        if not node.state.is_game_over():
            child_node = expand(node)
        else:
            child_node = node
        result = simulate(child_node)
        backpropagate(child_node, result)



In [22]:
state = TicTacToeState()
root = Node(state)
mcts(root, iterations=10)

In [24]:


def get_human_move(board):
    while True:
        move = int(input('Enter your move (0-8): '))
        if move in board.get_possible_moves():
            return move
        print('Invalid move')

def play_game(root):
    state = root.state
    while not state.is_game_over():
        print(display_board(state.board))
        print()
        if state.player == 1:
            move = get_human_move(state)
        else:
            mcts(root, iterations = 1000)
            node = select_child(root)
            state_difference = [np.abs(root.state.board[i] - node.state.board[i]) for i in range(len(root.state.board))]
            move = np.argmax(state_difference)
            root = node
        state.play_move(move)

    print(display_board(state.board))
    result = state.get_result()
    if result == 1:
        print('You win!')
    elif result == -1:
        print('You lose!')
    else:
        print('Draw!')

state = TicTacToeState()
root = Node(state)
play_game(root)


   |   |  
---+---+---
   |   |  
---+---+---
   |   |  

 X |   |  
---+---+---
   |   |  
---+---+---
   |   |  

 X |   | O
---+---+---
   |   |  
---+---+---
   |   |  



ValueError: invalid literal for int() with base 10: ''