In [1]:
import numpy as np
import random
import torch
from backgammon import BackgammonBoard, Game
from BackModel import ResidualBlock, BackModel
from tqdm import tqdm
import copy

In [2]:
class MCTSNode:
    def __init__(self, state=Game, parent=None):
        self.last_move = None
        self.state = state
        self.legal_moves = state.get_legal_moves()
        self.parent = parent
        self.children = []
        self.visits = 0
        self.wins = 0
        self.prior = 0.0

    def is_fully_expanded(self):
        return len(self.children) == len(self.state.get_legal_moves())

    def best_child(self, c_param=1.4):
        if not self.children:
            return None

        total_visits = sum(child.visits for child in self.children)

        scores = []
        for child in self.children:
            if child.visits > 0:
                q_value = child.wins / child.visits
            else:
                q_value = 0
            u_value = c_param * child.prior * np.sqrt(total_visits) / (1 + child.visits)
            scores.append(q_value + u_value)

        return self.children[np.argmax(scores)]

    def expand(self, model, add_dirichlet_noise=False, eps=0.25, alpha=0.3):
        model.to("cuda")
        move_probs, _ = model(self.state.get_input_matrix().to("cuda").unsqueeze(0))
        move_probs = torch.exp(move_probs).squeeze().detach().cpu().numpy()

        legal_indices = []
        for (start, end, die) in self.legal_moves:
            idx = die - 1 if start == -1 else start * 6 + (die - 1)
            legal_indices.append(idx)

        probs = move_probs[legal_indices]
        probs /= probs.sum()
    
        if add_dirichlet_noise and self.parent is None:
            noise = np.random.dirichlet([alpha] * len(probs))
            probs = (1 - eps) * probs + eps * noise

        for move, prior in zip(self.legal_moves, probs):
            new_state = copy.deepcopy(self.state)
            new_state.play_one_move(*move)
            child_node = MCTSNode(new_state, parent=self)
            child_node.last_move = move
            child_node.prior = prior
            self.children.append(child_node)
        
        model.to("cpu")

    def update(self, result):
        self.visits += 1
        self.wins += result*self.state.current_player

In [3]:
class MCTS_Searcher:
    def __init__(self, model, n_simulations=1000):
        self.model = model
        self.n_simulations = n_simulations

    def search(self, initial_state):
        root = MCTSNode(initial_state)

        for _ in range(self.n_simulations):
            node = root
            state = copy.deepcopy(initial_state)

            while node.is_fully_expanded() and node.children:
                node = node.best_child()
                state = node.state

            if not node.is_fully_expanded():
                node.expand(self.model, add_dirichlet_noise=(node.parent is None))
                node = random.choice(node.children)
                state = node.state

            result = self.simulate(state)

            while node is not None:
                node.update(result)
                node = node.parent

        best_child = root.best_child(c_param=0)
        if best_child is None:
            return root, root
        return root, best_child

    def simulate(self, state):
        current_state = copy.deepcopy(state)
        while current_state.check_game_over() == 0:
            legal_moves = current_state.get_legal_moves()
            if not legal_moves:
                break
            move_probs, value = self.model(current_state.get_input_matrix().unsqueeze(0))
            move_probs = torch.exp(move_probs).squeeze().detach().numpy()

            legal_indices = []
            for (start, end, die) in legal_moves:
                if start == -1:
                    idx = die - 1  # bar move encoding
                else:
                    idx = start * 6 + (die - 1)
                legal_indices.append(idx)

            weights = move_probs[legal_indices]
            weights /= weights.sum()

            chosen_idx = np.random.choice(len(legal_moves), p=weights)  # index in legal_moves
            start, end, die = legal_moves[chosen_idx]

            
            start, end, die = legal_moves[chosen_idx]
            current_state.play_one_move(start, end, die)
        return current_state.check_game_over()

In [4]:
model = BackModel(num_resnets=2, num_skips=2)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")

Total parameters: 576401
Trainable parameters: 576401


In [None]:
#model.load_state_dict(torch.load("model.pth"))
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.999)
searcher = MCTS_Searcher(model=model, n_simulations=100)

for epoch in range(1000):
    buffer = []

    loader = tqdm(range(0, 300))
    for i in loader:
        positions = []
        backgame = Game()
        while backgame.check_game_over() == 0:
            model.eval()
            root, next_board = searcher.search(backgame)

            if root == next_board:
                next_board.state.switch_player()

            else:
                board_state = np.array(root.state.get_input_matrix())
                probas = np.zeros(24*6)
                for child in root.children:
                    if child.last_move is not None:
                        start, end, die = child.last_move
                        index = start * 6 + (die - 1)
                        probas[index] = child.visits
                probas /= probas.sum()
                positions.append((board_state, probas, None))

            backgame = next_board.state
        positions = [(s, p, backgame.check_game_over()) for (s, p, v) in positions]
        buffer.extend(positions)

    random.shuffle(buffer)

    loader = tqdm(range(0, len(buffer), 32))
    model.train()
    total_loss = 0.0
    for idx, i in enumerate(loader, start=1):
        batch = buffer[i:i+32]

        states, target_probs, target_values = zip(*batch)
        states = torch.tensor(states, dtype=torch.float32)
        target_probs = torch.tensor(target_probs, dtype=torch.float32)
        target_values = torch.tensor(target_values, dtype=torch.float32)

        pred_probs, pred_values = model(states)
        value_loss = torch.nn.functional.mse_loss(pred_values.squeeze(), target_values)
        policy_loss = -torch.mean(torch.sum(target_probs * pred_probs, dim=1))
        loss = value_loss + policy_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()


        loader.set_postfix(loss=loss.item())

    avg_loss = total_loss / idx
    scheduler.step()
    print(f"Epoch {epoch}, Average Loss: {avg_loss:.2f}")
    torch.save(model.state_dict(), "model.pth")

