In [1]:
import numpy as np
import random
import torch
from backgammon import BackgammonBoard, Game
from BackModel import ResidualBlock, BackModel
from tqdm import tqdm
import copy

In [20]:
class MCTSNode:
    def __init__(self, state=Game, parent=None):
        self.last_move = None
        self.state = state
        self.legal_moves = state.get_legal_moves()
        self.parent = parent
        self.children = []
        self.visits = 0
        self.wins = 0
        self.prior = 0.0

    def is_fully_expanded(self):
        return len(self.children) == len(self.state.get_legal_moves())

    def best_child(self, c_param=1.4):
        if not self.children:
            return None

        total_visits = sum(child.visits for child in self.children)

        scores = []
        for child in self.children:
            if child.visits > 0:
                q_value = child.wins / child.visits
            else:
                q_value = 0
            u_value = c_param * child.prior * np.sqrt(total_visits) / (1 + child.visits)
            scores.append(q_value + u_value)

        return self.children[np.argmax(scores)]

    def expand(self, model):
        model.to("cuda")
        move_probs, _ = model(self.state.get_input_matrix().to("cuda").unsqueeze(0))
        move_probs = torch.exp(move_probs).squeeze().detach().cpu().numpy()

        legal_indices = []
        for (start, end, die) in self.legal_moves:
            idx = die - 1 if start == -1 else start * 6 + (die - 1)
            legal_indices.append(idx)

        probs = move_probs[legal_indices]
        probs /= probs.sum()

        for move, prior in zip(self.legal_moves, probs):
            new_state = copy.deepcopy(self.state)
            new_state.play_one_move(*move)
            child_node = MCTSNode(new_state, parent=self)
            child_node.last_move = move
            child_node.prior = prior
            self.children.append(child_node)
        
        model.to("cpu")

    def update(self, result):
        self.visits += 1
        self.wins += result*self.state.current_player

In [21]:
class MCTS_Searcher:
    def __init__(self, model, n_simulations=1000):
        self.model = model
        self.n_simulations = n_simulations

    def search(self, initial_state):
        root = MCTSNode(initial_state)

        for _ in range(self.n_simulations):
            node = root
            state = copy.deepcopy(initial_state)

            while node.is_fully_expanded() and node.children:
                node = node.best_child()
                state = node.state

            if not node.is_fully_expanded():
                node.expand(self.model)
                node = random.choice(node.children)
                state = node.state

            result = self.simulate(state)

            while node is not None:
                node.update(result)
                node = node.parent

        best_child = root.best_child(c_param=0)
        if best_child is None:
            return root, root
        return root, best_child

    def simulate(self, state):
        current_state = copy.deepcopy(state)
        while current_state.check_game_over() == 0:
            legal_moves = current_state.get_legal_moves()
            if not legal_moves:
                break
            move_probs, value = self.model(current_state.get_input_matrix().unsqueeze(0))
            move_probs = torch.exp(move_probs).squeeze().detach().numpy()

            legal_indices = []
            for (start, end, die) in legal_moves:
                if start == -1:
                    idx = die - 1  # bar move encoding
                else:
                    idx = start * 6 + (die - 1)
                legal_indices.append(idx)

            weights = move_probs[legal_indices]
            weights /= weights.sum()

            chosen_idx = np.random.choice(len(legal_moves), p=weights)  # index in legal_moves
            start, end, die = legal_moves[chosen_idx]

            
            start, end, die = legal_moves[chosen_idx]
            current_state.play_one_move(start, end, die)
        return current_state.check_game_over()

In [None]:
model = BackModel(num_resnets=3, num_skips=1)
model.load_state_dict(torch.load("model.pth"))
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.999)
searcher = MCTS_Searcher(model=model, n_simulations=100)

