Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Short-term fix to preserve NJT metadata cache in torch.compile #122836

Closed
wants to merge 47 commits into from

Conversation

jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented Mar 27, 2024

Stack from ghstack (oldest at bottom):

Idea: close over min / max sequence length in the main NJT view func (_nested_view_from_jagged) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in (val, 0) shaped tensors.

NB: This PR changes SDPA to operate on real views instead of using buffer_from_jagged() / ViewNestedFromBuffer, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.

Differential Revision: D55448636

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang

Copy link

pytorch-bot bot commented Mar 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122836

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 3 New Failures, 8 Unrelated Failures

As of commit 6ea26c9 with merge base ea47d54 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@davidberard98
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -6200,6 +6200,16 @@
device_check: NoCheck
dispatch: {}

- func: _nested_get_min_seqlen(Tensor self) -> int
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these need to be symbolic to work with dynamic shapes. TBD exploring this

Copy link
Contributor Author

@jbschlosser jbschlosser Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs more thought / design work to fully cover dynamic shapes. @YuqingJ mentioned to me that a static max length is being used for now, so maybe we can get a patch to work for FIRST without this, but it's iffy.

@clee2000
Copy link
Contributor

clee2000 commented Apr 2, 2024

@clee2000 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jbschlosser
Copy link
Contributor Author

@jbschlosser has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

[ghstack-poisoned]
@jbschlosser jbschlosser added the topic: not user facing topic category label May 1, 2024
[ghstack-poisoned]
return torch.ones_like(nt) * nt.max_seqlen()

for dynamic in [False, True]:
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jfyi there's with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): though _recompiles_for_inputs seems generally useful if you also want to make sure cases where recompiles=True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I lifted _recompiles_for_inputs() from some other test file, which isn't great :p

test/test_nestedtensor.py Outdated Show resolved Hide resolved
@@ -4339,6 +4353,115 @@ def forward(self, query, value, offsets):
self.assertTrue(torch.allclose(attn_output_eager, attn_output))
self.assertTrue(torch.allclose(value_grad, value.grad))

@dtypes(torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know if this will work with a max_seqlen manually assigned in the graph? (and if so, can you add a test / or add as a followup)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good Q! this works with manual assignment as done in FIRST (i.e. via convert_jagged_to_nested_tensor() -> ViewNestedFromBuffer with the metadata cache info). Also see test_dummy_mha_with_nt which tries to match as closely as possible to what the FIRST model does.

We should arguably support and test manual max seq len assignment for the new API. I'll address this or add a TODO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: I think I'll boot handling manual assignment outside of the FIRST use case until we can have public @property getters / setters for min / max sequence length that work with PT2 (this is in progress)

[ghstack-poisoned]
@jbschlosser
Copy link
Contributor Author

Just FYI, we have a small tool to detect dependencies here https://github.com/pytorch/test-infra/actions/workflows/pr-dependencies-check.yml. I run the job with this PR and get the following list of conflicts https://github.com/pytorch/test-infra/actions/runs/9573995289/attempts/1#summary-26396464107. This is an easy case because they are all in your stack, but the tool might be useful to find conflicts from elsewhere (if it happens)

wow that's super useful! thanks for the link :)

@jbschlosser
Copy link
Contributor Author

Confirmed that the memory leak issues are pre-existing and not introduced by this PR. They need more investigation, but I'm not even convinced these are NJT-related. For example, I saw a memory leak from running this minimal, non-NJT example with PYTORCH_TEST_CUDA_MEM_LEAK_CHECK=1:

def test_memory_leak(self, device):
    value = torch.randn(6, 4, requires_grad=True, device=device)
    m = torch.nn.Linear(4, 6, device=device)
    symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(m)
    m = torch.compile(symbolic_traced)
    m(value)

I'm going to go ahead and reland this with the ROCm / DEBUG=1 assert fixes added.

…pile"


Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Jun 20, 2024
ghstack-source-id: 853d357d36ed3edd4a831b1eccad625d53ac73ff
Pull Request resolved: #122836
@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@jbschlosser
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: pull / linux-jammy-py3.8-gcc11 / test (distributed, 1, 2, linux.2xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@jbschlosser
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: pull / linux-jammy-py3.8-gcc11 / test (distributed, 1, 2, linux.2xlarge), pull / linux-focal-cuda12.1-py3.10-gcc9-sm86 / test (default, 4, 5, linux.g5.4xlarge.nvidia.gpu), trunk / linux-focal-rocm6.1-py3.8 / test (default, 2, 2, linux.rocm.gpu)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: inductor / cuda12.1-py3.10-gcc9-sm86 / test (inductor, 1, 1, linux.g5.4xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@jbschlosser
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants