Skip to content

[BugFix] GAE shifted=True: tolerate missing next obs, non-canonical strides, docs#3757

Merged
vmoens merged 4 commits into
gh/vmoens/273/basefrom
gh/vmoens/273/head
May 18, 2026
Merged

[BugFix] GAE shifted=True: tolerate missing next obs, non-canonical strides, docs#3757
vmoens merged 4 commits into
gh/vmoens/273/basefrom
gh/vmoens/273/head

Conversation

@vmoens
Copy link
Copy Markdown
Collaborator

@vmoens vmoens commented May 15, 2026

Stack from ghstack (oldest at bottom):

Three small improvements to shifted-GAE in _call_value_nets:

  • Add _fill_missing_next_inputs(next_data, root_data, in_keys). When
    compact_obs=True drops ("next", obs) from the rollout, the
    shifted-GAE bootstrap path was dereferencing a None key and writing
    NaNs into the assembled batch. Fall back to the root key value at the
    done position for any missing next-input key.

  • Replace data_copy.view(-1) with data_copy.reshape(-1).
    Replay-buffer reads and memmap-backed tensors can expose
    non-canonical strides that view rejects; reshape accepts both.
    Drop the with ... as data_copy_view context that was paired with
    the in-place view — flattened tensors don't need write-back.

  • Hoist ndim resolution out of the inner _call_value_net closure
    so it's available to follow-up cleanups.

  • Docstring update on shifted=True across TD0 / TD1 / TDLambda /
    GAE / VTrace clarifying that the path expects long, contiguous
    trajectory windows with valid boundary next states.

Tests cover: shifted-bootstrap parity between a compact rollout (no
("next", obs)) and a NaN-filled reference for TD0 / TD1 / TDLambda /
GAE, non-canonical strides via transpose(0, 1) through shifted-GAE,
and a vectorized-axis sweep on the recurrent shifted/non-shifted parity
test.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 15, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3757

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 4 New Failures, 1 Cancelled Job, 2 Unrelated Failures

As of commit 8d89292 with merge base 0a01ee8 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vmoens added a commit that referenced this pull request May 18, 2026
…trides, docs

Three small improvements to shifted-GAE in ``_call_value_nets``:

- Add ``_fill_missing_next_inputs(next_data, root_data, in_keys)``. When
  ``compact_obs=True`` drops ``("next", obs)`` from the rollout, the
  shifted-GAE bootstrap path was dereferencing a None key and writing
  NaNs into the assembled batch. Fall back to the root key value at the
  done position for any missing next-input key.

- Replace ``data_copy.view(-1)`` with ``data_copy.reshape(-1)``.
  Replay-buffer reads and memmap-backed tensors can expose
  non-canonical strides that ``view`` rejects; ``reshape`` accepts both.
  Drop the ``with ... as data_copy_view`` context that was paired with
  the in-place ``view`` — flattened tensors don't need write-back.

- Hoist ``ndim`` resolution out of the inner ``_call_value_net`` closure
  so it's available to follow-up cleanups.

- Docstring update on ``shifted=True`` across TD0 / TD1 / TDLambda /
  GAE / VTrace clarifying that the path expects long, contiguous
  trajectory windows with valid boundary next states.

Tests cover: shifted-bootstrap parity between a compact rollout (no
``("next", obs)``) and a NaN-filled reference for TD0 / TD1 / TDLambda /
GAE, non-canonical strides via ``transpose(0, 1)`` through shifted-GAE,
and a vectorized-axis sweep on the recurrent shifted/non-shifted parity
test.


ghstack-source-id: 9712a4c
Pull-Request: #3757
@vmoens vmoens merged commit 8d89292 into gh/vmoens/273/base May 18, 2026
104 of 113 checks passed
@vmoens vmoens deleted the gh/vmoens/273/head branch May 18, 2026 21:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

BugFix CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Objectives

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant