<a href="https://colab.research.google.com/github/rajanaids-hub/Reinforcement_Learning_Lab/blob/main/MCTS_Tic_Tac_Toe_Exp9_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math
import random

# =====================================================================
# üéì STUDENT EXPERIMENT SETTINGS: MCTS üéì
# =====================================================================
# How many future games should the AI simulate before making a move?
# Try lowering this to 10 to make the AI play badly, or 1000 for perfect play!
MCTS_ITERATIONS = 1000

# Exploration constant (usually sqrt(2)). Higher = explores weird moves more.
C_PARAM = 1.41
# =====================================================================

# 1. The Game Engine (Tic-Tac-Toe)
class TicTacToe:
    """A simple Tic-Tac-Toe engine."""
    def __init__(self, state=None, player_to_move=1):
        # State is a list of 9 numbers: 1 (X), -1 (O), 0 (Empty)
        self.state = state if state else [0] * 9
        self.player_to_move = player_to_move # 1 for X, -1 for O

    def get_legal_moves(self):
        """Returns a list of empty indices (0 to 8)."""
        if self.get_winner() != 0: return [] # Game over
        return [i for i, val in enumerate(self.state) if val == 0]

    def play_move(self, move):
        """Returns a NEW game state after a move is made."""
        new_state = list(self.state)
        new_state[move] = self.player_to_move
        return TicTacToe(new_state, -self.player_to_move)

    def get_winner(self):
        """Returns 1 if X wins, -1 if O wins, 0 otherwise."""
        win_conditions = [
            [0, 1, 2], [3, 4, 5], [6, 7, 8], # Rows
            [0, 3, 6], [1, 4, 7], [2, 5, 8], # Cols
            [0, 4, 8], [2, 4, 6]             # Diagonals
        ]
        for combo in win_conditions:
            s = sum(self.state[i] for i in combo)
            if s == 3: return 1   # X wins
            if s == -3: return -1 # O wins
        return 0 # Draw or unfinished

    def is_terminal(self):
        return self.get_winner() != 0 or 0 not in self.state

    def display(self):
        symbols = {1: 'X', -1: 'O', 0: ' '}
        print("-" * 13)
        for row in range(3):
            r = [symbols[self.state[row * 3 + col]] for col in range(3)]
            print(f"| {r[0]} | {r[1]} | {r[2]} |")
            print("-" * 13)

# 2. The MCTS Tree Node
class MCTSNode:
    """Represents a single state in the MCTS search tree."""
    def __init__(self, game_state, parent=None, move=None):
        self.game_state = game_state
        self.parent = parent
        self.move = move # The move that led to this node
        self.untried_moves = game_state.get_legal_moves()
        self.children = []

        # MCTS Statistics
        self.visits = 0
        self.wins = 0.0

    def ucb1(self):
        """
        Upper Confidence Bound 1 (UCB1) Formula:
        Balances EXPLOITATION (high win rate) with EXPLORATION (few visits).
        """
        if self.visits == 0:
            return float('inf') # Always explore unvisited nodes first

        exploitation = self.wins / self.visits
        exploration = C_PARAM * math.sqrt(math.log(self.parent.visits) / self.visits)
        return exploitation + exploration

# 3. The Monte Carlo Tree Search Algorithm
def mcts_search(root_state, iterations=MCTS_ITERATIONS):
    """Runs the MCTS algorithm to find the best move."""
    root_node = MCTSNode(root_state)

    for _ in range(iterations):
        node = root_node
        state = root_state

        # --- PHASE 1: SELECTION ---
        # Traverse down the tree picking the child with the highest UCB1 score
        while node.untried_moves == [] and node.children != []:
            node = max(node.children, key=lambda c: c.ucb1())
            state = state.play_move(node.move)

        # --- PHASE 2: EXPANSION ---
        # If we reach a node with untried moves, pick one randomly and add a new child
        if node.untried_moves != []:
            move = random.choice(node.untried_moves)
            node.untried_moves.remove(move)

            state = state.play_move(move)
            child_node = MCTSNode(state, parent=node, move=move)
            node.children.append(child_node)
            node = child_node

        # --- PHASE 3: SIMULATION (Rollout) ---
        # Play completely random moves until the game ends
        rollout_state = state
        while not rollout_state.is_terminal():
            random_move = random.choice(rollout_state.get_legal_moves())
            rollout_state = rollout_state.play_move(random_move)

        # See who won the random rollout
        winner = rollout_state.get_winner()

        # --- PHASE 4: BACKPROPAGATION ---
        # Traverse back up to the root, updating win/visit counts
        while node is not None:
            node.visits += 1

            # The node's 'parent' made the move that led to this node.
            # So if Player 1 (X) just moved, we check if Player 1 won.
            player_who_just_moved = -node.game_state.player_to_move

            if winner == player_who_just_moved:
                node.wins += 1.0 # Win
            elif winner == 0:
                node.wins += 0.5 # Draw (half a win is better than a loss)

            node = node.parent

    # After all simulations, pick the child with the absolute MOST VISITS.
    # (High visits means it consistently scored well in the UCB1 formula)
    best_child = max(root_node.children, key=lambda c: c.visits)
    return best_child.move

# =====================================================================
# üöÄ MAIN EXECUTION: AI vs Random Game
# =====================================================================
if __name__ == "__main__":
    game = TicTacToe()
    print("ü§ñ MCTS AI (X) vs üé≤ Random Agent (O)")
    game.display()

    while not game.is_terminal():
        if game.player_to_move == 1:
            print(f"\nüß† MCTS AI is thinking... (Running {MCTS_ITERATIONS} simulations)")
            move = mcts_search(game)
            print(f"ü§ñ AI chooses position {move}")
        else:
            print("\nüé≤ Random Agent's Turn...")
            move = random.choice(game.get_legal_moves())
            print(f"üé≤ Random Agent chooses position {move}")

        game = game.play_move(move)
        game.display()

    # Game Over
    winner = game.get_winner()
    print("\nüèÅ GAME OVER üèÅ")
    if winner == 1:
        print("üèÜ MCTS AI (X) WINS!")
    elif winner == -1:
        print("üéâ Random Agent (O) WINS! (This should almost never happen!)")
    else:
        print("ü§ù It's a DRAW!")

ü§ñ MCTS AI (X) vs üé≤ Random Agent (O)
-------------
|   |   |   |
-------------
|   |   |   |
-------------
|   |   |   |
-------------

üß† MCTS AI is thinking... (Running 1000 simulations)
ü§ñ AI chooses position 4
-------------
|   |   |   |
-------------
|   | X |   |
-------------
|   |   |   |
-------------

üé≤ Random Agent's Turn...
üé≤ Random Agent chooses position 8
-------------
|   |   |   |
-------------
|   | X |   |
-------------
|   |   | O |
-------------

üß† MCTS AI is thinking... (Running 1000 simulations)
ü§ñ AI chooses position 5
-------------
|   |   |   |
-------------
|   | X | X |
-------------
|   |   | O |
-------------

üé≤ Random Agent's Turn...
üé≤ Random Agent chooses position 7
-------------
|   |   |   |
-------------
|   | X | X |
-------------
|   | O | O |
-------------

üß† MCTS AI is thinking... (Running 1000 simulations)
ü§ñ AI chooses position 3
-------------
|   |   |   |
-------------
| X | X | X |
-------------
|   | O | O |
---