100%|██████████| 300/300 [39:51<00:00,  7.97s/it]
  states = torch.tensor(states, dtype=torch.float32)
100%|██████████| 1129/1129 [00:12<00:00, 90.98it/s, loss=8.78]


Epoch 0, Average Loss: 4.24


100%|██████████| 300/300 [1:06:04<00:00, 13.22s/it]
100%|██████████| 1150/1150 [00:13<00:00, 86.22it/s, loss=2.74]


Epoch 1, Average Loss: 2.96


100%|██████████| 300/300 [1:09:34<00:00, 13.92s/it]
100%|██████████| 1154/1154 [00:13<00:00, 86.70it/s, loss=2.38]


Epoch 2, Average Loss: 2.63


100%|██████████| 300/300 [1:10:56<00:00, 14.19s/it]
100%|██████████| 1156/1156 [00:14<00:00, 81.26it/s, loss=2.76]


Epoch 3, Average Loss: 2.17


100%|██████████| 300/300 [1:10:20<00:00, 14.07s/it]
100%|██████████| 1115/1115 [00:12<00:00, 91.72it/s, loss=4.02]


Epoch 4, Average Loss: 2.01


100%|██████████| 300/300 [1:10:31<00:00, 14.11s/it]
100%|██████████| 1155/1155 [00:12<00:00, 91.71it/s, loss=1.33]


Epoch 5, Average Loss: 1.92


100%|██████████| 300/300 [1:10:54<00:00, 14.18s/it]
100%|██████████| 1152/1152 [00:12<00:00, 91.58it/s, loss=3.5] 


Epoch 6, Average Loss: 1.82


100%|██████████| 300/300 [1:13:46<00:00, 14.75s/it]
100%|██████████| 1135/1135 [00:13<00:00, 85.33it/s, loss=1.62]


Epoch 7, Average Loss: 1.78


100%|██████████| 300/300 [1:16:09<00:00, 15.23s/it]
100%|██████████| 1148/1148 [00:13<00:00, 85.33it/s, loss=1.76]


Epoch 8, Average Loss: 1.83


100%|██████████| 300/300 [1:16:36<00:00, 15.32s/it]
100%|██████████| 1150/1150 [00:13<00:00, 85.52it/s, loss=1.8] 


Epoch 9, Average Loss: 1.74


100%|██████████| 300/300 [1:18:13<00:00, 15.65s/it]
100%|██████████| 1134/1134 [00:13<00:00, 85.64it/s, loss=2.16]


Epoch 10, Average Loss: 1.67


100%|██████████| 300/300 [1:18:20<00:00, 15.67s/it]
100%|██████████| 1156/1156 [00:13<00:00, 85.35it/s, loss=1.53]


Epoch 11, Average Loss: 1.72


100%|██████████| 300/300 [1:19:23<00:00, 15.88s/it]
100%|██████████| 1151/1151 [00:13<00:00, 88.21it/s, loss=1.5] 


Epoch 12, Average Loss: 1.71


100%|██████████| 300/300 [1:21:11<00:00, 16.24s/it]
100%|██████████| 1147/1147 [00:13<00:00, 88.13it/s, loss=2.19]


Epoch 13, Average Loss: 1.64


 61%|██████    | 182/300 [50:15<35:51, 18.23s/it] 

In [7]:
NewGame = Game()

NewGame.roll_dice()
while not NewGame.game_over:
    #NewGame.board.display()
    #print(f"Player 1 pieces broken and borne off: {NewGame.broken_pieces[1], NewGame.collected_pieces[1]}")
    #print(f"Player 2 pieces broken and borne off: {NewGame.broken_pieces[-1], NewGame.collected_pieces[-1]}")
    try:
        NewGame.get_legal_moves()

        #move = input(f"Player {NewGame.current_player}, enter your move \n(start-1)*6 + (die-1) \n(Current dice: {NewGame.dice}): \nPress f to quit.")
        #if move == "f":
        #    print("Game ended by user.")
        #    break
        # start, end, die = NewGame.move_translator(int(move))

        try:
            (start, end, die) = random.choice(NewGame.legal_moves)
            #print(f"Translated move: {start+1} to {end+1} using die {die}")
            NewGame.play_one_move(start, end, die)
        except IndexError:
            #print(f"No legal moves available for Player {NewGame.current_player}. Switching turn.")
            NewGame.switch_player()
            continue


    except ValueError as ve:
        print(f"Try again. {ve}")
    
    if np.abs(NewGame.board.board).sum() + sum(NewGame.broken_pieces.values()) + sum(NewGame.collected_pieces.values()) != 30:
        raise ValueError("Game state invalid. Sum of pieces not 30.")
    
    NewGame.check_game_over()

print(f"Game over! Winner is Player {NewGame.check_game_over()}.")
NewGame.board.display()
print(f"Player 1 pieces broken and borne off: {NewGame.broken_pieces[1], NewGame.collected_pieces[1]}")
print(f"Player 2 pieces broken and borne off: {NewGame.broken_pieces[-1], NewGame.collected_pieces[-1]}")

Game over! Winner is Player -1.
          Top (Points 13-24)
 0  0  0  1  0  0  1  3  0  1  0  4
-----------------------------------
 3  1  0  0  0  0  1  0  0  0  0  0
        Bottom (Points 12-1)
Player 1 pieces broken and borne off: (0, 0)
Player 2 pieces broken and borne off: (0, 15)
