Skip to content

[Feature] MAPPOLoss + IPPOLoss + MultiAgentGAE + ValueNorm#3748

Merged
vmoens merged 3 commits into
pytorch:mainfrom
theap06:feat/mappo-ippo
May 19, 2026
Merged

[Feature] MAPPOLoss + IPPOLoss + MultiAgentGAE + ValueNorm#3748
vmoens merged 3 commits into
pytorch:mainfrom
theap06:feat/mappo-ippo

Conversation

@theap06
Copy link
Copy Markdown
Contributor

@theap06 theap06 commented May 13, 2026

Context

Multi-agent RL is currently the weakest research surface in torchrl: the only multi-agent loss shipped is QMixerLoss (DQN family, discrete actions). For cooperative continuous-control MARL — where most modern benchmarks live (SMAC, VMAS, PettingZoo MPE, Hanabi, Overcooked) — users have to hand-assemble ClipPPOLoss + manual set_keys(done=("agents", "done"), terminated=("agents", "terminated")) + manual make_value_estimator(GAE, ...). The existing sota-implementations/multiagent/mappo_ippo.py recipe shows what this boilerplate looks like.

This PR adds MAPPO (Yu et al. 2022) and IPPO (de Witt et al. 2020) as first-class objectives, plus the two pieces of supporting infrastructure they need.

What's new

  • torchrl.objectives.multiagent.MAPPOLoss — centralised-critic, decentralised-actor PPO. Subclasses ClipPPOLoss; defaults the value estimator to MultiAgentGAE, defaults normalize_advantage_exclude_dims=(-2,), and optionally accepts a ValueNorm for the critic-stability trick from the paper.
  • torchrl.objectives.multiagent.IPPOLoss — independent-learner counterpart. Each agent has its own local critic; no centralised state required.
  • torchrl.objectives.value.MultiAgentGAEGAE variant that broadcasts team-shared reward / done / terminated (shape [*B, T, 1]) across the agent dim before the vec-GAE recursion, so users don't have to manually replicate signals or override set_keys. New ValueEstimators.MAGAE enum entry.
  • torchrl.modules.ValueNorm — PopArt-style running value normaliser (van Hasselt et al. 2019), used opt-in by MAPPOLoss. Yu et al. 2022 Table 13 credits this trick with the algorithm's strong SMAC results.

Design notes

Two classes instead of a centralized: bool flag. The structural code difference between MAPPO and IPPO is small (~20 lines), but I made them separate named classes rather than a single class with a flag because:

  • The recent feedback on (HER) was explicit about avoiding wrapper-in-wrapper / "sampler-in-sampler" APIs. A boolean flag on a single class is the same pattern shifted to losses.
  • from torchrl.objectives.multiagent import MAPPOLoss is self-documenting; the docstring spells out the full recipe (centralised critic construction, etc.) for each algorithm independently.

MAGAE dispatch in plain PPO / A2C / Reinforce. Adding ValueEstimators.MAGAE to the enum would break every parent test that parametrises over list(ValueEstimators) unless every make_value_estimator knows the new enum value. Two options: (a) update ~29 test parametrisations to skip MAGAE, or (b) have plain PPO / A2C / Reinforce dispatch MAGAE to MultiAgentGAE. I went with (b) — it's ~5 lines per file, leaves the enum exhaustive, and is the right thing semantically (any actor-critic with the right data shapes can use MAGAE).

ValueNorm placement. Lives under torchrl/modules/ rather than torchrl/objectives/utils/ because it's a stateful learnable component that participates in .to(device) / state_dict(). Happy to move if reviewers prefer otherwise.

Out of scope (follow-up)

  • HAPPO / sequential update scheme (Kuba et al. 2022)
  • Multi-Agent Transformer (MAT)
  • Refactoring sota-implementations/multiagent/mappo_ippo.py to use the new classes — left untouched in this PR to keep the blast radius small; can be a one-line follow-up.

