Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't create large intermediary tensors in the backward of matmul #95261

Closed
wants to merge 6 commits into from

Conversation

lezcano
Copy link
Collaborator

@lezcano lezcano commented Feb 22, 2023

Stack from ghstack (oldest at bottom):

Currently, if we multiply a transposed batch of matrices with shape
[b, m, n] and a matrix with shape [n, k], when computing the gradient
of the matrix, we instantiate a matrix of shape [b, n, k]. This may be
a very large matrix. Instead, we fold the batch of matrices into a
matrix, which avoids creating any large intermediary tensor.

Note that multiplying a batch of matrices and a matrix naturally occurs
within an attention module, so this case surely happens in the wild.
In particular, this issue was found while investigating the OOMs caused by the
improved folding algorithm in the next PR of this stack. See #76828 (comment)
This PR fixes those OOMs and decreases the memory footprint of the
backward of matmul.

I understand this is a tricky one, so I put it on its own PR to discuss it.

Differential Revision: D43541495

Currently, if we multiply a transposed batch of matrices with shape
[b, m, n] and a matrix with shape [n, k], when computing the gradient
of the matrix, we instantiate a matrix of shape [b, m, n]. If the matrix
is large, creating this unnecessary batch of matrices may be time and
memory consuming. In this case, we fold the batch of matrices into a
matrix, which avoids creating any large intermediary tensor.

Note that multiplying a batch of matrices and a matrix naturally occurs
within an attention module, so this case surely happens in the wild.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 22, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95261

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 Failures

As of commit 3d85e21:

BROKEN TRUNK - The following jobs failed but were present on the merge base ac9b305:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@lezcano lezcano added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 22, 2023
@lezcano lezcano added the ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR label Feb 22, 2023
Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm willing to approve this to unblock, but to me all this really says is that we should just write the g-dang backwards formula for matmul by hand and call it a day.

@lezcano
Copy link
Collaborator Author

lezcano commented Feb 22, 2023

Agree about the backward matmul. Note that backward for matmul is simply

mat1: grad.matmul(mat2.dim() > 1 ? mat2.mH : mat2.conj().unsqueeze(0))
mat2: (mat1.dim() > 1 ? mat1.mH : mat1.conj().unsqueeze(1)).matmul(grad)
``
I remember implementing this once, but, at least back then, no one wanted matmul to be non-composite.

… matmul"


Currently, if we multiply a transposed batch of matrices with shape
[b, m, n] and a matrix with shape [n, k], when computing the gradient
of the matrix, we instantiate a matrix of shape [b, n, k]. This may be 
a very large matrix. Instead, we fold the batch of matrices into a
matrix, which avoids creating any large intermediary tensor.

Note that multiplying a batch of matrices and a matrix naturally occurs
within an attention module, so this case surely happens in the wild.
In particular, this issue was found while investigating the OOMs caused by the 
improved folding algorithm in the next PR of this stack. See #76828 (comment)
This PR fixes those OOMs and decreases the memory footprint of the
backward of matmul.

I understand this is a tricky one, so I put it on its own PR to discuss it.

[ghstack-poisoned]
@lezcano
Copy link
Collaborator Author

lezcano commented Feb 22, 2023

What we can do is, we can try to land this one, and then I'll append a PR onto this stack making this function CompositeExplicit.

Regardless, could any of you @ezyang @ngimel import this one? I reckon it may need some internal tweaking to be landed.

On a different note, I believe it was this bug we were hitting and it didn't allow us to land #75195. I'll revisit this as well on this same stack.

… matmul"


Currently, if we multiply a transposed batch of matrices with shape
[b, m, n] and a matrix with shape [n, k], when computing the gradient
of the matrix, we instantiate a matrix of shape [b, n, k]. This may be 
a very large matrix. Instead, we fold the batch of matrices into a
matrix, which avoids creating any large intermediary tensor.

Note that multiplying a batch of matrices and a matrix naturally occurs
within an attention module, so this case surely happens in the wild.
In particular, this issue was found while investigating the OOMs caused by the 
improved folding algorithm in the next PR of this stack. See #76828 (comment)
This PR fixes those OOMs and decreases the memory footprint of the
backward of matmul.

I understand this is a tricky one, so I put it on its own PR to discuss it.

[ghstack-poisoned]
… matmul"


Currently, if we multiply a transposed batch of matrices with shape
[b, m, n] and a matrix with shape [n, k], when computing the gradient
of the matrix, we instantiate a matrix of shape [b, n, k]. This may be 
a very large matrix. Instead, we fold the batch of matrices into a
matrix, which avoids creating any large intermediary tensor.

Note that multiplying a batch of matrices and a matrix naturally occurs
within an attention module, so this case surely happens in the wild.
In particular, this issue was found while investigating the OOMs caused by the 
improved folding algorithm in the next PR of this stack. See #76828 (comment)
This PR fixes those OOMs and decreases the memory footprint of the
backward of matmul.

I understand this is a tricky one, so I put it on its own PR to discuss it.

[ghstack-poisoned]
@albanD
Copy link
Collaborator

albanD commented Feb 22, 2023

I'm willing to approve this to unblock, but to me all this really says is that we should just write the g-dang backwards formula for matmul by hand and call it a day.

Note that we already have all the scafolding for matmul to be explicitly differentiable, we even have a matmul_backward native function. Just that only Nested uses it today.

What I remember is that backends didn't want to see matmuls indeed. I don't know what is the real BC guarantee here tbh, but maybe we should just make CPU/CUDA follow Nested and let everything else as CompositeImplicit?

@ezyang
Copy link
Contributor

ezyang commented Feb 22, 2023

At least in PT2, backends that don't want to see matmul can always use the old decomp, no big deal.

@ngimel
Copy link
Collaborator

ngimel commented Feb 22, 2023

Wouldn't the backward for matmul as written above still result in materialization of the big tensor if the inputs weren't flattened in forward? Even worse, even if inputs were flattened in forward and matmul went through addmm the above formula would still result in 3d gradient for 2d input that will be reduced by AutogradEngine?

@lezcano
Copy link
Collaborator Author

lezcano commented Feb 22, 2023

ugh, you are right. Let me think a bit about it.

… matmul"


Currently, if we multiply a transposed batch of matrices with shape
[b, m, n] and a matrix with shape [n, k], when computing the gradient
of the matrix, we instantiate a matrix of shape [b, n, k]. This may be 
a very large matrix. Instead, we fold the batch of matrices into a
matrix, which avoids creating any large intermediary tensor.

Note that multiplying a batch of matrices and a matrix naturally occurs
within an attention module, so this case surely happens in the wild.
In particular, this issue was found while investigating the OOMs caused by the 
improved folding algorithm in the next PR of this stack. See #76828 (comment)
This PR fixes those OOMs and decreases the memory footprint of the
backward of matmul.

I understand this is a tricky one, so I put it on its own PR to discuss it.

[ghstack-poisoned]
@lezcano
Copy link
Collaborator Author

lezcano commented Feb 23, 2023

I implemented another implementation that @ngimel suggested and added a test for it. This optimisation could be marginally more general (removing several 1's at the front) but I didn't bother implementing it as I don't think it's that common in practice.

I think this PR is ready for review as-is. And again, I think this PR should be landed via phabricator.

For the reasons given by @ngimel in #95261 (comment), let's shelve the "make matmul autograd-explicit" for now.

There are still optimisations that may be implemented here and there as pytorch/functorch#989 (comment) suggests. I will try to explore those in the future.

@albanD
Copy link
Collaborator

albanD commented Feb 23, 2023

Even worse, even if inputs were flattened in forward and matmul went through addmm the above formula would still result in 3d gradient for 2d input that will be reduced by AutogradEngine?

I don't think that is a problem. The AutogradEngine will handle the broadcasting for you if you didn't. But that doesn't mean you're not allow to do it!
So if you're differentiating matmul direction (and thus the expand is part of your functions's forward), then it is fine for your backward to reverse that expand.

@lezcano
Copy link
Collaborator Author

lezcano commented Feb 23, 2023

Yeah, the only thing is that one would need to be a bit more careful about how to implement it. If you guys think that it's still worth having a look at it, I'm happy to do so, but I think, for now, this and the next PR do a pretty good job at not creating large tensors / making unnecessary copies.

@ngimel
Copy link
Collaborator

ngimel commented Feb 23, 2023

Right, but the backward formula as written won't do implicit reduction (by substituting addmm instead of bmm), it will rely on AutogradEngine to do that. This is not to say explicit backward for matmul cannot be implemented, it can be, but unfortunately it's more involved than the neat formula above.

@albanD
Copy link
Collaborator

albanD commented Feb 23, 2023

Ho yes, it would be a quite subtle formula, I agree.

@lezcano
Copy link
Collaborator Author

lezcano commented Feb 23, 2023

Also, I just remembered the reason I didn't go for implementing the backward formula manually. The thing is that, in some cases, you may reshape and copy one matrix in the forwards and then, use this reshaped matrix in the backwards. In fact, I believe that if you want to reshape it in the forwards, you want to reshape it in the backwards. If you want to implement the backward by hand, you should create a helper function that returns this intermediary tensor and then use it in the backward to avoid making the same copy in the forward and the backward. And this is just a bit cumbersome (although we do it for some functions like linalg.solve).

I get the feeling that if the forward is implemented carefully, we should be able to avoid the pain that writing the backward may be in the end.

@ezyang
Copy link
Contributor

ezyang commented Feb 23, 2023

@ezyang has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 1, 2023
…5261)

Currently, if we multiply a transposed batch of matrices with shape
[b, m, n] and a matrix with shape [n, k], when computing the gradient
of the matrix, we instantiate a matrix of shape [b, n, k]. This may be
a very large matrix. Instead, we fold the batch of matrices into a
matrix, which avoids creating any large intermediary tensor.

Note that multiplying a batch of matrices and a matrix naturally occurs
within an attention module, so this case surely happens in the wild.
In particular, this issue was found while investigating the OOMs caused by the
improved folding algorithm in the next PR of this stack. See pytorch/pytorch#76828 (comment)
This PR fixes those OOMs and decreases the memory footprint of the
backward of matmul.

I understand this is a tricky one, so I put it on its own PR to discuss it.

Differential Revision: [D43541495](https://our.internmc.facebook.com/intern/diff/D43541495)
Pull Request resolved: pytorch/pytorch#95261
Approved by: https://github.com/ezyang
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 2, 2023
…5261)

Currently, if we multiply a transposed batch of matrices with shape
[b, m, n] and a matrix with shape [n, k], when computing the gradient
of the matrix, we instantiate a matrix of shape [b, n, k]. This may be
a very large matrix. Instead, we fold the batch of matrices into a
matrix, which avoids creating any large intermediary tensor.

Note that multiplying a batch of matrices and a matrix naturally occurs
within an attention module, so this case surely happens in the wild.
In particular, this issue was found while investigating the OOMs caused by the
improved folding algorithm in the next PR of this stack. See pytorch/pytorch#76828 (comment)
This PR fixes those OOMs and decreases the memory footprint of the
backward of matmul.

I understand this is a tricky one, so I put it on its own PR to discuss it.

Differential Revision: [D43541495](https://our.internmc.facebook.com/intern/diff/D43541495)
Pull Request resolved: pytorch/pytorch#95261
Approved by: https://github.com/ezyang
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
…5261)

Currently, if we multiply a transposed batch of matrices with shape
[b, m, n] and a matrix with shape [n, k], when computing the gradient
of the matrix, we instantiate a matrix of shape [b, n, k]. This may be
a very large matrix. Instead, we fold the batch of matrices into a
matrix, which avoids creating any large intermediary tensor.

Note that multiplying a batch of matrices and a matrix naturally occurs
within an attention module, so this case surely happens in the wild.
In particular, this issue was found while investigating the OOMs caused by the
improved folding algorithm in the next PR of this stack. See pytorch/pytorch#76828 (comment)
This PR fixes those OOMs and decreases the memory footprint of the
backward of matmul.

I understand this is a tricky one, so I put it on its own PR to discuss it.

Differential Revision: [D43541495](https://our.internmc.facebook.com/intern/diff/D43541495)
Pull Request resolved: pytorch/pytorch#95261
Approved by: https://github.com/ezyang
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
…5261)

Currently, if we multiply a transposed batch of matrices with shape
[b, m, n] and a matrix with shape [n, k], when computing the gradient
of the matrix, we instantiate a matrix of shape [b, n, k]. This may be
a very large matrix. Instead, we fold the batch of matrices into a
matrix, which avoids creating any large intermediary tensor.

Note that multiplying a batch of matrices and a matrix naturally occurs
within an attention module, so this case surely happens in the wild.
In particular, this issue was found while investigating the OOMs caused by the
improved folding algorithm in the next PR of this stack. See pytorch/pytorch#76828 (comment)
This PR fixes those OOMs and decreases the memory footprint of the
backward of matmul.

I understand this is a tricky one, so I put it on its own PR to discuss it.

Differential Revision: [D43541495](https://our.internmc.facebook.com/intern/diff/D43541495)
Pull Request resolved: pytorch/pytorch#95261
Approved by: https://github.com/ezyang
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 27, 2023
…5261)

Currently, if we multiply a transposed batch of matrices with shape
[b, m, n] and a matrix with shape [n, k], when computing the gradient
of the matrix, we instantiate a matrix of shape [b, n, k]. This may be
a very large matrix. Instead, we fold the batch of matrices into a
matrix, which avoids creating any large intermediary tensor.

Note that multiplying a batch of matrices and a matrix naturally occurs
within an attention module, so this case surely happens in the wild.
In particular, this issue was found while investigating the OOMs caused by the
improved folding algorithm in the next PR of this stack. See pytorch/pytorch#76828 (comment)
This PR fixes those OOMs and decreases the memory footprint of the
backward of matmul.

I understand this is a tricky one, so I put it on its own PR to discuss it.

Differential Revision: [D43541495](https://our.internmc.facebook.com/intern/diff/D43541495)
Pull Request resolved: pytorch/pytorch#95261
Approved by: https://github.com/ezyang
pruthvistony added a commit to ROCm/pytorch that referenced this pull request May 2, 2023
@facebook-github-bot facebook-github-bot deleted the gh/Lezcano/185/head branch June 8, 2023 14:44
lezcano added a commit that referenced this pull request Jul 24, 2023
The decomposition was not updated after #95261

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 24, 2023
The decomposition was not updated after #95261

ghstack-source-id: 458877f76eee5499064fef330be058889a2945b0
Pull Request resolved: #105850
lezcano added a commit that referenced this pull request Jul 24, 2023
The decomposition was not updated after #95261

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 24, 2023
The decomposition was not updated after #95261

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 24, 2023
The decomposition was not updated after #95261

ghstack-source-id: 23ec5eb3972cd89b798dfb33b5d58191bab4688b
Pull Request resolved: #105850
lezcano added a commit that referenced this pull request Jul 25, 2023
The decomposition was not updated after #95261

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 25, 2023
The decomposition was not updated after #95261

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 25, 2023
The decomposition was not updated after #95261

ghstack-source-id: 9a4fbd7f6ae92f272862ec22e1e8de05c83ba198
Pull Request resolved: #105850
lezcano added a commit that referenced this pull request Jul 25, 2023
The decomposition was not updated after #95261

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 25, 2023
The decomposition was not updated after #95261

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 25, 2023
The decomposition was not updated after #95261

ghstack-source-id: 9125d9e23bc2e9b8f1f15dd5808e50afa01db012
Pull Request resolved: #105850
lezcano added a commit that referenced this pull request Jul 25, 2023
The decomposition was not updated after #95261

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 25, 2023
The decomposition was not updated after #95261

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 25, 2023
The decomposition was not updated after #95261

ghstack-source-id: 32f31e31ee16c16ea3f7ee93b7a57c1f04f2e951
Pull Request resolved: #105850
pytorchmergebot pushed a commit that referenced this pull request Jul 26, 2023
The decomposition was not updated after #95261

Pull Request resolved: #105850
Approved by: https://github.com/Chillee
jhavukainen pushed a commit to kulinseth/pytorch that referenced this pull request Mar 15, 2024
…torch#95261)

Currently, if we multiply a transposed batch of matrices with shape
[b, m, n] and a matrix with shape [n, k], when computing the gradient
of the matrix, we instantiate a matrix of shape [b, n, k]. This may be
a very large matrix. Instead, we fold the batch of matrices into a
matrix, which avoids creating any large intermediary tensor.

Note that multiplying a batch of matrices and a matrix naturally occurs
within an attention module, so this case surely happens in the wild.
In particular, this issue was found while investigating the OOMs caused by the
improved folding algorithm in the next PR of this stack. See pytorch#76828 (comment)
This PR fixes those OOMs and decreases the memory footprint of the
backward of matmul.

I understand this is a tricky one, so I put it on its own PR to discuss it.

Differential Revision: [D43541495](https://our.internmc.facebook.com/intern/diff/D43541495)
Pull Request resolved: pytorch#95261
Approved by: https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: linalg_frontend release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants