In [16]:
from gym_interface import Agent, State
import copy
import math
import random
import socket
import time
from collections import deque, namedtuple
from typing import Dict, Iterable, List, Literal, Optional, Union
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from distutils.util import strtobool
# from rlmodel.utils.utils import print_args, print_box, connected_to_internet
import wandb
import setproctitle
from pathlib import Path

import os, sys
from PPO import PPO

*自定义处理函数*

In [17]:
def parse_Input(action) -> str:
    try:
        # 直接提取原始动作值
        raw_velocity = float(action[0])  # PPO输出的原始速度值
        raw_direction = float(action[1])  # PPO输出的原始方向值
        
        # 在这里进行缩放
        # 假设PPO输出范围是 [-1, 1]，缩放到目标范围
        velocity = (raw_velocity + 1) * 10  # [-1,1] -> [0,20]
        direction = raw_direction * np.pi   # [-1,1] -> [-π, π]
        # 确保在合理范围内
        velocity = max(0, min(20, velocity))
        direction = max(-np.pi, min(np.pi, direction))
        
        velocity_str = f"{velocity:.4f}"
        direction_str = f"{direction:.4f}"
        
        cmd = f"<c><targetVel><float>{velocity_str}</float></targetVel><targetDir><float>{direction_str}</float></targetDir></c>"
        
        return cmd
        
    except:
        # 出错时返回零动作
        return "<c><targetVel><float>0.0</float></targetVel><targetDir><float>0.0</float></targetDir></c>"

In [18]:
def parse_Output(state: List[Dict[str, any]]) -> dict:
    #example:如果要加别的输入数据，思考下如何修改，比如只想获取我车的distance参数。
    tmp = []
    Input = {}
    for input in state:
        for k, v in input.items():
            if k == 'carInfo':
                tmp.append(v)
            else:
                Input[k] = v
    Input['carInfo'] = tmp
    return Input

In [19]:
def cal_Reward(state:Dict[str, any]) -> float:
    """
    动力学一致、PPO 友好的车辆追踪奖励函数
    只依赖状态，不直接使用 action
    """
    if 'carInfo' not in state or len(state['carInfo']) < 2:
        return 0.0

    # -----------------------------
    # 1. 区分我方 / 敌方
    # -----------------------------
    if state['carInfo'][0]['baseInfo']['side'] == 1:
        self_car = state['carInfo'][0]
        enemy_car = state['carInfo'][1]
    else:
        self_car = state['carInfo'][1]
        enemy_car = state['carInfo'][0]

    try:
        # 位置
        sx, sy, sz = self_car['position'].values()
        ex, ey, ez = enemy_car['position'].values()

        # 速度
        svx, svy, svz = self_car['velocity'].values()
        evx, evy, evz = enemy_car['velocity'].values()

    except Exception:
        return 0.0
    
    # -----------------------------
    # 2. 基本物理量
    # -----------------------------
    dx, dy, dz = ex - sx, ey - sy, ez - sz
    distance = math.sqrt(dx*dx + dy*dy + dz*dz)

    self_speed = math.sqrt(svx*svx + svy*svy + svz*svz)
    enemy_speed = math.sqrt(evx*evx + evy*evy + evz*evz)

    # -----------------------------
    # 3. 全局历史状态（episode 内）
    # -----------------------------
    global prev_distance, prev_speed, prev_heading

    if 'prev_distance' not in globals() or prev_distance is None:
        prev_distance = distance
    if 'prev_speed' not in globals() or prev_speed is None:
        prev_speed = self_speed
    if 'prev_heading' not in globals():
        prev_heading = math.atan2(svy, svx) if self_speed > 0.1 else 0.0

    # -----------------------------
    # 4. 距离奖励（主项，平滑）
    # -----------------------------
    # 使用 log 缩放，避免远距离梯度过小
    distance_reward = 0#-math.log(distance / 1000.0 + 1.0)

    # -----------------------------
    # 5. 接近速度奖励（非常关键）
    # -----------------------------
    distance_change = prev_distance - distance
    proximity_reward = 1.00 * distance_change   # 接近为正，远离为负

    # -----------------------------
    # 6. 航向一致性奖励（基于真实速度）
    # -----------------------------
    heading_reward = 0.0
    if distance > 1.0 and self_speed > 0.1:
        to_target = [dx/distance, dy/distance, dz/distance]
        velocity_dir = [svx/self_speed, svy/self_speed, svz/self_speed]
        cos_angle = sum(to_target[i] * velocity_dir[i] for i in range(3))
        heading_reward = 0.5 * max(0.0, cos_angle)

    # -----------------------------
    # 7. 速度合理性奖励（随距离变化）
    # -----------------------------
    if distance > 3000:
        desired_speed = 15.0
    elif distance > 1000:
        desired_speed = 10.0
    elif distance > 200:
        desired_speed = 5.0
    else:
        desired_speed = 3.0

    speed_error = abs(self_speed - desired_speed)
    speed_reward = -0.03 * speed_error

    # -----------------------------
    # 8. 真实加速度惩罚（非常重要）
    # -----------------------------
    acc = abs(self_speed - prev_speed)
    acc_penalty = -0.2 * acc

    # -----------------------------
    # 9. 真实转向变化惩罚（不是 action）
    # -----------------------------
    heading_penalty = 0.0
    if self_speed > 0.1:
        current_heading = math.atan2(svy, svx)
        heading_change = abs(current_heading - prev_heading)
        if heading_change > math.pi:
            heading_change = 2 * math.pi - heading_change
        heading_penalty = -0.1 * heading_change
    else:
        current_heading = prev_heading

    # -----------------------------
    # 10. 时间惩罚（防止磨蹭）
    # -----------------------------
    time_penalty = 0 # -0.02

    # -----------------------------
    # 11. 终端奖励
    # -----------------------------
    terminal_reward = 0.0
    is_terminal = False
    if distance < 10.0:
        terminal_reward = 100.0
        is_terminal = True
        print("✅ 成功拦截")

    # -----------------------------
    # 12. 总奖励
    # -----------------------------
    total_reward = (
        distance_reward +
        proximity_reward +
        heading_reward +
        speed_reward +
        acc_penalty +
        heading_penalty +
        time_penalty +
        terminal_reward
    )

    # -----------------------------
    # 13. 更新历史状态
    # -----------------------------
    prev_distance = distance
    prev_speed = self_speed
    prev_heading = current_heading
    
    # 输出调试信息
    if is_terminal or np.random.random() < 0.01:
        print(f"距离奖励：{distance_reward:.3f}, 接近奖励：{proximity_reward:.3f}, 航向奖励：{heading_reward:.3f}, 速度奖励：{speed_reward:.3f}, 加速度惩罚：{acc_penalty:.3f}, 转向惩罚：{heading_penalty:.3f}, 时间惩罚：{time_penalty:.3f}, 终端奖励：{terminal_reward:.3f}, 总奖励：{total_reward:.3f}")

    return total_reward

In [20]:
def cal_reward_v1_5(state: dict) -> float:
    global prev_distance

    if 'carInfo' not in state or len(state['carInfo']) < 2:
        return 0.0

    # 区分敌我
    if state['carInfo'][0]['baseInfo']['side'] == 1:
        self_car = state['carInfo'][0]
        enemy_car = state['carInfo'][1]
    else:
        self_car = state['carInfo'][1]
        enemy_car = state['carInfo'][0]

    try:
        sx, sy, sz = self_car['position']['x'], self_car['position']['y'], self_car['position']['z']
        ex, ey, ez = enemy_car['position']['x'], enemy_car['position']['y'], enemy_car['position']['z']
    except KeyError:
        return 0.0

    dx, dy, dz = ex - sx, ey - sy, ez - sz
    distance = math.sqrt(dx*dx + dy*dy + dz*dz)

    if 'prev_distance' not in globals():
        prev_distance = distance

    # ========== 核心奖励 ==========
    delta = prev_distance - distance

    # 非常重要：放大
    reward = 0.05 * delta    # 每接近 1 m +0.05

    # 防止原地摆烂
    reward -= 0.001          # 每步轻惩罚

    # 终止奖励
    if distance < 50:
        reward += 500.0

    prev_distance = distance
    return float(reward)


In [21]:
def cal_reward_v2(state: dict) -> float:
    global prev_distance

    if 'carInfo' not in state or len(state['carInfo']) < 2:
        return 0.0

    # 区分敌我
    if state['carInfo'][0]['baseInfo']['side'] == 1:
        self_car = state['carInfo'][0]
        enemy_car = state['carInfo'][1]
    else:
        self_car = state['carInfo'][1]
        enemy_car = state['carInfo'][0]

    try:
        sx, sy, sz = self_car['position']['x'], self_car['position']['y'], self_car['position']['z']
        ex, ey, ez = enemy_car['position']['x'], enemy_car['position']['y'], enemy_car['position']['z']
        vx, vy, vz = self_car['velocity']['x'], self_car['velocity']['y'], self_car['velocity']['z']
    except KeyError:
        return 0.0

    dx, dy, dz = ex - sx, ey - sy, ez - sz
    distance = math.sqrt(dx*dx + dy*dy + dz*dz)
    speed = math.sqrt(vx*vx + vy*vy + vz*vz)

    if 'prev_distance' not in globals():
        prev_distance = distance

    # ========== 1) 距离进度奖励，核心奖励 ==========
    delta = prev_distance - distance
    reward = 0.05 * delta    # 每接近 1 m +0.05
    # 2) LOS航向奖励（辅助，别太大）
    # ======================================================
    r_los = 0.0
    if speed > 0.5:
        to_target = np.array([dx, dy, dz], dtype=np.float32) / distance
        vel_dir = np.array([vx, vy, vz], dtype=np.float32) / (speed + 1e-6)
        cos_theta = float(np.clip(np.dot(to_target, vel_dir), -1.0, 1.0))
        # 只奖励朝向目标（cos>0），避免“背对目标也拿负奖励把梯度搞乱”
        r_los = 0.2 * max(0.0, cos_theta)
        
    # 3) 速度 shaping（远处快，近处别太快）
    # ======================================================
    if distance > 1500:
        desired_speed = 18.0
    elif distance > 500:
        desired_speed = 14.0
    elif distance > 150:
        desired_speed = 9.0
    else:
        desired_speed = 6.0
    # 让速度靠近 desired_speed（轻一点，别压过主目标）
    r_speed = -0.02 * abs(speed - desired_speed)

    # 防止原地摆烂
    reward -= 0.001          # 每步轻惩罚

    # 终止奖励
    if distance < 50:
        reward += 500.0

    reward += r_los + r_speed 
    prev_distance = distance
    return float(reward)


In [22]:
def cal_Termination(state:Dict[str, any]) -> bool:
    #example
    if state['carInfo'][0]['baseInfo']['side'] == 1:
        self_pos = state['carInfo'][0]['position']
        enemy_pos = state['carInfo'][1]['position']
    elif state['carInfo'][1]['baseInfo']['side'] == 2:
        self_pos = state['carInfo'][1]['position']
        enemy_pos = state['carInfo'][0]['position']
    
    try:
        # 获取坐标值
        x1 = self_pos['x']
        y1 = self_pos['y']
        z1 = self_pos['z']
        x2 = enemy_pos['x']
        y2 = enemy_pos['y']
        z2 = enemy_pos['z']
        # 计算距离
        distance = math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2)
    except (KeyError, ValueError, TypeError) as e:
        # 如果数据格式有问题，不终止
        print(f"计算距离时出错: {e}")
        return False
    
    # 距离小于10米，终止
    if distance < 50:
        return True
    else:
        return False


_算法参数配置_

In [23]:
device = torch.device("cuda")
sys.path.append(os.path.abspath(os.getcwd()))
num_agents = 1
num_enemies = 1
episode_length = 100
save_interval = 1000
log_interval = 10
model_dir = (
        Path(os.path.dirname(os.path.dirname(os.getcwd())) + "/results")
        /"save"
    )
all_args = {
    "algorithm_name": "ppo",
    "use_recurrent_policy": False,
    "use_naive_recurrent_policy": False,
    "share_policy": True,
    "use_wandb": True,
    "seed": 0,
    "use_centralized_V": True,
    "use_linear_lr_decay": True,
    "hidden_size": 16,
    "recurrent_N": 1,
    "act_space": 2,
    "obs_space": 12,
    "shared_obs_space": 12*num_agents,
    "model_dir": None,
    "episode_length": episode_length,
    "gamma": 0.98,
    "gae_lambda": 0.95,
    "use_gae": True,
    "clip_param": 0.2,
    "ppo_epoch": 15,
    "num_mini_batch": 1,
    "data_chunk_length": 10,
    "value_loss_coef": 0.5,
    "entropy_coef": 0.01,
    "max_grad_norm": 10.0,
    "huber_delta": 10.0,
    "use_max_grad_norm": True,
    "use_clipped_value_loss": True,
    "use_huber_loss": True,
    "use_popart": True,
    "use_valuenorm": False,
    "use_value_active_masks": True,
    "use_policy_active_masks": True,
    "lr": 7e-5,
    "critic_lr": 7e-4,
    "opti_eps": 1e-5,
    "weight_decay": 0,
    "gain": 0.01,
    "use_orthogonal": True,
    "use_feature_normalization": True,
    "use_ReLU": False,
    "stacked_frames": 1,
    "layer_N": 1,
    "n_rollout_threads": 1,
}


run_dir = (
        Path(os.path.dirname(os.path.dirname(os.getcwd())) + "/results")
        / all_args["algorithm_name"]
    )
config = {
    "all_args": all_args,
    "num_agents": num_agents,
    "num_enemies":num_enemies,
    "device": device,
    "run_dir": run_dir
}

*W&B记录训练日志*

In [24]:
if all_args["use_wandb"]:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        # for supercloud when no internet_connection
        # if not connected_to_internet():
        #     import json

        #     # save a json file with your wandb api key in your
        #     # home folder as {'my_wandb_api_key': 'INSERT API HERE'}
        #     # NOTE this is only for running on systems without internet access
        #     # have to run `wandb sync wandb/run_name` to sync logs to wandboard
        #     with open(os.path.dirname(os.path.dirname(os.getcwd())) + "/keys.json") as json_file:
        #         key = json.load(json_file)
        #         my_wandb_api_key = key["my_wandb_api_key"]  # NOTE change here as well
        #     os.environ["WANDB_API_KEY"] = my_wandb_api_key
        #     os.environ["WANDB_MODE"] = "dryrun"
        #     os.environ["WANDB_SAVE_CODE"] = "true"

        # print_box("Creating wandboard...")
        run_name = f"{all_args['algorithm_name']}_seed{all_args['seed']}_{timestamp}"
        run = wandb.init(
            config=all_args,
            project="simplecq",
            # project=all_args.env_name,
            # entity="cc",
            notes=socket.gethostname(),
            name=run_name,
            # group=all_args.scenario_name,
            dir=str(run_dir),
            # job_type="training",
            reinit='return_previous',
        )
        
setproctitle.setproctitle(
        str(all_args["algorithm_name"])
        + "@"
        + str("wsbbuaa")
    )

# seed
torch.manual_seed(all_args["seed"])
torch.cuda.manual_seed_all(all_args["seed"])
np.random.seed(all_args["seed"])

In [25]:
# PPO参数设置
state_dim = 12  # 状态维度
action_dim = 2  # 动作维度
lr_actor = 3e-4
lr_critic = 3e-4  # Critic学习率
gamma = 0.995  # 折扣因子
K_epochs = 10  # 每次更新的训练轮数
eps_clipping = 0.2  # PPO裁剪参数
is_continuous_action_space = True  # 连续动作空间

# 初始化PPO
ppo_agent = PPO(
    state_dim=state_dim,
    action_dim=action_dim,
    lr_actor=lr_actor,
    lr_critic=lr_critic,
    gamma=gamma,
    K_epochs=K_epochs,
    eps_clipping=eps_clipping,
    is_continuous_action_space=is_continuous_action_space
)


*端口及输出类型指定*

In [26]:
port = 40029
outputs_type = {
    "targetDir": "float",
    "targetVel": "float"
}

In [27]:
agent = Agent(port=port, 
              outputs_type=outputs_type,
              process_input=parse_Input,
              process_output=parse_Output,
              reward_func=cal_reward_v1_5,
              end_func=cal_Termination)

In [28]:
def process_raw_state(raw_state: Dict[str, any]) -> tuple:
    """
    将原始状态字典转换为状态向量
    
    根据文档设计的状态空间：
    [d, cos(θ_rel), sin(θ_rel), v_self, v_enemy, cos(Δψ), sin(Δψ), d_dot, ...]
    """
    try:
        # 提取车辆信息
        if raw_state['carInfo'][0]['baseInfo']['side'] == 1:
            self_car = raw_state['carInfo'][0]
            enemy_car = raw_state['carInfo'][1]
        else:
            self_car = raw_state['carInfo'][1]
            enemy_car = raw_state['carInfo'][0]
        
        # 提取位置和速度
        self_pos = self_car['position']
        enemy_pos = enemy_car['position']
        self_vel = self_car['velocity']
        enemy_vel = enemy_car['velocity']
        
        # 计算相对位置向量
        dx = enemy_pos['x'] - self_pos['x']
        dy = enemy_pos['y'] - self_pos['y']
        dz = enemy_pos['z'] - self_pos['z']
        
        # 距离
        distance = np.sqrt(dx**2 + dy**2 + dz**2)
        
        # 相对角度（目标相对于我车航向）
        # 假设有航向信息
        self_heading = self_car.get('heading', 0.0)  # 需要从状态中获取
        enemy_heading = enemy_car.get('heading', 0.0)
        
        # 相对角度计算
        target_angle = np.arctan2(dy, dx)
        theta_rel = target_angle - self_heading
        # 归一化到 [-π, π]
        theta_rel = (theta_rel + np.pi) % (2 * np.pi) - np.pi
        
        # 航向差
        delta_psi = enemy_heading - self_heading
        delta_psi = (delta_psi + np.pi) % (2 * np.pi) - np.pi
        
        # 速度信息
        v_self = np.sqrt(self_vel['x']**2 + self_vel['y']**2 + self_vel['z']**2)
        v_enemy = np.sqrt(enemy_vel['x']**2 + enemy_vel['y']**2 + enemy_vel['z']**2)
        
        # 相对速度
        v_rel_x = enemy_vel['x'] - self_vel['x']
        v_rel_y = enemy_vel['y'] - self_vel['y']
        
        # 接近率
        if distance > 0.001:
            d_dot = -(dx * v_rel_x + dy * v_rel_y) / distance
        else:
            d_dot = 0.0
        
        # 构建状态向量（根据文档设计）
        state_vector = np.array([
            distance,                   # 相对距离
            np.cos(theta_rel),           # 相对角度的余弦
            np.sin(theta_rel),           # 相对角度的正弦
            v_self,                     # 自身速度
            v_enemy,                    # 目标速度
            np.cos(delta_psi),          # 航向差余弦
            np.sin(delta_psi),          # 航向差正弦
            d_dot,                      # 接近率
            v_rel_x,                     # 相对速度x
            v_rel_y,                    # 相对速度y
            dx/distance if distance > 0 else 0,  # 归一化相对位置x
            dy/distance if distance > 0 else 0,  # 归一化相对位置y
        ], dtype=np.float32)
        
        # 状态归一化
        normalized_state_vector = normalize_state(state_vector.copy())
        
        return state_vector, normalized_state_vector
        
    except Exception as e:
        print(f"状态处理错误: {e}")
        # 返回零状态
        zero_vector = np.zeros(12, dtype=np.float32)
        return zero_vector, zero_vector
    
def normalize_state(state: np.ndarray) -> np.ndarray:
    """
    状态归一化
    """
    if len(state) != 12:
        raise ValueError(f"状态向量维度应为12，但得到{len(state)}")
    
    normalized = state.copy()
    
    # 1. 相对距离 [0] - 范围: [0, +∞)
    # 假设最大观测距离为10000米，使用对数压缩处理大范围距离
    if state[0] > 0:
        normalized[0] = np.log1p(state[0]) / np.log1p(10000)  # 对数归一化
    else:
        normalized[0] = 0.0
    
    # 2. 相对角度余弦 [1] - 范围: [-1, 1]，已经是归一化的
    # 不需要额外处理
    
    # 3. 相对角度正弦 [2] - 范围: [-1, 1]，已经是归一化的
    # 不需要额外处理
    
    # 4. 自身速度 [3] - 范围: [0, +∞)
    # 假设最大速度为20 m/s
    normalized[3] = np.tanh(state[3] / 20.0)  # 使用tanh限制在[-1,1]
    
    # 5. 目标速度 [4] - 范围: [0, +∞)
    normalized[4] = np.tanh(state[4] / 20.0)  # 使用tanh限制在[-1,1]
    
    # 6. 航向差余弦 [5] - 范围: [-1, 1]，已经是归一化的
    # 不需要额外处理
    
    # 7. 航向差正弦 [6] - 范围: [-1, 1]，已经是归一化的
    # 不需要额外处理
    
    # 8. 接近率 [7] - 范围: (-∞, +∞)
    # 接近率可能很大，使用tanh压缩
    normalized[7] = np.tanh(state[7] / 20.0)  # 除以20进行缩放
    
    # 9. 相对速度x [8] - 范围: (-∞, +∞)
    normalized[8] = np.tanh(state[8] / 20.0)  # 假设最大相对速度20 m/s
    
    # 10. 相对速度y [9] - 范围: (-∞, +∞)
    normalized[9] = np.tanh(state[9] / 20.0)  # 假设最大相对速度20 m/s
    
    # 11. 归一化相对位置x [10] - 范围: [-1, 1]，已经是归一化的
    # 确保在有效范围内
    normalized[10] = np.clip(state[10], -1.0, 1.0)
    
    # 12. 归一化相对位置y [11] - 范围: [-1, 1]，已经是归一化的
    normalized[11] = np.clip(state[11], -1.0, 1.0)
    
    return normalized

*训练函数*

In [None]:
def train(env: Agent, episodes, enable_log=True):
    total_steps = 0
    decision_interval = 10
    max_steps_per_episode = 20000
    Win = 0
    best_sr = 0
    episode_rewards_history = []
    success_rate_history = []

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_dir = "D:/A_code/simplecq/record"
    os.makedirs(log_dir, exist_ok=True)

    for episode in range(episodes):
        episode_reward = 0.0
        episode_steps = 0

        # 清空 PPO buffer
        ppo_agent.buffer.clear()

        # reset env
        raw_state = env.reset()
        raw_state_dic = parse_Output(raw_state)
        log_state, current_state = process_raw_state(raw_state_dic)
        episode_min_distance = float("inf")
        episode_last_distance = None
        # === 延迟一拍相关变量 ===
        acc_reward = 0.0

        # === 日志缓存 ===
        episode_logs = []
        # 初始动作（占位）
        action = np.array([0.0, 0.0])

        for step in range(max_steps_per_episode):
            # # ---------- 低频决策 ----------
            if step % decision_interval == 0:
                action = ppo_agent.select_action(current_state)

                # 记录日志（只在决策帧）
                raw_v, raw_d = float(action[0]), float(action[1])
                velocity = (raw_v + 1) * 10
                direction = raw_d * np.pi
                episode_logs.append(
                    f"[{episode}] {' '.join(map(str, log_state))} {[velocity, direction]}"
                )

            # ---------- 环境推进 ----------
            next_raw_state, reward, done, _ = env.step(action)

            if next_raw_state == '':
                # step 无效，直接跳过（不结算 reward）
                continue

            # ====== 计算 distance（只用 state，不碰 reward） ======
            state_dict = next_raw_state

            if 'carInfo' in state_dict and len(state_dict['carInfo']) >= 2:
                if state_dict['carInfo'][0]['baseInfo']['side'] == 1:
                    self_car = state_dict['carInfo'][0]
                    enemy_car = state_dict['carInfo'][1]
                else:
                    self_car = state_dict['carInfo'][1]
                    enemy_car = state_dict['carInfo'][0]

                dx = enemy_car['position']['x'] - self_car['position']['x']
                dy = enemy_car['position']['y'] - self_car['position']['y']
                dz = enemy_car['position']['z'] - self_car['position']['z']

                distance = (dx*dx + dy*dy + dz*dz) ** 0.5

                episode_min_distance = min(episode_min_distance, distance)
                episode_last_distance = distance
            # ---------- reward 延迟一拍结算 ----------
            acc_reward += reward
            episode_reward += reward
            episode_steps += 1
            total_steps += 1
            # ---------- 状态更新 ----------
            log_state, next_state = process_raw_state(next_raw_state)
            current_state = next_state

            if done:
                Win += 1
                ppo_agent.buffer.rewards.append(acc_reward)
                ppo_agent.buffer.is_terminals.append(True)
                acc_reward = 0.0
                break
            # 非终止：每到达一个决策段末尾就写入一次（标 terminal=False）
            if (step + 1) % decision_interval == 0:
                ppo_agent.buffer.rewards.append(acc_reward)
                ppo_agent.buffer.is_terminals.append(False)
                acc_reward = 0.0


        # ---------- PPO 更新 ----------
        ppo_agent.update()

        # ---------- 日志写文件（episode 级） ----------
        # log_path = f"{log_dir}/action0112_{timestamp}.txt"
        # with open(log_path, "a") as f:
        #     for line in episode_logs:
        #         f.write(line + "\n")

        # ---------- 统计 ----------
        episode_rewards_history.append(episode_reward)
        success_rate = Win / (episode + 1)
        success_rate_history.append(success_rate)

        if all_args.get("use_wandb", False):
            wandb.log({
                "episode_min_distance": episode_min_distance,
                "episode_last_distance": episode_last_distance,
                "episode": episode,
                "episode_reward": episode_reward,
                "episode_length": episode_steps,
                "success_rate": success_rate,
                "win_count": Win,
                "total_steps": total_steps,
                "avg_reward_100": np.mean(episode_rewards_history[-100:])
            }, step=episode)

        # ---------- 保存模型 ----------
        save_cooldown = 20        # 至少间隔 20 个 episode 才允许再存
        min_improve = 0.02        # 至少提升 2% 才存
        last_best_save_ep = -10**9
        if (success_rate > best_sr + min_improve) and (episode - last_best_save_ep >= save_cooldown):
            best_sr = success_rate
            last_best_save_ep = episode
            os.makedirs(str(model_dir), exist_ok=True)
            ppo_agent.save(f"{model_dir}/ppo_best0112_{episode}.pth")
        # if episode % save_interval == 0 or episode == episodes - 1:
        #     ppo_agent.save(f"{model_dir}/ppo_model_{episode}.pth")

        # ---------- 打印 ----------
        if episode % log_interval == 0:
            avg_reward = np.mean(episode_rewards_history[-100:])
            print(f"回合 {episode}/{episodes}")
            print(f"  累计奖励: {episode_reward:.2f}")
            print(f"  步数: {episode_steps}")
            print(f"  平均奖励(100): {avg_reward:.2f}")
            print(f"  成功率: {success_rate:.3f}")

        if episode % 100 == 0 and enable_log:
            print(f"已完成：{episode / episodes * 100:.1f}%")

    print(f"训练结束。成功：{Win} 次")

    if all_args.get("use_wandb", False):
        run.finish()


*运行训练*

In [30]:
train(agent, 500)


状态处理错误: list index out of range


ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接。