**This notebook is for demonstration purposes only. The focus is on illustrating the ideas rather than code implementation details.** For pseudocode and well-organized implementations, please refer to the textbook and the AIMA Python repository:
https://github.com/aimacode/aima-python/tree/master

In [None]:
import math
import random
import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import matplotlib
matplotlib.use("TkAgg")  # or "QtAgg"

import matplotlib.pyplot as plt
from matplotlib.patches import Circle, FancyArrowPatch
from matplotlib.widgets import Button

# ----------------------------
# Config
# ----------------------------
RANDOM_SEED = 1234
AUTO_STEP_SECONDS = 0.05
FIGSIZE = (13, 6)

UCT_C = math.sqrt(2.0)
# UCT_C = 3

P0_COLOR = "#b9b2cf"
P1_COLOR = "#d9b7a6"
EDGE_COLOR = "#222222"
HILITE = "#ffd54a"
TEXT_COLOR = "#111111"

# ----------------------------
# Layout / Game tree definition
# ----------------------------
POS: Dict[str, Tuple[float, float]] = {
    "R":  (0.0,  0.0),

    "L":  (-2.0, -1.2),
    "M":  (0.0,  -1.2),
    "Rt": (2.0,  -1.2),

    "LL": (-3.0, -2.4),
    "LM": (-1.5, -2.4),

    "MM": (0.0,  -2.4),
    "MR": (1.2,  -2.4),

    "LML": (-2.2, -3.6),
    "LMR": (-1.2, -3.6),

    "MML": (-0.4, -3.6),
    "MMR": (0.6,  -3.6),
}

OUT_POS = (2.8, -2.2)  # pointer when rollout goes outside the *tree* (i.e., into game-only states)

EDGES: List[Tuple[str, str]] = [
    ("R", "L"), ("R", "M"), ("R", "Rt"),
    ("L", "LL"), ("L", "LM"),
    ("LM", "LML"), ("LM", "LMR"),
    ("M", "MM"), ("M", "MR"),
    ("MM", "MML"), ("MM", "MMR"),
]

# Deterministic terminal outcomes (edit as you like)
# Winner is 0 or 1.
TERMINAL_WINNER: Dict[str, int] = {
    "LL":  0,
    "LML": 1,
    "LMR": 0,
    "MR":  1,
    "MML": 0,
    "MMR": 1,
    "Rt":  0,
}


def build_game_children() -> Dict[str, List[str]]:
    ch: Dict[str, List[str]] = {k: [] for k in POS.keys()}
    for u, v in EDGES:
        ch[u].append(v)
    return ch


GAME_CHILDREN = build_game_children()
TERMINALS = {n for n in POS if len(GAME_CHILDREN.get(n, [])) == 0}


def is_terminal(state: str) -> bool:
    return state in TERMINALS


def player_at_depth(depth: int) -> int:
    return depth % 2


def node_color(player: int) -> str:
    return P0_COLOR if player == 0 else P1_COLOR


def pstr(p: int) -> str:
    return "Player 0 (purple)" if p == 0 else "Player 1 (brown)"


# ----------------------------
# Tree node (MCTS)
# ----------------------------
@dataclass
class TreeNode:
    name: str
    parent: Optional[str]
    player_to_move: int

    # all possible game children
    children_all: List[str] = field(default_factory=list)
    # children not yet expanded into the search tree
    untried: List[str] = field(default_factory=list)
    # expanded children in the tree
    children_tree: List[str] = field(default_factory=list)

    visits: int = 0
    wins: int = 0  # wins for player_to_move at THIS node


# ----------------------------
# Event model (step-by-step)
# ----------------------------
@dataclass
class Event:
    phase: str  # "select" | "expand" | "simulate" | "backprop"
    tree_path: List[str]            # path in TREE (existing nodes)
    rollout_path: List[str]         # path in GAME during rollout (may include non-tree states)
    focus: Optional[str] = None
    expanded: Optional[str] = None
    winner: Optional[int] = None
    backprop_node: Optional[str] = None
    note: Optional[str] = None


# ----------------------------
# UCT
# ----------------------------
def uct_score(parent_visits: int, child: TreeNode) -> float:
    if child.visits == 0:
        return 0.0, float("inf")
    exploit = child.wins / child.visits
    explore = UCT_C * math.sqrt(math.log(max(parent_visits, 1)) / child.visits)
    return exploit,explore


def deepcopy_tree(tree: Dict[str, TreeNode]) -> Dict[str, TreeNode]:
    newt: Dict[str, TreeNode] = {}
    for k, v in tree.items():
        newt[k] = TreeNode(
            name=v.name,
            parent=v.parent,
            player_to_move=v.player_to_move,
            children_all=list(v.children_all),
            untried=list(v.untried),
            children_tree=list(v.children_tree),
            visits=v.visits,
            wins=v.wins,
        )
    return newt


def build_initial_tree(root: str = "R") -> Dict[str, TreeNode]:
    depth = int(round(-POS[root][1] / 1.2))
    node = TreeNode(
        name=root,
        parent=None,
        player_to_move=player_at_depth(depth),
        children_all=list(GAME_CHILDREN.get(root, [])),
        untried=list(GAME_CHILDREN.get(root, [])),
        children_tree=[],
        visits=0,
        wins=0,
    )
    return {root: node}


def run_one_iteration_strict(tree: Dict[str, TreeNode], root: str) -> List[Tuple[Event, Dict[str, TreeNode]]]:
    """
    Run ONE MCTS iteration, but return [(event, tree_snapshot_after_event), ...].
    This prevents "future leakage" into earlier steps.
    """
    steps: List[Tuple[Event, Dict[str, TreeNode]]] = []

    # ---------- SELECT ----------
    cur = root
    tree_path = [cur]
    best_uct = []
    
    while True:
        node = tree[cur]
        print()
        if is_terminal(cur):
            break
        if node.untried:
            # not fully expanded => stop selection
            break
        uct_string = ""
        # choose best UCT among expanded children
        best_child = None
        best_val = -1e100
        for c in node.children_tree:
            exploit,explore = uct_score(node.visits if node.visits > 0 else 1, tree[c])
            val=exploit+explore if explore != float("inf") else float("inf")
            print(c, exploit, explore, end=", ")
            uct_string += f"({c}, {exploit:.3f}, {explore:.3f}),\n"
            if val > best_val:
                best_val = val
                best_child = c

        if best_child is None:
            break

        cur = best_child
        tree_path.append(cur)
        best_uct.append(uct_string)

    ev_select = Event(
        phase="select",
        tree_path=list(tree_path),
        rollout_path=list(tree_path),  # so far rollout path equals tree path
        focus=cur,
        # note=f"Selected down the tree using UCT until not-fully-expanded or terminal. \nBest child UCB1:=(exploit,explore):={best_uct}",
        note=f"Tree path: {tree_path}\nBest child UCB1:=(c,exploit,explore):={best_uct}"
    )
    steps.append((ev_select, deepcopy_tree(tree)))

    # ---------- EXPAND (always one step; may be no-op) ----------
    expanded_node = cur
    if (not is_terminal(cur)) and tree[cur].untried:
        new_child = tree[cur].untried.pop(0)  # deterministic order; change to random if you want
        depth = int(round(-POS[new_child][1] / 1.2))
        tree[new_child] = TreeNode(
            name=new_child,
            parent=cur,
            player_to_move=player_at_depth(depth),
            children_all=list(GAME_CHILDREN.get(new_child, [])),
            untried=list(GAME_CHILDREN.get(new_child, [])),
            children_tree=[],
            visits=0,
            wins=0,
        )
        tree[cur].children_tree.append(new_child)
        expanded_node = new_child
        tree_path.append(expanded_node)

        ev_expand = Event(
            phase="expand",
            tree_path=list(tree_path),
            rollout_path=list(tree_path),
            focus=cur,
            expanded=expanded_node,
            note=f"Expanded one child: {cur} -> {expanded_node}."
        )
    else:
        ev_expand = Event(
            phase="expand",
            tree_path=list(tree_path),
            rollout_path=list(tree_path),
            focus=cur,
            expanded=None,
            note="No expansion (terminal or fully expanded)."
        )

    steps.append((ev_expand, deepcopy_tree(tree)))

    # ---------- SIMULATE / ROLLOUT ----------
    sim_state = expanded_node
    rollout_path = [sim_state]
    while not is_terminal(sim_state):
        children = GAME_CHILDREN.get(sim_state, [])
        if not children:
            break
        sim_state = random.choice(children)
        rollout_path.append(sim_state)

    winner = TERMINAL_WINNER[sim_state]

    # full rollout path includes the tree path (to expanded_node) then rollout beyond it
    rollout_full = list(tree_path) + rollout_path[1:]

    ev_sim = Event(
        phase="simulate",
        tree_path=list(tree_path),
        rollout_path=list(rollout_full),
        focus=sim_state,
        winner=winner,
        note="Rolled out with a random policy to a terminal state."
    )
    steps.append((ev_sim, deepcopy_tree(tree)))  # IMPORTANT: snapshot BEFORE backprop updates

    # ---------- BACKPROP (one node per step) ----------
    for n in reversed(tree_path):
        tn = tree[n]
        tn.visits += 1
        if winner == tn.player_to_move:
            tn.wins += 1

        ev_bp = Event(
            phase="backprop",
            tree_path=list(tree_path),
            rollout_path=list(rollout_full),
            winner=winner,
            backprop_node=n,
            focus=n,
            note="Backprop updated this node."
        )
        steps.append((ev_bp, deepcopy_tree(tree)))

    return steps


def run_mcts(num_iterations: int, seed: int) -> Tuple[List[Event], List[Dict[str, TreeNode]]]:
    random.seed(seed)
    tree = build_initial_tree("R")

    all_events: List[Event] = []
    all_states: List[Dict[str, TreeNode]] = []

    for _ in range(num_iterations):
        iter_steps = run_one_iteration_strict(tree, "R")
        for ev, snap in iter_steps:
            all_events.append(ev)
            all_states.append(snap)

    root = "R"

    by_visits = sorted(
        tree[root].children_tree,
        key=lambda c: tree[c].visits,
        reverse=True
    )

    print("\nRoot children summary:")
    for c in by_visits:
        w, n = tree[c].wins, tree[c].visits
        print(f"{c:>3} | visits={n:>4} | wins={w:>4}")

    print("\nChosen action (decision policy):", by_visits[0])

    return all_events, all_states

In [17]:
# ----------------------------
# Drawing
# ----------------------------
def compute_phase_counts(events: List[Event], upto: int) -> Dict[str, int]:
    c = {"select": 0, "expand": 0, "simulate": 0, "backprop": 0}
    upto = max(0, min(upto, len(events) - 1))
    for i in range(upto + 1):
        c[events[i].phase] += 1
    return c


def set_tree_axes(ax):
    xs = [p[0] for p in POS.values()]
    ys = [p[1] for p in POS.values()]
    pad_x, pad_y = 0.8, 0.8
    ax.set_xlim(min(xs) - pad_x, max(xs) + pad_x)
    ax.set_ylim(min(ys) - pad_y, max(ys) + pad_y)
    ax.set_aspect("equal")
    ax.axis("off")


def draw(ax_tree, ax_text, tree: Dict[str, TreeNode], event: Event, counts: Dict[str, int]):
    ax_tree.clear()
    ax_text.clear()
    set_tree_axes(ax_tree)
    ax_text.axis("off")

    in_tree = set(tree.keys())

    # edges among tree nodes
    for u, v in EDGES:
        if (u in in_tree) and (v in in_tree):
            x1, y1 = POS[u]
            x2, y2 = POS[v]
            ax_tree.plot([x1, x2], [y1, y2], color=EDGE_COLOR, linewidth=2, zorder=1)

    # highlight
    highlight_path = [n for n in event.tree_path if n in in_tree]
    highlight_node = event.focus if (event.focus in in_tree) else (highlight_path[-1] if highlight_path else None)

    # rollout outside-tree pointer: only in SIMULATE step if rollout leaves tree
    if event.phase == "simulate":
        left_tree_at = None
        last_in_tree = None
        for n in event.rollout_path:
            if n in in_tree:
                last_in_tree = n
            else:
                left_tree_at = n
                break
        if left_tree_at is not None and last_in_tree is not None:
            x1, y1 = POS[last_in_tree]
            x2, y2 = OUT_POS
            arrow = FancyArrowPatch(
                (x1, y1), (x2, y2),
                connectionstyle="arc3,rad=-0.35",
                arrowstyle="-|>",
                mutation_scale=16,
                linewidth=2,
                color="#666666",
                linestyle="--",
                zorder=1.5
            )
            ax_tree.add_patch(arrow)
            ax_tree.text(x2, y2, "rollout\n(outside tree)", ha="left", va="center",
                         fontsize=10, color="#666666")

    # nodes
    radius = 0.25
    for name in in_tree:
        x, y = POS[name]
        tn = tree[name]

        base = node_color(tn.player_to_move)
        edgec = "#000000"
        lw = 2
        if name in highlight_path:
            lw = 3
        if name == highlight_node:
            edgec = HILITE
            lw = 4

        circ = Circle((x, y), radius, facecolor=base, edgecolor=edgec, linewidth=lw, zorder=2)
        ax_tree.add_patch(circ)

        ax_tree.text(x, y + 0.32, name, ha="center", va="center",
                     fontsize=11, fontweight="bold", color=TEXT_COLOR, zorder=3)

        ax_tree.text(x, y, f"{tn.wins}/{tn.visits}", ha="center", va="center",
                     fontsize=9, color=TEXT_COLOR, zorder=3)

    # right panel
    lines: List[str] = []

    if event.phase == "select":
        lines += [
            "Step: SELECTION (UCT)",
            "Tree path: " + " → ".join(event.tree_path),
            f"Stop at: {event.focus}",
            f"{event.note}",
        ]

    elif event.phase == "expand":
        lines += ["Step: EXPANSION"]
        if event.expanded is None:
            lines += ["Expanded: (none)", event.note or ""]
        else:
            lines += [f"Expanded new node: {event.expanded}", event.note or ""]

        # IMPORTANT: do NOT show winner here (winner is None by design)

    elif event.phase == "simulate":
        lines += [
            "Step: SIMULATION / ROLLOUT",
            "Rollout path (game): " + " → ".join(event.rollout_path),
            f"Terminal: {event.focus}",
            f"Winner: {pstr(event.winner)}",
        ]

    else:
        lines += [
            "Step: BACKPROPAGATION",
            f"Winner: {pstr(event.winner)}",
            f"Updating node: {event.backprop_node}",
        ]
        if event.backprop_node in tree:
            tn = tree[event.backprop_node]
            lines += [f"Now: {tn.wins}/{tn.visits} (wins/visits for {pstr(tn.player_to_move)})"]

    lines += [
        "",
        "Function-call counters (so far):",
        f"SELECT:   {counts['select']}",
        f"EXPAND:   {counts['expand']}",
        f"ROLLOUT:  {counts['simulate']}",
        f"BACKPROP: {counts['backprop']}",
    ]

    y = 0.98
    for s in lines:
        ax_text.text(0.02, y, s, transform=ax_text.transAxes,
                     fontsize=10, va="top", color=TEXT_COLOR)
        y -= 0.075

In [18]:
# ----------------------------
# Stepper UI
# ----------------------------
class Stepper:
    def __init__(self, events: List[Event], states: List[Dict[str, TreeNode]]):
        self.events = events
        self.states = states
        self.i = 0
        self.playing = False
        self._rendering = False
        self._last_action_t = 0.0
        self._debounce_s = 0.10  # 100ms: enough to prevent double-fire

        self.fig = plt.figure(figsize=FIGSIZE)
        self.fig.subplots_adjust(left=0.04, right=0.8, top=0.90, bottom=0.16)
        gs = self.fig.add_gridspec(1, 2, width_ratios=[1.5, 1.0])
        self.ax_tree = self.fig.add_subplot(gs[0, 0])
        self.ax_text = self.fig.add_subplot(gs[0, 1])

        # Buttons on top
        self.ax_prev  = self.fig.add_axes([0.12, 0.04, 0.10, 0.07], zorder=100)
        self.ax_next  = self.fig.add_axes([0.23, 0.04, 0.10, 0.07], zorder=100)
        self.ax_play  = self.fig.add_axes([0.36, 0.04, 0.10, 0.07], zorder=100)
        self.ax_pause = self.fig.add_axes([0.47, 0.04, 0.10, 0.07], zorder=100)
        self.ax_reset = self.fig.add_axes([0.60, 0.04, 0.10, 0.07], zorder=100)

        for ax in [self.ax_prev, self.ax_next, self.ax_play, self.ax_pause, self.ax_reset]:
            ax.set_navigate(False)
            ax.patch.set_alpha(1.0)
            ax.patch.set_zorder(100)

        self.b_prev = Button(self.ax_prev, "Prev")
        self.b_next = Button(self.ax_next, "Next")
        self.b_play = Button(self.ax_play, "Play")
        self.b_pause = Button(self.ax_pause, "Pause")
        self.b_reset = Button(self.ax_reset, "Reset")

        for b in [self.b_prev, self.b_next, self.b_play, self.b_pause, self.b_reset]:
            b.ax.set_zorder(200)

        self.b_prev.on_clicked(self.prev)
        self.b_next.on_clicked(self.next)
        self.b_play.on_clicked(self.play)
        self.b_pause.on_clicked(self.pause)
        self.b_reset.on_clicked(self.reset)

        self.timer = self.fig.canvas.new_timer(interval=int(AUTO_STEP_SECONDS * 1000))
        self.timer.add_callback(self._tick)

        self.fig.canvas.mpl_connect("key_press_event", self.on_key)
        self.render()

    def clamp(self):
        self.i = max(0, min(self.i, len(self.events) - 1))

    def render(self):
        if self._rendering:
            return
        self._rendering = True
        try:
            ev = self.events[self.i]
            tree = self.states[self.i]
            counts = compute_phase_counts(self.events, self.i)
            draw(self.ax_tree, self.ax_text, tree, ev, counts)
            self.fig.suptitle(f"MCTS step {self.i+1}/{len(self.events)}", fontsize=14)
            self.fig.canvas.draw_idle()
        except Exception:
            self.pause()
            print("\n[ERROR in render]")
            raise
        finally:
            self._rendering = False

    def _tick(self):
        if not self.playing:
            return
        if self._rendering:
            return
        try:
            if self.i >= len(self.events) - 1:
                self.pause()
                return
            self.i += 1
            self.render()
        except Exception:
            self.pause()
            print("\n[ERROR in timer tick]")
            raise


    def next(self, _=None):
        import time
        now = time.time()
        if now - self._last_action_t < self._debounce_s:
            return
        self._last_action_t = now

        self.pause()  # prevents timer from racing you
        if self._rendering:
            return
        self.i += 1
        self.clamp()
        self.render()

    def prev(self, _=None):
        import time
        now = time.time()
        if now - self._last_action_t < self._debounce_s:
            return
        self._last_action_t = now

        self.pause()
        if self._rendering:
            return
        self.i -= 1
        self.clamp()
        self.render()


    def reset(self, _=None):
        self.pause()
        self.i = 0
        self.render()

    def play(self, _=None):
        if self.playing:
            return
        self.playing = True
        self.timer.stop()
        self.timer.start()

    def pause(self, _=None):
        if self.playing:
            self.playing = False
            self.timer.stop()

    def on_key(self, event):
        if event.key in ["right", "d", "n"]:
            self.next()
        elif event.key in ["left", "a", "p"]:
            self.prev()
        elif event.key == " ":
            if self.playing:
                self.pause()
            else:
                self.play()
        elif event.key in ["r"]:
            self.reset()


def main():
    events, states = run_mcts(num_iterations=100, seed=RANDOM_SEED)
    Stepper(events, states)
    plt.show()


if __name__ == "__main__":
    main()





L 1.0 1.4823038073675114, M 0.0 1.4823038073675114, Rt 0.0 1.4823038073675114, 

L 0.5 1.1774100225154747, M 0.0 1.6651092223153956, Rt 0.0 1.6651092223153956, 

L 0.6666666666666666 1.03583715336408, M 0.0 1.7941225779941015, Rt 0.0 1.7941225779941015, 

L 0.6666666666666666 1.0929347248663588, M 0.0 1.3385661990458504, Rt 0.0 1.8930184728248456, 

L 0.6666666666666666 1.1389791186424545, M 0.0 1.3949588341794583, Rt 0.0 1.3949588341794583, 
LL 1.0 1.4823038073675114, LM 0.0 1.4823038073675114, 

L 0.5 1.019666990168809, M 0.0 1.442026886600883, Rt 0.0 1.442026886600883, 
LL 1.0 1.1774100225154747, LM 0.0 1.6651092223153956, 

L 0.4 0.9374912431241628, M 0.0 1.4823038073675114, Rt 0.0 1.4823038073675114, 

L 0.4 0.9597051824376164, M 0.3333333333333333 1.2389740629499464, Rt 0.0 1.5174271293851465, 
MM 1.0 1.4823038073675114, MR 0.0 1.4823038073675114, 

L 0.4 0.9793661772388039, M 0.25 1.0949646735850365, Rt 0.0 1.5485138917033878, 

L 0.4 0.9969767599674528, M 0.25 1.11465390363