[Feature] MAPPOLoss + IPPOLoss + MultiAgentGAE + ValueNorm#3748
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3748
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit a9245f2 with merge base 258dfad ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
I think you're right about making MA training more straightforward, but I have some concerns:
|
I think the layers of abstractions make sense. I think the value estimators would help because of the utilization of the collectors as well. For the compatibility issues, I can write up some test cases to ensure it doesn't impact existing algos. |
|
I'd like to file a request for the MAPPO loss if you are open to it. Centralization should happen over agents across agent groups. The current mappo loss (and all torchrl loss modules applied to MARL) only centralizes over agents in the same group. Theres a few settings where you may want to centralize across groups:
|
Thanks for raising this. I agree cross-group centralization is the right generalization for heterogeneous teams and ad-hoc teamwork. I’d prefer to keep this PR scoped to the current single-group MAPPO path, since the limitation comes from the existing MultiAgentMLP centralization behavior and this PR is already fairly broad. I’ll open a follow-up issue and take a stab at it separately. Feel free to provide feedback btw! My initial preference is the composable CrossGroupCritic TensorDictModule route, since it keeps the current API untouched and avoids a wider MultiAgentMLP refactor. |
|
@vmoens results for the training LGTM! |
|
@vmoens I think the CI is failing due to the commit: Add CUDA support for prioritized replay sampling. This landed on main before this branch was opened. None of the commits in this PR touch any C extension or build system files;the failure reproduces on main independently. Happy to confirm if you want me to rebase and rerun, but I don't think it's blocking for review. |
…nnotations Two small, self-contained compatibility fixes pulled out of the MAPPO PR so they can be reviewed independently. 1. ``torchrl/modules/tensordict_module/probabilistic.py``: ``ProbabilisticTensorDictModule.__init__`` gained a ``generator`` keyword in tensordict main, but the current stable release doesn't have it yet. We probe the signature once at import time and forward ``generator=`` only when the underlying class accepts it, so ``SafeProbabilisticModule`` keeps working against both tensordict versions. 2. ``torchrl/objectives/common.py``: ``LossModule.convert_to_functional`` was checking ``self.__class__.__annotations__`` directly. Python does *not* inherit class annotations through ``__annotations__``, so any subclass that didn't redeclare its parent's annotations got a noisy ``UserWarning`` per network on every instantiation (4 warnings per ``MAPPOLoss(...)``, 6 per ``DDPGLoss(...)``, etc.). The check now walks ``type(self).__mro__`` and unions the annotations from every base class. Co-authored-by: Cursor <cursoragent@cursor.com>
Adds small extension points the new multi-agent classes need from their
parent classes, with identity defaults so existing callers are
unaffected.
- ``torchrl/objectives/value/advantages.py``: ``GAE`` grows three hook
methods consumed by both ``forward`` and ``value_estimate``:
* ``_prepare_signals(reward, done, terminated, value)`` -- identity
by default; ``MultiAgentGAE`` overrides it to broadcast team-shared
signals along the agent dim.
* ``_broadcast_optional(tensor, value)`` -- applied to the
``truncated`` tensor along the ``auto_reset_env`` path.
* ``_normalize_advantage(adv)`` -- default reproduces the previous
global ``mean / std`` normalisation; ``MultiAgentGAE`` keeps the
agent dim independent.
Refactoring ``GAE.forward`` and ``GAE.value_estimate`` to call the
hooks costs ~5 LOC and lets the multi-agent subclass drop ~120 lines
of copy-pasted vec-GAE plumbing.
- ``torchrl/objectives/ppo.py``: ``PPOLoss.loss_critic`` gains a single
hook ``_critic_loss_inputs(target_return, state_value,
old_state_value)`` that lets subclasses transform the three tensors
uniformly before the ``distance_loss`` + ``_clip_value_loss`` block.
Default is identity -- exactly the previous behaviour. The MAPPO
``ValueNorm`` integration plugs in here so it composes cleanly with
``clip_value``, ``separate_losses``, and ``log_explained_variance``
instead of duplicating the whole body.
Co-authored-by: Cursor <cursoragent@cursor.com>
Adds first-class objectives for cooperative continuous-control MARL, plus the supporting infrastructure they need. ## What's new - ``torchrl.objectives.multiagent.MAPPOLoss`` -- centralised-critic, decentralised-actor PPO (Yu et al. 2022, https://arxiv.org/abs/2103.01955). Subclasses ``ClipPPOLoss``; defaults the value estimator to ``MultiAgentGAE``, defaults ``normalize_advantage_exclude_dims=(-2,)``, and optionally accepts a ``ValueNorm`` for the critic-stability trick from the paper (Table 13). - ``torchrl.objectives.multiagent.IPPOLoss`` -- independent-learner counterpart (de Witt et al. 2020, https://arxiv.org/abs/2011.09533). Each agent has its own local critic; no centralised state required. - ``torchrl.objectives.value.MultiAgentGAE`` -- ``GAE`` subclass that broadcasts team-shared ``reward`` / ``done`` / ``terminated`` (shape ``[*B, T, 1]``) across the agent dim before the vec-GAE recursion, so users don't have to manually replicate signals. Per-agent rewards pass through unchanged (competitive settings). New ``ValueEstimators.MAGAE`` enum entry. - ``torchrl.modules.ValueNorm`` / ``PopArtValueNorm`` / ``RunningValueNorm`` -- abstract ``ValueNorm`` (an ``nn.Module``) with two implementations: PopArt-style EMA (van Hasselt et al. 2019, https://arxiv.org/abs/1809.04474) and exact Welford running stats. Plugs into ``MAPPOLoss(value_norm=...)`` and normalises both the critic target and the prediction so the MSE / smooth-L1 distance stays on a fixed scale as reward scales drift. Composes correctly with the parent ``ClipPPOLoss`` features (``clip_value``, ``separate_losses``, ``log_explained_variance``) via the new ``_critic_loss_inputs`` hook. ## Design notes **Two classes, no centralized boolean flag.** The MAPPO / IPPO structural difference is small but the construction recipes differ in ways that matter (centralised critic vs. per-agent critic), so we expose them as separate named classes; the docstring on each spells out the recipe explicitly. **MAGAE dispatch in plain PPO / A2C / Reinforce.** Adding ``ValueEstimators.MAGAE`` to the enum would break every parent test that parametrises over ``list(ValueEstimators)`` unless every ``make_value_estimator`` knows the new value. We dispatch MAGAE through ``MultiAgentGAE`` in those losses (~5 lines per file). When the registry from pytorch#3780 lands, the explicit branches collapse to a single ``@register_value_estimator`` decorator. **ValueNorm placement.** Lives under ``torchrl/modules/`` rather than ``torchrl/objectives/utils/`` because it's a stateful learnable component that participates in ``.to(device)`` / ``state_dict()``. ## Verification - ``pytest test/objectives/test_mappo.py`` -- 25/25 passing. Includes regression tests for: * ``value_norm`` registered exactly once in ``state_dict()`` (no duplicate ``_value_norm_module.*`` keys); * ``state_dict()`` save / load round-trips the running stats; * ``clip_value`` + ``value_norm`` still produces ``value_clip_fraction`` in the output; * ``separate_losses=True`` + ``value_norm`` correctly detaches so critic-loss grads do not flow into actor params; * no spurious annotation warnings on instantiation. - ``pytest test/objectives/`` -- 7280 passing, 2319 skipped (no regressions). - ``examples/multiagent/mappo_vmas.py --algo mappo --frames 200_000`` provides a minimal end-to-end VMAS Navigation smoke recipe. Co-authored-by: Cursor <cursoragent@cursor.com>
|
@theap06 I took the liberty of pushing the review fixes directly to your branch (the PR was marked as
The registry refactor that was bundled into the original PR is now its own PR — #3780 — extended to convert every loss in New test coverage (5 regression tests) all targeting bugs from the review:
Full test suite: Diff is down from 1794+/171- to 1637+/38- across 24 files. |
|
This looks cool but am i missing something or the 2 losses in https://github.com/pytorch/rl/blob/main/torchrl/objectives/multiagent/mappo.py look the same? |
|
@matteobettini Nope, you're not missing anything! They're intentionally the same! The only real difference is what critic you pass in. MAPPOLoss expects a centralised critic, IPPOLoss expects a decentralised one, but the loss math is identical either way.Exposed them as separate classes mostly so the API makes it obvious which algorithm you're running. Happy to collapse it to IPPOLoss = MAPPOLoss or add some actual structural difference in a follow-up if you think that's cleaner, just lmk! |
Mmh i am not sure I am a big fan of 2 names for the same thing. The risk here is that people might think changing the loss name changes the loss while it does not actually change anything and the change has to be done somewhere else. |
|
@matteobettini I can draft up a follow up pr addressing this if you would like later tonight. |
What do you think? we could (1) drop IPPOLoss keeping only MAPPOLoss or (2) keep only one called |
I was thinking we just create a field called MultiPPOLoss and the user can choose. |


Context
Multi-agent RL is currently the weakest research surface in torchrl: the only multi-agent loss shipped is
QMixerLoss(DQN family, discrete actions). For cooperative continuous-control MARL — where most modern benchmarks live (SMAC, VMAS, PettingZoo MPE, Hanabi, Overcooked) — users have to hand-assembleClipPPOLoss+ manualset_keys(done=("agents", "done"), terminated=("agents", "terminated"))+ manualmake_value_estimator(GAE, ...). The existingsota-implementations/multiagent/mappo_ippo.pyrecipe shows what this boilerplate looks like.This PR adds MAPPO (Yu et al. 2022) and IPPO (de Witt et al. 2020) as first-class objectives, plus the two pieces of supporting infrastructure they need.
What's new
torchrl.objectives.multiagent.MAPPOLoss— centralised-critic, decentralised-actor PPO. SubclassesClipPPOLoss; defaults the value estimator toMultiAgentGAE, defaultsnormalize_advantage_exclude_dims=(-2,), and optionally accepts aValueNormfor the critic-stability trick from the paper.torchrl.objectives.multiagent.IPPOLoss— independent-learner counterpart. Each agent has its own local critic; no centralised state required.torchrl.objectives.value.MultiAgentGAE—GAEvariant that broadcasts team-sharedreward/done/terminated(shape[*B, T, 1]) across the agent dim before the vec-GAE recursion, so users don't have to manually replicate signals or overrideset_keys. NewValueEstimators.MAGAEenum entry.torchrl.modules.ValueNorm— PopArt-style running value normaliser (van Hasselt et al. 2019), used opt-in byMAPPOLoss. Yu et al. 2022 Table 13 credits this trick with the algorithm's strong SMAC results.Design notes
Two classes instead of a
centralized: boolflag. The structural code difference between MAPPO and IPPO is small (~20 lines), but I made them separate named classes rather than a single class with a flag because:from torchrl.objectives.multiagent import MAPPOLossis self-documenting; the docstring spells out the full recipe (centralised critic construction, etc.) for each algorithm independently.MAGAE dispatch in plain PPO / A2C / Reinforce. Adding
ValueEstimators.MAGAEto the enum would break every parent test that parametrises overlist(ValueEstimators)unless everymake_value_estimatorknows the new enum value. Two options: (a) update ~29 test parametrisations to skip MAGAE, or (b) have plain PPO / A2C / Reinforce dispatch MAGAE toMultiAgentGAE. I went with (b) — it's ~5 lines per file, leaves the enum exhaustive, and is the right thing semantically (any actor-critic with the right data shapes can use MAGAE).ValueNorm placement. Lives under
torchrl/modules/rather thantorchrl/objectives/utils/because it's a stateful learnable component that participates in.to(device)/state_dict(). Happy to move if reviewers prefer otherwise.Out of scope (follow-up)
sota-implementations/multiagent/mappo_ippo.pyto use the new classes — left untouched in this PR to keep the blast radius small; can be a one-line follow-up.Verification
pytest test/objectives/test_mappo.py— 16/16 passing. Synthetic-tensordict tests for forward shapes, backward, centralised-vs-decentralised critic semantics, share-params modes, ValueNorm convergence, and critic-loss bounded-ness under 10× reward inflation.pytest test/test_cost.py -k "ppo or qmixer or a2c or reinforce"— 2394/2394 passing (no regressions).test_cost.py— 8788 passing, 1 pre-existing unrelated failure (test_exploration_compile—torch.compile+torch.utils.mkldnndeprecation, no MAPPO involvement).examples/multiagent/mappo_vmas.py --algo mappo --frames 200_000provides a minimal end-to-end smoke recipe on VMAS Navigation.cc @matteobettini