In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei','Noto Sans CJK SC','DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
from matplotlib.patches import Circle, Rectangle
import ipywidgets as widgets
from IPython.display import display, clear_output

class Node:
    def __init__(self, name, value=None, children=None):
        self.name     = name
        self.value    = value
        self.children = children or []
        self.alpha    = -float('inf')
        self.beta     =  float('inf')
        self.final    = None
        self.pruned   = False
        self.visited  = False
        self.current  = False

def build_tree():
    vals = [  1, -15,   2,  19,
            18,  23,   4,   3,
             2,   1,   7,   8,
             9,  10,  -2,   5,
            -1, -30,   4,   7,
            20,  -1,  -1,  -5 ]
    leaves = [Node(f"L{i}", value=v) for i,v in enumerate(vals)]
    lvl3 = [ Node(f"N3_{i//2}", children=leaves[i:i+2])
             for i in range(0, len(leaves), 2) ]
    lvl2 = [ Node(f"N2_{i//2}", children=lvl3[i:i+2])
             for i in range(0, len(lvl3), 2) ]
    lvl1 = [ Node(f"N1_{i//2}", children=lvl2[i:i+2])
             for i in range(0, len(lvl2), 2) ]
    return Node("Root", children=lvl1)

def layout_tree(root, x_spacing=1.5, y_spacing=1.2):
    pos = {}
    depth_map = {}
    x = 0
    def dfs(n, depth):
        nonlocal x
        depth_map[n] = depth
        if not n.children:
            pos[n] = (x * x_spacing, -depth * y_spacing)
            x += 1
        else:
            for c in n.children:
                dfs(c, depth+1)
            xs = [pos[c][0] for c in n.children]
            pos[n] = (sum(xs)/len(xs), -depth * y_spacing)
    dfs(root, 0)
    return pos, depth_map

def mark_pruned(n):
    n.pruned = True
    for c in n.children:
        mark_pruned(c)

def snapshot(root, action):
    def copy_node(n):
        m = Node(n.name, n.value)
        m.alpha, m.beta, m.final = n.alpha, n.beta, n.final
        m.pruned, m.visited, m.current = n.pruned, n.visited, n.current
        m.children = [copy_node(c) for c in n.children]
        return m
    return {"tree": copy_node(root), "action": action}

def alphabeta(n, α, β, is_max, steps):
    n.current = True; n.visited = True; n.alpha, n.beta = α, β
    steps.append(snapshot(root, f"Visit {'MAX' if is_max else 'MIN'} {n.name}"))
    if not n.children:
        n.final = n.value; n.current = False
        steps.append(snapshot(root, f"Leaf {n.name} returns {n.value}"))
        return n.value

    if is_max:
        v = -float('inf')
        for c in n.children:
            if c.pruned: continue
            val = alphabeta(c, α, β, False, steps)
            v = max(v, val); α = max(α, val); n.alpha = α
            steps.append(snapshot(root, f"MAX {n.name} updates α→{α}"))
            if β <= α:
                for sib in n.children[n.children.index(c)+1:]:
                    mark_pruned(sib)
                steps.append(snapshot(root, f"β-cutoff at {n.name}"))
                break
        n.final = v; n.current = False
        return v
    else:
        v = float('inf')
        for c in n.children:
            if c.pruned: continue
            val = alphabeta(c, α, β, True, steps)
            v = min(v, val); β = min(β, val); n.beta = β
            steps.append(snapshot(root, f"MIN {n.name} updates β→{β}"))
            if β <= α:
                for sib in n.children[n.children.index(c)+1:]:
                    mark_pruned(sib)
                steps.append(snapshot(root, f"α-cutoff at {n.name}"))
                break
        n.final = v; n.current = False
        return v

def reset_tree(n):
    n.alpha, n.beta, n.final = -float('inf'), float('inf'), None
    n.pruned, n.visited, n.current = False, False, False
    for c in n.children: reset_tree(c)

def restore(n, snap_node):
    n.alpha, n.beta, n.final       = snap_node.alpha, snap_node.beta, snap_node.final
    n.pruned, n.visited, n.current = snap_node.pruned, snap_node.visited, snap_node.current
    for c, sc in zip(n.children, snap_node.children):
        restore(c, sc)

def draw(root, pos, depth_map, snap, root_is_max):
    fig, ax = plt.subplots(figsize=(30, 6))
    fig.subplots_adjust(left=0.03, right=0.97, top=0.85, bottom=0.12)
    ax.axis('off')

    def draw_edges(n):
        x0,y0 = pos[n]
        for c in n.children:
            x1,y1 = pos[c]
            style = 'r--' if c.pruned else 'k-'
            color, ls = style[0], style[1:]
            ax.plot([x0,x1],[y0,y1],
                    color=color, linestyle=ls,
                    alpha=0.3, linewidth=2, zorder=0)
            draw_edges(c)
    draw_edges(root)

    for n, (x,y) in pos.items():
        d = depth_map[n]
        # determine if this node is a MAX or MIN node
        is_max = root_is_max if (d%2==0) else (not root_is_max)

        face = ('gold'      if n.current  else
                'lightgreen'if n.visited  else
                'red'       if n.pruned   else
                'lightblue')
        α = 0.8 if (n.visited or n.current) else (0.3 if n.pruned else 0.8)

        if is_max:
            patch = Rectangle((x-0.25,y-0.25),0.5,0.5,
                              ec='black', fc=face, alpha=α, zorder=1)
        else:
            patch = Circle((x,y),0.25,
                           ec='black', fc=face, alpha=α, zorder=1)
        ax.add_patch(patch)

        label = n.name if n.children else str(n.value)
        ax.text(x,y, label, ha='center', va='center', fontsize=10, zorder=3)

        if n.visited or n.current:
            info = [f"α={n.alpha if n.alpha!=-float('inf') else '-∞'}",
                    f"β={n.beta  if n.beta != float('inf') else '+∞'}"]
            if n.final is not None:
                info.append(f"val={n.final}")
            ax.text(x,y-0.5, "\n".join(info),
                    ha='center', va='top', fontsize=8,
                    bbox=dict(boxstyle="round", fc="white", ec="gray", pad=0.2),
                    zorder=3)

    ax.text(0.5, 1.01, snap["action"],
            transform=ax.transAxes, ha='center',
            fontsize=14, bbox=dict(facecolor="#eef", edgecolor="black", pad=0.5),
            zorder=3)

    plt.show()

root        = build_tree()
positions, depth_map = layout_tree(root)

root_is_max = True
steps       = []

def generate_steps():
    global steps
    steps = []
    reset_tree(root)
    steps.append(snapshot(root, "initial state"))
    alphabeta(root, -float('inf'), float('inf'), root_is_max, steps)

generate_steps()

# some widgets
slider        = widgets.IntSlider(description='Step', min=0, max=len(steps)-1, value=0)
btn_prev      = widgets.Button(description='Prev')
btn_next      = widgets.Button(description='Next')
btn_reset     = widgets.Button(description='Reset')
btn_toggle    = widgets.Button(description='Root: MAX')
out           = widgets.Output()

def update(change=None):
    idx = slider.value
    reset_tree(root)
    restore(root, steps[idx]["tree"])
    with out:
        clear_output(wait=True)
        draw(root, positions, depth_map, steps[idx], root_is_max)

def on_prev(_):   slider.value = max(0,   slider.value-1)
def on_next(_):   slider.value = min(len(steps)-1, slider.value+1)
def on_reset(_):  slider.value = 0

def on_toggle(_):
    global root_is_max
    root_is_max = not root_is_max
    btn_toggle.description = 'Root: MAX' if root_is_max else 'Root: MIN'
    generate_steps()
    slider.max = len(steps)-1
    slider.value = 0
    update()

# binding
slider.observe(update, names='value')
btn_prev.on_click(on_prev)
btn_next.on_click(on_next)
btn_reset.on_click(on_reset)
btn_toggle.on_click(on_toggle)

controls = widgets.HBox([btn_prev, slider, btn_next, btn_reset, btn_toggle])
display(controls, out)
update()


HBox(children=(Button(description='Prev', style=ButtonStyle()), IntSlider(value=0, description='Step', max=78)…

Output()