Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why torch.nn.Linear is split into Transpose and Gemm layers in torch.onnx.export()? #3257

Closed
imai-lm opened this issue Oct 24, 2017 · 16 comments
Labels
module: onnx Related to torch.onnx

Comments

@imai-lm
Copy link

imai-lm commented Oct 24, 2017

I looked into the output of torch.onnx.export() and found that every layers declared as torch.nn.Linear() was split into two layers; Transpose then Gemm. I think it is redundant, because Gemm operator of ONNX has transB attribute, which transposes the second argument.
Why wouldn't you use the attribute and simply translate it to Gemm only?

@apaszke
Copy link
Contributor

apaszke commented Oct 24, 2017

cc: @ezyang @houseroad

@ezyang ezyang added the module: onnx Related to torch.onnx label Oct 24, 2017
@ezyang
Copy link
Contributor

ezyang commented Oct 24, 2017

Yes, this is a known issue. To fix this we'll need to introduce a little optimization pass that fuses transposes into Gemm operators.

@imai-lm
Copy link
Author

imai-lm commented Oct 26, 2017

I have another question: Why haven't you use FC op instead of Gemm? Because it is experimental?

@imai-lm
Copy link
Author

imai-lm commented Oct 26, 2017

I found onnx/onnx#18, where FC was used. I know Gemm was not implemented in ONNX at that time, but why did you change it?

@fmassa
Copy link
Member

fmassa commented Oct 26, 2017

I think the plan is (was?) to remove FC onnx/onnx#47 (comment)

@imai-lm
Copy link
Author

imai-lm commented Oct 26, 2017

@fmassa Oh, I see. Thanks.

@imai-lm
Copy link
Author

imai-lm commented Oct 27, 2017

Now I found there's another op named "expand" between Transpose and Gemm... What's this??

@ezyang
Copy link
Contributor

ezyang commented Oct 29, 2017

It's a bug. Please try #3325

@imai-lm
Copy link
Author

imai-lm commented Oct 30, 2017

@ezyang Sorry, but I couldn't find your PRed branch ezyang:pr/expand-opt on your repo...

@imai-lm
Copy link
Author

imai-lm commented Oct 30, 2017

@imai-lm
Copy link
Author

imai-lm commented Oct 30, 2017

@ezyang I have tried it, but the result was that there was Expand instead of expand... Can you prune it away??

@ezyang
Copy link
Contributor

ezyang commented Oct 30, 2017

