-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Precomputing Transposes for frozen linear layers #65631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit fa959e5 (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
import io | ||
import unittest | ||
from itertools import product | ||
from typing import Any | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import unittest | ||
from torch.testing._internal.jit_utils import JitTestCase | ||
from torch._C import parse_ir | ||
|
||
from torch.jit._recursive import wrap_cpp_module | ||
from torch.testing import FileCheck | ||
from torch.testing._internal.common_quantized import override_quantized_engine | ||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM | ||
from torch.testing._internal.common_quantized import override_quantized_engine | ||
from torch.testing._internal.common_utils import set_default_dtype | ||
from torch.testing._internal.jit_utils import JitTestCase | ||
from torch.utils import mkldnn as mkldnn_utils | ||
|
||
|
||
from torch.jit._recursive import wrap_cpp_module | ||
from typing import Any | ||
from itertools import product | ||
import io | ||
|
||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess some linting step I did ran isort
.
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Can you add more comments to this - both to the PR description detailing the benchmarking that was done and also inline into the pass itself ?
Also, i thought we determined this was good for CPU as well ?
if (node->kind() != aten::linear) { | ||
return false; | ||
} | ||
auto weight = node->namedInput("weight"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weight is never none, do you mean bias ? Can we handle the no-bias case as well here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like bias is actually handled
|
||
void handleBlockAndSubblocks(Block* block) { | ||
// Can't delete nodes while also iterating over it | ||
std::vector<Node*> constant_linear_nodes; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can iterate like this https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/constant_propagation.cpp#L396
node->destroy(); | ||
}; | ||
|
||
void handleBlockAndSubblocks(Block* block) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this block is a good use case for
class DepthFirstGraphNodeIterator { |
namespace torch { | ||
namespace jit { | ||
|
||
// Concats multiple linear ops with the same Tensor input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bad copy pasta
|
||
private: | ||
std::shared_ptr<Graph> graph_; | ||
bool graph_modified = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: graph_modified
-> graph_modified_
Given that we determined in the meeting that single core CPU is the main metric to look at (due to prod uses usually running many concurrent processes with 1 core per process), and that multicore CPU didn't significantly regress, I will remove the restriction on only running this on GPU. Link to benchmarking results |
[ghstack-poisoned]
[ghstack-poisoned]
@Gamrix has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Stack from ghstack:
Differential Revision: D31314248