Skip to content

Conversation

jjsjann123
Copy link
Collaborator

  1. enabling linear in autodiff;
  2. remove control flow in python for linear;

1. enabling linear in autodiff;
2. remove control flow in python for linear;
@facebook-github-bot facebook-github-bot added cla signed oncall: jit Add this issue/PR to JIT oncall triage queue labels Jan 21, 2021
@jjsjann123
Copy link
Collaborator Author

Note to myself.

onnx doesn't support linear. 🤦

@ejguan ejguan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 21, 2021
@jjsjann123
Copy link
Collaborator Author

hmmm, so onnx does support linear directly. I tried it with this tutorial: https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

Maybe it's the opset_version? I'll double check the script.

@ngimel ngimel requested review from albanD and eellison January 25, 2021 17:22
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything but the onnx pass looks ok to me.

I don't know how these passes are written so it would be good to have someone familiar with them review at least this part.

last_end = 0

for event in prof.function_events:
if event.name == 'aten::linear':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very arbitrary? Why do you need to special case linear here?
If there is a good reason, it would help to have a comment here explaining it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test here is somewhat hard-coded. top_level_expected_events_and_shapes maps top level profiled kernel.

Right now since we expose aten::linear, the old hard-coded names / shapes are not top level any more. So we are skipping aten::linear. I'll put a note here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, this looks a little funky to me too, but i dont really know anything about it. cc @ilia-cher

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very hacky and arbitrary, please make sure to fix the test the right way

what this test does: runs profiler, gets a list of function (operator) events - note that some ops are called by other ops, then we go through the events and look at the ones that don't have a parent (that is they are the top level events)

your PR as i understand changes the ops and thus changes the output of the profiler

test/test_jit.py Outdated
# script module
jit_t = torch.jit.script(t)

x = x_init.detach().clone().requires_grad_()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the clone is not needed here right? (same on all the lines here)

}
}

static void decomposeLinear(Block* b) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this tested anywhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was to hitting CI error on onnx export, where it complains aten::linear not in the onnx opset.
After adding the decompose pass, the test passed.

I can add a standalone cpp test to verify the decomposition.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want the deccomposition on by default, but we might consider adding it in a more similar way to https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/decompose_ops.cpp#L195. Additionally, if there aren't any other use cases for decomposing linear (and i cant think of any of the top of my head), it might be easier to have this implemented as an onnx decomposition cc @BowenBao

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is inserted in PreprocessForONNX and exposed in python via torch._C._jit_pass_onnx_preprocess.
We didn't put it int decompose_ops.cpp because we don't want the op decomposition for fuser.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I just mean the API for doing decomposition there is a little bit easier. I think this is fine for landing, someone from ONNX might want to rewrite it later.

} // namespace

void PreprocessForONNX(std::shared_ptr<Graph>& graph) {
GRAPH_DUMP("priot to decompose linear", graph);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like debug prints

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will update

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are fine i think

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything that I can read looks good to me.
You just need someone to take a look at the jit pass.

} // namespace

void PreprocessForONNX(std::shared_ptr<Graph>& graph) {
GRAPH_DEBUG("priot to decompose linear", graph);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'm not familiar with these macros. Is it ok to leave them in the release code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is, they don't print by default and we have them elsewhere in the code.

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!! This lgtm but i wanted to get input on the profiler change before landing

}
}

static void decomposeLinear(Block* b) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want the deccomposition on by default, but we might consider adding it in a more similar way to https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/decompose_ops.cpp#L195. Additionally, if there aren't any other use cases for decomposing linear (and i cant think of any of the top of my head), it might be easier to have this implemented as an onnx decomposition cc @BowenBao

} // namespace

void PreprocessForONNX(std::shared_ptr<Graph>& graph) {
GRAPH_DEBUG("priot to decompose linear", graph);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is, they don't print by default and we have them elsewhere in the code.

last_end = 0

for event in prof.function_events:
if event.name == 'aten::linear':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, this looks a little funky to me too, but i dont really know anything about it. cc @ilia-cher

@eellison eellison requested a review from ilia-cher January 26, 2021 21:41
print(prof.function_events)

top_level_expected_events_and_shapes = [
second_level_expected_events_and_shapes = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is "second level"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we explicitly want to profile the shape in transpose and addmm. 🤦
Will update.

last_end = 0

for event in prof.function_events:
if event.name == 'aten::linear':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very hacky and arbitrary, please make sure to fix the test the right way

what this test does: runs profiler, gets a list of function (operator) events - note that some ops are called by other ops, then we go through the events and look at the ones that don't have a parent (that is they are the top level events)

your PR as i understand changes the ops and thus changes the output of the profiler

@jjsjann123 jjsjann123 requested a review from ilia-cher January 27, 2021 18:36
@jjsjann123
Copy link
Collaborator Author

Updated per comment @ilia-cher

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@eellison
Copy link
Contributor

cc @ljk53 @bhosmer will this break internal mobile models that have the decomposed ops manually written out to be included ? I see that both vulkan and and xnn_pack rewrite the ops to aten::linear, so I'm not sure. This is a good example of when it would be nice if the operators were taken from the graph instead of manually specified by users.

@eellison eellison requested review from bhosmer and ljk53 January 29, 2021 01:42
@bhosmer
Copy link

bhosmer commented Jan 29, 2021

cc @ljk53 @bhosmer will this break internal mobile models that have the decomposed ops manually written out to be included ? I see that both vulkan and and xnn_pack rewrite the ops to aten::linear, so I'm not sure. This is a good example of when it would be nice if the operators were taken from the graph instead of manually specified by users.

@eellison good question, and I'm not sure of the answer. @ljk53 @iseeyuan, do you guys know (or know who to ask)?

@eellison
Copy link
Contributor

eellison commented Jan 29, 2021

from offline @iseeyuan, because aten::linear has already existed in builds for a while:

It should be fine for mobile then. We now selectively build torchlib based on operators. If the user has a model, he exported to mobile using old code (model v1). It runs with all the ops emitted from the graph you mentioned above. Now this PR is landed. Model v1 can still run because all the op bits are there. If user re-generates the model with new code (model v2), it has aten::linear, which could also run in both new and old codes, because aten::linear exists before the PR is landed.

Awesome that ops are selectively built now !

Comment on lines +3220 to +3221
('aten::linear', [[128, 20], [30, 20], [30]]),
('aten::linear', [[128, 30], [40, 30], [40]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for updating

@facebook-github-bot
Copy link
Contributor

@eellison merged this pull request in e488e3c.

@ngimel
Copy link
Collaborator

ngimel commented Feb 3, 2021

Reverting, broke a lot of builds

Traceback (most recent call last):
  File "test_mkldnn.py", line 609, in test_linear_backward
    y2.backward()
  File "/opt/conda/lib/python3.6/site-packages/torch/tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py", line 147, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: derivative for mkldnn_linear is not implemented

likely related to recently landed #49453

@facebook-github-bot
Copy link
Contributor

This pull request has been reverted by 26f9ac9.

@jjsjann123
Copy link
Collaborator Author

errr. looks like they just forgot to add an entry in derivatives for mkldnn_linear.
I'll try to update it.

@XiaobingSuper
Copy link
Collaborator

@jjsjann123 , you can apply the following patch to solve it:

diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp
index 6dd563d6b0..b9a9cd5e5a 100644
--- a/aten/src/ATen/native/Linear.cpp
+++ b/aten/src/ATen/native/Linear.cpp
@@ -13,20 +13,6 @@
 
 namespace at { namespace native {
 
-// in order to dispatch mkldnn linear to addmm which can be removed after linear ported.
-Tensor& mkldnn_addmm_wraper_out(Tensor& result, const Tensor& bias,
-    const Tensor& input, const Tensor& weight, Scalar beta, Scalar alpha) {
-  TORCH_CHECK(false,
-      "mkldnn_addmm_wraper_out: in-place mkldnn operations are not supported yet");
-}
-
-Tensor mkldnn_addmm_wraper(const Tensor& bias, const Tensor& input,
-    const Tensor& weight, Scalar beta, Scalar alpha) {
-  TORCH_CHECK(input.dim() == 2,
-      "mkldnn_addmm_wraper: input needs to has dim 2, input dim ", input.dim());
-  return at::mkldnn_linear(input, weight.t(), bias);
-}
-
 Tensor linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
   if (input.is_mkldnn()) {
     return at::mkldnn_linear(input, weight, bias);
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 7b30baceb0..849b23d8a8 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2124,7 +2124,7 @@
   use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
   python_module: nn
 
-- func: mkldnn_linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
+- func: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor
   use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
   python_module: nn
   dispatch:
@@ -4273,7 +4273,6 @@
     CUDA: addmm_out_cuda
     SparseCPU: addmm_out_sparse_dense_cpu
     SparseCUDA: addmm_out_sparse_dense_cuda
-    MkldnnCPU: mkldnn_addmm_wraper_out
 
 - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
   variants: function, method
@@ -4282,7 +4281,6 @@
     CUDA: addmm_cuda
     SparseCPU: addmm_sparse_dense_cpu
     SparseCUDA: addmm_sparse_dense_cuda
-    MkldnnCPU: mkldnn_addmm_wraper
 
 - func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
   variants: method
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 265fe3a8cf..e030a92763 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -187,9 +187,9 @@
   tensor2: handle_r_to_c(tensor2.scalar_type(), grad * (tensor1 * value).conj())
 
 - name: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
-  self: "grad.is_mkldnn() ? grad.to_dense() : maybe_multiply(grad, beta.conj())"
-  mat1: "grad.is_mkldnn() ? mkldnn_linear_backward_input(mat1.sizes(), grad, mat2.t()) : mm_mat1_backward(grad, mat2, mat1.sizes(), mat1.strides(), alpha)"
-  mat2: "grad.is_mkldnn() ? (std::get<0>(mkldnn_linear_backward_weights(grad, mat1, mat2.t(), true))).t() : mm_mat2_backward(grad, mat1, mat2.sizes(), mat2.strides(), alpha)"
+  self: maybe_multiply(grad, beta.conj())
+  mat1: mm_mat1_backward(grad, mat2, mat1.sizes(), mat1.strides(), alpha)
+  mat2: mm_mat2_backward(grad, mat1, mat2.sizes(), mat2.strides(), alpha)
 
 - name: _sparse_addmm(Tensor self, Tensor sparse, Tensor dense, *, Scalar beta=1, Scalar alpha=1) -> Tensor
   self: maybe_multiply(grad, beta)
@@ -1881,6 +1881,9 @@
 - name: mkldnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
   grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, false, false, false, false, grad_input_mask)
 
+- name: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor
+  self, weight, bias: mkldnn_linear_backward(self, grad, weight, grad_input_mask)
+
 # fft
 - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
   self: fft_r2c_backward(grad, dim, normalization, onesided, self.size(dim.back()))

@jjsjann123
Copy link
Collaborator Author

Thanks. I'm rebuilding it with the fix. @XiaobingSuper

facebook-github-bot pushed a commit that referenced this pull request Feb 5, 2021
Summary:
patch PR #50856 and rollbak the revert D26105797 (e488e3c)

Pull Request resolved: #51613

Reviewed By: mruberry

Differential Revision: D26253999

Pulled By: ngimel

fbshipit-source-id: a20b1591de06dd277e4cd95542e3291a2f5a252c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue open source Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants