New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Don't create large intermediary tensors in the backward of matmul #95261
Conversation
Currently, if we multiply a transposed batch of matrices with shape [b, m, n] and a matrix with shape [n, k], when computing the gradient of the matrix, we instantiate a matrix of shape [b, m, n]. If the matrix is large, creating this unnecessary batch of matrices may be time and memory consuming. In this case, we fold the batch of matrices into a matrix, which avoids creating any large intermediary tensor. Note that multiplying a batch of matrices and a matrix naturally occurs within an attention module, so this case surely happens in the wild. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95261
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 2 FailuresAs of commit 3d85e21: BROKEN TRUNK - The following jobs failed but were present on the merge base ac9b305:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm willing to approve this to unblock, but to me all this really says is that we should just write the g-dang backwards formula for matmul by hand and call it a day.
Agree about the backward matmul. Note that backward for matmul is simply
|
… matmul" Currently, if we multiply a transposed batch of matrices with shape [b, m, n] and a matrix with shape [n, k], when computing the gradient of the matrix, we instantiate a matrix of shape [b, n, k]. This may be a very large matrix. Instead, we fold the batch of matrices into a matrix, which avoids creating any large intermediary tensor. Note that multiplying a batch of matrices and a matrix naturally occurs within an attention module, so this case surely happens in the wild. In particular, this issue was found while investigating the OOMs caused by the improved folding algorithm in the next PR of this stack. See #76828 (comment) This PR fixes those OOMs and decreases the memory footprint of the backward of matmul. I understand this is a tricky one, so I put it on its own PR to discuss it. [ghstack-poisoned]
What we can do is, we can try to land this one, and then I'll append a PR onto this stack making this function Regardless, could any of you @ezyang @ngimel import this one? I reckon it may need some internal tweaking to be landed. On a different note, I believe it was this bug we were hitting and it didn't allow us to land #75195. I'll revisit this as well on this same stack. |
… matmul" Currently, if we multiply a transposed batch of matrices with shape [b, m, n] and a matrix with shape [n, k], when computing the gradient of the matrix, we instantiate a matrix of shape [b, n, k]. This may be a very large matrix. Instead, we fold the batch of matrices into a matrix, which avoids creating any large intermediary tensor. Note that multiplying a batch of matrices and a matrix naturally occurs within an attention module, so this case surely happens in the wild. In particular, this issue was found while investigating the OOMs caused by the improved folding algorithm in the next PR of this stack. See #76828 (comment) This PR fixes those OOMs and decreases the memory footprint of the backward of matmul. I understand this is a tricky one, so I put it on its own PR to discuss it. [ghstack-poisoned]
… matmul" Currently, if we multiply a transposed batch of matrices with shape [b, m, n] and a matrix with shape [n, k], when computing the gradient of the matrix, we instantiate a matrix of shape [b, n, k]. This may be a very large matrix. Instead, we fold the batch of matrices into a matrix, which avoids creating any large intermediary tensor. Note that multiplying a batch of matrices and a matrix naturally occurs within an attention module, so this case surely happens in the wild. In particular, this issue was found while investigating the OOMs caused by the improved folding algorithm in the next PR of this stack. See #76828 (comment) This PR fixes those OOMs and decreases the memory footprint of the backward of matmul. I understand this is a tricky one, so I put it on its own PR to discuss it. [ghstack-poisoned]
Note that we already have all the scafolding for matmul to be explicitly differentiable, we even have a matmul_backward native function. Just that only Nested uses it today. What I remember is that backends didn't want to see matmuls indeed. I don't know what is the real BC guarantee here tbh, but maybe we should just make CPU/CUDA follow Nested and let everything else as CompositeImplicit? |
At least in PT2, backends that don't want to see matmul can always use the old decomp, no big deal. |
Wouldn't the backward for matmul as written above still result in materialization of the big tensor if the inputs weren't flattened in forward? Even worse, even if inputs were flattened in forward and matmul went through |
ugh, you are right. Let me think a bit about it. |
… matmul" Currently, if we multiply a transposed batch of matrices with shape [b, m, n] and a matrix with shape [n, k], when computing the gradient of the matrix, we instantiate a matrix of shape [b, n, k]. This may be a very large matrix. Instead, we fold the batch of matrices into a matrix, which avoids creating any large intermediary tensor. Note that multiplying a batch of matrices and a matrix naturally occurs within an attention module, so this case surely happens in the wild. In particular, this issue was found while investigating the OOMs caused by the improved folding algorithm in the next PR of this stack. See #76828 (comment) This PR fixes those OOMs and decreases the memory footprint of the backward of matmul. I understand this is a tricky one, so I put it on its own PR to discuss it. [ghstack-poisoned]
I implemented another implementation that @ngimel suggested and added a test for it. This optimisation could be marginally more general (removing several 1's at the front) but I didn't bother implementing it as I don't think it's that common in practice. I think this PR is ready for review as-is. And again, I think this PR should be landed via phabricator. For the reasons given by @ngimel in #95261 (comment), let's shelve the "make matmul autograd-explicit" for now. There are still optimisations that may be implemented here and there as pytorch/functorch#989 (comment) suggests. I will try to explore those in the future. |
I don't think that is a problem. The AutogradEngine will handle the broadcasting for you if you didn't. But that doesn't mean you're not allow to do it! |
Yeah, the only thing is that one would need to be a bit more careful about how to implement it. If you guys think that it's still worth having a look at it, I'm happy to do so, but I think, for now, this and the next PR do a pretty good job at not creating large tensors / making unnecessary copies. |
Right, but the backward formula as written won't do implicit reduction (by substituting addmm instead of bmm), it will rely on AutogradEngine to do that. This is not to say explicit backward for matmul cannot be implemented, it can be, but unfortunately it's more involved than the neat formula above. |
Ho yes, it would be a quite subtle formula, I agree. |
Also, I just remembered the reason I didn't go for implementing the backward formula manually. The thing is that, in some cases, you may reshape and copy one matrix in the forwards and then, use this reshaped matrix in the backwards. In fact, I believe that if you want to reshape it in the forwards, you want to reshape it in the backwards. If you want to implement the backward by hand, you should create a helper function that returns this intermediary tensor and then use it in the backward to avoid making the same copy in the forward and the backward. And this is just a bit cumbersome (although we do it for some functions like I get the feeling that if the forward is implemented carefully, we should be able to avoid the pain that writing the backward may be in the end. |
@ezyang has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…5261) Currently, if we multiply a transposed batch of matrices with shape [b, m, n] and a matrix with shape [n, k], when computing the gradient of the matrix, we instantiate a matrix of shape [b, n, k]. This may be a very large matrix. Instead, we fold the batch of matrices into a matrix, which avoids creating any large intermediary tensor. Note that multiplying a batch of matrices and a matrix naturally occurs within an attention module, so this case surely happens in the wild. In particular, this issue was found while investigating the OOMs caused by the improved folding algorithm in the next PR of this stack. See pytorch/pytorch#76828 (comment) This PR fixes those OOMs and decreases the memory footprint of the backward of matmul. I understand this is a tricky one, so I put it on its own PR to discuss it. Differential Revision: [D43541495](https://our.internmc.facebook.com/intern/diff/D43541495) Pull Request resolved: pytorch/pytorch#95261 Approved by: https://github.com/ezyang
…5261) Currently, if we multiply a transposed batch of matrices with shape [b, m, n] and a matrix with shape [n, k], when computing the gradient of the matrix, we instantiate a matrix of shape [b, n, k]. This may be a very large matrix. Instead, we fold the batch of matrices into a matrix, which avoids creating any large intermediary tensor. Note that multiplying a batch of matrices and a matrix naturally occurs within an attention module, so this case surely happens in the wild. In particular, this issue was found while investigating the OOMs caused by the improved folding algorithm in the next PR of this stack. See pytorch/pytorch#76828 (comment) This PR fixes those OOMs and decreases the memory footprint of the backward of matmul. I understand this is a tricky one, so I put it on its own PR to discuss it. Differential Revision: [D43541495](https://our.internmc.facebook.com/intern/diff/D43541495) Pull Request resolved: pytorch/pytorch#95261 Approved by: https://github.com/ezyang
…5261) Currently, if we multiply a transposed batch of matrices with shape [b, m, n] and a matrix with shape [n, k], when computing the gradient of the matrix, we instantiate a matrix of shape [b, n, k]. This may be a very large matrix. Instead, we fold the batch of matrices into a matrix, which avoids creating any large intermediary tensor. Note that multiplying a batch of matrices and a matrix naturally occurs within an attention module, so this case surely happens in the wild. In particular, this issue was found while investigating the OOMs caused by the improved folding algorithm in the next PR of this stack. See pytorch/pytorch#76828 (comment) This PR fixes those OOMs and decreases the memory footprint of the backward of matmul. I understand this is a tricky one, so I put it on its own PR to discuss it. Differential Revision: [D43541495](https://our.internmc.facebook.com/intern/diff/D43541495) Pull Request resolved: pytorch/pytorch#95261 Approved by: https://github.com/ezyang
…5261) Currently, if we multiply a transposed batch of matrices with shape [b, m, n] and a matrix with shape [n, k], when computing the gradient of the matrix, we instantiate a matrix of shape [b, n, k]. This may be a very large matrix. Instead, we fold the batch of matrices into a matrix, which avoids creating any large intermediary tensor. Note that multiplying a batch of matrices and a matrix naturally occurs within an attention module, so this case surely happens in the wild. In particular, this issue was found while investigating the OOMs caused by the improved folding algorithm in the next PR of this stack. See pytorch/pytorch#76828 (comment) This PR fixes those OOMs and decreases the memory footprint of the backward of matmul. I understand this is a tricky one, so I put it on its own PR to discuss it. Differential Revision: [D43541495](https://our.internmc.facebook.com/intern/diff/D43541495) Pull Request resolved: pytorch/pytorch#95261 Approved by: https://github.com/ezyang
…5261) Currently, if we multiply a transposed batch of matrices with shape [b, m, n] and a matrix with shape [n, k], when computing the gradient of the matrix, we instantiate a matrix of shape [b, n, k]. This may be a very large matrix. Instead, we fold the batch of matrices into a matrix, which avoids creating any large intermediary tensor. Note that multiplying a batch of matrices and a matrix naturally occurs within an attention module, so this case surely happens in the wild. In particular, this issue was found while investigating the OOMs caused by the improved folding algorithm in the next PR of this stack. See pytorch/pytorch#76828 (comment) This PR fixes those OOMs and decreases the memory footprint of the backward of matmul. I understand this is a tricky one, so I put it on its own PR to discuss it. Differential Revision: [D43541495](https://our.internmc.facebook.com/intern/diff/D43541495) Pull Request resolved: pytorch/pytorch#95261 Approved by: https://github.com/ezyang
…tmul (pytorch#95261)" This reverts commit 03cc0f5.
The decomposition was not updated after #95261 [ghstack-poisoned]
The decomposition was not updated after #95261 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
The decomposition was not updated after #95261 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
The decomposition was not updated after #95261 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
The decomposition was not updated after #95261 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
The decomposition was not updated after #95261 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
The decomposition was not updated after #95261 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
The decomposition was not updated after #95261 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
The decomposition was not updated after #95261 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
The decomposition was not updated after #95261 Pull Request resolved: #105850 Approved by: https://github.com/Chillee
…torch#95261) Currently, if we multiply a transposed batch of matrices with shape [b, m, n] and a matrix with shape [n, k], when computing the gradient of the matrix, we instantiate a matrix of shape [b, n, k]. This may be a very large matrix. Instead, we fold the batch of matrices into a matrix, which avoids creating any large intermediary tensor. Note that multiplying a batch of matrices and a matrix naturally occurs within an attention module, so this case surely happens in the wild. In particular, this issue was found while investigating the OOMs caused by the improved folding algorithm in the next PR of this stack. See pytorch#76828 (comment) This PR fixes those OOMs and decreases the memory footprint of the backward of matmul. I understand this is a tricky one, so I put it on its own PR to discuss it. Differential Revision: [D43541495](https://our.internmc.facebook.com/intern/diff/D43541495) Pull Request resolved: pytorch#95261 Approved by: https://github.com/ezyang
Stack from ghstack (oldest at bottom):
Currently, if we multiply a transposed batch of matrices with shape
[b, m, n] and a matrix with shape [n, k], when computing the gradient
of the matrix, we instantiate a matrix of shape [b, n, k]. This may be
a very large matrix. Instead, we fold the batch of matrices into a
matrix, which avoids creating any large intermediary tensor.
Note that multiplying a batch of matrices and a matrix naturally occurs
within an attention module, so this case surely happens in the wild.
In particular, this issue was found while investigating the OOMs caused by the
improved folding algorithm in the next PR of this stack. See #76828 (comment)
This PR fixes those OOMs and decreases the memory footprint of the
backward of matmul.
I understand this is a tricky one, so I put it on its own PR to discuss it.
Differential Revision: D43541495