-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Short-term fix to preserve NJT metadata cache in torch.compile #122836
Conversation
[ghstack-poisoned]
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…pile" Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
@@ -6200,6 +6200,16 @@ | |||
device_check: NoCheck | |||
dispatch: {} | |||
|
|||
- func: _nested_get_min_seqlen(Tensor self) -> int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these need to be symbolic to work with dynamic shapes. TBD exploring this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs more thought / design work to fully cover dynamic shapes. @YuqingJ mentioned to me that a static max length is being used for now, so maybe we can get a patch to work for FIRST without this, but it's iffy.
…pile" Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
…pile" Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
@clee2000 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…pile" Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) [ghstack-poisoned]
@jbschlosser has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
return torch.ones_like(nt) * nt.max_seqlen() | ||
|
||
for dynamic in [False, True]: | ||
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jfyi there's with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
though _recompiles_for_inputs
seems generally useful if you also want to make sure cases where recompiles=True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I lifted _recompiles_for_inputs()
from some other test file, which isn't great :p
test/test_nestedtensor.py
Outdated
@@ -4339,6 +4353,115 @@ def forward(self, query, value, offsets): | |||
self.assertTrue(torch.allclose(attn_output_eager, attn_output)) | |||
self.assertTrue(torch.allclose(value_grad, value.grad)) | |||
|
|||
@dtypes(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you know if this will work with a max_seqlen manually assigned in the graph? (and if so, can you add a test / or add as a followup)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good Q! this works with manual assignment as done in FIRST (i.e. via convert_jagged_to_nested_tensor() -> ViewNestedFromBuffer
with the metadata cache info). Also see test_dummy_mha_with_nt
which tries to match as closely as possible to what the FIRST model does.
We should arguably support and test manual max seq len assignment for the new API. I'll address this or add a TODO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: I think I'll boot handling manual assignment outside of the FIRST use case until we can have public @property
getters / setters for min / max sequence length that work with PT2 (this is in progress)
wow that's super useful! thanks for the link :) |
Confirmed that the memory leak issues are pre-existing and not introduced by this PR. They need more investigation, but I'm not even convinced these are NJT-related. For example, I saw a memory leak from running this minimal, non-NJT example with def test_memory_leak(self, device):
value = torch.randn(6, 4, requires_grad=True, device=device)
m = torch.nn.Linear(4, 6, device=device)
symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(m)
m = torch.compile(symbolic_traced)
m(value) I'm going to go ahead and reland this with the ROCm / DEBUG=1 assert fixes added. |
…pile" Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
ghstack-source-id: 853d357d36ed3edd4a831b1eccad625d53ac73ff Pull Request resolved: #122836
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: pull / linux-jammy-py3.8-gcc11 / test (distributed, 1, 2, linux.2xlarge) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 3 checks: pull / linux-jammy-py3.8-gcc11 / test (distributed, 1, 2, linux.2xlarge), pull / linux-focal-cuda12.1-py3.10-gcc9-sm86 / test (default, 4, 5, linux.g5.4xlarge.nvidia.gpu), trunk / linux-focal-rocm6.1-py3.8 / test (default, 2, 2, linux.rocm.gpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: inductor / cuda12.1-py3.10-gcc9-sm86 / test (inductor, 1, 1, linux.g5.4xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -i |
Stack from ghstack (oldest at bottom):
Idea: close over min / max sequence length in the main NJT view func (
_nested_view_from_jagged
) so that view replay during fake-ification propagates these correctly in torch.compile.For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in
(val, 0)
shaped tensors.NB: This PR changes SDPA to operate on real views instead of using
buffer_from_jagged()
/ViewNestedFromBuffer
, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.Differential Revision: D55448636
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang