Skip to content

Conversation

jwieczorekhabana
Copy link
Contributor

@jwieczorekhabana jwieczorekhabana commented Aug 27, 2024

When tensor folding occurs during matmul operation returned tensor is a view. This can cause issues when matmul is used inside a custom function and such view is then returned as output. Then it cannot be modified inplace and causes errors.
It can be especially problematic when after such function inplace allreduce is performed.
Issue is resolved when unsafe_view is returned from matmul instead. This solution aligns matmul decomposition with eager implementation in such a way that a non view tensor is returned.

Test included in this PR reproduces the issue.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec

Copy link

pytorch-bot bot commented Aug 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134568

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d155a11 with merge base 41e6534 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Aug 27, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

@xinyu-intel
Copy link
Contributor

@jgong5 @EikanWang Hi, can you help on review the PR?

@jwieczorekhabana
Copy link
Contributor Author

Error thrown:
image

@albanD albanD requested a review from zou3519 August 27, 2024 14:41
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 27, 2024
@jwieczorekhabana
Copy link
Contributor Author

@zou3519 Hi I have linter patch ready. In the error messages I've also seen note about adding test owner. From what I've seen in similar test it's just '# Owner(s): ["module: unknown"]. I'm also wondering whether it's the right place for this kind of test or you might already have a more fitted suite somewhere else. Could you give any hints on that?

Comment on lines 7 to 17
class TestCustomFunction(TestCase):
def test_autograd_function_with_matmul_folding_at_output(self):
"""
When tensor folding occurs during matmul operation returned tensor is a view.
This can cause issues when matmul is used inside a custom function
and such view is then returned as output. Then it cannot be modified inplace
and causes errors.
It can be especially problematic when after such function inplace allreduce
is performed. This test recreates this behaviour.
Issue is resolved when unsafe_view is returned from matmul instead.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to test/dynamo/test_misc.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code change makes sense to me. Let's move the test to an existing file rather than put it in its own testcase.

@xinyu-intel
Copy link
Contributor

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Sep 2, 2024
@jwieczorekhabana
Copy link
Contributor Author

@pytorchbot rebase

Copy link

pytorch-bot bot commented Sep 5, 2024

You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra.

@zou3519
Copy link
Contributor

zou3519 commented Sep 5, 2024

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

When tensor folding occurs during matmul operation returned tensor is a view.
This can cause issues when matmul is used inside a custom function
and such view is then returned as output. Then it cannot be modified inplace
and causes errors.
It can be especially problematic when after such function inplace allreduce
is performed.
Issue is resolved when unsafe_view is returned from matmul instead.
This solution aligns matmul decomposition with eager implementation
in such a way that a non view tensor is returned.
- Removed return types from forward/backward functions in test_custom_function
  to be compatible with python 3.8
- Updated graph in test_proxy_tensor test_reflect_r_over_x
  os that _unsafe_view is added instead of view_4 after mm
@pytorchmergebot
Copy link
Collaborator

Successfully rebased main onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout main && git pull --rebase)

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lint failing

@jwieczorekhabana
Copy link
Contributor Author

lint failing

done

@zou3519 zou3519 added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 18, 2024
@jwieczorekhabana
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…ytorch#134568)

When tensor folding occurs during matmul operation returned tensor is a view. This can cause issues when matmul is used inside a custom function and such view is then returned as output. Then it cannot be modified inplace and causes errors.
It can be especially problematic when after such function inplace allreduce is performed.
Issue is resolved when unsafe_view is returned from matmul instead. This solution aligns matmul decomposition with eager implementation in such a way that a non view tensor is returned.

Test included in this PR reproduces the issue.

Pull Request resolved: pytorch#134568
Approved by: https://github.com/zou3519
mryszt pushed a commit to HabanaAI/pytorch-fork that referenced this pull request Oct 14, 2024
When tensor folding occurs during matmul operation returned tensor is a view.
This can cause issues when matmul is used inside a custom function
and such view is then returned as output. Then it cannot be modified inplace
and causes errors.
It can be especially problematic when after such function inplace allreduce
is performed.
Issue is resolved when unsafe_view is returned from matmul instead.
This solution aligns matmul decomposition with eager implementation
in such a way that a non view tensor is returned.

Pull request openned to pytorch pytorch#134568

Change-Id: I77484ff6f22d3e290352348b1acbffa267eb063b
mryszt pushed a commit to HabanaAI/pytorch-fork that referenced this pull request Oct 14, 2024
When tensor folding occurs during matmul operation returned tensor is a view.
This can cause issues when matmul is used inside a custom function
and such view is then returned as output. Then it cannot be modified inplace
and causes errors.
It can be especially problematic when after such function inplace allreduce
is performed.
Issue is resolved when unsafe_view is returned from matmul instead.
This solution aligns matmul decomposition with eager implementation
in such a way that a non view tensor is returned.

Pull request openned to pytorch pytorch#134568

Change-Id: I77484ff6f22d3e290352348b1acbffa267eb063b
aostrowski-hbn pushed a commit to HabanaAI/pytorch-fork that referenced this pull request Oct 16, 2024
When tensor folding occurs during matmul operation returned tensor is a view.
This can cause issues when matmul is used inside a custom function
and such view is then returned as output. Then it cannot be modified inplace
and causes errors.
It can be especially problematic when after such function inplace allreduce
is performed.
Issue is resolved when unsafe_view is returned from matmul instead.
This solution aligns matmul decomposition with eager implementation
in such a way that a non view tensor is returned.

Pull request openned to pytorch pytorch#134568

Change-Id: I77484ff6f22d3e290352348b1acbffa267eb063b
aostrowski-hbn pushed a commit to HabanaAI/pytorch-fork that referenced this pull request Jan 7, 2025
When tensor folding occurs during matmul operation returned tensor is a view.
This can cause issues when matmul is used inside a custom function
and such view is then returned as output. Then it cannot be modified inplace
and causes errors.
It can be especially problematic when after such function inplace allreduce
is performed.
Issue is resolved when unsafe_view is returned from matmul instead.
This solution aligns matmul decomposition with eager implementation
in such a way that a non view tensor is returned.

Pull request openned to pytorch pytorch#134568

Change-Id: I77484ff6f22d3e290352348b1acbffa267eb063b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants