解决长链路多工具智能体中的 GRPO 训练崩溃问题
在 τ-bench airline(50 任务、多轮对话、多工具调用场景)上的系统性消融实验,通过创新的 PRM-Lite + LATA 联合方案,相较 Vanilla GRPO 基线实现 +37% 的整体 pass^1 提升。Swanlab训练曲线链接:https://swanlab.cn/@godstear/agentic-grpo-longhorizon?utm_source=website_qr&utm_medium=qr_scan
最优检查点(step 250):联合方案(PRM-Lite + LATA)达到 0.240 的整体 pass^1 —— 相较 Vanilla GRPO 基线(0.175)提升 +37%。
| 指标 | Vanilla | Turn-Discount | PRM-Lite | LATA | 联合方案 | 相较 Vanilla |
|---|---|---|---|---|---|---|
| 整体 pass^1 | 0.175 | 0.125 | 0.140 | 0.185 | 0.240 | +37% |
| 泛化能力 | 0.071 | 0.052 | 0.059 | 0.088 | 0.110 | +55% |
| 错误率 | 0.200 | 0.345 | 0.365 | 0.290 | 0.140 | −30% |
| 推理深度 (p50 tokens) | 72 | 245 | 169 | 183 | 313 | +334% |
泛化 pass^1 = (uncovered_seen × 24 + unseen × 10) / 34,排除训练集泄漏影响的核心指标。
观察:Turn-Discount 停滞(被动保护);LATA 通过 √L 归一化持续增长;联合方案在 step 250 达到峰值(0.240),step 300 温和回落至 0.225。
标准的 GRPO(Group Relative Policy Optimization) 应用于长链路、多工具对话智能体(τ-bench airline,50 任务,40 训练 / 10 测试)时会发生灾难性的训练崩溃。我们识别出三个根本原因:
结果奖励是二元的(0/1)。当 group_size=8 时,群组容易达到全 0 或全 1 状态 → 优势方差 → 0 → 梯度消失。
40 个训练任务中有 16 个是 covered_seen(有 72B 教师轨迹覆盖)。策略记住教师模式,covered 性能虚高,而 uncovered/unseen 几乎为零。
线性长度归一化 advantage / L 惩罚长回复。策略学会"以量补质" —— 短推理 + 频繁工具试错 —— 导致 step 150 后崩溃。
关键发现:训练验证奖励不是可靠代理。Turn-Discount 报告验证奖励 0.80,但真实评测仅 0.125 —— 6.4 倍差距。
我们设计并验证了四个消融实验,每个针对特定的失效模式:
思路:通过指数衰减的 token 权重保护早期轮次推理。
机制:weight[t] = α^(L-1-t),其中 α=1.05,归一化至 mean(weight)=1。早期 token 获得更高优势,抑制晚期猜测。
结果:成功阻止崩溃形状(回复长度 −23% vs Vanilla −63%),但评测仍然较低(0.125),因为缺乏质量引导。
思路:用 1/√L 替换线性 1/L 归一化,保留长推理的边际激励。
机制:advantage_token = A / sqrt(L) 替代 A / L。当回复长度增长 4 倍时,每 token 梯度仅减半(Vanilla 中降至四分之一)。
结果:持续提升(0.155 → 0.185 → 0.190),错误率从 0.345 降至 0.290,但无质量信号时可见天花板。
思路:用密集的基于规则的过程奖励打破群组饱和。
机制:15 条手工规则(P1–P8 惩罚,B1–B7 奖励)提供连续的 [-0.5, +0.5] 信号。最终奖励 = outcome + 0.3 × process_score。
结果:成功消除 score/min = 0/1 死锁,但信号被轨迹级线性归一化稀释 —— 错误率反而恶化至 0.365。
思路:PRM-Lite 提供局部质量信号;LATA 的 √L 归一化确保这些信号不被回复长度淹没。
机制:每轮过程分数的惩罚/奖励通过 A/√L 传播到单个 token,使策略能够学习哪一轮错了,而不仅仅是整个轨迹是否成功。
结果:0.240 整体表现 —— 超越所有单组件基线。错误率唯一持续下降(0.170 → 0.140 → 0.120)。Unseen 任务表现转正并稳定。
核心洞察:价值不在于单独拥有过程奖励或更好的归一化 —— 而在于信号传播。PRM-Lite 生成局部信号;LATA 的 √L 提供传输通道。两者单独使用效果均不佳。
消融报告实证证明了长链路智能体 GRPO 的分解原理:
- 信号源(PRM-Lite):15 条手工规则提供密集的每轮质量信号
[-0.5, +0.5]。 - 信号通路(LATA):
advantage / √L替代advantage / L,防止回复长度稀释。 - 单独失效:PRM-Lite 单独(0.140 整体,0.365 错误)— 信号被淹没。LATA 单独(0.185 整体)— 无信号源。只有联合方案解锁 0.240。
此分解是模型无关的,适用于 τ-bench 之外的任何长形式 RL 任务。
完全可解释、零可训练参数的过程奖励模型:
- P1–P8 惩罚:占位符(−0.05)、冗余(−0.03)、错误重复(−0.04)、无推理(−0.05)
- B1–B7 奖励:恢复(+0.05)、数据链(+0.08)、读取多样性(+0.01)、思考奖励(条件触发)
- 反黑客防御:条件思考评分、基于 schema 的实体提取、n_tools > 8 时的长度惩罚
- Bypass 模式 + 融合内核 + TP=2 将每步内存峰值从 OOM 降至 73.2 GB,在 2×A800 上实现 7B 策略 + 72B-AWQ 模拟器。
- 离线优先:所有脚本注入
HF_HUB_OFFLINE=1,适用于隔离的 HPC 集群。 - Render-Twice-Diff SFT:一种模板无关的多轮工具调用损失掩码方法,避免 off-by-one token 错误。
| 实验 | Step | 整体 | 泛化 pass^1 | 错误率 | per_turn p50 | 备注 |
|---|---|---|---|---|---|---|
| Vanilla | 200 | 0.175 | 0.071 | 0.200 | 72 | 崩溃基线 |
| Turn-Discount | 250 | 0.125 | 0.052 | 0.345 | 245 | 被动保护 |
| PRM-Lite | 250 | 0.140 | 0.059 | 0.365 | 169 | 信号阻塞 |
| LATA | 250 | 0.185 | 0.088 | 0.290 | 183 | √L 增益 |
| 联合方案 | 250 | 0.240 | 0.110 | 0.140 | 313 | 最优检查点 |
| 假设 | 状态 | 证据 |
|---|---|---|
| H1: Turn-Discount 阻止推理崩溃 | ✅ 已验证 | 回复长度 −23%,无断崖 |
| H2: 联合方案打破群组饱和 | ✅ 已验证 | 300 步内 score/min 从未 0/1 |
| H3: 联合方案改善 OOD 泛化 | ✅ 已验证 | unseen 为正(0.15–0.175) |
| H4: LATA 优于 Turn-Discount | ✅ 已验证 | +0.060 整体,−0.055 错误 |
| H5: 联合方案 > max(单组件) | ✅ 已验证 | 0.240 > 0.185 > 0.140 > 0.125 |
📦 agentic-grpo-longhorizon/
├── ⚙️ configs/ # 所有实验的 Hydra YAML 配置
│ ├── turn_discount.yaml
│ ├── prm_lite.yaml
│ ├── lata.yaml
│ ├── prm_lite_lata.yaml
│ └── eval/ # 各实验评测配置
├── 💻 src/ # 核心源码
│ ├── 🌍 envs/ # τ-bench 包装器与工具配置
│ │ ├── 🐍 tau_bench_wrapper.py
│ │ ├── 🐍 tau_bench_interaction.py # PRM-Lite 规则引擎
│ │ └── 🐍 tau_bench_tools.py
│ ├── 📊 evaluation/
│ │ └── 🐍 pass_k_eval.py # 独立 pass@k 评测器
│ ├── 🧠 models/
│ │ └── 🐍 vllm_policy.py # vLLM 策略包装器
│ └── 🎓 training/
│ └── 🐍 sft_dataset.py # SFT 数据收集
├── 📜 scripts/
│ ├── 🚀 train/grpo/ # GRPO 训练启动脚本
│ │ ├── 📜 run_exp1_turn_discount.sh
│ │ ├── 📜 run_exp2_lata.sh
│ │ ├── 📜 run_exp3_prm_lite.sh
│ │ ├── 📜 run_exp4_prm_lite_lata.sh
│ │ └── 📜 run_vanilla.sh
│ ├── 📈 eval/ # 独立评测启动脚本
│ │ ├── 📜 eval_exp1_turn_discount.sh
│ │ ├── 📜 eval_exp2_lata.sh
│ │ ├── 📜 eval_exp3_prm_lite.sh
│ │ └── 📜 eval_exp4_prm_lite_lata.sh
│ ├── 🔧 train/sft/ # SFT 预热脚本
│ └── 🖥️ vllm_server/ # vLLM 服务启动脚本
├── 📚 docs/
│ └── 🔬 ablation/
│ ├── 📝 ablation_diagnosis_report.md # 完整诊断报告(≈800 行)
│ ├── 📝 ablation_plan.md # 实验设计手册
│ ├── 🖼️ ablation_comparison.png
│ └── 🖼️ ablation_progression.png
├── 🧪 experiments/ # 检查点、HF 导出、评测输出
├── 📄 requirements.txt
└── 🔨 setup.sh # 一键环境搭建
(实验过程中改了不少源码,以本仓库的verl框架和benchmark为标准)
# 一键搭建(conda + PyTorch 2.7 + CUDA 12.6 + 依赖)
bash setup.sh
conda activate agentrl
cd agentic-grpo-longhorizon
# 或手动安装:
pip install torch==2.7.0 --index-url https://download.pytorch.org/whl/cu126
pip install -r requirements.txt
cd ../tau-bench && pip install -e .
cd ../verl && pip install -e .
cd agentic-grpo-longhorizon# 示例:联合方案(PRM-Lite + LATA)
cd scripts/train/grpo
bash run_exp4_prm_lite_lata.sh
# 或:Vanilla GRPO 基线
bash run_vanilla.sh# 自动评测 step 200/250/300 检查点
cd scripts/eval
bash eval_exp4_prm_lite_lata.sh硬件:2×A800(80GB)。GPU 0 运行 7B 策略 vLLM;GPU 1 运行 72B-AWQ 用户模拟器 vLLM。
离线模式:所有脚本注入HF_HUB_OFFLINE=1和TRANSFORMERS_OFFLINE=1,适用于隔离环境。
| 📄 文档 | 📝 内容 |
|---|---|
docs/ablation/ablation_diagnosis_report.md |
主报告:训练曲线、评测数据、机制分析、假设验证 |
docs/ablation/ablation_plan.md |
实验设计手册:代码实现、PRM-Lite 规则集、黑客风险分析 |
docs/vanilla_grpo/vanilla_grpo_diagnosis.md |
Vanilla GRPO 崩溃诊断:三个根本原因、五个检查点分析 |
../agentic-grpo-longhorizon-blog.md |
🆕 技术博客:从训练崩溃到稳定收敛的完整复盘(PRM-Lite + LATA) |
- 训练框架: veRL 0.6.1 (FSDP + vLLM V1)
- 策略模型: Qwen2.5-7B-Instruct
- 用户模拟器: Qwen2.5-72B-Instruct-AWQ
- 评测基准: τ-bench airline (50 任务)
- 推理引擎: vLLM V1 with tool-call parsing (Hermes)
- 注意力机制: FlashAttention-2
为什么重要:大多数 RLHF/RLAIF 工作聚焦单轮问答或代码生成。本项目 tackle 更难的问题 —— 多轮、多工具、部分可观测的对话智能体 —— 在这里 vanilla GRPO 会灾难性失败。PRM-Lite + LATA 联合设计提供了一条有原理支撑、轻量且可解释的通往稳定训练的路径,无需昂贵的学习式奖励模型。

