Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Short-term fix to preserve NJT metadata cache in torch.compile #122836

Open
wants to merge 45 commits into
base: gh/jbschlosser/131/base
Choose a base branch
from

Conversation

jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented Mar 27, 2024

Stack from ghstack (oldest at bottom):

Idea: close over min / max sequence length in the main NJT view func (_nested_view_from_jagged) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in (val, 0) shaped tensors.

NB: This PR changes SDPA to operate on real views instead of using buffer_from_jagged() / ViewNestedFromBuffer, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.

Differential Revision: D55448636

Copy link

pytorch-bot bot commented Mar 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122836

Note: Links to docs will display an error until the docs builds have been completed.

❌ 13 New Failures, 1 Pending

As of commit 06e068c with merge base edb45dc (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@davidberard98
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -6200,6 +6200,16 @@
device_check: NoCheck
dispatch: {}

- func: _nested_get_min_seqlen(Tensor self) -> int
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these need to be symbolic to work with dynamic shapes. TBD exploring this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs more thought / design work to fully cover dynamic shapes. @YuqingJ mentioned to me that a static max length is being used for now, so maybe we can get a patch to work for FIRST without this, but it's iffy.

@clee2000
Copy link
Contributor

clee2000 commented Apr 2, 2024

@clee2000 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jbschlosser
Copy link
Contributor Author

@jbschlosser has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

[ghstack-poisoned]
@jbschlosser jbschlosser added the topic: not user facing topic category label May 1, 2024
[ghstack-poisoned]
return torch.ones_like(nt) * nt.max_seqlen()

for dynamic in [False, True]:
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jfyi there's with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): though _recompiles_for_inputs seems generally useful if you also want to make sure cases where recompiles=True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I lifted _recompiles_for_inputs() from some other test file, which isn't great :p

test/test_nestedtensor.py Outdated Show resolved Hide resolved
@@ -4339,6 +4353,115 @@ def forward(self, query, value, offsets):
self.assertTrue(torch.allclose(attn_output_eager, attn_output))
self.assertTrue(torch.allclose(value_grad, value.grad))

@dtypes(torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know if this will work with a max_seqlen manually assigned in the graph? (and if so, can you add a test / or add as a followup)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good Q! this works with manual assignment as done in FIRST (i.e. via convert_jagged_to_nested_tensor() -> ViewNestedFromBuffer with the metadata cache info). Also see test_dummy_mha_with_nt which tries to match as closely as possible to what the FIRST model does.

We should arguably support and test manual max seq len assignment for the new API. I'll address this or add a TODO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: I think I'll boot handling manual assignment outside of the FIRST use case until we can have public @property getters / setters for min / max sequence length that work with PT2 (this is in progress)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
@@ -6185,12 +6185,12 @@
CompositeExplicitAutogradNonFunctional: _nested_view_from_buffer_copy
autogen: _nested_view_from_buffer_copy.out

- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a)
- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick question, do we want to eventually make max/min_seqlen available from the torch.nested.nested_tensor_from_jagged API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We definitely need a way for the user to provide these manually. Adding them to the nested_tensor_from_jagged() API is one way to do this. One additional way / alternative is to provide public setter @properties for min / max sequence length. Some better Dynamo tracing support for @properties landed recently so this may be feasible now.

I was originally thinking we'd only allow manual setting through setter properties, but I could probably be convinced that the nested_tensor_from_jagged() API should accept these too. Opinions on this?

…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants