In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random, math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim



class UnionFind:
    """简单的并查集，用于判断节点连通情况"""
    def __init__(self, n):
        self.parent = list(range(n))

    def find(self, x):
        while self.parent[x] != x:
            self.parent[x] = self.parent[self.parent[x]]
            x = self.parent[x]
        return x

    def union(self, x, y):
        rx = self.find(x)
        ry = self.find(y)
        if rx != ry:
            self.parent[ry] = rx

class Region:
    def __init__(self, id, pos, H, region_discount_rate, ARC):
        """
        参数：
          id: 区域编号
          pos: 区域在二维平面上的位置 (x, y)
          H: 燃料量（血量），取值 100～1000
          region_discount_rate: 区域自然衰减的折扣率（0～1）
          ARC: 面积（公顷）
        """
        self.id = id
        self.pos = np.array(pos)
        self.H = H
        self.max_health = H
        self.region_discount_rate = region_discount_rate
        self.ARC = ARC
        self.K = 1.0  # 火焰强度初始为 0
        # 定义燃烧系数，设定常数 c（这里取 0.001），使得燃烧系数 = c * ARC
        self.transmisive = 5

    @property
    def active(self):
        """如果区域尚未烧尽，则返回 1，否则 0"""
        return 1 if self.H > 0 else 0

    def pos_vector(self):
        #return a standarized position vector
        return 0

class Edge:
    def __init__(self, region1, region2, border_direction, shared_border_area):
        """
        表示两个区域之间的接壤关系
        参数：
          region1, region2: 相邻的两个 Region 对象
          border_direction: 接壤朝向（一个二维单位向量），这里约定为从 region1 指向 region2
          shared_border_area: 接壤面积，取两区域面积平方根的较小值
        """
        self.region1 = region1
        self.region2 = region2
        self.border_direction = border_direction  # np.array，形状 (2,)
        self.shared_border_area = shared_border_area

class WildfireSimulation:
    def __init__(self, num_regions):
        self.num_regions = num_regions
        self.nodes = {}      # {region id: Region}
        self.edges = []      # Edge 列表
        self.adjacency = {}  # 邻接表：{region id: [(neighbor_id, Edge), ...]}
        self.history = []    # 记录每个时间步的状态: (t, env_factors, state)
        self.gamma = 0.8     # 折扣率（用于火焰传播中描述自身衰减）

        self.init_regions()
        self.init_edges()      # 保证生成的 graph 是 fully connected 的
        self.init_adjacency()

        # 指定火源（起火源）节点，默认为节点 0，并预先点燃
        # self.fire_origin = np.random.randint(0, self.num_regions)
        self.fire_origin = np.random.randint(0, 3)
        self.nodes[self.fire_origin].K = 50.0

    def init_regions(self):
        """随机生成区域及其属性"""
        for i in range(self.num_regions):
            pos = (random.uniform(0, 100), random.uniform(0, 100))
            # H = random.uniform(600, 1000)
            H = 1000
            region_discount_rate = 0.2
            ARC = 10
            self.nodes[i] = Region(i, pos, H, region_discount_rate, ARC)

    def init_edges(self):
        """
        根据区域在平面上的位置，若两区域中心距离小于阈值，则认为它们相邻，建立边。
        为了保证 graph 是 fully connected 的，我们先按阈值生成边，
        然后利用并查集判断连通性，若存在不连通部分，再从所有可能的节点对中按距离最小的顺序依次添加边，
        直到整个 graph 连通。
        """
        self.edges = []
        node_ids = list(self.nodes.keys())
        threshold = 30.0  # 邻近判断阈值

        # 根据阈值生成初始的边
        for i in range(len(node_ids)):
            for j in range(i+1, len(node_ids)):
                node_i = self.nodes[node_ids[i]]
                node_j = self.nodes[node_ids[j]]
                dist = np.linalg.norm(node_i.pos - node_j.pos)
                if dist < threshold:
                    direction = node_j.pos - node_i.pos
                    if np.linalg.norm(direction) != 0:
                        border_direction = direction / np.linalg.norm(direction)
                    else:
                        border_direction = np.array([1.0, 0.0])
                    shared_border_area = min(math.sqrt(node_i.ARC), math.sqrt(node_j.ARC))
                    self.edges.append(Edge(node_i, node_j, border_direction, shared_border_area))

        # 利用并查集检查当前 graph 的连通性
        uf = UnionFind(self.num_regions)
        for edge in self.edges:
            uf.union(edge.region1.id, edge.region2.id)

        # 将所有节点对按距离从小到大排序
        all_pairs = []
        for i in range(len(node_ids)):
            for j in range(i+1, len(node_ids)):
                node_i = self.nodes[node_ids[i]]
                node_j = self.nodes[node_ids[j]]
                dist = np.linalg.norm(node_i.pos - node_j.pos)
                all_pairs.append((dist, node_ids[i], node_ids[j]))
        all_pairs.sort(key=lambda x: x[0])

        # 如果有不连通的部分，依次添加最小距离边，直到所有节点都连通
        for dist, i, j in all_pairs:
            if uf.find(i) != uf.find(j):
                node_i = self.nodes[i]
                node_j = self.nodes[j]
                direction = node_j.pos - node_i.pos
                if np.linalg.norm(direction) != 0:
                    border_direction = direction / np.linalg.norm(direction)
                else:
                    border_direction = np.array([1.0, 0.0])
                shared_border_area = min(math.sqrt(node_i.ARC), math.sqrt(node_j.ARC))
                self.edges.append(Edge(node_i, node_j, border_direction, shared_border_area))
                uf.union(i, j)

    def init_adjacency(self):
        """建立邻接表，每个区域记录所有相邻区域及对应边信息"""
        for node_id in self.nodes:
            self.adjacency[node_id] = []
        for edge in self.edges:
            id1 = edge.region1.id
            id2 = edge.region2.id
            self.adjacency[id1].append((id2, edge))
            self.adjacency[id2].append((id1, edge))

    def run_session_TD(self, max_iter):
        """
        外部控制方法：运行 T 个时间步的模拟。
        首先记录初始状态（t=0），然后逐步调用 next_step 来推进模拟。
        """
        # 记录初始状态（t=0）
        initial_env = {
            'wind_direction': None,
            'wind_strength': None,
            'dryness': None,
            'temperature': None
        }
        state = {node_id: {'pos': node.pos.copy(), 'H': node.H, 'K': node.K}
                 for node_id, node in self.nodes.items()}
        self.history.append((0, initial_env, state, None))

        # 逐步模拟，从 t=1 到 end
        t = 0
        while t < max_iter:
            t += 1
            self.next(t)
            if self.is_end():
                print(f"game end at {t}")
                self.history.append('end')
                break
        self.history.append('end')

    def is_end(self):
        contained = True
        for node in self.nodes.keys():
            if self.nodes[node].H >= 1 and self.nodes[node].K >= 1:
                contained = False
        return contained

    def next(self, t, agent):
        """
        执行从时刻 t-1 到时刻 t 的一步模拟：
         - 生成当前的环境因子（风向、风力、干燥度、温度）
         - 计算每个区域的火焰传播：计算 retention, incoming, outgoing，
           得到新的火焰强度 new_K
         - 根据 new_K 更新每个区域的火焰强度 K 和剩余燃料 H（考虑死亡率）
         - 将当前时刻的状态记录到 history 中
        """
        # print(f"--------------------------------{t}---------------------------")
        # 生成环境因子
        wind_angle = random.uniform(0, 2*math.pi)
        # wind_angle = 1
        wind_direction = np.array([math.cos(wind_angle), math.sin(wind_angle)])
        # wind_strength = random.uniform(0, 1)
        wind_strength = 1
        dryness = 1.5
        # temperature = random.uniform(0, 1)
        temperature = 1
        env_factors = {
            'wind_direction': wind_direction,
            'wind_strength': wind_strength,
            'dryness': dryness,
            'temperature': temperature
        }

        # 这里定义一个因子，用于U->V之间的传播计算
        factor = 0.5 * (dryness + temperature)

        new_K = {}
        # 对每个区域计算火焰强度的更新
        for node_id, node in self.nodes.items():
            if node.H < 0:
                # print(f"node {node_id} stop transimission")
                new_K[node_id] = 0.0
                node.H = 0
            else:
                # retention 部分（内部传播保持部分）
                retention = node.active * node.transmisive * node.K
                incoming = 0.0
                outgoing = 0.0

                # 遍历邻接区域，计算火焰传播（incoming 和 outgoing）
                for (nbr_id, edge) in self.adjacency[node_id]:
                    neighbor = self.nodes[nbr_id]
                    if edge.region1.id == node_id:
                        phi_in = np.dot(wind_direction, -edge.border_direction) * edge.shared_border_area * wind_strength
                    else:
                        phi_in = np.dot(wind_direction, edge.border_direction) * edge.shared_border_area * wind_strength
                    incoming += factor * neighbor.K * phi_in

                    if edge.region1.id == node_id:
                        phi_out = np.dot(wind_direction, edge.border_direction) * edge.shared_border_area * wind_strength
                    else:
                        phi_out = np.dot(wind_direction, -edge.border_direction) * edge.shared_border_area * wind_strength
                    outgoing += factor * node.K * phi_out

                # 计算新火焰强度
                K_new_value = retention + max(0, incoming) - max(0, outgoing)
                new_K[node_id] = K_new_value if K_new_value > 0 else 0.0

                # print(f"node {node_id}")
                # print(f"K new is: {retention} + {incoming} - {outgoing}")
                # print(f" node {node.id} K: {node.K} -> {K_new_value}")

        # 更新所有节点的火焰强度 K 和剩余燃料 H（这里根据传播模型计算死亡量）
        # print("在行动之前的新 K 值：", list(new_K.values()))
        #action的状态转移，影响的是newK
        action, dist = agent.action(self)
        action = action.detach().cpu().numpy()

        new_K = self.apply_action(action, agent, new_K)
        # print("行动之后的新 K 值：", list(new_K.values()))
        #转移之后，根据最后的newK来更新
        for node_id, node in self.nodes.items():
            node.K = new_K[node_id]
            if node.H > 0:
                # new_death 表示在本次更新中因火焰传播造成的死亡
                new_death = node.region_discount_rate * new_K[node_id]
                # 此处更新：先扣除死亡部分，再更新剩余 H
                new_K[node_id] -= new_death
                node.H -= new_death

            if node.H<0:
                node.H = 0
                node.K = 0


        # 记录当前时刻状态
        state = {node_id: {'pos': node.pos.copy(), 'H': node.H, 'K': node.K}
                 for node_id, node in self.nodes.items()}
        self.history.append((t, env_factors, state, action))


    def apply_action(self, action, agent, newK):
        """
        状态转移函数 P(s_t' | s_t, a_t)：

        agent 的救火 action 会对环境 V_t 中的每个节点产生抑制作用。具体规则如下：
          对于每一个节点 i，其火焰强度更新为：

              i.K = max(0, i.K - (N * 100))

          其中，N 表示分配给该区域的飞机数量，计算公式：

              N = agent.J * weight_i

          其中 action 是一个权重向量 a_t = [weight_0, weight_1, ..., weight_{n-1}]
          满足 ∑_i weight_i = 1；agent.J 为消防飞机总数。

        参数:
          action: 列表或数组，长度等于节点数，每一项为对应节点分配的飞机百分比（权重）。
          agent: FirefightingAgent 对象，其属性 J 表示消防飞机总数。
        """
        # 遍历所有节点，根据对应的 action 权重更新火焰强度 K
        final_K = {}
        for i in range(0, len(list(newK.keys()))):
            weight_i = action[i]
            # 计算分配到该节点的飞机数量
            N = agent.J * weight_i
            reduction = N * agent.capability  # 每架飞机减少 100 单位火焰强度
            # 更新火焰强度，并确保不为负数
            K_before = newK[i]
            K_after = max(0, K_before - reduction)
            final_K[i] = K_after
            # print(f"节点 {i}: K {K_before:.1f} -> {K_after:.1f}，飞机分配: {N:.2f} (weight {weight_i:.2f})")

        return final_K #字典final_K


    def return_lists(self):
        #at time t: (t, env_factors (Bt), state (Vt)), Vt state: {node_id: {pos, H, K} for each node_id}
        E = self.adjacency
        max_health_total = np.sum([self.nodes[id].max_health for id in range(0,self.num_regions)])
        print(self.history)
        B_list = []
        Vt_list = []
        a_list = []
        rt_list = []

        t = 1
        while self.history[t+1] != 'end':
            Bt = self.history[t][1]
            Vt = self.history[t][2]
            Vnext = self.history[t+1][2]
            at = self.history[t][3]

            H_total_t = np.sum([list(Vt.values())[i]['H'] for i in range(0, self.num_regions)])
            H_total_next = np.sum([list(Vnext.values())[i]['H'] for i in range(0, self.num_regions)])
            diff = H_total_next - H_total_t
            rt = diff - 10
            B_list.append(Bt)
            Vt_list.append(Vt)
            rt_list.append(rt)
            a_list.append(at)
            t+=1

        #get reward for last session
        Bt = self.history[t][1]
        Vt = self.history[t][2]
        at = self.history[t][3]
        Ht_values = [list(Vt.values())[node]['H'] for node in range(0, self.num_regions)]
        Ht_max_values = [self.nodes[node].max_health for node in range(0, self.num_regions)]
        num_protected = np.sum([1 for node in range(0, self.num_regions) if (Ht_max_values[node]-Ht_values[node])/Ht_max_values[node] <= 0.10 ])
        print(f"num protected: {num_protected}")
        rt = np.sum(Ht_values) + num_protected*500
        rt_list.append(rt)
        Vt_list.append(Vt)
        B_list.append(Bt)
        a_list.append(at)
        assert len(rt_list) == t
        return (E, Vt_list, B_list, a_list, rt_list)


    def plot_at_t(self, t):
        """
        绘制指定时刻 t 的区域状态：
        - 每个区域根据其在二维平面上的位置绘制，并用颜色显示剩余燃料 H
        - 图中标题中显示当前环境因子信息（如果有），否则显示“初始状态”
        """
        record = None
        for rec in self.history:
            if rec[0] == t:
                record = rec
                break
        if record is None:
            print("未找到时刻 {} 的记录。".format(t))
            return
        time_step, env_factors, state, action = record

        xs, ys, colors, labels = [], [], [], []
        for node_id, info in state.items():
            xs.append(info['pos'][0])
            ys.append(info['pos'][1])
            colors.append(info['H'])
            labels.append(f"{node_id}\nH:{info['H']:.1f}\nK:{info['K']:.1f}")

        plt.figure(figsize=(8, 6))
        sc = plt.scatter(xs, ys, c=colors, cmap='hot', s=200, vmin=0, vmax=1000)
        for i, txt in enumerate(labels):
            plt.annotate(txt, (xs[i], ys[i]), textcoords="offset points", xytext=(0, 10), ha='center')
        plt.colorbar(sc, label="remained H")

        # 绘制连通关系（灰色虚线）
        for edge in self.edges:
            pos1 = edge.region1.pos
            pos2 = edge.region2.pos
            plt.plot([pos1[0], pos2[0]], [pos1[1], pos2[1]],
                    color='gray', linestyle='--', linewidth=1, zorder=0)

        # 如果环境因子中包含风向信息，则绘制风向箭头
        if env_factors['wind_direction'] is not None:
            wind_direction = env_factors['wind_direction']
            arrow_scale = 15  # 箭头长度的缩放因子，可根据需要调整
            # 固定绘制位置，例如 (50,50)
            x0 = 50
            y0 = 50
            head_width = 10 * env_factors['wind_strength'] if env_factors['wind_strength'] is not None else 0
            plt.arrow(x0, y0,
                    wind_direction[0] * arrow_scale,
                    wind_direction[1] * arrow_scale,
                    head_width=head_width, head_length=4, fc='blue', ec='blue', linewidth=2)
            plt.text(x0, y0, "Wind", color='blue', fontsize=12)
            title_str = ("time {} s firesimul\n: w_direct {} | w_velocity: {:.2f} | dryness: {} | temp: {:.2f}"
                        .format(time_step,
                                np.round(env_factors['wind_direction'], 2),
                                env_factors['wind_strength'],
                                env_factors['dryness'],
                                env_factors['temperature']))
        else:
            # t=0 时未设置环境因子
            title_str = f"time {time_step} s firesimul original state"

        plt.title(title_str)
        plt.xlabel("X 坐标")
        plt.ylabel("Y 坐标")
        plt.xlim([-5, 105])
        plt.ylim([-5, 105])
        plt.grid(True)
        plt.show()




# 定义一个初始化函数
def init_weights(m):
    if isinstance(m, nn.Linear):  # 只对Linear层进行初始化
        nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)

def check_parameters(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            if torch.isnan(param).any():
                print(f"⚠️ 参数 {name} 出现 NaN 值！")
            if torch.isinf(param).any():
                print(f"⚠️ 参数 {name} 出现 Inf 值！")
def check_gradients(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            if param.grad is not None:
                if torch.isnan(param.grad).any():
                    print(f"⚠️ 梯度 {name} 出现 NaN 值！")
                if torch.isinf(param.grad).any():
                    print(f"⚠️ 梯度 {name} 出现 Inf 值！")
            else:
                print(f"⚠️ 梯度 {name} 为 None，可能没有参与计算图！")
def print_model_parameters(model):
    for name, param in model.named_parameters():
        print(f"参数名: {name}")
        print(f"参数值: {param.data}")  # 打印参数值
        if torch.isnan(param).any():
            print(f"⚠️ 参数 {name} 出现 NaN 值！")
        if torch.isinf(param).any():
            print(f"⚠️ 参数 {name} 出现 Inf 值！")
        print("-" * 40)


class A2CNetwork(nn.Module):
    def __init__(self, num_nodes, env_dim=5, hidden_dim=512):
        """
        num_nodes: 节点个数
        env_dim: 环境向量的维度（例如 [wind_direction_x, wind_direction_y, wind_strength, dryness, temperature]）
        hidden_dim: 隐藏层维度
        """
        super(A2CNetwork, self).__init__()
        # 对环境信息做全连接嵌入
        self.env_fc = nn.Linear(env_dim, hidden_dim)
        # 对图节点的不变特征（X_invariant，形状: (num_nodes,2)）与变化特征（X_variant，形状: (num_nodes,2)）进行拼接后为4维，
        # 然后映射到 hidden_dim 维
        self.gcn_fc = nn.Linear(4, hidden_dim)
        self.gcn_fc2 = nn.Linear(hidden_dim, hidden_dim)
        # 最后将所有节点的隐藏表示汇聚（这里采用求和）得到图嵌入
        # 接着将图嵌入与环境嵌入拼接后输入 MLP
        self.actor_mlp = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim*2),
            nn.LayerNorm(hidden_dim*2),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 32),
            nn.Dropout(0.5),
            nn.Linear(32, num_nodes)     # 输出 logits（未归一化）——对应每个节点的分配权重
        )
        self.critic_mlp = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),  # 添加 LayerNorm
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, 32),
            nn.LayerNorm(32),  # 添加 LayerNorm
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(32, 1)
        )
        # 使用 apply() 方法遍历并应用初始化
        self.env_fc.apply(init_weights)
        self.gcn_fc.apply(init_weights)
        self.gcn_fc2.apply(init_weights)
        self.actor_mlp.apply(init_weights)
        self.critic_mlp.apply(init_weights)


    def forward(self, env_vector, X_invariant, X_variant, A):
        """
        参数：
          env_vector: Tensor, shape (env_dim,) ；当前环境外部信息
          X_invariant: Tensor, shape (num_nodes, 2)；节点的不变特征（例如位置）
          X_variant: Tensor, shape (num_nodes, 2)；节点的变化特征（例如当前 K, H）
          A: Tensor, shape (num_nodes, num_nodes)；邻接矩阵（可加上自环），这里采用简单的加和聚合
        返回：
          action_probs: Tensor, shape (num_nodes,)，各节点权重分布（softmax 后）
          value: Tensor, shape (1,) 状态价值估计
        """
        # 拼接节点特征：得到 (num_nodes, 4)
        # print(f"第一层X embedding gcn_fc的参数:{print_model_parameters(self.gcn_fc)}")

        X = torch.cat([X_invariant, X_variant], dim=1)
        X = (X - X.mean(dim=0)) / (X.std(dim=0) + 1e-8)
        # print(f"X = [X_invariant, X_variant]: {X}")
        # print(f"A : {A}")
        h = F.relu(self.gcn_fc(X))  # (num_nodes, hidden_dim)

        # 简单的GCN聚合：h = A * h （邻接矩阵乘法，假设 A 已构造好）
        h = torch.matmul(A, h)
        # print(f"h=matmul(A, h)时候：{h}")
        h = F.relu(self.gcn_fc2(h))  # (num_nodes, hidden_dim)
        # print(f"h=F_relu(gcn_fc2(h))时候：{h}")

        # 聚合所有节点的表示，得到图嵌入：采用求和（也可以采用平均）
        graph_emb = torch.sum(h, dim=0)  # (hidden_dim,)
        # 环境信息嵌入
        env_emb = F.relu(self.env_fc(env_vector))  # (hidden_dim,)
        # 拼接图嵌入和环境嵌入
        combined = torch.cat([graph_emb, env_emb], dim=0)  # (hidden_dim*2,)
        if torch.isnan(combined).any():
          print("⚠️ 输入 combined 出现 NaN 值！")
          print(f"combined is {combined}")
          print(f"env_emb is {env_emb}")
          print(f"graph_emb is {graph_emb}")
        if torch.isinf(combined).any():
          print("⚠️ 输入 combined 出现 Inf 值！")

        # Actor: 输出各节点的 logits，经过 softmax 得到概率分布
        logits = self.actor_mlp(combined)  # (num_nodes,)
        # print("--------------->>>>>>>>>>>>>")
        # print(f"原生的logits in actor_mlp is {logits}")
        if torch.isnan(logits).any():
            print("⚠️ logits 出现 NaN 值！")
            raise ValueError("logits 出现 NaN 值！")
        if torch.isinf(logits).any():
            print("⚠️ logits 出现 Inf 值！")
        # print("---------------<<<<<<<<<<<<<")
        logits = logits - logits.max(dim=-1, keepdim=True)[0]
        action_probs = F.softmax(logits, dim=-1)
        if torch.isnan(action_probs).any() or torch.isinf(action_probs).any():
          print("⚠️ 策略分布出现异常！")
        # Critic: 状态价值
        value = self.critic_mlp(combined)  # (1,)
        return action_probs, value










class Agent:
    def __init__(self, num_nodes, env_dim=5, hidden_dim=256, lr=0.001):
        # 消防飞机资源
        self.J = 20
        self.capability = 100  # 每架飞机对应减少火焰强度的单位（例如 100）
        # 建立 A2C 网络，输入维度由 num_nodes 和 env_dim 决定
        self.network = A2CNetwork(num_nodes, env_dim, hidden_dim)
        self.optimizer = optim.Adam(self.network.parameters(), lr=lr)

        self.gamma = 0.9  # 折扣因子
        self.dirichlet_concentration = 10.0 #用于构造 Dirichlet 分布

        self.loss_history = []

    def get_state(self, simulation):
        """
        从 simulation 中提取 RL 使用的状态：
          - env_vector: 由 env_factors 构造，形如 [wind_direction_x, wind_direction_y, wind_strength, dryness, temperature]
            如果 env_factors 为 None（如 t=0）则置零。
          - X_invariant: 每个节点的固定特征，这里取节点位置 (x, y)；形状 (num_nodes, 2)
          - X_variant: 每个节点的动态特征，这里取 (K, H)；形状 (num_nodes, 2)
          - A: 邻接矩阵（FloatTensor），尺寸 (num_nodes, num_nodes)。这里简单构造二值矩阵并加上单位对角线。
        """
        # 取 simulation.history 中最后一条记录（若最后记录为 'end' 则取倒数第二条）
        if simulation.history[-1] == 'end':
            t, env_factors, state, _ = simulation.history[-2]
        else:
            t, env_factors, state, _ = simulation.history[-1]
        num_nodes = simulation.num_regions

        # env_vector: 如果 env_factors 为 None，则置零
        if env_factors['wind_direction'] is not None:
            wind_dir = env_factors['wind_direction']
            wind_strength = env_factors['wind_strength']
            dryness = env_factors['dryness']
            temperature = env_factors['temperature']
            env_vector = torch.tensor([wind_dir[0], wind_dir[1], wind_strength, dryness, temperature], dtype=torch.float)
        else:
            env_vector = torch.zeros(5, dtype=torch.float)

        # X_invariant: 取每个节点的 pos (x, y)
        X_invariant = []
        # X_variant: 取每个节点的 [K, H]
        X_variant = []
        for i in range(num_nodes):
            node_state = state[i]
            pos = node_state['pos']  # 假设 pos 是 numpy 数组或列表
            X_invariant.append([pos[0], pos[1]])
            X_variant.append([node_state['K'], node_state['H']])
        X_invariant = torch.tensor(X_invariant, dtype=torch.float)  # shape (num_nodes, 2)
        X_variant = torch.tensor(X_variant, dtype=torch.float)      # shape (num_nodes, 2)

        # 构造邻接矩阵 A：根据 simulation.adjacency，令 A[i,j]=1 如果存在边，否则 0，然后加上自环
        A = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
        for i in range(num_nodes):
            for (j, edge) in simulation.adjacency[i]:
                A[i, j] = 1.0
        # 加自环
        A += torch.eye(num_nodes)
        # 这里可以考虑归一化 A（例如用对称归一化），此处简单使用
        return env_vector, X_invariant, X_variant, A

    # def action(self, simulation):
    #     """
    #     从 simulation 提取状态，并利用当前网络计算动作概率分布，
    #     然后返回一个 numpy 数组表示各节点的飞机分配权重。
    #     """
    #     env_vector, X_invariant, X_variant, A = self.get_state(simulation)
    #     action_probs, _ = self.network(env_vector, X_invariant, X_variant, A)
    #     # action_probs 是一个 tensor（num_nodes,），转换为 numpy 数组
    #     return action_probs.detach().cpu().numpy()

    def action(self, simulation):
        """
        从 simulation 提取状态，并利用当前网络计算动作分布（一个概率向量）。
        这里我们采用 Dirichlet 分布，将网络输出的 softmax 概率乘以一个浓度参数，
        得到一个 Dirichlet 分布，从中采样动作。
        """
        env_vector, X_invariant, X_variant, A = self.get_state(simulation)
        # 网络输出 actor_logits，然后通过 softmax 得到概率向量
        # 假设网络内部已输出 action_probs（即 softmax(logits)）
        action_probs, _ = self.network(env_vector, X_invariant, X_variant, A)
        # 为避免数值问题，确保 action_probs > 0
        action_probs = torch.nan_to_num(action_probs, nan=1e-6)
        action_probs = torch.clamp(action_probs, min=1e-6)
        # print(action_probs)
        # 构造 Dirichlet 分布，注意这里使用一个浓度参数乘以概率向量
        concentration = action_probs * self.dirichlet_concentration
        dist = torch.distributions.Dirichlet(concentration)
        action = dist.rsample()  # 重参数化采样
        return action, dist

    def get_value(self, simulation):
        """
        仅返回当前状态的价值评估（critic 部分）。
        """
        env_vector, X_invariant, X_variant, A = self.get_state(simulation)
        _, value = self.network(env_vector, X_invariant, X_variant, A)
        return value

    def update_policy(self, transitions):
        """
        根据收集到的 transitions 来更新策略。
        transitions 是一个列表，每个元素为 (env_vector, X_invariant, X_variant, A, action, reward, next_env_vector, next_X_invariant, next_X_variant, next_A, done)
        其中 done 为布尔值，表示是否终止。

        使用 A2C 算法（单步 TD 估计）：
            advantage = r + gamma * V(next) * (1 - done) - V(current)
            Actor loss: -log(prob(a)) * advantage (加上 entropy bonus)
            Critic loss: MSE(V(current), r + gamma*V(next) * (1-done))
        """
        actor_losses = []
        critic_losses = []
        entropy_losses = []

        for transition in transitions:
            (env_vector, X_invariant, X_variant, A, action_taken, old_log_prob, reward,
             next_env_vector, next_X_invariant, next_X_variant, next_A, done) = transition
            # 当前状态的网络输出
            probs, value = self.network(env_vector, X_invariant, X_variant, A)

            # print(probs)
            concentration = torch.clamp(probs, min=1e-6) * self.dirichlet_concentration
            dist = torch.distributions.Dirichlet(concentration)
            # 计算对数概率，action_taken 是一个连续向量（tensor）
            log_prob = dist.log_prob(action_taken)

            # print(f"action: {action_taken}, log_prob: {log_prob}")
            # log_prob = dist.log_prob(torch.tensor(action_taken))




            # 计算下一状态价值（如果 done 则值为0）
            with torch.no_grad():
                _, next_value = self.network(next_env_vector, next_X_invariant, next_X_variant, next_A)
                target = reward + self.gamma * next_value * (1 - done)
            # print(f"target: {target}, value {value}")
            # 计算优势函数
            advantage = target - value #正2.5

             # PPO 比率
            ratio = torch.exp(log_prob - old_log_prob)
            # Clipped ratio
            epsilon = 1e-4
            clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
            # PPO actor loss（取两者中较小者）
            actor_loss = -torch.min(ratio * advantage.detach(), clipped_ratio * advantage.detach())
            # Critic loss（Smooth L1 Loss）
            critic_loss = F.smooth_l1_loss(value, target.detach())

            # Entropy bonus（鼓励探索）
            # print(f"original advantage: {advantage}")
            # Actor loss
            # print(log_prob)
            # actor_loss = -log_prob * advantage.detach()
            # Critic loss（均方误差）
            critic_loss = F.smooth_l1_loss(value, target.detach())
            # print(f"actor_loss: {actor_loss}")
            # print(f"critic_loss: {critic_loss}")
            # Entropy bonus（鼓励探索）
            entropy = dist.entropy()
            # print(f"entropy: {entropy}")
            actor_losses.append(actor_loss - entropy*0.1)
            critic_losses.append(critic_loss)

        # print(actor_losses)
        loss = torch.stack(actor_losses).mean() + torch.stack(critic_losses).mean()

        # print(f"loss: {loss}")
        self.loss_history.append(loss.item())
        check_parameters(self.network.actor_mlp)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.network.actor_mlp.parameters(), max_norm=1.0)
        self.optimizer.step()
        check_parameters(self.network.gcn_fc)


def train_agent(num_regions, num_episodes=1, max_steps=5):
    # 初始化 simulation 和 agent
    sim = WildfireSimulation(num_regions)
    agent = Agent(num_nodes=sim.num_regions, env_dim=5, hidden_dim=64, lr=0.1)
    # 存储所有 episode 的奖励
    episode_rewards = []
    final_rewards = [0.1,0.5,0.7,0.9,1.0,0.1,0.5,0.01,0.05]
    step_rewards = [-0.1, -0.2, -0.1, -0.4, -0.5,-0.7,-0.1,-0.2,-0.2,-0.3,-0.4,-0.02]
    setps_to_contaiminated = []

    for episode in range(num_episodes):
        if episode % 100 == 0:
          print(f"===== Episode {episode} =====")
        # 重置 simulation：重新初始化节点状态和历史记录
        sim.__init__(num_regions=sim.num_regions)  # 重新构造环境
        sim.history = []  # 清空历史记录
        # 在这里可选：预先记录初始状态
        # 记录 transitions 用于更新策略
        transitions = []
        total_reward = 0.0
        t = 0
        # 运行一个 episode（用 run_session_TD 或逐步调用 next()）
        # 这里我们使用逐步推进的方式：
        # 先记录初始状态
        initial_env = {
            'wind_direction': None,
            'wind_strength': None,
            'dryness': None,
            'temperature': None
        }
        state = {node_id: {'pos': sim.nodes[node_id].pos.copy(), 'H': sim.nodes[node_id].H, 'K': sim.nodes[node_id].K}
                 for node_id in sim.nodes}
        sim.history.append((0, initial_env, state, None))
        done = False
        while t < max_steps and not done:
            t += 1
            # 执行一步模拟（内部会记录状态及 agent action）
            sim.next(t, agent)
            # 获取当前状态（s)和下一状态 (s') 的相关图信息，构造 transition
            # 这里简单地认为当前状态为 sim.history[-2]，下一状态为 sim.history[-1]
            s_record = sim.history[-2]
            s_env, s_state, _ = s_record[1], s_record[2], s_record[3]
            ns_record = sim.history[-1]
            ns_env, ns_state, _ = ns_record[1], ns_record[2], ns_record[3]
            # 构造 env_vector, X_invariant, X_variant, A
            env_vector, X_invariant, X_variant, A = agent.get_state(sim)
            # 对于下一状态，也重新构造（这里简单调用 get_state 后略作修改）
            # 注意：为了简单，假设 next_state 的输入与当前状态类似（实际应由 ns_record 构造）
            next_env_vector, next_X_invariant, next_X_variant, next_A = agent.get_state(sim)
            # 获得 agent 选择的 action（保存为一个离散的 index，假设取概率最大的 index）
            # action_probs = agent.action(sim)
            # action_index = int(np.argmax(action_probs))
            action, dist = agent.action(sim)
            old_log_prob = dist.log_prob(action)


            # 计算奖励：根据文档：R = -1000*(H_total_{t+1} - H_total_t)/(max_health_total) - 10
            H_total_t = np.sum([s_state[node]['H'] for node in s_state])
            H_total_next = np.sum([ns_state[node]['H'] for node in ns_state])
            max_health_total = np.sum([sim.nodes[node].max_health for node in sim.nodes])
            r = -(H_total_next - H_total_t)/max_health_total - 0.01
            r = (r - np.mean(step_rewards)) / (np.std(step_rewards) + 1e-4)
            step_rewards.append(r)

            # 如果模拟结束，则给予终止奖励 1000
            if sim.is_end():
                setps_to_contaiminated.append(t)
                if episode % 100 == 0:
                  print(f"game end at {t} session, and get final reward")
                  sim.plot_at_t(t)
                Ht_values = [s_state[node]['H'] for node in s_state]

                num_protected = np.sum([1 for node in s_state if (sim.nodes[node].max_health-s_state[node]['H'])/sim.nodes[node].max_health <= 0.10 ])

                # r = np.sum(Ht_values) - max_health_total + num_protected*700
                r = num_protected - sim.num_regions/2.0

                mean_reward = np.mean(final_rewards)
                std_reward = np.std(final_rewards) + 1e-8
                normalized_r = (r - mean_reward) / std_reward
                final_rewards.append(normalized_r) #为了做reward归一化
                r = normalized_r

                done = True
            total_reward += r
            transition = (env_vector, X_invariant, X_variant, A,
                          action, old_log_prob, torch.tensor(r, dtype=torch.float),
                          next_env_vector, next_X_invariant, next_X_variant, next_A,
                          torch.tensor(float(done)))
            transitions.append(transition)
        # 更新 agent 策略（一次 episode 内采样的 transitions，用 A2C 单步 TD 更新）
        agent.update_policy(transitions)
        episode_rewards.append(total_reward)
        if episode % 100 == 0:
            print(f"Episode {episode} total reward: {total_reward}")
            print(f"Episode {episode} agent prev loss: {agent.loss_history[-1]}")

    return agent, episode_rewards, agent.loss_history, setps_to_contaiminated

In [None]:
def set_global_random_seed(seed):
    random.seed(seed)  # Python random module
    np.random.seed(seed)  # NumPy random module
    torch.manual_seed(seed)  # PyTorch random module
    torch.cuda.manual_seed(seed)  # For CUDA-specific operations
    torch.cuda.manual_seed_all(seed)  # If using multiple GPUs
    torch.backends.cudnn.deterministic = True  # Make CuDNN deterministic
    torch.backends.cudnn.benchmark = False  # Disable CuDNN auto-tuning for reproducibility


set_global_random_seed(123)

trained_agent, rewards, losses, steps_to_peace = train_agent(20, num_episodes=1000, max_steps=100)


In [None]:

window = 200
smooth = np.convolve(losses, np.ones(window)/window, mode='valid')
plt.plot(smooth)
plt.xlabel("Episode")
plt.ylabel("loss")
plt.title("Training loss Curve")
plt.show()

window = 600
smooth_rewards = np.convolve(rewards, np.ones(window)/window, mode='valid')
plt.plot(smooth_rewards)
plt.xlabel("Episode")
plt.ylabel("total rewards each ep")
plt.title("total rewards each Curve")
plt.show()


window = 600
smooth_steps_to_peace = np.convolve(steps_to_peace, np.ones(window)/window, mode='valid')
plt.plot(smooth_steps_to_peace)
plt.xlabel("Episode")
plt.ylabel("total steps_to_peace each ep")
plt.title("total steps_to_peace each Curve")
plt.show()

In [None]:
step_length = 5

smooth_steps_to_peace = [
    np.mean(steps_to_peace[i:i + window])
    for i in range(0, len(steps_to_peace) - window + 1, step_length)
]

# Create x-axis values for plotting
x_values = range(0, len(smooth_steps_to_peace) * step_length, step_length)

# Plotting
plt.plot(x_values, smooth_steps_to_peace, label=f"Moving Avg (Window={window}, Step={step_length})")
plt.xlabel("Episode")
plt.ylabel("Total steps_to_peace each ep")
plt.title("Total steps_to_peace each Curve")
plt.legend()
plt.show()

In [None]:
torch.stack([torch.Tensor([60.00]),torch.Tensor([900.00])]).mean()