Verification

  • pytest test/objectives/test_mappo.py — 16/16 passing. Synthetic-tensordict tests for forward shapes, backward, centralised-vs-decentralised critic semantics, share-params modes, ValueNorm convergence, and critic-loss bounded-ness under 10× reward inflation.
  • pytest test/test_cost.py -k "ppo or qmixer or a2c or reinforce" — 2394/2394 passing (no regressions).
  • Full test_cost.py — 8788 passing, 1 pre-existing unrelated failure (test_exploration_compiletorch.compile + torch.utils.mkldnn deprecation, no MAPPO involvement).
  • examples/multiagent/mappo_vmas.py --algo mappo --frames 200_000 provides a minimal end-to-end smoke recipe on VMAS Navigation.

cc @matteobettini

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 13, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3748

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit a9245f2 with merge base 258dfad (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 13, 2026
@theap06 theap06 force-pushed the feat/mappo-ippo branch from 0d8dfea to d1f1eb0 Compare May 13, 2026 09:47
@Xmaster6y
Copy link
Copy Markdown
Contributor

I think you're right about making MA training more straightforward, but I have some concerns:

  • MultiAgentGAE.forward seems to duplicate most of GAE.forward; maybe an additional level of abstraction is needed.
  • ValueNorm should be more generic and less tied to MAPPO
  • We might need a registry for value estimators instead of enums
  • We should maybe handle/consider potential compatibility issues with MAGAE for other algs

@theap06
Copy link
Copy Markdown
Contributor Author

theap06 commented May 13, 2026

I think you're right about making MA training more straightforward, but I have some concerns:

  • MultiAgentGAE.forward seems to duplicate most of GAE.forward; maybe an additional level of abstraction is needed.
  • ValueNorm should be more generic and less tied to MAPPO
  • We might need a registry for value estimators instead of enums
  • We should maybe handle/consider potential compatibility issues with MAGAE for other algs

I think the layers of abstractions make sense. I think the value estimators would help because of the utilization of the collectors as well. For the compatibility issues, I can write up some test cases to ensure it doesn't impact existing algos.

@itwasabhi
Copy link
Copy Markdown
Contributor

I'd like to file a request for the MAPPO loss if you are open to it. Centralization should happen over agents across agent groups. The current mappo loss (and all torchrl loss modules applied to MARL) only centralizes over agents in the same group.

Theres a few settings where you may want to centralize across groups:

  • heterogenous agents (differenting observation/action specs) make sense to keep agents seperated
  • ad-hoc teamwork, where agents in one group are (for example) following a pre-defined policy.

@theap06
Copy link
Copy Markdown
Contributor Author

theap06 commented May 15, 2026

I'd like to file a request for the MAPPO loss if you are open to it. Centralization should happen over agents across agent groups. The current mappo loss (and all torchrl loss modules applied to MARL) only centralizes over agents in the same group.

Theres a few settings where you may want to centralize across groups:

  • heterogenous agents (differenting observation/action specs) make sense to keep agents seperated
  • ad-hoc teamwork, where agents in one group are (for example) following a pre-defined policy.

Thanks for raising this. I agree cross-group centralization is the right generalization for heterogeneous teams and ad-hoc teamwork. I’d prefer to keep this PR scoped to the current single-group MAPPO path, since the limitation comes from the existing MultiAgentMLP centralization behavior and this PR is already fairly broad.

I’ll open a follow-up issue and take a stab at it separately. Feel free to provide feedback btw! My initial preference is the composable CrossGroupCritic TensorDictModule route, since it keeps the current API untouched and avoids a wider MultiAgentMLP refactor.

@theap06
Copy link
Copy Markdown
Contributor Author

theap06 commented May 18, 2026

@vmoens results for the training
image
image

LGTM!

@theap06
Copy link
Copy Markdown
Contributor Author

theap06 commented May 19, 2026

@vmoens I think the CI is failing due to the commit: Add CUDA support for prioritized replay sampling. This landed on main before this branch was opened. None of the commits in this PR touch any C extension or build system files;the failure reproduces on main independently. Happy to confirm if you want me to rebase and rerun, but I don't think it's blocking for review.

vmoens and others added 3 commits May 19, 2026 10:42
…nnotations

Two small, self-contained compatibility fixes pulled out of the MAPPO
PR so they can be reviewed independently.

1. ``torchrl/modules/tensordict_module/probabilistic.py``:
   ``ProbabilisticTensorDictModule.__init__`` gained a ``generator``
   keyword in tensordict main, but the current stable release doesn't
   have it yet. We probe the signature once at import time and forward
   ``generator=`` only when the underlying class accepts it, so
   ``SafeProbabilisticModule`` keeps working against both tensordict
   versions.

2. ``torchrl/objectives/common.py``: ``LossModule.convert_to_functional``
   was checking ``self.__class__.__annotations__`` directly. Python does
   *not* inherit class annotations through ``__annotations__``, so any
   subclass that didn't redeclare its parent's annotations got a noisy
   ``UserWarning`` per network on every instantiation (4 warnings per
   ``MAPPOLoss(...)``, 6 per ``DDPGLoss(...)``, etc.). The check now
   walks ``type(self).__mro__`` and unions the annotations from every
   base class.

Co-authored-by: Cursor <cursoragent@cursor.com>
Adds small extension points the new multi-agent classes need from their
parent classes, with identity defaults so existing callers are
unaffected.

- ``torchrl/objectives/value/advantages.py``: ``GAE`` grows three hook
  methods consumed by both ``forward`` and ``value_estimate``:
    * ``_prepare_signals(reward, done, terminated, value)`` -- identity
      by default; ``MultiAgentGAE`` overrides it to broadcast team-shared
      signals along the agent dim.
    * ``_broadcast_optional(tensor, value)`` -- applied to the
      ``truncated`` tensor along the ``auto_reset_env`` path.
    * ``_normalize_advantage(adv)`` -- default reproduces the previous
      global ``mean / std`` normalisation; ``MultiAgentGAE`` keeps the
      agent dim independent.

  Refactoring ``GAE.forward`` and ``GAE.value_estimate`` to call the
  hooks costs ~5 LOC and lets the multi-agent subclass drop ~120 lines
  of copy-pasted vec-GAE plumbing.

- ``torchrl/objectives/ppo.py``: ``PPOLoss.loss_critic`` gains a single
  hook ``_critic_loss_inputs(target_return, state_value,
  old_state_value)`` that lets subclasses transform the three tensors
  uniformly before the ``distance_loss`` + ``_clip_value_loss`` block.
  Default is identity -- exactly the previous behaviour. The MAPPO
  ``ValueNorm`` integration plugs in here so it composes cleanly with
  ``clip_value``, ``separate_losses``, and ``log_explained_variance``
  instead of duplicating the whole body.

Co-authored-by: Cursor <cursoragent@cursor.com>
Adds first-class objectives for cooperative continuous-control MARL,
plus the supporting infrastructure they need.

## What's new

- ``torchrl.objectives.multiagent.MAPPOLoss`` -- centralised-critic,
  decentralised-actor PPO (Yu et al. 2022, https://arxiv.org/abs/2103.01955).
  Subclasses ``ClipPPOLoss``; defaults the value estimator to
  ``MultiAgentGAE``, defaults ``normalize_advantage_exclude_dims=(-2,)``,
  and optionally accepts a ``ValueNorm`` for the critic-stability trick
  from the paper (Table 13).

- ``torchrl.objectives.multiagent.IPPOLoss`` -- independent-learner
  counterpart (de Witt et al. 2020, https://arxiv.org/abs/2011.09533).
  Each agent has its own local critic; no centralised state required.

- ``torchrl.objectives.value.MultiAgentGAE`` -- ``GAE`` subclass that
  broadcasts team-shared ``reward`` / ``done`` / ``terminated`` (shape
  ``[*B, T, 1]``) across the agent dim before the vec-GAE recursion, so
  users don't have to manually replicate signals. Per-agent rewards
  pass through unchanged (competitive settings). New
  ``ValueEstimators.MAGAE`` enum entry.

- ``torchrl.modules.ValueNorm`` / ``PopArtValueNorm`` /
  ``RunningValueNorm`` -- abstract ``ValueNorm`` (an ``nn.Module``)
  with two implementations: PopArt-style EMA (van Hasselt et al. 2019,
  https://arxiv.org/abs/1809.04474) and exact Welford running stats.
  Plugs into ``MAPPOLoss(value_norm=...)`` and normalises both the
  critic target and the prediction so the MSE / smooth-L1 distance
  stays on a fixed scale as reward scales drift. Composes correctly
  with the parent ``ClipPPOLoss`` features (``clip_value``,
  ``separate_losses``, ``log_explained_variance``) via the new
  ``_critic_loss_inputs`` hook.

## Design notes

**Two classes, no centralized boolean flag.** The MAPPO / IPPO
structural difference is small but the construction recipes differ in
ways that matter (centralised critic vs. per-agent critic), so we
expose them as separate named classes; the docstring on each spells
out the recipe explicitly.

**MAGAE dispatch in plain PPO / A2C / Reinforce.** Adding
``ValueEstimators.MAGAE`` to the enum would break every parent test
that parametrises over ``list(ValueEstimators)`` unless every
``make_value_estimator`` knows the new value. We dispatch MAGAE
through ``MultiAgentGAE`` in those losses (~5 lines per file). When
the registry from pytorch#3780 lands, the explicit branches collapse to a
single ``@register_value_estimator`` decorator.

**ValueNorm placement.** Lives under ``torchrl/modules/`` rather than
``torchrl/objectives/utils/`` because it's a stateful learnable
component that participates in ``.to(device)`` / ``state_dict()``.

## Verification

- ``pytest test/objectives/test_mappo.py`` -- 25/25 passing.
  Includes regression tests for:
    * ``value_norm`` registered exactly once in ``state_dict()``
      (no duplicate ``_value_norm_module.*`` keys);
    * ``state_dict()`` save / load round-trips the running stats;
    * ``clip_value`` + ``value_norm`` still produces
      ``value_clip_fraction`` in the output;
    * ``separate_losses=True`` + ``value_norm`` correctly detaches
      so critic-loss grads do not flow into actor params;
    * no spurious annotation warnings on instantiation.
- ``pytest test/objectives/`` -- 7280 passing, 2319 skipped
  (no regressions).
- ``examples/multiagent/mappo_vmas.py --algo mappo --frames 200_000``
  provides a minimal end-to-end VMAS Navigation smoke recipe.

Co-authored-by: Cursor <cursoragent@cursor.com>
@vmoens vmoens force-pushed the feat/mappo-ippo branch from 56bfdb4 to a9245f2 Compare May 19, 2026 09:44
@vmoens
Copy link
Copy Markdown
Collaborator

vmoens commented May 19, 2026

@theap06 I took the liberty of pushing the review fixes directly to your branch (the PR was marked as maintainer_can_modify: true). Force-pushed with --force-with-lease. The branch now has 3 commits:

  1. [Compat] — generator-kwarg guard on SafeProbabilisticModule (cleaned up: now uses a precomputed extra_kwargs dict instead of the conditional unpack trick) + walks the MRO in LossModule.convert_to_functional's annotation check, so subclasses don't have to redeclare their parents' annotations. That second fix is what was producing the 4 UserWarnings per MAPPOLoss(...) instantiation.

  2. [Refactor] — exposes a single _critic_loss_inputs(target_return, state_value, old_state_value) hook on PPOLoss.loss_critic (identity by default) so that MAPPOLoss can inject PopArt-style normalisation without duplicating the parent body. This means value_norm now composes correctly with clip_value, separate_losses=True, and log_explained_variance — none of which worked in the original loss_critic override. Same idea drives the _prepare_signals / _broadcast_optional / _normalize_advantage hooks on GAE that the original PR already had, just cleaned up.

  3. [Feature] — the actual MAPPO/IPPO/MultiAgentGAE/ValueNorm work. Changes from the original:

    • MAPPOLoss.__init__ no longer calls add_module("_value_norm_module", ...)nn.Module.__setattr__ already registers it. state_dict() now has exactly three value_norm.* keys instead of six (value_norm.* + _value_norm_module.*).
    • IPPOLoss is now a 1-line subclass of MAPPOLoss; the two __init__s were copy-pasted before.
    • RunningValueNorm.update uses math.prod instead of torch.tensor([...]).prod() (no host-device sync). The misleading variance comment is also fixed.
    • examples/multiagent/mappo_vmas.py — dropped the dead if False else branch and fixed the set_keys(done=("agents", "done"), terminated=("agents", "terminated")) call so the example actually works on VMAS (which writes per-agent done flags).

The registry refactor that was bundled into the original PR is now its own PR — #3780 — extended to convert every loss in torchrl.objectives (not just PPO/A2C/Reinforce), so @register_value_estimator actually delivers on its promise.

New test coverage (5 regression tests) all targeting bugs from the review:

  • test_value_norm_not_double_registered — exactly 3 keys in state_dict().
  • test_value_norm_state_dict_round_trip — save/load round-trips the running stats.
  • test_value_norm_composes_with_clip_valuevalue_clip_fraction shows up in the output dict even with value_norm attached.
  • test_value_norm_separate_losses_detaches_actor_grads — critic loss does not backprop into actor params.
  • test_no_spurious_annotation_warnings_on_instantiation — verifies the LossModule annotation fix above.

Full test suite: pytest test/objectives/ passes 7280 tests, 0 regressions.

Diff is down from 1794+/171- to 1637+/38- across 24 files.

Copy link
Copy Markdown
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks

@vmoens vmoens merged commit 126c628 into pytorch:main May 19, 2026
110 checks passed
@matteobettini
Copy link
Copy Markdown
Contributor

This looks cool but am i missing something or the 2 losses in https://github.com/pytorch/rl/blob/main/torchrl/objectives/multiagent/mappo.py look the same?

@theap06
Copy link
Copy Markdown
Contributor Author

theap06 commented May 19, 2026

@matteobettini Nope, you're not missing anything! They're intentionally the same! The only real difference is what critic you pass in. MAPPOLoss expects a centralised critic, IPPOLoss expects a decentralised one, but the loss math is identical either way.Exposed them as separate classes mostly so the API makes it obvious which algorithm you're running. Happy to collapse it to IPPOLoss = MAPPOLoss or add some actual structural difference in a follow-up if you think that's cleaner, just lmk!

@matteobettini
Copy link
Copy Markdown
Contributor

@matteobettini Nope, you're not missing anything! They're intentionally the same! The only real difference is what critic you pass in. MAPPOLoss expects a centralised critic, IPPOLoss expects a decentralised one, but the loss math is identical either way.Exposed them as separate classes mostly so the API makes it obvious which algorithm you're running. Happy to collapse it to IPPOLoss = MAPPOLoss or add some actual structural difference in a follow-up if you think that's cleaner, just lmk!

Mmh i am not sure I am a big fan of 2 names for the same thing. The risk here is that people might think changing the loss name changes the loss while it does not actually change anything and the change has to be done somewhere else.

@theap06
Copy link
Copy Markdown
Contributor Author

theap06 commented May 20, 2026

@matteobettini I can draft up a follow up pr addressing this if you would like later tonight.

@matteobettini
Copy link
Copy Markdown
Contributor

matteobettini commented May 20, 2026

@matteobettini I can draft up a follow up pr addressing this if you would like later tonight.

What do you think? we could (1) drop IPPOLoss keeping only MAPPOLoss or (2) keep only one called MultiPPOLoss

@theap06
Copy link
Copy Markdown
Contributor Author

theap06 commented May 20, 2026

@matteobettini I can draft up a follow up pr addressing this if you would like later tonight.

What do you think? we could (1) drop IPPOLoss keeping only MAPPOLoss or (2) keep only one called MultiPPOLoss

I was thinking we just create a field called MultiPPOLoss and the user can choose.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Documentation Improvements or additions to documentation Examples Feature New feature Integrations/torch_geometric Integrations Modules Objectives

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants