-
Notifications
You must be signed in to change notification settings - Fork 22.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add aten mkldnn linear operator #19210
Conversation
Differential Revision: D14901641 Differential Version: 79209392
Differential Revision: D14901641 Differential Version: 79250212
Differential Revision: D14901641 Differential Version: 79296738
Differential Revision: D14901641 Differential Version: 79299281
Differential Revision: D14901641 Differential Version: 79300548
Differential Revision: D14901641 Differential Version: 79320912
Differential Revision: D14901641 Differential Version: 79329883
Differential Revision: D14901641 Differential Version: 79689943
Differential Revision: D14901641 Differential Version: 79698400
Differential Revision: D14901641 Differential Version: 79729547
Differential Revision: D14901641 Differential Version: 79814099
Differential Revision: D14901641 Differential Version: 80386619
Differential Revision: D14901641 Differential Version: 80533694
Differential Revision: D14901641 Differential Version: 80534302
Differential Revision: D14901641 Differential Version: 80534909
Differential Revision: D14901641 Differential Version: 80541659
Differential Revision: D14901641 Differential Version: 80543785
Differential Revision: D14901641 Differential Version: 80564048
Differential Revision: D14901641 Differential Version: 80580145
Differential Revision: D14901641 Differential Version: 80586585
@zdevito @suo @ZolotukhinM In this PR I'm changing nn.Linear to directly call c++ aten::linear (which underlyingly dispatches to the same addmm/matmul call), but the difference is nn.Linear's torchscript now will directly capture linear instead of addmm/matmul. (In this PR only scripting get changed, tracing is not unaffected, but I would like to change tracing as well in a follow up PR) I think this is the right change since the torchscript now get simplified from
to
And it makes other backends like mkldnn/conversion works like work's life easier to support linear. |
I'm not sure that we should do this. If you run the scripted function you get the following result:
Which outputs the graph
What is the end goal here ? If we continually create new ops which were previously composed of existing ones, we'll be putting more and more stress on various parts of the system -shape analysis, the graph fuser, TVM, etc. I don't have a good context on mkldnn interop so there may be valid reasons there. But as far as torchscript I'm not really convinced. |
It is totally reasonable that there would be a built in linear operator. Tons of libraries already have it as optimized fused thing. We should add it. However, @eellison is right: anytime a fused op is added, it is also necessary to change all of the JIT analysis and optimization passes so that they still work in this world. In the case of Linear, this is change is almost certainly breaking matrix multiply optimizations that the JIT does because of missing formulas for shape propagation, differentiation, and others. This is an invasive change on a very important operator, so it deserves careful consideration about what possibly is breaking by hiding its implementation from optimization passes. |
Differential Revision: D14901641 Differential Version: 80683620
@zdevito - can you please help to come up with change plan for this change in JIT. As you said, switching to at:linear does make sense. I guess grepping for aten::addmm should be a good starting point. One option is to land this PR first as it switches only scripting and follow up with a separate PR for doing tracing and proper JIT changes. Would it be reasonable? |
Differential Revision: D14901641 Differential Version: 80773730
Differential Revision: D14901641 Differential Version: 80780048
I added a workaround that overrides the forward method of nn.Linear in case in to_mkldnn. Will do the linear dispatch change in a separate diff to unblock mkldnn. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a hack with the follow up later - it looks good
@@ -40,6 +44,12 @@ Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) { | |||
return new_with_itensor_mkldnn(std::move(y), self.options()); | |||
} | |||
|
|||
Tensor mkldnn_clone(const Tensor& self) { | |||
ideep::tensor& src = itensor_from_mkldnn(self); | |||
ideep::tensor dst{src}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this does make a copy, right? (it's supposed to)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Differential Revision: D14901641 Differential Version: 80785115
Differential Revision: D14901641 Differential Version: 80799730
Summary: Pull Request resolved: pytorch/pytorch#19210 Reviewed By: dzhulgakov Differential Revision: D14901641 fbshipit-source-id: 8fa68b9941fd93cea0f313a828cba34c5c81ae11
This pull request has been merged in c9f380d. |
Summary: Pull Request resolved: pytorch#19210 Reviewed By: dzhulgakov Differential Revision: D14901641 fbshipit-source-id: 8fa68b9941fd93cea0f313a828cba34c5c81ae11
Stack:
:white_circle: #19633 Add is_mkldnn to at::Tensor 💚
:white_circle: #19204 Add aten mkldnn conv2d operator 💚
:white_circle: #19205 Add aten mkldnn ops: relu, max_pool2d and avg_pool2d 💚
:white_circle: #19206 Add aten mkldnn batch_norm operator 💚
:white_circle: #19207 Add aten mkldnn add operator 💚
:white_circle: #19209 Add aten mkldnn view operator 💚
:black_circle: #19210 Add aten mkldnn linear operator 💚
:white_circle: #19648 Adjust resnext run script 💚
Pull Request resolved: #19210
Differential Revision: D14901641