Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate PyTorch just-in-time compilation using MKL-DNN #23657

Open
Jianhui-Li opened this issue Aug 1, 2019 · 15 comments
Open

Accelerate PyTorch just-in-time compilation using MKL-DNN #23657

Jianhui-Li opened this issue Aug 1, 2019 · 15 comments
Assignees
Labels
feature A request for a proper, new feature. module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Jianhui-Li
Copy link

🚀 Feature

Accelerate PyTorch just-in-time compilation using MKL-DNN

Motivation

PyTorch's just-in-time (JIT) compiler rewrites and runs Pytorch model at production-efficiency. MKL-DNN is built to accelerate deep learning applications in production environment. With the high performance primitives like conv, rnn, and gemm, MKL-DNN accelerates most deep learning models significantly on multiple Intel CPU generations using AVX2, AVX512, AVX512-VNNI and future deep learning acceleration technology.

With MKL-DNN enabled in JIT compiler, user can use JIT mode to get best performance with MKL-DNN with minimum change of Pytorch code. In imperative mode, user needs to explicitly insert format conversion for MKL-DNN operations using tensor.to_mkldnn() and to_dense(). In JIT mode, user doesn’t have to do so. User may need to pass an explicit flag or invoke a specific MKL-DNN optimization pass. It automatically converts CPU path op to MKL-DNN op, and propagates mkl-dnn format across neighbor MKL-DNN operations. It includes all performance benefits possibly achieved in imperative mode and additional graph optimization.

Pitch

Use PyTorch just-in-time compilation to get MKL-DNN acceleration with one flag (or function call)

Additional context

The MKL-DNN optimization pass includes mkl-dnn format propagation and fusion as initial step. The mkl-dnn formation propagation converts CPU ops to MKL-DNN ops. Format conversion ops are added in-between CPU and MKL-DNN ops.

The implementation of PyTorch MKL-DNN JIT backend will be located in the ‘backend’ directory in JIT sub-directory

@mrshenli mrshenli added feature A request for a proper, new feature. oncall: jit Add this issue/PR to JIT oncall triage queue module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 1, 2019
@gottbrath gottbrath assigned suo, ZolotukhinM, bddppq and jianyuh and unassigned suo and ZolotukhinM Aug 29, 2019
@gottbrath
Copy link
Contributor

Jianhui, do I understand this correctly that this optimization only makes sense for inference use cases?

@Jianhui-Li
Copy link
Author

@gottbrath The optimization we are implementing at this stage is for inference. But the fusion and mkl-dnn format propagation can be extended to work for training as well, which depends on those pending training PRs on imperative mode enabling the MKL-DNN backward operation.

@gottbrath
Copy link
Contributor

In the meeting today you said you had a concern with weights being treated as constants or not. Can you articulate this question/request either here or in another issue? I think we have the right people watching.

@Jianhui-Li
Copy link
Author

Yes. The question is whether you have plan for the support of freezing graph, i.e. mark model weight to be constant to allow constant propagation and subexpression elimination. TF supports this https://www.tensorflow.org/guide/extend/model_files#freezing. User needs to call this explictly to enable optimization for inference.

@CaoZhongZ
Copy link
Contributor

CaoZhongZ commented Aug 30, 2019

We did experiments to 'freeze params', and described the intension and expected results in:
https://gist.github.com/CaoZhongZ/34c2796deef1cc8871039b3d7441f770

Also some code snippets.

@jgong5
Copy link
Collaborator

jgong5 commented Sep 2, 2019

One thing I would like to clarify on "freezing graph" feature support is that it is not specific to MKL-DNN. It is a general feature applicable to all backends.

@gottbrath
Copy link
Contributor

@suo I don't think I've heard "freezing" as a stage in our desired JIT workflow but I can see how knowing that the weights are constant could allow optimizations that wouldn't be possible otherwise. What are your thoughts on this topic?

@soumith
Copy link
Member

soumith commented Sep 4, 2019

@dzhulgakov and I have talked quite a bit about "freezing". It was a topic especially for MKL-DNN packing. Without user-marking code that a weight is locked / frozen, I don't think we have a way to do this across iterations (i.e. we separate the pointers given as inputs from the graph, so we have no guarantees on the pointers corresponding to the same Tensor, or even the same data)

@ZolotukhinM
Copy link

@soumith @dzhulgakov, how about we perform the MKLDNN conversion as a module to module transformation? In this transform we would create a copy of the original module with all weights packed appropriately and all ops rewritten to mkldnn counterparts?

@soumith
Copy link
Member

soumith commented Sep 5, 2019

@ZolotukhinM here's sample code that will fail with program transform:

x = torch.randn(20, 10)
y = torch.nn.Linear(10, 20)
y.weight = x

y = torch.jit.script(y) # let's say MKLDNN transform has been applied
x.fill_(0)
out = y(inp) # WRONG because y will use y.weight at transform time, but that's not what user expects.


@ZolotukhinM
Copy link

ZolotukhinM commented Sep 5, 2019

I would expect it to be slightly different:

x = torch.randn(20, 10)
y = torch.nn.Linear(10, 20)
y.weight = x

y = torch.jit.script(y)
z = to_mkldnn(y)
x.fill_(0)
out = y(inp) # will use the zeroed version of x
out2 = z(inp) # will use the original version of x that was packed at the 'to_mkldnn' step

Do you think it's still confusing?

In other words, I think it's very important for that step to be explicit, otherwise I would absolutely agree that it would be confusing. If we're going to make this step implicit, then I agree with your and Dima's conclusion. However, note that quantization is performing somewhat similar transformation and its API looks like the one I showed here with similar assumptions.

@CaoZhongZ
Copy link
Contributor

Yes, we expect user to explicitly specify which 'params' are constants otherwise there is no safe way we could pre-pack weight for inference. Although we required this feature from MKL-DNN’s perspective however cuDNN would also benefit from it when cudnnReorderFilterAndBias call is possible after user guaranteed they won’t change weight and bias.

@soumith
Copy link
Member

soumith commented Sep 5, 2019

that's not confusing, that's what we need to figure out, a locking / unlocking API

@ZolotukhinM
Copy link

FWIW, I think "freezing" feature is valuable independently on whether we use it for MKLDNN or not. I just think that the conversion might be implemented without it and it would be better aligned with how quantization, another model transformation, is planned to be implemented.

@soumith
Copy link
Member

soumith commented Sep 10, 2019

FWIW, I think "freezing" feature is valuable independently on whether we use it for MKLDNN or not.

Yes, previously it came up in the context of CuDNN / RNN, because CuDNN wanted the weights to be pre-packed in a certain way.

I think as long as we have an undo / unlock mechanism, something like model.freeze() might be worth adding.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

10 participants