-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Short-term fix to preserve NJT metadata cache in torch.compile #122836
base: gh/jbschlosser/131/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122836
Note: Links to docs will display an error until the docs builds have been completed. ❌ 13 New Failures, 1 PendingAs of commit 06e068c with merge base edb45dc (): NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…pile" Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
@@ -6200,6 +6200,16 @@ | |||
device_check: NoCheck | |||
dispatch: {} | |||
|
|||
- func: _nested_get_min_seqlen(Tensor self) -> int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these need to be symbolic to work with dynamic shapes. TBD exploring this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs more thought / design work to fully cover dynamic shapes. @YuqingJ mentioned to me that a static max length is being used for now, so maybe we can get a patch to work for FIRST without this, but it's iffy.
…pile" Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
@clee2000 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…pile" Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
@jbschlosser has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
return torch.ones_like(nt) * nt.max_seqlen() | ||
|
||
for dynamic in [False, True]: | ||
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jfyi there's with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
though _recompiles_for_inputs
seems generally useful if you also want to make sure cases where recompiles=True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I lifted _recompiles_for_inputs()
from some other test file, which isn't great :p
test/test_nestedtensor.py
Outdated
@@ -4339,6 +4353,115 @@ def forward(self, query, value, offsets): | |||
self.assertTrue(torch.allclose(attn_output_eager, attn_output)) | |||
self.assertTrue(torch.allclose(value_grad, value.grad)) | |||
|
|||
@dtypes(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you know if this will work with a max_seqlen manually assigned in the graph? (and if so, can you add a test / or add as a followup)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good Q! this works with manual assignment as done in FIRST (i.e. via convert_jagged_to_nested_tensor() -> ViewNestedFromBuffer
with the metadata cache info). Also see test_dummy_mha_with_nt
which tries to match as closely as possible to what the FIRST model does.
We should arguably support and test manual max seq len assignment for the new API. I'll address this or add a TODO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: I think I'll boot handling manual assignment outside of the FIRST use case until we can have public @property
getters / setters for min / max sequence length that work with PT2 (this is in progress)
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
@@ -6185,12 +6185,12 @@ | |||
CompositeExplicitAutogradNonFunctional: _nested_view_from_buffer_copy | |||
autogen: _nested_view_from_buffer_copy.out | |||
|
|||
- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a) | |||
- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quick question, do we want to eventually make max/min_seqlen available from the torch.nested.nested_tensor_from_jagged API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We definitely need a way for the user to provide these manually. Adding them to the nested_tensor_from_jagged()
API is one way to do this. One additional way / alternative is to provide public setter @properties
for min / max sequence length. Some better Dynamo tracing support for @properties
landed recently so this may be feasible now.
I was originally thinking we'd only allow manual setting through setter properties, but I could probably be convinced that the nested_tensor_from_jagged()
API should accept these too. Opinions on this?
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Idea: close over min / max sequence length in the main NJT view func (
_nested_view_from_jagged
) so that view replay during fake-ification propagates these correctly in torch.compile.For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in
(val, 0)
shaped tensors.NB: This PR changes SDPA to operate on real views instead of using
buffer_from_jagged()
/ViewNestedFromBuffer
, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.Differential Revision: D55448636