-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Add onnx support for InstanceNorm #4626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
74e65e9
to
d840c65
Compare
d840c65
to
ea90c81
Compare
Address #4584, it works end 2 end so far. But some unused tensors got exported, too. Will fix it soon. |
ea90c81
to
dad96d4
Compare
c865552
to
72f71e0
Compare
Please ignore the failed onnx-fb-universe CI, I already updated the expected files in onnxbot/onnx-fb-universe#268 |
torch/nn/functional.py
Outdated
eps=eps, affine=affine) | ||
|
||
|
||
def instance_norm(input, weight, bias, saved_running_mean, saved_running_var, |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/export.cpp
Outdated
// The jit tracer may record some unnecessary information. | ||
// We should remove the unused input and initializer. | ||
// Gather all the names of useful nodes and filter out unused ones. | ||
std::unordered_set<std::string> useful_names; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/export.cpp
Outdated
for (auto input : g->inputs()) { | ||
input_names.push_back(value_name(input)); | ||
if (useful_names.find(value_name(input)) == useful_names.end()) { | ||
continue; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
CC @linkerzhang |
72f71e0
to
579be41
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, when it lands I can also upstream some changes from https://github.com/dzhulgakov/fast-neural-style to pytorch/examples (onnx exporter)
torch/csrc/jit/export.cpp
Outdated
for (auto input : g->inputs()) { | ||
input_names.push_back(value_name(input)); | ||
if (useful_names.find(value_name(input)) == useful_names.end()) { | ||
continue; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/onnx/symbolic.py
Outdated
@@ -619,3 +619,15 @@ def symbolic(g, input, all_weights, h0, **fkwargs): | |||
return prev_output, h_outs | |||
|
|||
return torch.onnx.symbolic_override(symbolic) | |||
|
|||
|
|||
def instance_norm_symbolic_builder(*args, **kwargs): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/functional.py
Outdated
""" | ||
import torch | ||
func = instance_norm | ||
if torch._C._jit_is_tracing(input): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/onnx/__init__.py
Outdated
flat_args = tuple(function._iter_variables(args)) | ||
if not any(map(torch._C._jit_is_tracing, flat_args)): | ||
flat_args = tuple(function._iter_None_variables(args)) | ||
if not any(map(lambda x: False if x is None else torch._C._jit_is_tracing(x), flat_args)): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
I'm on the hook for finishing this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For mapping the tensor inputs - I wish we could find a better solution.
torch/onnx/symbolic.py
Outdated
@@ -619,3 +619,15 @@ def symbolic(g, input, all_weights, h0, **fkwargs): | |||
return prev_output, h_outs | |||
|
|||
return torch.onnx.symbolic_override(symbolic) | |||
|
|||
|
|||
def instance_norm_symbolic_builder(*args, **kwargs): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
6f2776d
to
60abfe9
Compare
torch/csrc/jit/export.cpp
Outdated
@@ -215,7 +236,10 @@ void encodeGraph(onnx::GraphProto * p_g, const std::shared_ptr<Graph> & g, const | |||
for (auto & tensor : initializers) { | |||
// TODO: stop using positions to determine which initializers | |||
// match to which inputs | |||
std::string name = p_g->get_input_name(inputs_count++); | |||
std::string name = input_names[inputs_count++]; | |||
if (useful_names.find(name) == useful_names.end()) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/export.cpp
Outdated
@@ -181,7 +181,28 @@ void encodeGraph(onnx::GraphProto * p_g, const std::shared_ptr<Graph> & g, const | |||
JIT_ASSERT(p_g != nullptr); | |||
p_g->set_name("torch-jit-export"); | |||
|
|||
// The jit tracer may record some unnecessary information. | |||
// We should remove the unused input and initializer. | |||
// Gather all the names of useful nodes and filter out unused ones. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
60abfe9
to
6f870f5
Compare
After reading this post, I tried to install pytorch by source and the onnx does work on instanceNorm2D now. Thanks. However, I encountered some other problems. I trained some pytorch model using the pytorch version installed through http://pytorch.org/. I was able to save the trained model as .pth file and load it back in successfully by doing the following:
However, after I install the latest pytorch from source. The load_state_dict is reporting errors:
I did the following to check which layer was causing the problem
I found that for the instanceNorm2D layers, model2 contains xxx.running_mean and xxx.running_var, but not model1. Any idea what I should do? |
Another question:
If I do |
Now, we export the InstanceNorm as an InstanceNorm op, not Reshape + BatchNorm + Reshape