for epoch in range(1000):
    buffer = []

    loader = tqdm(range(0, 300))
    for i in loader:
        positions = []
        backgame = Game()
        while backgame.check_game_over() == 0:
            model.eval()
            root, next_board = searcher.search(backgame)

            if root == next_board:
                next_board.state.switch_player()

            else:
                board_state = np.array(root.state.get_input_matrix())
                probas = np.zeros(24*6)
                for child in root.children:
                    if child.last_move is not None:
                        start, end, die = child.last_move
                        index = start * 6 + (die - 1)
                        probas[index] = child.visits
                probas /= probas.sum()
                positions.append((board_state, probas, None))

            backgame = next_board.state
        positions = [(s, p, backgame.check_game_over()) for (s, p, v) in positions]
        buffer.extend(positions)

    random.shuffle(buffer)

    loader = tqdm(range(0, len(buffer), 32))
    model.train()
    total_loss = 0.0
    for idx, i in enumerate(loader, start=1):
        batch = buffer[i:i+32]

        states, target_probs, target_values = zip(*batch)
        states = torch.tensor(states, dtype=torch.float32)
        target_probs = torch.tensor(target_probs, dtype=torch.float32)
        target_values = torch.tensor(target_values, dtype=torch.float32)

        pred_probs, pred_values = model(states)
        value_loss = torch.nn.functional.mse_loss(pred_values.squeeze(), target_values)
        policy_loss = -torch.mean(torch.sum(target_probs * pred_probs, dim=1))
        loss = value_loss + policy_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()


        loader.set_postfix(loss=loss.item())

    avg_loss = total_loss / idx
    scheduler.step()
    print(f"Epoch {epoch}, Average Loss: {avg_loss}")
    torch.save(model.state_dict(), "model.pth")

100%|██████████| 300/300 [32:29<00:00,  6.50s/it]
100%|██████████| 1134/1134 [00:09<00:00, 125.98it/s, loss=2.64]


Epoch 0, Average Loss: 2.3650626183817627


100%|██████████| 300/300 [39:50<00:00,  7.97s/it]
100%|██████████| 1161/1161 [00:09<00:00, 127.21it/s, loss=2.54]


Epoch 1, Average Loss: 1.9989599199976826


100%|██████████| 300/300 [37:22<00:00,  7.48s/it]
100%|██████████| 1141/1141 [00:09<00:00, 124.61it/s, loss=1.57]


Epoch 2, Average Loss: 1.7860846499619414


 32%|███▏      | 97/300 [11:25<25:16,  7.47s/it]

In [7]:
NewGame = Game()

NewGame.roll_dice()
while not NewGame.game_over:
    #NewGame.board.display()
    #print(f"Player 1 pieces broken and borne off: {NewGame.broken_pieces[1], NewGame.collected_pieces[1]}")
    #print(f"Player 2 pieces broken and borne off: {NewGame.broken_pieces[-1], NewGame.collected_pieces[-1]}")
    try:
        NewGame.get_legal_moves()

        #move = input(f"Player {NewGame.current_player}, enter your move \n(start-1)*6 + (die-1) \n(Current dice: {NewGame.dice}): \nPress f to quit.")
        #if move == "f":
        #    print("Game ended by user.")
        #    break
        # start, end, die = NewGame.move_translator(int(move))

        try:
            (start, end, die) = random.choice(NewGame.legal_moves)
            #print(f"Translated move: {start+1} to {end+1} using die {die}")
            NewGame.play_one_move(start, end, die)
        except IndexError:
            #print(f"No legal moves available for Player {NewGame.current_player}. Switching turn.")
            NewGame.switch_player()
            continue


    except ValueError as ve:
        print(f"Try again. {ve}")
    
    if np.abs(NewGame.board.board).sum() + sum(NewGame.broken_pieces.values()) + sum(NewGame.collected_pieces.values()) != 30:
        raise ValueError("Game state invalid. Sum of pieces not 30.")
    
    NewGame.check_game_over()

print(f"Game over! Winner is Player {NewGame.check_game_over()}.")
NewGame.board.display()
print(f"Player 1 pieces broken and borne off: {NewGame.broken_pieces[1], NewGame.collected_pieces[1]}")
print(f"Player 2 pieces broken and borne off: {NewGame.broken_pieces[-1], NewGame.collected_pieces[-1]}")

Game over! Winner is Player -1.
          Top (Points 13-24)
 0  0  0  1  0  0  1  3  0  1  0  4
-----------------------------------
 3  1  0  0  0  0  1  0  0  0  0  0
        Bottom (Points 12-1)
Player 1 pieces broken and borne off: (0, 0)
Player 2 pieces broken and borne off: (0, 15)