@imai-lm So, IIUC, your exported ONNX model had Expand in it (i.e., it didn't error on expand)? That shouldn't happen. Did you recompile the C++ bits? If so, I'll try to repro tomorrow.

@imai-lm
Copy link
Author

imai-lm commented Oct 30, 2017

@ezyang yes, my exported ONNX model had Expand ops in it, and there was no error. I tried it under different environments (Ubuntu, and macOS) with reinstallation and got the same result. What I tried was the following:

from torch.autograd import Variable
import torch.onnx
import torchvision


dummy_input = Variable(torch.randn(10, 3, 224, 224))
model = torchvision.models.alexnet(pretrained=True)
torch.onnx.export(model, dummy_input, "alexnet.proto", verbose=True)

The result model was:

graph(%1 : Float(10, 3, 224, 224)
      %2 : Float(64, 3, 11, 11)
      %3 : Float(64)
      %4 : Float(192, 64, 5, 5)
      %5 : Float(192)
      %6 : Float(384, 192, 3, 3)
      %7 : Float(384)
      %8 : Float(256, 384, 3, 3)
      %9 : Float(256)
      %10 : Float(256, 256, 3, 3)
      %11 : Float(256)
      %12 : Float(4096, 9216)
      %13 : Float(4096)
      %14 : Float(4096, 4096)
      %15 : Float(4096)
      %16 : Float(1000, 4096)
      %17 : Float(1000)) {
  %19 : UNKNOWN_TYPE = Conv[kernel_shape=[11, 11], strides=[4, 4], pads=[2, 2, 2, 2], dilations=[1, 1], group=1](%1, %2), uses = [[%20.i0]];
  %20 : Float(10, 64, 55, 55) = Add[broadcast=1, axis=1](%19, %3), uses = [%21.i0];
  %21 : Float(10, 64, 55, 55) = Relu(%20), uses = [%22.i0];
  %22 : Float(10, 64, 27, 27) = MaxPool[dilations=[1, 1], kernel_shape=[3, 3], pads=[0, 0], strides=[2, 2]](%21), uses = [%23.i0];
  %24 : UNKNOWN_TYPE = Conv[kernel_shape=[5, 5], strides=[1, 1], pads=[2, 2, 2, 2], dilations=[1, 1], group=1](%22, %4), uses = [[%25.i0]];
  %25 : Float(10, 192, 27, 27) = Add[broadcast=1, axis=1](%24, %5), uses = [%26.i0];
  %26 : Float(10, 192, 27, 27) = Relu(%25), uses = [%27.i0];
  %27 : Float(10, 192, 13, 13) = MaxPool[dilations=[1, 1], kernel_shape=[3, 3], pads=[0, 0], strides=[2, 2]](%26), uses = [%28.i0];
  %29 : UNKNOWN_TYPE = Conv[kernel_shape=[3, 3], strides=[1, 1], pads=[1, 1, 1, 1], dilations=[1, 1], group=1](%27, %6), uses = [[%30.i0]];
  %30 : Float(10, 384, 13, 13) = Add[broadcast=1, axis=1](%29, %7), uses = [%31.i0];
  %31 : Float(10, 384, 13, 13) = Relu(%30), uses = [%32.i0];
  %33 : UNKNOWN_TYPE = Conv[kernel_shape=[3, 3], strides=[1, 1], pads=[1, 1, 1, 1], dilations=[1, 1], group=1](%31, %8), uses = [[%34.i0]];
  %34 : Float(10, 256, 13, 13) = Add[broadcast=1, axis=1](%33, %9), uses = [%35.i0];
  %35 : Float(10, 256, 13, 13) = Relu(%34), uses = [%36.i0];
  %37 : UNKNOWN_TYPE = Conv[kernel_shape=[3, 3], strides=[1, 1], pads=[1, 1, 1, 1], dilations=[1, 1], group=1](%35, %10), uses = [[%38.i0]];
  %38 : Float(10, 256, 13, 13) = Add[broadcast=1, axis=1](%37, %11), uses = [%39.i0];
  %39 : Float(10, 256, 13, 13) = Relu(%38), uses = [%40.i0];
  %40 : Float(10, 256, 6, 6) = MaxPool[dilations=[1, 1], kernel_shape=[3, 3], pads=[0, 0], strides=[2, 2]](%39), uses = [%41.i0];
  %41 : Float(10, 9216) = Reshape[shape=[10, 9216]](%40), uses = [%42.i0];
  %43 : Float(10, 9216), %44 : UNKNOWN_TYPE = Dropout[is_test=1, ratio=0.5](%41), uses = [[%47.i0], []];
  %45 : Float(9216!, 4096!) = Transpose[perm=[1, 0]](%12), uses = [%47.i1];
  %46 : Float(10!, 4096) = Expand[shape=[10, 4096]](%13), uses = [%47.i2];
  %47 : Float(10, 4096) = Gemm[alpha=1, beta=1](%43, %45, %46), uses = [%48.i0];
  %48 : Float(10, 4096) = Relu(%47), uses = [%49.i0];
  %50 : Float(10, 4096), %51 : UNKNOWN_TYPE = Dropout[is_test=1, ratio=0.5](%48), uses = [[%54.i0], []];
  %52 : Float(4096!, 4096!) = Transpose[perm=[1, 0]](%14), uses = [%54.i1];
  %53 : Float(10!, 4096) = Expand[shape=[10, 4096]](%15), uses = [%54.i2];
  %54 : Float(10, 4096) = Gemm[alpha=1, beta=1](%50, %52, %53), uses = [%55.i0];
  %55 : Float(10, 4096) = Relu(%54), uses = [%58.i0];
  %56 : Float(4096!, 1000!) = Transpose[perm=[1, 0]](%16), uses = [%58.i1];
  %57 : Float(10!, 1000) = Expand[shape=[10, 1000]](%17), uses = [%58.i2];
  %58 : Float(10, 1000) = Gemm[alpha=1, beta=1](%55, %56, %57), uses = [%0.i0];
  return (%58);
}

@ezyang
Copy link
Contributor

ezyang commented Oct 30, 2017

Reproduced. Looking into it.

@ezyang
Copy link
Contributor

ezyang commented Oct 30, 2017

@imai-lm I just realized what's going on: there's some more post-processing going on after the verbose print, so what you see is not what you actually get. If you look at the actual ONNX file produced it will not have any Expand calls. We'll fix the printing problem shortly.

@imai-lm
Copy link
Author

imai-lm commented Oct 30, 2017

@ezyang I confirmed that with the following code. Actually there were no Expand ops. Thanks.

import onnx


model = onnx.load('alexnet.proto')
print(onnx.helper.printable_graph(model.graph))

@soumith soumith added this to JIT/ATen/ONNX in Issue Categories Dec 1, 2017
@ezyang ezyang closed this as completed Dec 28, 2017
@soumith soumith removed this from JIT/ATen/ONNX in Issue Categories Feb 20, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx
Projects
None yet
Development

No branches or pull requests

4 participants