# 蒙特卡洛树搜索（MCTS）

蒙特卡洛树搜索（Monte Carlo Tree Search，MCTS）是一种让人工智能具备高效决策能力的核心算法，在复杂序列决策问题的求解中占据重要地位。该算法的应用价值显著，典型案例如AlphaGo在围棋领域击败人类顶尖选手，其核心技术支撑便包含MCTS。MCTS的核心优势在于，能够在可选方案极多的复杂问题中，快速定位可靠的解决方案。以围棋为例，其状态空间规模高达$10^{170}$，远超宇宙原子总数，传统算法难以应对，而MCTS可有效解决这一困境。

In [123]:
import math
import random
import time
from collections import deque

## MCTS节点

关键原理对应：
1. `select`方法：直接实现UCT公式，是“选择”阶段的核心，对应原理中“用UCT指导路径选择”；
2. `expand`方法：对应“扩展”阶段，新增子节点拓展搜索树，对应原理中“新增探索节点”；
3. `simulate`方法：对应“模拟”阶段，快速验证路径价值，为回溯提供收益数据；
4. `backpropagate`方法：对应“回溯”阶段，更新节点 $n_i$ 和 $w_i$，对应原理中“积累探索经验”。

In [124]:
class MCTSNode:
    def __init__(self, pos, end_pos=(2,2), grid_size=(3,3), obstacle_pos=(1,1), parent=None):
        self.pos = pos  # 当前位置 (x,y)
        self.end_pos = end_pos
        self.grid_size = grid_size
        self.obstacle_pos = obstacle_pos
        
        # MCTS统计数据
        self.visits = 0  # 访问次数
        self.total_reward = 0.0  # 总奖励
        self.reward = 0.0  # 平均奖励
        
        # 树结构
        self.parent = parent
        self.children = {}  # 使用字典存储子节点，key为位置
        self.untried_moves = self._get_valid_moves()  # 未尝试的移动
        
    def _get_valid_moves(self):
        """获取所有有效移动"""
        if self.pos == self.end_pos:
            return []
            
        moves = []
        x, y = self.pos
        
        # 定义四个方向
        directions = [
            ('up', (x, y+1)),
            ('down', (x, y-1)), 
            ('left', (x-1, y)),
            ('right', (x+1, y))
        ]
        
        for dir_name, (nx, ny) in directions:
            # 检查边界
            if not (0 <= nx < self.grid_size[0] and 0 <= ny < self.grid_size[1]):
                continue
            # 检查障碍物
            if (nx, ny) == self.obstacle_pos:
                continue
            # 检查父节点（防止直接回退，但允许探索新路径）
            if self.parent and (nx, ny) == self.parent.pos:
                continue
            moves.append((dir_name, (nx, ny)))
            
        return moves
    
    def is_fully_expanded(self):
        """检查节点是否完全扩展（所有可能的子节点都已创建）"""
        return len(self.untried_moves) == 0
    
    def is_terminal(self):
        """检查是否为终止节点"""
        return self.pos == self.end_pos
    
    def select_child(self):
        """使用UCT公式选择最佳子节点"""
        if not self.children:
            return None
            
        # UCT参数
        C = math.sqrt(2)  # 探索常数
        
        best_child = None
        best_uct = -float('inf')
        
        for child in self.children.values():
            # 如果有未访问的子节点，优先选择（这样可以保证探索）
            if child.visits == 0:
                # 给未访问节点一个高的UCT值，但不是无穷大
                return child
            
            # 计算UCT值
            exploitation = child.reward  # 平均奖励
            exploration = C * math.sqrt(math.log(self.visits) / child.visits)
            uct = exploitation + exploration
            
            if uct > best_uct:
                best_uct = uct
                best_child = child
                
        return best_child
    
    def expand(self):
        """扩展一个新子节点"""
        if not self.untried_moves:
            return None  # 无法扩展
            
        # 随机选择一个未尝试的移动
        move, new_pos = random.choice(self.untried_moves)
        self.untried_moves.remove((move, new_pos))
        
        # 创建新节点
        new_child = MCTSNode(
            pos=new_pos,
            end_pos=self.end_pos,
            grid_size=self.grid_size,
            obstacle_pos=self.obstacle_pos,
            parent=self
        )
        
        self.children[new_pos] = new_child
        return new_child
    
    def simulate(self):
        """从当前节点开始随机模拟"""
        current_pos = self.pos
        visited = set([current_pos])
        path_length = 0
        max_steps = self.grid_size[0] * self.grid_size[1] * 2  # 防止无限循环
        
        while current_pos != self.end_pos and path_length < max_steps:
            # 获取当前所有有效移动
            x, y = current_pos
            possible_moves = []
            
            # 检查四个方向
            for nx, ny in [(x, y+1), (x, y-1), (x-1, y), (x+1, y)]:
                # 检查边界
                if not (0 <= nx < self.grid_size[0] and 0 <= ny < self.grid_size[1]):
                    continue
                # 检查障碍物
                if (nx, ny) == self.obstacle_pos:
                    continue
                # 检查是否访问过（防止循环）
                if (nx, ny) in visited:
                    continue
                
                possible_moves.append((nx, ny))
            
            if not possible_moves:
                # 无路可走，模拟失败
                return 0.0, path_length + 1
            
            # 随机选择下一步
            current_pos = random.choice(possible_moves)
            visited.add(current_pos)
            path_length += 1
        
        # 计算奖励
        if current_pos == self.end_pos:
            # 到达终点，奖励与路径长度成反比
            # 最短可能路径长度（曼哈顿距离）
            min_distance = abs(self.pos[0] - self.end_pos[0]) + abs(self.pos[1] - self.end_pos[1])
            expected_length = min_distance + 2  # 稍微放宽
            
            if path_length <= expected_length:
                reward = 1.0
            else:
                # 路径越长，奖励越低
                penalty = (path_length - expected_length) / (max_steps - expected_length)
                reward = max(0.1, 1.0 - penalty * 0.5)
            return reward, path_length + 1
        else:
            return 0.0, path_length + 1
    
    def backpropagate(self, reward):
        """回溯更新节点统计信息"""
        self.visits += 1
        self.total_reward += reward
        self.reward = self.total_reward / self.visits
        
        if self.parent:
            self.parent.backpropagate(reward)

## MCTS主逻辑（迭代控制和可视化）

In [125]:
class MCTSPathPlanner:
    def __init__(self, start_pos, end_pos, grid_size, obstacle_pos):
        self.start_pos = start_pos
        self.end_pos = end_pos
        self.grid_size = grid_size
        self.obstacle_pos = obstacle_pos
        self.root = MCTSNode(
            pos=start_pos,
            end_pos=end_pos,
            grid_size=grid_size,
            obstacle_pos=obstacle_pos
        )
        self.best_path = []
        self.best_reward = -1.0
    
    def _tree_policy(self, node):
        """树策略：选择和扩展节点"""
        # 一直选择直到遇到未完全扩展的节点或终止节点
        while not node.is_terminal():
            if not node.is_fully_expanded():
                # 扩展节点
                new_child = node.expand()
                if new_child:
                    return new_child
            
            # 如果节点已完全扩展，选择最佳子节点继续
            child = node.select_child()
            if child is None:
                break  # 没有子节点可以选择了
            node = child
        
        return node
    
    def search(self, iterations=1000, timeout=10):
        """执行MCTS搜索"""
        start_time = time.time()
        
        for i in range(iterations):
            if time.time() - start_time > timeout:
                print(f"超时，完成 {i} 次迭代")
                break
            
            # 1. 选择阶段：使用树策略选择/扩展节点
            leaf_node = self._tree_policy(self.root)
            
            # 2. 模拟阶段：从叶节点开始模拟
            reward, steps = leaf_node.simulate()
            
            # 3. 回溯阶段：更新路径上的所有节点
            leaf_node.backpropagate(reward)
            
            # 记录最佳路径（基于奖励）
            if reward > self.best_reward:
                self.best_reward = reward
                # 从根节点到当前节点的路径
                path = []
                node = leaf_node
                while node:
                    path.append(node.pos)
                    node = node.parent
                self.best_path = list(reversed(path))
            
            # 进度显示
            if (i + 1) % 100 == 0:
                print(f"迭代 {i+1}/{iterations}, 最佳奖励: {self.best_reward:.3f}")
        
        print(f"\n搜索完成:")
        print(f"总迭代次数: {min(i+1, iterations)}")
        print(f"最佳奖励: {self.best_reward:.3f}")
        
        # 如果找到了路径，确保它是从起点到终点的完整路径
        if self.best_path and self.best_path[-1] != self.end_pos:
            # 尝试从最后节点继续搜索到终点
            final_path = self._complete_path(self.best_path)
            if final_path:
                self.best_path = final_path
        
        return self.best_path
    
    def _complete_path(self, partial_path):
        """尝试完成部分路径"""
        if not partial_path:
            return None
            
        # 使用简单的BFS从最后节点到终点
        start = partial_path[-1]
        
        # BFS寻找最短路径
        queue = deque([(start, [start])])
        visited = set(partial_path)  # 避免回到已访问的节点
        
        while queue:
            current, path = queue.popleft()
            
            if current == self.end_pos:
                # 合并路径（去掉起点重复）
                return partial_path[:-1] + path
            
            # 生成邻居
            x, y = current
            for nx, ny in [(x, y+1), (x, y-1), (x-1, y), (x+1, y)]:
                if not (0 <= nx < self.grid_size[0] and 0 <= ny < self.grid_size[1]):
                    continue
                if (nx, ny) == self.obstacle_pos:
                    continue
                if (nx, ny) in visited:
                    continue
                    
                visited.add((nx, ny))
                queue.append(((nx, ny), path + [(nx, ny)]))
        
        return None
    
    def get_best_path_from_tree(self):
        """直接从MCTS树中提取最佳路径"""
        path = [self.root.pos]
        current = self.root
        
        # 贪心地选择平均奖励最高的子节点
        while current.children and not current.is_terminal():
            best_child = None
            best_reward = -1.0
            
            for child in current.children.values():
                if child.reward > best_reward and child.pos not in path:
                    best_reward = child.reward
                    best_child = child
            
            if best_child is None:
                break
                
            path.append(best_child.pos)
            current = best_child
        
        return path
    
    def visualize_grid(self, path=None):
        """可视化网格和路径"""
        if path is None:
            path = self.best_path
            
        print("\n=== 网格可视化 ===")
        for y in range(self.grid_size[1]-1, -1, -1):
            row = []
            for x in range(self.grid_size[0]):
                pos = (x, y)
                if pos == self.start_pos:
                    row.append(" S ")
                elif pos == self.end_pos:
                    row.append(" E ")
                elif pos == self.obstacle_pos:
                    row.append(" X ")
                elif path and pos in path:
                    idx = path.index(pos)
                    row.append(f"{idx:2d} ")
                else:
                    row.append(" . ")
            print("".join(row))
        
        if path:
            print(f"\n路径详细: {' -> '.join(str(pos) for pos in path)}")
            
            # 验证路径连续性
            valid = True
            for i in range(len(path)-1):
                p1, p2 = path[i], path[i+1]
                dx = abs(p1[0] - p2[0])
                dy = abs(p1[1] - p2[1])
                if dx + dy != 1:
                    print(f"警告: 路径不连续 {p1} -> {p2}")
                    valid = False
                if p2 == self.obstacle_pos:
                    print(f"错误: 路径经过障碍物 {p2}")
                    valid = False
            
            if valid:
                print("路径验证: 有效")
                if path[-1] == self.end_pos:
                    print(f"✓ 成功到达终点！路径长度: {len(path)-1}步")
                else:
                    print(f"⚠ 路径未到达终点，最后位置: {path[-1]}")
            else:
                print("路径验证: 无效")

## 启动控制

In [126]:
def find_path_mcts(start_pos, end_pos, grid_size, obstacle_pos, iterations=500):
    """使用MCTS寻找路径的简单接口"""
    print("=" * 50)
    print("MCTS路径规划")
    print("=" * 50)
    print(f"起点: {start_pos}")
    print(f"终点: {end_pos}")
    print(f"网格大小: {grid_size}")
    print(f"障碍物: {obstacle_pos}")
    print(f"迭代次数: {iterations}")
    
    planner = MCTSPathPlanner(start_pos, end_pos, grid_size, obstacle_pos)
    path = planner.search(iterations=iterations, timeout=5)
    
    if path:
        planner.visualize_grid(path)
    else:
        print("\n✗ 未找到路径")
    
    return path

## 启动

In [127]:
# 测试1: 简单3x3网格
print("\n测试1: 3x3网格")
path1 = find_path_mcts(
    start_pos=(0, 0),
    end_pos=(2, 2),
    grid_size=(3, 3),
    obstacle_pos=(1, 1),
    iterations=300
)

# 测试2: 无障碍物
print("\n\n" + "="*50)
print("测试2: 无障碍物")
path2 = find_path_mcts(
    start_pos=(0, 0),
    end_pos=(2, 2),
    grid_size=(3, 3),
    obstacle_pos=(-1, -1),  # 无有效障碍物
    iterations=200
)

# 测试3: 更大网格
print("\n\n" + "="*50)
print("测试3: 4x4网格")
path3 = find_path_mcts(
    start_pos=(0, 0),
    end_pos=(3, 3),
    grid_size=(4, 4),
    obstacle_pos=(1, 2),
    iterations=1000
)

# 测试4: 更大网格
print("\n\n" + "="*50)
print("测试3: 5x5网格")
path3 = find_path_mcts(
    start_pos=(0, 0),
    end_pos=(4, 4),
    grid_size=(5, 5),
    obstacle_pos=(0, 1),
    iterations=2000
)


测试1: 3x3网格
MCTS路径规划
起点: (0, 0)
终点: (2, 2)
网格大小: (3, 3)
障碍物: (1, 1)
迭代次数: 300
迭代 100/300, 最佳奖励: 1.000
迭代 200/300, 最佳奖励: 1.000
迭代 300/300, 最佳奖励: 1.000

搜索完成:
总迭代次数: 300
最佳奖励: 1.000

=== 网格可视化 ===
 .  .  E 
 .  X  3 
 S  1  2 

路径详细: (0, 0) -> (1, 0) -> (2, 0) -> (2, 1) -> (2, 2)
路径验证: 有效
✓ 成功到达终点！路径长度: 4步


测试2: 无障碍物
MCTS路径规划
起点: (0, 0)
终点: (2, 2)
网格大小: (3, 3)
障碍物: (-1, -1)
迭代次数: 200
迭代 100/200, 最佳奖励: 1.000
迭代 200/200, 最佳奖励: 1.000

搜索完成:
总迭代次数: 200
最佳奖励: 1.000

=== 网格可视化 ===
 .  3  E 
 .  2  . 
 S  1  . 

路径详细: (0, 0) -> (1, 0) -> (1, 1) -> (1, 2) -> (2, 2)
路径验证: 有效
✓ 成功到达终点！路径长度: 4步


测试3: 4x4网格
MCTS路径规划
起点: (0, 0)
终点: (3, 3)
网格大小: (4, 4)
障碍物: (1, 2)
迭代次数: 1000
迭代 100/1000, 最佳奖励: 1.000
迭代 200/1000, 最佳奖励: 1.000
迭代 300/1000, 最佳奖励: 1.000
迭代 400/1000, 最佳奖励: 1.000
迭代 500/1000, 最佳奖励: 1.000
迭代 600/1000, 最佳奖励: 1.000
迭代 700/1000, 最佳奖励: 1.000
迭代 800/1000, 最佳奖励: 1.000
迭代 900/1000, 最佳奖励: 1.000
迭代 1000/1000, 最佳奖励: 1.000

搜索完成:
总迭代次数: 1000
最佳奖励: 1.000

=== 网格可视化 ===
 .  .  5  E 
 .  X  4  . 
 .  .  