Skip to content

[Feature] Forward generator kwarg through ProbabilisticActor#3704

Merged
vmoens merged 1 commit into
pytorch:mainfrom
vmoens:fix-rng-base-actor
May 6, 2026
Merged

[Feature] Forward generator kwarg through ProbabilisticActor#3704
vmoens merged 1 commit into
pytorch:mainfrom
vmoens:fix-rng-base-actor

Conversation

@vmoens
Copy link
Copy Markdown
Collaborator

@vmoens vmoens commented May 5, 2026

Summary

Follow-up to pytorch/tensordict#1689, which added a generator kwarg to ProbabilisticTensorDictModule. That PR shipped the actual sampling implementation; this PR is the thin pass-through layer needed in TorchRL so that ProbabilisticActor(..., generator=...) actually works.

Without this change, the kwarg flows from ProbabilisticActor into SafeProbabilisticModule via **kwargs and is rejected, because SafeProbabilisticModule.__init__ enumerates its kwargs explicitly and forwards them by name to ProbabilisticTensorDictModule.

Closes the action item from the discussion in pytorch/rl#3701 — separating an agent's RNG stream from its environment's, motivated by Patterson et al., Empirical Design in Reinforcement Learning (arXiv:2304.01315).

Change

  • SafeProbabilisticModule.__init__ (torchrl/modules/tensordict_module/probabilistic.py): adds the generator kwarg and forwards it to super().__init__. Adds a docstring entry.
  • ProbabilisticActor (torchrl/modules/tensordict_module/actors.py): adds a docstring entry. **kwargs already forwards generator, so no signature change is needed.
  • New TestProbabilisticActorGenerator in test/test_actors.py with 7 tests covering Generator object, int seed, isolation from torch.manual_seed, in-place advancement, tensordict-key form (Generator and int / JAX-PRNG variants), and the default (generator=None) path.

Test plan

Surfaces the new ``generator`` kwarg added in pytorch/tensordict#1689
through ``SafeProbabilisticModule`` so ``ProbabilisticActor(...,
generator=...)`` works as advertised. Without this, the kwarg flows from
``ProbabilisticActor`` into ``SafeProbabilisticModule`` via ``**kwargs``
but is rejected because ``SafeProbabilisticModule.__init__`` enumerates
its kwargs explicitly.

The generator may be a ``torch.Generator`` (used in place), an ``int``
(shorthand for ``Generator().manual_seed(int)``), or a ``NestedKey`` to
fetch the generator from the input tensordict on every call. See the
tensordict PR for the full sampling implementation; this PR is a thin
forward + docstring entries on ``SafeProbabilisticModule`` and
``ProbabilisticActor``.

Motivated by pytorch#3701 (separating agent and environment RNG
streams in a meta-environment).
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 5, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3704

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 36c2e71 with merge base 9a8dbae (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 5, 2026
@vmoens vmoens merged commit d5da1e5 into pytorch:main May 6, 2026
130 of 135 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Feature New feature Integrations/torch_geometric Integrations Modules

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant