In [4]:
import math
from collections import deque

def compute_value(node, maximizingPlayer):
    """Compute the MIN/MAX value of any subtree."""
    if isinstance(node, int):
        return node
    if maximizingPlayer:
        return max(compute_value(c, False) for c in node)
    else:
        return min(compute_value(c, True) for c in node)

def print_tree_levels_with_values(tree, title="TREE"):
    print(f"\n===== {title} =====\n")

    queue = deque([(tree, 0, True)])  # (node, depth, maximizing?)
    levels = {}

    while queue:
        node, lvl, maximizing = queue.popleft()
        levels.setdefault(lvl, [])

        if isinstance(node, int):
            levels[lvl].append(str(node))
        else:
            label = "MAX" if maximizing else "MIN"
            val = compute_value(node, maximizing)
            levels[lvl].append(f"{label}({val})")

            for child in node:
                queue.append((child, lvl+1, not maximizing))

    for lvl, items in levels.items():
        print(f"Level {lvl}:  " + "   ".join(items))
    print()


def alpha_beta(node, depth, alpha, beta, maximizingPlayer, pruned_nodes):
    """Standard Alpha-Beta pruning with node ID tracking."""
    if isinstance(node, int):
        return node

    if maximizingPlayer:
        value = -math.inf
        for child in node:
            result = alpha_beta(child, depth-1, alpha, beta, False, pruned_nodes)
            value = max(value, result)
            alpha = max(alpha, value)
            if beta <= alpha:
                # Mark remaining siblings as pruned
                idx = node.index(child)
                for x in node[idx+1:]:
                    pruned_nodes.add(id(x))
                break
        return value

    else:
        value = math.inf
        for child in node:
            result = alpha_beta(child, depth-1, alpha, beta, True, pruned_nodes)
            value = min(value, result)
            beta = min(beta, value)
            if beta <= alpha:
                idx = node.index(child)
                for x in node[idx+1:]:
                    pruned_nodes.add(id(x))
                break
        return value


def prune_tree(node, pruned_nodes):
    """Return a reduced tree after alpha-beta pruning."""
    if isinstance(node, int):
        return node

    new_children = []
    for child in node:
        if id(child) not in pruned_nodes:
            new_children.append(prune_tree(child, pruned_nodes))

    return new_children


tree = [
    [ [10, 9], [14, 18] ],
    [ [5, 4], [50, 3] ]
]


print_tree_levels_with_values(tree, "ORIGINAL TREE")

pruned_nodes = set()
result = alpha_beta(tree, depth=3, alpha=-math.inf, beta=math.inf,
                    maximizingPlayer=True, pruned_nodes=pruned_nodes)

pruned_tree = prune_tree(tree, pruned_nodes)

print_tree_levels_with_values(pruned_tree, "TREE AFTER ALPHA-BETA PRUNING")

print("Final Optimal Value =", result)




===== ORIGINAL TREE =====

Level 0:  MAX(10)
Level 1:  MIN(10)   MIN(5)
Level 2:  MAX(10)   MAX(18)   MAX(5)   MAX(50)
Level 3:  10   9   14   18   5   4   50   3


===== TREE AFTER ALPHA-BETA PRUNING =====

Level 0:  MAX(10)
Level 1:  MIN(10)   MIN(5)
Level 2:  MAX(10)   MAX(14)   MAX(5)
Level 3:  10   9   14   5   4

Final Optimal Value = 10
